Skip to content

Commit

Permalink
[dtensor] refactor get_coordiniate (pytorch#95457)
Browse files Browse the repository at this point in the history
This refactor get_coordinate to return a optional[list] instead of
directly the coordinate on dim, this is so that we can check if the
rank is inside the mesh easily

Differential Revision: [D43643579](https://our.internmc.facebook.com/intern/diff/D43643579)
Pull Request resolved: pytorch#95457
Approved by: https://github.com/XilunWu
  • Loading branch information
wanchaol authored and pytorchmergebot committed Feb 28, 2023
1 parent bb9a05b commit 261eb46
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 26 deletions.
2 changes: 1 addition & 1 deletion test/distributed/_spmd/test_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def fn(to_receive: torch.Tensor, to_scatter: List[torch.Tensor]):
# use a local_tensor + 1 for tracing to make sure that we are not
# simply replaying recorded tensor value
to_receive = torch.empty_like(
scattered_tensors[mesh.get_coordinate_on_dim(dim)]
scattered_tensors[mesh.get_coordinate()[dim]]
)
traced_fn = make_fx(fn)(to_receive, [t + 1 for t in scattered_tensors])

Expand Down
6 changes: 3 additions & 3 deletions test/distributed/_tensor/test_device_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ def test_reduce_scatter_nd(self):
contiguous=True,
)
scattered_tensor = torch.empty_like(
local_rs_list[mesh.get_coordinate_on_dim(dim)],
local_rs_list[mesh.get_coordinate()[dim]],
device=self.device_type,
)
global_ranks = [
Expand Down Expand Up @@ -523,7 +523,7 @@ def test_scatter_nd(self):
for global_rank in global_ranks
]
received_tensor = torch.empty_like(
scattered_tensors[mesh.get_coordinate_on_dim(dim)]
scattered_tensors[mesh.get_coordinate()[dim]]
)
mesh.scatter(received_tensor, scattered_tensors, mesh_dim=dim)
self.assertEqual(received_tensor, torch.ones(3, 3) * self.rank)
Expand Down Expand Up @@ -563,7 +563,7 @@ def test_all_to_all_nd(self):
# check all dim groups
dim_to_subgroups = mesh.get_dim_groups()
for dim, dim_group in enumerate(dim_to_subgroups):
my_coordinate = mesh.get_coordinate_on_dim(dim)
my_coordinate = mesh.get_coordinate()[dim]
dim_group_size = get_world_size(dim_group)
global_ranks = [
get_global_rank(dim_group, i) for i in range(dim_group_size)
Expand Down
8 changes: 4 additions & 4 deletions torch/distributed/_tensor/device_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,12 +291,12 @@ def backend(self) -> str:
def get_rank(self) -> int:
return get_rank()

def get_coordinate_on_dim(self, dim: int) -> Optional[int]:
def get_coordinate(self) -> Optional[List[int]]:
"""
Return the relative index of this rank relative to a given
dimension of the mesh. If this rank is not part of the mesh, return None.
"""
return self._coordinate_on_dim[dim] if self._coordinate_on_dim else None
return self._coordinate_on_dim if self._coordinate_on_dim else None

def scatter(
self,
Expand Down Expand Up @@ -473,7 +473,7 @@ def reduce_scatter(
warnings.warn(
"ProcessGroupGloo does not support reduce_scatter, falling back with all reduce!"
)
my_coordinate = self.get_coordinate_on_dim(mesh_dim)
my_coordinate = self.get_coordinate()
# TODO: what should happen if rank is not in the mesh?
# see issue https://github.com/pytorch/tau/pull/492
assert (
Expand All @@ -497,7 +497,7 @@ def reduce_scatter(
flat_tensor, op=op, mesh_dim=mesh_dim, async_op=async_op
)
# scatter the tensor
output_offset = offset_list[my_coordinate]
output_offset = offset_list[my_coordinate[mesh_dim]]
output.copy_(
flat_tensor[output_offset : output_offset + output.numel()].view(
output.shape
Expand Down
24 changes: 12 additions & 12 deletions torch/distributed/_tensor/placement_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def _shard_tensor(
shard and scatter a tensor on a mesh dimension (use coordinate
0 on the mesh dimension as source of truth)
"""
my_coordinate = mesh.get_coordinate_on_dim(mesh_dim)
my_coordinate = mesh.get_coordinate()
num_chunks = mesh.size(dim=mesh_dim)
# TODO: what should happen if rank is not in the mesh?
# see issue https://github.com/pytorch/tau/pull/492
Expand All @@ -121,10 +121,10 @@ def _shard_tensor(
scatter_list, pad_idx = self._split_tensor(
tensor, num_chunks, with_padding=True, contiguous=True
)
output = torch.empty_like(scatter_list[my_coordinate])
output = torch.empty_like(scatter_list[my_coordinate[mesh_dim]])
mesh.scatter(output, scatter_list, mesh_dim=mesh_dim)

if pad_idx != 0 and my_coordinate >= pad_idx:
if pad_idx != 0 and my_coordinate[mesh_dim] >= pad_idx:
output = self._unpad_tensor(output)
return output

Expand All @@ -138,7 +138,7 @@ def _reduce_shard_tensor(
"""
reduce and scatter a tensor on a mesh dimension
"""
my_coordinate = mesh.get_coordinate_on_dim(mesh_dim)
my_coordinate = mesh.get_coordinate()
num_chunks = mesh.size(dim=mesh_dim)
# TODO: what should happen if rank is not in the mesh?
# see issue https://github.com/pytorch/tau/pull/492
Expand All @@ -150,14 +150,14 @@ def _reduce_shard_tensor(
)
# wrap with comm tensor
scattered_list = [CommTensor(t) for t in scattered_list]
output = torch.empty_like(scattered_list[my_coordinate])
output = torch.empty_like(scattered_list[my_coordinate[mesh_dim]])
mesh.reduce_scatter(
CommTensor(output),
scattered_list, # pyre-ignore[6]
op=reduce_op,
mesh_dim=mesh_dim,
)
if pad_idx != 0 and my_coordinate >= pad_idx:
if pad_idx != 0 and my_coordinate[mesh_dim] >= pad_idx:
output = self._unpad_tensor(output)
return output

Expand All @@ -172,7 +172,7 @@ def _to_replicate_tensor(
This function all_gather all shards and return a tensor that
is replicated on the previously sharded mesh dimension
"""
my_coordinate = mesh.get_coordinate_on_dim(mesh_dim)
my_coordinate = mesh.get_coordinate()
num_chunks = mesh.size(dim=mesh_dim)
# TODO: what should happen if rank is not in the mesh?
# see issue https://github.com/pytorch/tau/pull/492
Expand All @@ -181,7 +181,7 @@ def _to_replicate_tensor(
), "Rank if not part of mesh" # TODO: figure out behavior here
# check if it needs to pad input tensor before all_gather
pad_idx = size[self.dim] % num_chunks
if pad_idx != 0 and my_coordinate >= pad_idx:
if pad_idx != 0 and my_coordinate[mesh_dim] >= pad_idx:
local_tensor = self._pad_tensor(local_tensor).contiguous()

gathered_list = []
Expand Down Expand Up @@ -377,15 +377,15 @@ def _local_shape_from_global_shape(
ndim = len(global_shape)
for idx, placement in enumerate(self.placements):
mesh_dim_size = self.mesh.size(idx)
my_coordinate = self.mesh.get_coordinate_on_dim(idx)
my_coordinate = self.mesh.get_coordinate()
assert my_coordinate is not None, "Rank not part of mesh!"
if isinstance(placement, Shard):
shard_dim = placement.dim
assert (
shard_dim < ndim
), f"Sharding dim {shard_dim} greater than tensor ndim {ndim}"
local_shard_size, _ = placement._local_shard_size_on_dim(
local_shape[shard_dim], mesh_dim_size, my_coordinate
local_shape[shard_dim], mesh_dim_size, my_coordinate[idx]
)
assert isinstance(local_shard_size, int)
local_shape[shard_dim] = local_shard_size
Expand Down Expand Up @@ -414,7 +414,7 @@ def local_offsets(self) -> Tuple[int, ...]:

for idx, placement in enumerate(self.placements):
mesh_dim_size = self.mesh.size(idx)
my_coordinate = self.mesh.get_coordinate_on_dim(idx)
my_coordinate = self.mesh.get_coordinate()
assert my_coordinate is not None, "Rank not part of mesh!"
if isinstance(placement, Shard):
shard_dim = placement.dim
Expand All @@ -424,7 +424,7 @@ def local_offsets(self) -> Tuple[int, ...]:
shard_size, shard_offset = placement._local_shard_size_on_dim(
local_shape[shard_dim],
mesh_dim_size,
my_coordinate,
my_coordinate[idx],
return_offset=True,
)
local_shape[shard_dim] = shard_size
Expand Down
6 changes: 3 additions & 3 deletions torch/distributed/_tensor/redistribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def _redistribute_with_local_tensor(
sorted_placements.sort(key=_replicate_then_shard)

for i, (current, target) in sorted_placements:
my_coordinate = device_mesh.get_coordinate_on_dim(i)
my_coordinate = device_mesh.get_coordinate()
num_chunks = device_mesh.size(dim=i)
# TODO: what should happen if rank is not in the mesh?
# see issue https://github.com/pytorch/tau/pull/492
Expand Down Expand Up @@ -131,7 +131,7 @@ def _redistribute_with_local_tensor(
with_padding=False,
contiguous=False,
)
new_local_tensor = shards[my_coordinate].clone()
new_local_tensor = shards[my_coordinate[i]].clone()
else:
# NOTE: this case shouldn't hit _decompose_sharding, decompose sharding should
# decompose Shard(0) -> Shard(1) into Shard(0) -> Replicate -> Shard(1)
Expand All @@ -149,7 +149,7 @@ def _redistribute_with_local_tensor(
if current.is_replicate():
# For replicate -> partial, we zero out all other ranks of the current mesh dim
# and leave only 1 rank have the data, to perform a "zero cost" reshard.
if my_coordinate is not None and my_coordinate != 0:
if my_coordinate[i] != 0:
new_local_tensor = local_tensor.zero_()
else:
new_local_tensor = local_tensor
Expand Down
6 changes: 3 additions & 3 deletions torch/distributed/tensor/parallel/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,9 @@ def _get_box_for(tensor: DistributedTensor, idx: int) -> Tuple[torch.Size, torch

def _get_local_box(tensor: DistributedTensor) -> Tuple[torch.Size, torch.Size]:
device_mesh = tensor.device_mesh
dim_0_coord = device_mesh.get_coordinate_on_dim(0)
assert dim_0_coord is not None
return _get_box_for(tensor, dim_0_coord)
coord = device_mesh.get_coordinate()
assert coord is not None
return _get_box_for(tensor, coord[0])


def _create_shard_md_from_dt(dt: DistributedTensor, current_rank: int) -> ShardMetadata:
Expand Down

0 comments on commit 261eb46

Please sign in to comment.