diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index b3790b082ed57..c514ea4ab31fd 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -2958,6 +2958,23 @@ def test_collectives(self): def test_allreduce_coalesced(self): self._test_allreduce_coalesced(backend="nccl") + @requires_nccl() + @skip_if_lt_x_gpu(1) + def test_allgather_base(self): + store = dist.FileStore(self.file_name, self.world_size) + dist.init_process_group( + "nccl", + world_size=self.world_size, + rank=self.rank, + store=store, + ) + device = "cuda" + tensor = torch.ones(10, 10, device=torch.device(device)) + output_tensor = torch.zeros(10, 10, device=torch.device(device)) + dist.all_gather_into_tensor(output_tensor, tensor) + self.assertEqual(output_tensor, tensor) + + if __name__ == "__main__": assert ( not torch.cuda._initialized diff --git a/torch/csrc/distributed/c10d/Ops.cpp b/torch/csrc/distributed/c10d/Ops.cpp index 15e186fe3d22d..f825afca2a1d9 100644 --- a/torch/csrc/distributed/c10d/Ops.cpp +++ b/torch/csrc/distributed/c10d/Ops.cpp @@ -88,6 +88,13 @@ allgather_( output_tensors, work); } +c10::intrusive_ptr _allgather_base_( + at::Tensor& output_tensor, + at::Tensor& input_tensor, + const c10::intrusive_ptr& process_group) { + return process_group->_allgather_base(output_tensor, input_tensor); +} + std::tuple, c10::intrusive_ptr> reduce_scatter_( const std::vector& output_tensors, const std::vector>& input_tensors, @@ -197,6 +204,9 @@ TORCH_LIBRARY(c10d, m) { m.def( "allgather_", dispatch(c10::DispatchKey::CompositeExplicitAutograd, allgather_)); + m.def( + "_allgather_base_", + dispatch(c10::DispatchKey::CompositeExplicitAutograd, _allgather_base_)); m.def( "reduce_scatter_", dispatch(c10::DispatchKey::CompositeExplicitAutograd, reduce_scatter_)); @@ -303,6 +313,21 @@ c10::intrusive_ptr allgather( output_tensors, input_tensors, process_group, opts.timeout.count())); } +c10::intrusive_ptr _allgather_base( + const c10::intrusive_ptr& process_group, + at::Tensor& output_tensor, + at::Tensor& input_tensor, + const AllgatherOptions& opts) { + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("c10d::_allgather_base_", "") + .typed( + at::Tensor&, + at::Tensor&, + const c10::intrusive_ptr<::c10d::ProcessGroup>&)>(); + + return op.call(output_tensor, input_tensor, process_group); +} + c10::intrusive_ptr reduce_scatter( const c10::intrusive_ptr& process_group, const std::vector& output_tensors, diff --git a/torch/csrc/distributed/c10d/Ops.hpp b/torch/csrc/distributed/c10d/Ops.hpp index 8ef78126e5b9e..72f09e341d7df 100644 --- a/torch/csrc/distributed/c10d/Ops.hpp +++ b/torch/csrc/distributed/c10d/Ops.hpp @@ -32,6 +32,12 @@ TORCH_API c10::intrusive_ptr allgather( const std::vector& input_tensors, const AllgatherOptions& opts = {}); +TORCH_API c10::intrusive_ptr _allgather_base( + const c10::intrusive_ptr& process_group, + at::Tensor& outputTensor, + at::Tensor& inputTensor, + const AllgatherOptions& opts = {}); + TORCH_API c10::intrusive_ptr reduce_scatter( const c10::intrusive_ptr& process_group, const std::vector& output_tensors, diff --git a/torch/csrc/distributed/c10d/OpsImpl.cpp b/torch/csrc/distributed/c10d/OpsImpl.cpp index 94f5febec14d0..78e26c9656d8d 100644 --- a/torch/csrc/distributed/c10d/OpsImpl.cpp +++ b/torch/csrc/distributed/c10d/OpsImpl.cpp @@ -211,6 +211,20 @@ allgather_cuda_( output_tensors, work); } +c10::intrusive_ptr _allgather_base_cpu_( + at::Tensor& output_tensor, + at::Tensor& input_tensor, + const c10::intrusive_ptr& process_group) { + return process_group->_allgather_base(output_tensor, input_tensor); +} + +c10::intrusive_ptr _allgather_base_cuda_( + at::Tensor& output_tensor, + at::Tensor& input_tensor, + const c10::intrusive_ptr& process_group) { + return process_group->_allgather_base(output_tensor, input_tensor); +} + std::tuple, c10::intrusive_ptr> reduce_scatter_cpu_( const std::vector& output_tensors, @@ -409,6 +423,14 @@ TORCH_LIBRARY_IMPL(c10d, CUDA, m) { m.impl("allgather_", allgather_cuda_); } +TORCH_LIBRARY_IMPL(c10d, CPU, m) { + m.impl("_allgather_base_", _allgather_base_cpu_); +} + +TORCH_LIBRARY_IMPL(c10d, CUDA, m) { + m.impl("_allgather_base_", _allgather_base_cuda_); +} + TORCH_LIBRARY_IMPL(c10d, CPU, m) { m.impl("reduce_scatter_", reduce_scatter_cpu_); } diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 673f481d60251..2424506eef0ff 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -1187,7 +1187,13 @@ that adds a prefix to each key inserted to the store. .def( "_allgather_base", - &::c10d::ProcessGroup::_allgather_base, + [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, + at::Tensor& output_tensor, + at::Tensor& input_tensor, + const ::c10d::AllgatherOptions& opts) { + return ::c10d::ops::_allgather_base( + self, output_tensor, input_tensor, opts); + }, py::arg("output"), py::arg("input"), py::arg("opts") = ::c10d::AllgatherOptions(),