Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for all_gather object #3047

Merged
merged 6 commits into from
Aug 31, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions ignite/distributed/comp_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@
return tensor

def _collective_op(
self, tensor: Union[torch.Tensor, float, str], fn: Callable, *args: Any, **kwargs: Any
self, tensor: Union[torch.Tensor, Number, str], fn: Callable, *args: Any, **kwargs: Any
) -> Union[torch.Tensor, float, List[float], List[str]]:
tensor_to_number = tensor_to_str = False
device = self.device()
Expand Down Expand Up @@ -216,10 +216,10 @@
return cast(Union[torch.Tensor, float], self._collective_op(tensor, self._do_all_reduce, op, group=group))

def all_gather(
self, tensor: Union[torch.Tensor, float, str], group: Optional[Any] = None
self, tensor: Union[torch.Tensor, float, str, Any], group: Optional[Any] = None
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved
sadra-barikbin marked this conversation as resolved.
Show resolved Hide resolved
) -> Union[torch.Tensor, float, List[float], List[str]]:
if not isinstance(tensor, (torch.Tensor, Number, str)):
raise TypeError(f"Unhandled input type {type(tensor)}")
return self._do_all_gather_object(tensor, group=group)
sadra-barikbin marked this conversation as resolved.
Show resolved Hide resolved

return self._collective_op(tensor, self._do_all_gather, group=group)

Expand Down Expand Up @@ -282,6 +282,10 @@
def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None) -> torch.Tensor:
pass

@abstractmethod
def _do_all_gather_object(self, tensor: Any, group: Optional[Any] = None) -> List[Any]:
pass

Check warning on line 287 in ignite/distributed/comp_models/base.py

View check run for this annotation

Codecov / codecov/patch

ignite/distributed/comp_models/base.py#L287

Added line #L287 was not covered by tests

@abstractmethod
def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
pass
Expand Down Expand Up @@ -373,6 +377,9 @@
def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None) -> torch.Tensor:
return tensor

def _do_all_gather_object(self, tensor: Any, group: Optional[Any] = None) -> Any:
return tensor

Check warning on line 381 in ignite/distributed/comp_models/base.py

View check run for this annotation

Codecov / codecov/patch

ignite/distributed/comp_models/base.py#L381

Added line #L381 was not covered by tests

def _do_new_group(self, ranks: List[int], **kwargs: Any) -> Any:
return ranks

Expand Down
6 changes: 6 additions & 0 deletions ignite/distributed/comp_models/horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,12 @@
tensor = tensor.unsqueeze(0)
return hvd.allgather(tensor)

def _do_all_gather_object(self, tensor: Any, group: Optional[Any] = None) -> List[Any]:
if group is not None:
raise NotImplementedError("all_gather with group for horovod is not implemented")

Check warning on line 197 in ignite/distributed/comp_models/horovod.py

View check run for this annotation

Codecov / codecov/patch

ignite/distributed/comp_models/horovod.py#L197

Added line #L197 was not covered by tests
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved

return hvd.allgather_object(tensor)

def _do_new_group(self, ranks: List[int], **kwargs: Any) -> Any:
return hvd.ProcessSet(ranks)

Expand Down
15 changes: 15 additions & 0 deletions ignite/distributed/comp_models/native.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,7 @@
if group is not None and not isinstance(group, dist.ProcessGroup):
raise ValueError("Argument group should be list of int or ProcessGroup")
reduce_op = self._reduce_op_map[op]
# we do if/else here for compatbility with older pytorch versions
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved
if group is not None:
dist.all_reduce(tensor, reduce_op, group=group)
else:
Expand All @@ -441,12 +442,26 @@
if tensor.ndimension() == 0:
tensor = tensor.unsqueeze(0)
output = [torch.zeros_like(tensor) for _ in range(group_size)]
# we do if/else here for compatbility with older pytorch versions
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved
if group is not None:
dist.all_gather(output, tensor, group=group)
else:
dist.all_gather(output, tensor)
return torch.cat(output, dim=0)

def _do_all_gather_object(self, tensor: Any, group: Optional[Any] = None) -> List[Any]:
if group == dist.GroupMember.NON_GROUP_MEMBER:
return tensor

Check warning on line 454 in ignite/distributed/comp_models/native.py

View check run for this annotation

Codecov / codecov/patch

ignite/distributed/comp_models/native.py#L454

Added line #L454 was not covered by tests
elif group is None:
group_size = self.get_world_size()
elif isinstance(group, dist.ProcessGroup):
group_size = group.size()

Check warning on line 458 in ignite/distributed/comp_models/native.py

View check run for this annotation

Codecov / codecov/patch

ignite/distributed/comp_models/native.py#L457-L458

Added lines #L457 - L458 were not covered by tests
else:
raise ValueError("Argument group should be list of int or ProcessGroup")

Check warning on line 460 in ignite/distributed/comp_models/native.py

View check run for this annotation

Codecov / codecov/patch

ignite/distributed/comp_models/native.py#L460

Added line #L460 was not covered by tests
output = [None for _ in range(group_size)]
dist.all_gather_object(output, tensor, group=group)
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved
return output

def _do_new_group(self, ranks: List[int], **kwargs: Any) -> Any:
return dist.new_group(ranks=ranks, **kwargs)

Expand Down
3 changes: 3 additions & 0 deletions ignite/distributed/comp_models/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,9 @@
xm.all_reduce("sum", [output], groups=group)
return output.reshape(-1, *output.shape[2:])

def _do_all_gather_object(self, tensor: Any, group: Optional[Any] = None) -> List[Any]:
raise NotImplementedError("all_gather on object is not implemented for xla")

Check warning on line 159 in ignite/distributed/comp_models/xla.py

View check run for this annotation

Codecov / codecov/patch

ignite/distributed/comp_models/xla.py#L159

Added line #L159 was not covered by tests

def _do_new_group(self, ranks: List[int], **kwargs: Any) -> Any:
return [ranks]

Expand Down
40 changes: 30 additions & 10 deletions tests/ignite/distributed/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,21 +156,22 @@ def _test_distrib_all_reduce_group(device):

def _test_distrib_all_gather(device):
rank = idist.get_rank()
ws = idist.get_world_size()

res = torch.tensor(idist.all_gather(10), device=device)
true_res = torch.tensor([10] * idist.get_world_size(), device=device)
true_res = torch.tensor([10] * ws, device=device)
assert (res == true_res).all()

t = torch.tensor(rank, device=device)
res = idist.all_gather(t)
true_res = torch.tensor([i for i in range(idist.get_world_size())], device=device)
true_res = torch.tensor([i for i in range(ws)], device=device)
assert (res == true_res).all()

x = "test-test"
if rank == 0:
x = "abc"
res = idist.all_gather(x)
true_res = ["abc"] + ["test-test"] * (idist.get_world_size() - 1)
true_res = ["abc"] + ["test-test"] * (ws - 1)
assert res == true_res

base_x = "tests/ignite/distributed/utils/test_native.py" * 2000
Expand All @@ -179,22 +180,41 @@ def _test_distrib_all_gather(device):
x = "abc"

res = idist.all_gather(x)
true_res = ["abc"] + [base_x] * (idist.get_world_size() - 1)
true_res = ["abc"] + [base_x] * (ws - 1)
assert res == true_res

t = torch.arange(100, device=device).reshape(4, 25) * (rank + 1)
in_dtype = t.dtype
res = idist.all_gather(t)
assert res.shape == (idist.get_world_size() * 4, 25)
assert res.shape == (ws * 4, 25)
assert res.dtype == in_dtype
true_res = torch.zeros(idist.get_world_size() * 4, 25, device=device)
for i in range(idist.get_world_size()):
true_res = torch.zeros(ws * 4, 25, device=device)
for i in range(ws):
true_res[i * 4 : (i + 1) * 4, ...] = torch.arange(100, device=device).reshape(4, 25) * (i + 1)
assert (res == true_res).all()

if idist.get_world_size() > 1:
with pytest.raises(TypeError, match=r"Unhandled input type"):
idist.all_reduce([0, 1, 2])
if ws > 1 and idist.backend() != "xla-tpu":
t = {
"a": [rank + 1, rank + 2, torch.tensor(rank + 3, device=device)],
"b": torch.tensor([[rank + 1, rank + 2, rank + 3]], device=device),
"c": {"abcd": rank, "cdfg": torch.tensor(rank, dtype=torch.uint8, device=device)},
}
res = idist.all_gather(t)
assert isinstance(res, list) and len(res) == ws
for i, obj in enumerate(res):
assert isinstance(obj, dict)
assert list(obj.keys()) == ["a", "b", "c"], obj
expected_device = (
device if torch.device(device).type == "cpu" else torch.device(f"{torch.device(device).type}:{i}")
)
expected = {
"a": [i + 1, i + 2, torch.tensor(i + 3, device=expected_device)],
"b": torch.tensor([[i + 1, i + 2, i + 3]], device=expected_device),
"c": {"abcd": i, "cdfg": torch.tensor(i, dtype=torch.uint8, device=expected_device)},
}
assert obj["a"] == expected["a"]
assert (obj["b"] == expected["b"]).all()
assert obj["c"] == expected["c"]


def _test_distrib_all_gather_group(device):
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved
Expand Down