Skip to content

Commit

Permalink
[dtensor] multi-dim mesh redistribute follow up (pytorch#133023)
Browse files Browse the repository at this point in the history
follow up from pytorch#131210

and added one test case from user in

pytorch#132751

Pull Request resolved: pytorch#133023
Approved by: https://github.com/tianyu-l
ghstack dependencies: pytorch#133022
  • Loading branch information
wanchaol authored and pytorchmergebot committed Aug 9, 2024
1 parent 3b7edc1 commit 8875226
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 7 deletions.
2 changes: 2 additions & 0 deletions test/distributed/_tensor/test_redistribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,13 +510,15 @@ def test_redistribute_shard_dim_multi_dim_mesh(self):
([Shard(0), Shard(1), Shard(2)], [Shard(2), Shard(1), Shard(0)]),
([Shard(1), Shard(0), Shard(0)], [Replicate(), Shard(0), Shard(0)]),
([Shard(1), Replicate(), Shard(0)], [Replicate(), Shard(0), Shard(0)]),
([Shard(0), Shard(0), Shard(1)], [Shard(0), Shard(1), Shard(2)]),
]
comm_counts_3d = [
3, # 2: S0 - R, 1: S1 -> R, 0: S0 -> S1
3, # 2: S0 -> R, 1: S1 -> R, 0: S0 -> S1, 1: R -> S0, 2: R -> S0
2, # 2: S2 -> R, 0: S1 -> S2
1, # 0: S1 -> R
2, # 2: S0 -> R, 1: R -> S0, 2: R -> S0, 0: S1 -> R
2, # 2: S1 -> S2, 1: S0 -> S1
]

comm_mode = CommDebugMode()
Expand Down
15 changes: 8 additions & 7 deletions torch/distributed/_tensor/_redistribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,14 +118,15 @@ def _gen_transform_infos(
# mesh_dim, we need to replicate the tensor on the mesh dim first to clear the nested sharding
target = Replicate()

transform_infos.append(
_TransformInfo(
mesh_dim=mesh_dim,
src_dst_placements=(current, target),
logical_shape=mesh_dims_to_logical_shape[mesh_dim],
if current != target:
transform_infos.append(
_TransformInfo(
mesh_dim=mesh_dim,
src_dst_placements=(current, target),
logical_shape=mesh_dims_to_logical_shape[mesh_dim],
)
)
)
current_placements[mesh_dim] = target
current_placements[mesh_dim] = target

# We always traverse from outer placement to inner placement to collect the remaining
# needed transform infos (i.e. the replication from nested sharding might need to further
Expand Down

0 comments on commit 8875226

Please sign in to comment.