Skip to content

Commit

Permalink
[21/N] Add alltoall_base custom op with CPU/CUDA implementations (pyt…
Browse files Browse the repository at this point in the history
…orch#89813)

Differential Revision: [D41812670](https://our.internmc.facebook.com/intern/diff/D41812670)
Pull Request resolved: pytorch#89813
Approved by: https://github.com/kwen2501
  • Loading branch information
H-Huang authored and pytorchmergebot committed Dec 8, 2022
1 parent e65ee39 commit 8015078
Show file tree
Hide file tree
Showing 7 changed files with 119 additions and 15 deletions.
14 changes: 14 additions & 0 deletions test/distributed/test_c10d_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1522,6 +1522,20 @@ def _test_allreduce_coalesced(self, backend):
for tensor in tensors:
self.assertEqual(tensor, torch.ones(10, 10) * self.world_size)

def _test_all_to_all_single(self, backend):
store = dist.FileStore(self.file_name, self.world_size)
dist.init_process_group(
backend,
world_size=self.world_size,
rank=self.rank,
store=store,
)
device = "cuda" if backend == "nccl" else "cpu"
# test alltoall_base
input_tensor = torch.ones(2, 2, device=torch.device(device))
output_tensor = torch.zeros(2, 2, device=torch.device(device))
dist.all_to_all_single(output_tensor, input_tensor)

class CompilerTest(MultiProcessTestCase):
def setUp(self):
super(CompilerTest, self).setUp()
Expand Down
4 changes: 4 additions & 0 deletions test/distributed/test_c10d_gloo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2405,6 +2405,10 @@ def test_collectives(self):
def test_allreduce_coalesced(self):
self._test_allreduce_coalesced(backend="gloo")

@requires_gloo()
def test_all_to_all_single(self):
self._test_all_to_all_single(backend="gloo")

@requires_gloo()
def test_allgather_coalesced(self):
store = dist.FileStore(self.file_name, self.world_size)
Expand Down
5 changes: 5 additions & 0 deletions test/distributed/test_c10d_nccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2948,6 +2948,11 @@ def test_collectives(self):
def test_allreduce_coalesced(self):
self._test_allreduce_coalesced(backend="nccl")

@requires_nccl()
@skip_if_lt_x_gpu(1)
def test_all_to_all_single(self):
self._test_all_to_all_single(backend="nccl")

@requires_nccl()
@skip_if_lt_x_gpu(1)
def test_allgather_base(self):
Expand Down
43 changes: 43 additions & 0 deletions torch/csrc/distributed/c10d/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,21 @@ c10::intrusive_ptr<Work> alltoall_(
AllToAllOptions{std::chrono::milliseconds(timeout)});
}

c10::intrusive_ptr<Work> alltoall_base_(
at::Tensor& output,
at::Tensor& input,
const c10::intrusive_ptr<ProcessGroup>& process_group,
std::vector<int64_t> output_split_sizes,
std::vector<int64_t> input_split_sizes,
int64_t timeout) {
return process_group->alltoall_base(
output,
input,
output_split_sizes,
input_split_sizes,
AllToAllOptions{std::chrono::milliseconds(timeout)});
}

c10::intrusive_ptr<Work> barrier(
const c10::intrusive_ptr<ProcessGroup>& process_group,
const std::vector<int64_t>& device_ids,
Expand Down Expand Up @@ -271,6 +286,9 @@ TORCH_LIBRARY(c10d, m) {
m.def(
"alltoall_",
dispatch(c10::DispatchKey::CompositeExplicitAutograd, alltoall_));
m.def(
"alltoall_base_",
dispatch(c10::DispatchKey::CompositeExplicitAutograd, alltoall_base_));
m.def(
"barrier",
dispatch(c10::DispatchKey::CompositeExplicitAutograd, barrier));
Expand Down Expand Up @@ -523,6 +541,31 @@ c10::intrusive_ptr<Work> alltoall(
output_tensors, input_tensors, process_group, opts.timeout.count());
}

c10::intrusive_ptr<Work> alltoall_base(
const c10::intrusive_ptr<ProcessGroup>& process_group,
at::Tensor& output,
at::Tensor& input,
std::vector<int64_t> output_split_sizes,
std::vector<int64_t> input_split_sizes,
const AllToAllOptions& opts) {
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("c10d::alltoall_base_", "")
.typed<c10::intrusive_ptr<::c10d::Work>(
at::Tensor&,
at::Tensor&,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
std::vector<int64_t>,
std::vector<int64_t>,
int64_t)>();
return op.call(
output,
input,
process_group,
output_split_sizes,
input_split_sizes,
opts.timeout.count());
}

void monitored_barrier(
const c10::intrusive_ptr<ProcessGroup>& process_group,
const BarrierOptions& opts,
Expand Down
8 changes: 8 additions & 0 deletions torch/csrc/distributed/c10d/Ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,14 @@ TORCH_API c10::intrusive_ptr<Work> scatter(
const std::vector<std::vector<at::Tensor>>& input_tensors,
const ScatterOptions& opts = {});

TORCH_API c10::intrusive_ptr<Work> alltoall_base(
const c10::intrusive_ptr<ProcessGroup>& process_group,
at::Tensor& output,
at::Tensor& input,
const std::vector<int64_t> outputSplitSizes,
const std::vector<int64_t> inputSplitSizes,
const AllToAllOptions& opts = {});

TORCH_API c10::intrusive_ptr<Work> alltoall(
const c10::intrusive_ptr<ProcessGroup>& process_group,
at::TensorList output_tensors,
Expand Down
38 changes: 38 additions & 0 deletions torch/csrc/distributed/c10d/OpsImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,36 @@ c10::intrusive_ptr<Work> alltoall_cuda_(
AllToAllOptions{std::chrono::milliseconds(timeout)});
}

c10::intrusive_ptr<Work> alltoall_base_cpu_(
at::Tensor& output,
at::Tensor& input,
const c10::intrusive_ptr<ProcessGroup>& process_group,
std::vector<int64_t> output_split_sizes,
std::vector<int64_t> input_split_sizes,
int64_t timeout) {
return process_group->alltoall_base(
output,
input,
output_split_sizes,
input_split_sizes,
AllToAllOptions{std::chrono::milliseconds(timeout)});
}

c10::intrusive_ptr<Work> alltoall_base_cuda_(
at::Tensor& output,
at::Tensor& input,
const c10::intrusive_ptr<ProcessGroup>& process_group,
std::vector<int64_t> output_split_sizes,
std::vector<int64_t> input_split_sizes,
int64_t timeout) {
return process_group->alltoall_base(
output,
input,
output_split_sizes,
input_split_sizes,
AllToAllOptions{std::chrono::milliseconds(timeout)});
}

c10::intrusive_ptr<Work> barrier_cpu(
const c10::intrusive_ptr<ProcessGroup>& process_group,
const std::vector<int64_t>& device_ids,
Expand Down Expand Up @@ -558,6 +588,14 @@ TORCH_LIBRARY_IMPL(c10d, CUDA, m) {
m.impl("alltoall_", alltoall_cuda_);
}

TORCH_LIBRARY_IMPL(c10d, CPU, m) {
m.impl("alltoall_base_", alltoall_base_cpu_);
}

TORCH_LIBRARY_IMPL(c10d, CUDA, m) {
m.impl("alltoall_base_", alltoall_base_cuda_);
}

TORCH_LIBRARY_IMPL(c10d, CPU, m) {
m.impl("barrier", barrier_cpu);
}
Expand Down
22 changes: 7 additions & 15 deletions torch/csrc/distributed/c10d/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1440,34 +1440,26 @@ that adds a prefix to each key inserted to the store.

.def(
"alltoall_base",
&::c10d::ProcessGroup::alltoall_base,
py::arg("output_tensor"),
py::arg("input_tensor"),
py::arg("output_split_sizes"),
py::arg("input_split_sizes"),
py::arg("opts") = ::c10d::AllToAllOptions(),
py::call_guard<py::gil_scoped_release>())

.def(
"alltoall_base",
[](::c10d::ProcessGroup& self,
[](const c10::intrusive_ptr<::c10d::ProcessGroup>& self,
at::Tensor& output,
at::Tensor& input,
std::vector<int64_t> outputSplitSizes,
std::vector<int64_t> inputSplitSizes) {
return self.alltoall_base(
std::vector<int64_t> inputSplitSizes,
const ::c10d::AllToAllOptions& opts) {
return ::c10d::ops::alltoall_base(
self,
output,
input,
outputSplitSizes,
inputSplitSizes,
::c10d::AllToAllOptions());
opts);
},
py::arg("output"),
py::arg("input"),
py::arg("output_split_sizes"),
py::arg("input_split_sizes"),
py::arg("opts") = ::c10d::AllToAllOptions(),
py::call_guard<py::gil_scoped_release>())

.def(
"alltoall",
[](const c10::intrusive_ptr<::c10d::ProcessGroup>& self,
Expand Down

0 comments on commit 8015078

Please sign in to comment.