Skip to content

Commit

Permalink
Added support for all_gather object
Browse files Browse the repository at this point in the history
  • Loading branch information
vfdev-5 committed Aug 29, 2023
1 parent 11a1fba commit 053ff42
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 13 deletions.
13 changes: 10 additions & 3 deletions ignite/distributed/comp_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def _apply_op(
return tensor

def _collective_op(
self, tensor: Union[torch.Tensor, float, str], fn: Callable, *args: Any, **kwargs: Any
self, tensor: Union[torch.Tensor, Number, str], fn: Callable, *args: Any, **kwargs: Any
) -> Union[torch.Tensor, float, List[float], List[str]]:
tensor_to_number = tensor_to_str = False
device = self.device()
Expand Down Expand Up @@ -216,10 +216,10 @@ def all_reduce(
return cast(Union[torch.Tensor, float], self._collective_op(tensor, self._do_all_reduce, op, group=group))

def all_gather(
self, tensor: Union[torch.Tensor, float, str], group: Optional[Any] = None
self, tensor: Union[torch.Tensor, float, str, Any], group: Optional[Any] = None
) -> Union[torch.Tensor, float, List[float], List[str]]:
if not isinstance(tensor, (torch.Tensor, Number, str)):
raise TypeError(f"Unhandled input type {type(tensor)}")
return self._do_all_gather_object(tensor, group=group)

return self._collective_op(tensor, self._do_all_gather, group=group)

Expand Down Expand Up @@ -282,6 +282,10 @@ def _do_all_reduce(self, tensor: torch.Tensor, op: str = "SUM", group: Optional[
def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None) -> torch.Tensor:
pass

@abstractmethod
def _do_all_gather_object(self, tensor: Any, group: Optional[Any] = None) -> List[Any]:
pass

Check warning on line 287 in ignite/distributed/comp_models/base.py

View check run for this annotation

Codecov / codecov/patch

ignite/distributed/comp_models/base.py#L287

Added line #L287 was not covered by tests

@abstractmethod
def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
pass
Expand Down Expand Up @@ -373,6 +377,9 @@ def _do_all_reduce(self, tensor: torch.Tensor, op: str = "SUM", group: Optional[
def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None) -> torch.Tensor:
return tensor

def _do_all_gather_object(self, tensor: Any, group: Optional[Any] = None) -> Any:
return tensor

Check warning on line 381 in ignite/distributed/comp_models/base.py

View check run for this annotation

Codecov / codecov/patch

ignite/distributed/comp_models/base.py#L381

Added line #L381 was not covered by tests

def _do_new_group(self, ranks: List[int], **kwargs: Any) -> Any:
return ranks

Expand Down
6 changes: 6 additions & 0 deletions ignite/distributed/comp_models/horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,12 @@ def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None) -> t
tensor = tensor.unsqueeze(0)
return hvd.allgather(tensor)

def _do_all_gather_object(self, tensor: Any, group: Optional[Any] = None) -> List[Any]:
if group is not None:
raise NotImplementedError("all_gather with group for horovod is not implemented")

Check warning on line 197 in ignite/distributed/comp_models/horovod.py

View check run for this annotation

Codecov / codecov/patch

ignite/distributed/comp_models/horovod.py#L195-L197

Added lines #L195 - L197 were not covered by tests

return hvd.allgather_object(tensor)

Check warning on line 199 in ignite/distributed/comp_models/horovod.py

View check run for this annotation

Codecov / codecov/patch

ignite/distributed/comp_models/horovod.py#L199

Added line #L199 was not covered by tests

def _do_new_group(self, ranks: List[int], **kwargs: Any) -> Any:
return hvd.ProcessSet(ranks)

Expand Down
15 changes: 15 additions & 0 deletions ignite/distributed/comp_models/native.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,7 @@ def _do_all_reduce(self, tensor: torch.Tensor, op: str = "SUM", group: Optional[
if group is not None and not isinstance(group, dist.ProcessGroup):
raise ValueError("Argument group should be list of int or ProcessGroup")
reduce_op = self._reduce_op_map[op]
# we do if/else here for compatbility with older pytorch versions
if group is not None:
dist.all_reduce(tensor, reduce_op, group=group)
else:
Expand All @@ -441,12 +442,26 @@ def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None) -> t
if tensor.ndimension() == 0:
tensor = tensor.unsqueeze(0)
output = [torch.zeros_like(tensor) for _ in range(group_size)]
# we do if/else here for compatbility with older pytorch versions
if group is not None:
dist.all_gather(output, tensor, group=group)
else:
dist.all_gather(output, tensor)
return torch.cat(output, dim=0)

def _do_all_gather_object(self, tensor: Any, group: Optional[Any] = None) -> List[Any]:
if group == dist.GroupMember.NON_GROUP_MEMBER:
return tensor

Check warning on line 454 in ignite/distributed/comp_models/native.py

View check run for this annotation

Codecov / codecov/patch

ignite/distributed/comp_models/native.py#L454

Added line #L454 was not covered by tests
elif group is None:
group_size = self.get_world_size()
elif isinstance(group, dist.ProcessGroup):
group_size = group.size()

Check warning on line 458 in ignite/distributed/comp_models/native.py

View check run for this annotation

Codecov / codecov/patch

ignite/distributed/comp_models/native.py#L457-L458

Added lines #L457 - L458 were not covered by tests
else:
raise ValueError("Argument group should be list of int or ProcessGroup")

Check warning on line 460 in ignite/distributed/comp_models/native.py

View check run for this annotation

Codecov / codecov/patch

ignite/distributed/comp_models/native.py#L460

Added line #L460 was not covered by tests
output = [None for _ in range(group_size)]
dist.all_gather_object(output, tensor, group=group)
return output

def _do_new_group(self, ranks: List[int], **kwargs: Any) -> Any:
return dist.new_group(ranks=ranks, **kwargs)

Expand Down
3 changes: 3 additions & 0 deletions ignite/distributed/comp_models/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,9 @@ def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None) -> t
xm.all_reduce("sum", [output], groups=group)
return output.reshape(-1, *output.shape[2:])

def _do_all_gather_object(self, tensor: Any, group: Optional[Any] = None) -> List[Any]:
raise NotImplementedError("all_gather on object is not implemented for xla")

Check warning on line 159 in ignite/distributed/comp_models/xla.py

View check run for this annotation

Codecov / codecov/patch

ignite/distributed/comp_models/xla.py#L159

Added line #L159 was not covered by tests

def _do_new_group(self, ranks: List[int], **kwargs: Any) -> Any:
return [ranks]

Expand Down
38 changes: 28 additions & 10 deletions tests/ignite/distributed/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,21 +156,22 @@ def _test_distrib_all_reduce_group(device):

def _test_distrib_all_gather(device):
rank = idist.get_rank()
ws = idist.get_world_size()

res = torch.tensor(idist.all_gather(10), device=device)
true_res = torch.tensor([10] * idist.get_world_size(), device=device)
true_res = torch.tensor([10] * ws, device=device)
assert (res == true_res).all()

t = torch.tensor(rank, device=device)
res = idist.all_gather(t)
true_res = torch.tensor([i for i in range(idist.get_world_size())], device=device)
true_res = torch.tensor([i for i in range(ws)], device=device)
assert (res == true_res).all()

x = "test-test"
if rank == 0:
x = "abc"
res = idist.all_gather(x)
true_res = ["abc"] + ["test-test"] * (idist.get_world_size() - 1)
true_res = ["abc"] + ["test-test"] * (ws - 1)
assert res == true_res

base_x = "tests/ignite/distributed/utils/test_native.py" * 2000
Expand All @@ -179,22 +180,39 @@ def _test_distrib_all_gather(device):
x = "abc"

res = idist.all_gather(x)
true_res = ["abc"] + [base_x] * (idist.get_world_size() - 1)
true_res = ["abc"] + [base_x] * (ws - 1)
assert res == true_res

t = torch.arange(100, device=device).reshape(4, 25) * (rank + 1)
in_dtype = t.dtype
res = idist.all_gather(t)
assert res.shape == (idist.get_world_size() * 4, 25)
assert res.shape == (ws * 4, 25)
assert res.dtype == in_dtype
true_res = torch.zeros(idist.get_world_size() * 4, 25, device=device)
for i in range(idist.get_world_size()):
true_res = torch.zeros(ws * 4, 25, device=device)
for i in range(ws):
true_res[i * 4 : (i + 1) * 4, ...] = torch.arange(100, device=device).reshape(4, 25) * (i + 1)
assert (res == true_res).all()

if idist.get_world_size() > 1:
with pytest.raises(TypeError, match=r"Unhandled input type"):
idist.all_reduce([0, 1, 2])
if ws > 1 and idist.backend() != "xla-tpu":
t = {
"a": [rank + 1, rank + 2, torch.tensor(rank + 3, device=device)],
"b": torch.tensor([[rank + 1, rank + 2, rank + 3]], device=device),
"c": {"abcd": rank, "cdfg": torch.tensor(rank, dtype=torch.uint8, device=device)},
}
res = idist.all_gather(t)
assert isinstance(res, list) and len(res) == ws
for i, obj in enumerate(res):
assert isinstance(obj, dict)
assert list(obj.keys()) == ["a", "b", "c"], obj
expected_device = device if device.type == "cpu" else torch.device(f"{device.type}:{i}")
expected = {
"a": [i + 1, i + 2, torch.tensor(i + 3, device=expected_device)],
"b": torch.tensor([[i + 1, i + 2, i + 3]], device=expected_device),
"c": {"abcd": i, "cdfg": torch.tensor(i, dtype=torch.uint8, device=expected_device)},
}
assert obj["a"] == expected["a"]
assert (obj["b"] == expected["b"]).all()
assert obj["c"] == expected["c"]


def _test_distrib_all_gather_group(device):
Expand Down

0 comments on commit 053ff42

Please sign in to comment.