Skip to content

Commit

Permalink
Add Gloo TCP_TLS transport (pytorch#56442)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#56442

Test Plan: Imported from OSS

Reviewed By: malfet

Differential Revision: D27896285

Pulled By: pbelevich

fbshipit-source-id: 589af59ca4c7c9bab2329f079382c09b71cfcf9e
  • Loading branch information
pbelevich authored and facebook-github-bot committed May 7, 2021
1 parent 96fce78 commit 96e1a83
Show file tree
Hide file tree
Showing 7 changed files with 169 additions and 1 deletion.
4 changes: 4 additions & 0 deletions .jenkins/pytorch/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,10 @@ if [[ "${BUILD_ENVIRONMENT}" == *xla* ]]; then
./xla/scripts/apply_patches.sh
fi

if [[ "${BUILD_ENVIRONMENT}" == pytorch-linux-xenial-py3.6-gcc7-build || "${BUILD_ENVIRONMENT}" == pytorch-linux-xenial-py3.6-gcc5.4-build ]]; then
export USE_GLOO_WITH_OPENSSL=ON
fi

if [[ "$BUILD_ENVIRONMENT" == *-bazel-* ]]; then
set -e

Expand Down
96 changes: 96 additions & 0 deletions .jenkins/pytorch/create_test_cert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from datetime import datetime, timedelta
from tempfile import mkdtemp
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography import x509
from cryptography.x509.oid import NameOID
from cryptography.hazmat.primitives import hashes

temp_dir = mkdtemp()
print(temp_dir)


def genrsa(path):
key = rsa.generate_private_key(
public_exponent=65537,
key_size=2048,
)
with open(path, "wb") as f:
f.write(key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption(),
))
return key


def create_cert(path, C, ST, L, O, key):
subject = issuer = x509.Name([
x509.NameAttribute(NameOID.COUNTRY_NAME, C),
x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, ST),
x509.NameAttribute(NameOID.LOCALITY_NAME, L),
x509.NameAttribute(NameOID.ORGANIZATION_NAME, O),
])
cert = x509.CertificateBuilder().subject_name(
subject
).issuer_name(
issuer
).public_key(
key.public_key()
).serial_number(
x509.random_serial_number()
).not_valid_before(
datetime.utcnow()
).not_valid_after(
# Our certificate will be valid for 10 days
datetime.utcnow() + timedelta(days=10)
).add_extension(
x509.BasicConstraints(ca=True, path_length=None), critical=True,
).sign(key, hashes.SHA256())
# Write our certificate out to disk.
with open(path, "wb") as f:
f.write(cert.public_bytes(serialization.Encoding.PEM))
return cert


def create_req(path, C, ST, L, O, key):
csr = x509.CertificateSigningRequestBuilder().subject_name(x509.Name([
# Provide various details about who we are.
x509.NameAttribute(NameOID.COUNTRY_NAME, C),
x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, ST),
x509.NameAttribute(NameOID.LOCALITY_NAME, L),
x509.NameAttribute(NameOID.ORGANIZATION_NAME, O),
])).sign(key, hashes.SHA256())
with open(path, "wb") as f:
f.write(csr.public_bytes(serialization.Encoding.PEM))
return csr


def sign_certificate_request(path, csr_cert, ca_cert, private_ca_key):
cert = x509.CertificateBuilder().subject_name(
csr_cert.subject
).issuer_name(
ca_cert.subject
).public_key(
csr_cert.public_key()
).serial_number(
x509.random_serial_number()
).not_valid_before(
datetime.utcnow()
).not_valid_after(
# Our certificate will be valid for 10 days
datetime.utcnow() + timedelta(days=10)
# Sign our certificate with our private key
).sign(private_ca_key, hashes.SHA256())
with open(path, "wb") as f:
f.write(cert.public_bytes(serialization.Encoding.PEM))
return cert


ca_key = genrsa(temp_dir + "/ca.key")
ca_cert = create_cert(temp_dir + "/ca.pem", u"US", u"New York", u"New York", u"Gloo Certificate Authority", ca_key)

pkey = genrsa(temp_dir + "/pkey.key")
csr = create_req(temp_dir + "/csr.csr", u"US", u"California", u"San Francisco", u"Gloo Testing Company", pkey)

cert = sign_certificate_request(temp_dir + "/cert.pem", csr, ca_cert, ca_key)
18 changes: 18 additions & 0 deletions .jenkins/pytorch/run_glootls_test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#!/bin/bash

CREATE_TEST_CERT="$(dirname "${BASH_SOURCE[0]}")/create_test_cert.py"
TMP_CERT_DIR=$(python "$CREATE_TEST_CERT")

openssl verify -CAfile "${TMP_CERT_DIR}/ca.pem" "${TMP_CERT_DIR}/cert.pem"

export GLOO_DEVICE_TRANSPORT=TCP_TLS
export GLOO_DEVICE_TRANSPORT_TCP_TLS_PKEY=${TMP_CERT_DIR}/pkey.key
export GLOO_DEVICE_TRANSPORT_TCP_TLS_CERT=${TMP_CERT_DIR}/cert.pem
export GLOO_DEVICE_TRANSPORT_TCP_TLS_CA_FILE=${TMP_CERT_DIR}/ca.pem

time python test/run_test.py --include distributed/test_c10d_gloo --verbose --determine-from="$DETERMINE_FROM" -- ProcessGroupGlooTest

unset GLOO_DEVICE_TRANSPORT
unset GLOO_DEVICE_TRANSPORT_TCP_TLS_PKEY
unset GLOO_DEVICE_TRANSPORT_TCP_TLS_CERT
unset GLOO_DEVICE_TRANSPORT_TCP_TLS_CA_FILE
8 changes: 8 additions & 0 deletions .jenkins/pytorch/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,11 @@ test_python() {
assert_git_not_dirty
}

test_python_gloo_with_tls() {
source "$(dirname "${BASH_SOURCE[0]}")/run_glootls_test.sh"
assert_git_not_dirty
}


test_aten() {
# Test ATen
Expand Down Expand Up @@ -478,6 +483,9 @@ else
test_distributed
test_benchmarks
test_rpc
if [[ "${BUILD_ENVIRONMENT}" == pytorch-linux-xenial-py3.6-gcc7-test || "${BUILD_ENVIRONMENT}" == pytorch-linux-xenial-py3.6-gcc5.4-test ]]; then
test_python_gloo_with_tls
fi
fi

if [[ "$BUILD_ENVIRONMENT" == *coverage* ]]; then
Expand Down
7 changes: 7 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,9 @@ cmake_dependent_option(
cmake_dependent_option(
USE_GLOO "Use Gloo. Only available if USE_DISTRIBUTED is on." ON
"USE_DISTRIBUTED" OFF)
cmake_dependent_option(
USE_GLOO_WITH_OPENSSL "Use Gloo with OpenSSL. Only available if USE_GLOO is on." OFF
"USE_GLOO AND LINUX AND NOT INTERN_BUILD_MOBILE" OFF)
cmake_dependent_option(
USE_TENSORPIPE "Use TensorPipe. Only available if USE_DISTRIBUTED is on." ON
"USE_DISTRIBUTED" OFF)
Expand Down Expand Up @@ -327,6 +330,10 @@ if(WIN32)
endif()
endif()

if(USE_GLOO_WITH_OPENSSL)
set(USE_TCP_OPENSSL_LOAD ON CACHE STRING "")
endif()

# Linux distributions do not want too many embedded sources, in that sense we
# need to be able to build pytorch with an (almost) empty third_party
# directory.
Expand Down
4 changes: 3 additions & 1 deletion tools/print_test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,12 +631,14 @@ def append(self, test_case: TestCase) -> None:
self.test_suites[suite_name] = TestSuite(suite_name)
if test_case.name in self.test_suites[suite_name].test_cases:
# We expect duplicate tests for test_cpp_extensions_aot, distributed/test_distributed_fork,
# and distributed/test_distributed_spawn. In these cases, we store the test case that took the longest,
# and distributed/test_distributed_spawn and test_c10d_gloo.
# In these cases, we store the test case that took the longest,
# as in these jobs, the duplicate tests are run in parallel.
# For other unexpected cases, we should raise a warning.
if self.name == 'test_cpp_extensions_aot' or \
self.name == 'distributed/test_distributed_fork' or \
self.name == 'distributed/test_distributed_spawn' or \
self.name == 'distributed/test_c10d_gloo' or \
self.name == 'cpp': # The caffe2 cpp tests spawn duplicate test cases as well.
time_difference = self.test_suites[suite_name].replace(test_case)
self.total_time += time_difference
Expand Down
33 changes: 33 additions & 0 deletions torch/lib/c10d/GlooDeviceFactory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
#include <gloo/transport/tcp/device.h>
#endif

#if GLOO_HAVE_TRANSPORT_TCP_TLS
#include <gloo/transport/tcp/tls/device.h>
#endif

#if GLOO_HAVE_TRANSPORT_UV
#include <gloo/transport/uv/device.h>
#endif
Expand Down Expand Up @@ -59,6 +63,35 @@ C10_REGISTER_CREATOR(GlooDeviceRegistry, LINUX, makeTCPDevice);
C10_REGISTER_CREATOR(GlooDeviceRegistry, TCP, makeTCPDevice);
#endif

#if GLOO_HAVE_TRANSPORT_TCP_TLS
static std::string cstr_to_std_string(const char* chars) {
return std::string (chars != nullptr ? chars : "");
}

static std::shared_ptr<::gloo::transport::Device> makeTCPTLSDevice(
const std::string& interface,
const std::string& hostname) {
TORCH_CHECK(
!interface.empty() || !hostname.empty(),
"GlooDeviceFactory::makeTCPTLSDevice(): interface or hostname "
"can't be empty");

::gloo::transport::tcp::attr attr;
if (!interface.empty()) {
attr.iface = interface;
} else {
attr.hostname = hostname;
}
const auto pkey = cstr_to_std_string(std::getenv("GLOO_DEVICE_TRANSPORT_TCP_TLS_PKEY"));
const auto cert = cstr_to_std_string(std::getenv("GLOO_DEVICE_TRANSPORT_TCP_TLS_CERT"));
const auto caFile = cstr_to_std_string(std::getenv("GLOO_DEVICE_TRANSPORT_TCP_TLS_CA_FILE"));
const auto caPath = cstr_to_std_string(std::getenv("GLOO_DEVICE_TRANSPORT_TCP_TLS_CA_PATH"));
return ::gloo::transport::tcp::tls::CreateDevice(attr, pkey, cert, caFile, caPath);
}

C10_REGISTER_CREATOR(GlooDeviceRegistry, TCP_TLS, makeTCPTLSDevice);
#endif

#if GLOO_HAVE_TRANSPORT_UV
static std::shared_ptr<::gloo::transport::Device> makeUVDevice(
const std::string& interfaceName,
Expand Down

0 comments on commit 96e1a83

Please sign in to comment.