From 8875226d62a7f1164bc0f1d3e08df84e99014f45 Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Thu, 8 Aug 2024 13:26:30 -0700 Subject: [PATCH] [dtensor] multi-dim mesh redistribute follow up (#133023) follow up from https://github.com/pytorch/pytorch/pull/131210 and added one test case from user in https://github.com/pytorch/pytorch/issues/132751 Pull Request resolved: https://github.com/pytorch/pytorch/pull/133023 Approved by: https://github.com/tianyu-l ghstack dependencies: #133022 --- test/distributed/_tensor/test_redistribute.py | 2 ++ torch/distributed/_tensor/_redistribute.py | 15 ++++++++------- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/test/distributed/_tensor/test_redistribute.py b/test/distributed/_tensor/test_redistribute.py index 2648c400c97b5..5634869155052 100644 --- a/test/distributed/_tensor/test_redistribute.py +++ b/test/distributed/_tensor/test_redistribute.py @@ -510,6 +510,7 @@ def test_redistribute_shard_dim_multi_dim_mesh(self): ([Shard(0), Shard(1), Shard(2)], [Shard(2), Shard(1), Shard(0)]), ([Shard(1), Shard(0), Shard(0)], [Replicate(), Shard(0), Shard(0)]), ([Shard(1), Replicate(), Shard(0)], [Replicate(), Shard(0), Shard(0)]), + ([Shard(0), Shard(0), Shard(1)], [Shard(0), Shard(1), Shard(2)]), ] comm_counts_3d = [ 3, # 2: S0 - R, 1: S1 -> R, 0: S0 -> S1 @@ -517,6 +518,7 @@ def test_redistribute_shard_dim_multi_dim_mesh(self): 2, # 2: S2 -> R, 0: S1 -> S2 1, # 0: S1 -> R 2, # 2: S0 -> R, 1: R -> S0, 2: R -> S0, 0: S1 -> R + 2, # 2: S1 -> S2, 1: S0 -> S1 ] comm_mode = CommDebugMode() diff --git a/torch/distributed/_tensor/_redistribute.py b/torch/distributed/_tensor/_redistribute.py index 15d51dbe82928..127bf3e9857d7 100644 --- a/torch/distributed/_tensor/_redistribute.py +++ b/torch/distributed/_tensor/_redistribute.py @@ -118,14 +118,15 @@ def _gen_transform_infos( # mesh_dim, we need to replicate the tensor on the mesh dim first to clear the nested sharding target = Replicate() - transform_infos.append( - _TransformInfo( - mesh_dim=mesh_dim, - src_dst_placements=(current, target), - logical_shape=mesh_dims_to_logical_shape[mesh_dim], + if current != target: + transform_infos.append( + _TransformInfo( + mesh_dim=mesh_dim, + src_dst_placements=(current, target), + logical_shape=mesh_dims_to_logical_shape[mesh_dim], + ) ) - ) - current_placements[mesh_dim] = target + current_placements[mesh_dim] = target # We always traverse from outer placement to inner placement to collect the remaining # needed transform infos (i.e. the replication from nested sharding might need to further