Skip to content

Commit

Permalink
[16/N] Add _allgather_base custom op with CPU/CUDA implementation (py…
Browse files Browse the repository at this point in the history
…torch#88889)

Differential Revision: [D41227739](https://our.internmc.facebook.com/intern/diff/D41227739)
Pull Request resolved: pytorch#88889
Approved by: https://github.com/kwen2501
  • Loading branch information
H-Huang authored and pytorchmergebot committed Nov 12, 2022
1 parent 3765621 commit df1df9d
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 1 deletion.
17 changes: 17 additions & 0 deletions test/distributed/test_c10d_nccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2958,6 +2958,23 @@ def test_collectives(self):
def test_allreduce_coalesced(self):
self._test_allreduce_coalesced(backend="nccl")

@requires_nccl()
@skip_if_lt_x_gpu(1)
def test_allgather_base(self):
store = dist.FileStore(self.file_name, self.world_size)
dist.init_process_group(
"nccl",
world_size=self.world_size,
rank=self.rank,
store=store,
)
device = "cuda"
tensor = torch.ones(10, 10, device=torch.device(device))
output_tensor = torch.zeros(10, 10, device=torch.device(device))
dist.all_gather_into_tensor(output_tensor, tensor)
self.assertEqual(output_tensor, tensor)


if __name__ == "__main__":
assert (
not torch.cuda._initialized
Expand Down
25 changes: 25 additions & 0 deletions torch/csrc/distributed/c10d/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,13 @@ allgather_(
output_tensors, work);
}

c10::intrusive_ptr<Work> _allgather_base_(
at::Tensor& output_tensor,
at::Tensor& input_tensor,
const c10::intrusive_ptr<ProcessGroup>& process_group) {
return process_group->_allgather_base(output_tensor, input_tensor);
}

std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>> reduce_scatter_(
const std::vector<at::Tensor>& output_tensors,
const std::vector<std::vector<at::Tensor>>& input_tensors,
Expand Down Expand Up @@ -197,6 +204,9 @@ TORCH_LIBRARY(c10d, m) {
m.def(
"allgather_",
dispatch(c10::DispatchKey::CompositeExplicitAutograd, allgather_));
m.def(
"_allgather_base_",
dispatch(c10::DispatchKey::CompositeExplicitAutograd, _allgather_base_));
m.def(
"reduce_scatter_",
dispatch(c10::DispatchKey::CompositeExplicitAutograd, reduce_scatter_));
Expand Down Expand Up @@ -303,6 +313,21 @@ c10::intrusive_ptr<Work> allgather(
output_tensors, input_tensors, process_group, opts.timeout.count()));
}

c10::intrusive_ptr<Work> _allgather_base(
const c10::intrusive_ptr<ProcessGroup>& process_group,
at::Tensor& output_tensor,
at::Tensor& input_tensor,
const AllgatherOptions& opts) {
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("c10d::_allgather_base_", "")
.typed<c10::intrusive_ptr<Work>(
at::Tensor&,
at::Tensor&,
const c10::intrusive_ptr<::c10d::ProcessGroup>&)>();

return op.call(output_tensor, input_tensor, process_group);
}

c10::intrusive_ptr<Work> reduce_scatter(
const c10::intrusive_ptr<ProcessGroup>& process_group,
const std::vector<at::Tensor>& output_tensors,
Expand Down
6 changes: 6 additions & 0 deletions torch/csrc/distributed/c10d/Ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ TORCH_API c10::intrusive_ptr<Work> allgather(
const std::vector<at::Tensor>& input_tensors,
const AllgatherOptions& opts = {});

TORCH_API c10::intrusive_ptr<Work> _allgather_base(
const c10::intrusive_ptr<ProcessGroup>& process_group,
at::Tensor& outputTensor,
at::Tensor& inputTensor,
const AllgatherOptions& opts = {});

TORCH_API c10::intrusive_ptr<Work> reduce_scatter(
const c10::intrusive_ptr<ProcessGroup>& process_group,
const std::vector<at::Tensor>& output_tensors,
Expand Down
22 changes: 22 additions & 0 deletions torch/csrc/distributed/c10d/OpsImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,20 @@ allgather_cuda_(
output_tensors, work);
}

c10::intrusive_ptr<Work> _allgather_base_cpu_(
at::Tensor& output_tensor,
at::Tensor& input_tensor,
const c10::intrusive_ptr<ProcessGroup>& process_group) {
return process_group->_allgather_base(output_tensor, input_tensor);
}

c10::intrusive_ptr<Work> _allgather_base_cuda_(
at::Tensor& output_tensor,
at::Tensor& input_tensor,
const c10::intrusive_ptr<ProcessGroup>& process_group) {
return process_group->_allgather_base(output_tensor, input_tensor);
}

std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>
reduce_scatter_cpu_(
const std::vector<at::Tensor>& output_tensors,
Expand Down Expand Up @@ -409,6 +423,14 @@ TORCH_LIBRARY_IMPL(c10d, CUDA, m) {
m.impl("allgather_", allgather_cuda_);
}

TORCH_LIBRARY_IMPL(c10d, CPU, m) {
m.impl("_allgather_base_", _allgather_base_cpu_);
}

TORCH_LIBRARY_IMPL(c10d, CUDA, m) {
m.impl("_allgather_base_", _allgather_base_cuda_);
}

TORCH_LIBRARY_IMPL(c10d, CPU, m) {
m.impl("reduce_scatter_", reduce_scatter_cpu_);
}
Expand Down
8 changes: 7 additions & 1 deletion torch/csrc/distributed/c10d/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1187,7 +1187,13 @@ that adds a prefix to each key inserted to the store.

.def(
"_allgather_base",
&::c10d::ProcessGroup::_allgather_base,
[](const c10::intrusive_ptr<::c10d::ProcessGroup>& self,
at::Tensor& output_tensor,
at::Tensor& input_tensor,
const ::c10d::AllgatherOptions& opts) {
return ::c10d::ops::_allgather_base(
self, output_tensor, input_tensor, opts);
},
py::arg("output"),
py::arg("input"),
py::arg("opts") = ::c10d::AllgatherOptions(),
Expand Down

0 comments on commit df1df9d

Please sign in to comment.