From 78c6840c1c06039f5f18860e8d639ead5d3e086f Mon Sep 17 00:00:00 2001 From: JackCaoG <59073027+JackCaoG@users.noreply.github.com> Date: Tue, 18 Jul 2023 10:20:57 -0700 Subject: [PATCH] Support unordered sharding spec for partial replication (#5316) * Suport unordered sharding spec for partial replication * add 4d test * handle 2d tensor with 2d mesh case * refactoring --- test/spmd/test_xla_sharding.py | 89 ++++++++++++++++++++++++++ torch_xla/experimental/xla_sharding.py | 26 ++++---- 2 files changed, 101 insertions(+), 14 deletions(-) diff --git a/test/spmd/test_xla_sharding.py b/test/spmd/test_xla_sharding.py index 290c396f62a..0d1ae1c8046 100644 --- a/test/spmd/test_xla_sharding.py +++ b/test/spmd/test_xla_sharding.py @@ -301,6 +301,95 @@ def test_mark_sharding_partial(self): actual = (xt1 @ t2).cpu() self.assertTrue(torch.allclose(expected, actual)) + def test_mark_sharding_not_ordered_partial_3d(self): + device = xm.xla_device() + t1 = torch.randn(8, 16, 32).to(device) + t2 = torch.randn(8, 16, 32).to(device) + # Somehow the eager cpu result is different from the xla result. + expected = t1 + t2 + # To re-materialize t1 and t2. + xm.mark_step() + xm.wait_device_ops() + expected = expected.cpu() + + # Shard along two axes if four or more devices are available + z_dim = 2 if self.n_devices >= 4 else 1 + mesh = self._get_mesh((z_dim, 1, self.n_devices // z_dim)) + + # Expect local shard size to be [8, 16 / z_dim, 32] + xt1 = xs.mark_sharding(t1, mesh, (1, 0, None)) + + for local_shard in xt1.local_shards: + self.assertEqual(local_shard.data.size()[0], 8) + self.assertEqual(local_shard.data.size()[1], 16 / z_dim) + self.assertEqual(local_shard.data.size()[2], 32) + + # partial replication requires >1 devices; otherwise, it's replicated. + if self.n_devices > 1: + # xt1 is sharded `z_dim`-way, replicated `n_devices/z_dim`-way. + self.assertTrue('last_tile_dim_replicate' in + torch_xla._XLAC._get_xla_sharding_spec(t1)) + self.assertTrue('[%d,%d,1,%d]' % + (1, z_dim, self.n_devices // + z_dim) in torch_xla._XLAC._get_xla_sharding_spec(t1)) + actual = (xt1 + t2).cpu() + self.assertTrue(torch.allclose(expected, actual)) + + def test_mark_sharding_not_ordered_partial_4d(self): + device = xm.xla_device() + t1 = torch.randn(8, 16, 32, 64).to(device) + t2 = torch.randn(8, 16, 32, 64).to(device) + # Somehow the eager cpu result is different from the xla result. + expected = t1 + t2 + # To re-materialize t1 and t2. + xm.mark_step() + xm.wait_device_ops() + expected = expected.cpu() + + # Shard along two axes if four or more devices are available + z_dim = 2 if self.n_devices >= 4 else 1 + mesh = self._get_mesh((z_dim, 1, 1, self.n_devices // z_dim)) + + # Expect local shard size to be [8, 16, 32 / z_dim, 64] + xt1 = xs.mark_sharding(t1, mesh, (2, None, 0, None)) + + for local_shard in xt1.local_shards: + self.assertEqual(local_shard.data.size()[0], 8) + self.assertEqual(local_shard.data.size()[1], 16) + self.assertEqual(local_shard.data.size()[2], 32 / z_dim) + self.assertEqual(local_shard.data.size()[3], 64) + + # partial replication requires >1 devices; otherwise, it's replicated. + if self.n_devices > 1: + # xt1 is sharded `z_dim`-way, replicated `n_devices/z_dim`-way. + self.assertTrue('last_tile_dim_replicate' in + torch_xla._XLAC._get_xla_sharding_spec(t1)) + self.assertTrue('[1,1,%d,1,%d]' % + (z_dim, + (self.n_devices // + z_dim)) in torch_xla._XLAC._get_xla_sharding_spec(t1)) + actual = (xt1 + t2).cpu() + self.assertTrue(torch.allclose(expected, actual)) + + def test_mark_sharding_not_ordered_2d_tensor_3d_mesh(self): + ct1 = torch.randn(16, 16, device='cpu') + ct2 = torch.randn(16, 16, device='cpu') + expected = ct1 + ct2 + + t1 = ct1.to(xm.xla_device()) + t2 = ct2.to(xm.xla_device()) + mesh = self._get_mesh((1, self.n_devices, 1)) + # sharding spec here is not ordered. + xt1 = xs.mark_sharding(t1, mesh, partition_spec=(2, 1)) + if self.n_devices > 1: + hlo = torch_xla._XLAC._get_xla_tensors_hlo([xt1.global_tensor]) + sharding_annotation = 'sharding={devices=[1,1,%d]%s}' % ( + self.n_devices, ','.join( + [str(d) for d in mesh.get_logical_mesh().flatten()])) + self.assertIn(sharding_annotation, hlo) + actual = (xt1 + t2).cpu() + self.assertTrue(torch.allclose(expected, actual)) + def test_partial_replication_addmm(self): device = xm.xla_device() z_dim = 2 if self.n_devices >= 4 else 1 diff --git a/torch_xla/experimental/xla_sharding.py b/torch_xla/experimental/xla_sharding.py index ee7e32117b4..6d841a33b8d 100644 --- a/torch_xla/experimental/xla_sharding.py +++ b/torch_xla/experimental/xla_sharding.py @@ -323,10 +323,6 @@ def _get_tile_assignment(mesh: Mesh, partition_spec: Tuple[Union[int, None]]) -> List[int]: # Use Torch.tensor here to make use of the torch.transpose_ mesh_list_tensor = torch.tensor(mesh.get_logical_mesh().tolist()) - # This is partial sharding case, tile_assigniment will be ignore in favor of - # group_assignment and replication_groups. - if (mesh_list_tensor.dim() != len(partition_spec)): - return mesh_list_tensor.tolist() partition_spec_list = list(partition_spec) for i in range(len(partition_spec_list)): if partition_spec_list[i] == None: @@ -339,26 +335,28 @@ def _get_tile_assignment(mesh: Mesh, return mesh_list_tensor.permute(partition_spec_list).tolist() -def _get_group_assignment( - sharding_type: ShardingType, mesh: Mesh, - partition_spec: Tuple[Union[int, None]]) -> Tuple[List, List]: +def _get_group_assignment(sharding_type: ShardingType, mesh: Mesh, + partition_spec: Tuple[Union[int, None]], + tile_assignment: List) -> Tuple[List, List]: group_assignment = list() replication_groups = list() + # TODO(JackCaoG): 3d mesh on 2d tensor + mesh_shape_list = list(torch.tensor(tile_assignment).size()) if sharding_type is ShardingType.PARTIAL: # Shard across groups and replicate within subgroups; replicated dims # will be used to group replication devices. tile_dims = [d for d in partition_spec if d is not None] - replicated_dims = set(range(len(mesh.mesh_shape))) - set(tile_dims) + replicated_dims = set(range(len(mesh_shape_list))) - set(tile_dims) - group_list = [np.array(mesh.get_logical_mesh().tolist())] + group_list = [np.array(tile_assignment)] for d in tile_dims: _group_list = list() for group_members in group_list: - _group_list += np.split(group_members, mesh.mesh_shape[d], d) + _group_list += np.split(group_members, mesh_shape_list[d], d) group_list = _group_list replication_groups = [group.flatten().tolist() for group in group_list] - group_tile_shape = list(mesh.mesh_shape) + group_tile_shape = mesh_shape_list for d in replicated_dims: group_tile_shape[d] = 1 group_assignment = np.arange(len(replication_groups)).reshape( @@ -415,7 +413,6 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh, assert len(specs) == len(np.unique(specs)), \ f"Each device mesh dimension should appear at most once in partition_spec {partition_spec}." - tile_assignment = _get_tile_assignment(mesh, partition_spec) # check for sharding 2D tensor on a 3D mesh original_shape = tuple(t.shape) # number of dims to expand on tensor @@ -426,9 +423,10 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh, shape = (1,) * tensor_expand + (*original_shape,) t = t.expand(shape) + tile_assignment = _get_tile_assignment(mesh, partition_spec) sharding_type = _get_sharding_type(partition_spec, num_devices) group_assignment, replication_groups = _get_group_assignment( - sharding_type, mesh, partition_spec) + sharding_type, mesh, partition_spec, tile_assignment) def tensor_squeeze(t, tensor_expand): if tensor_expand: @@ -484,7 +482,7 @@ def __post_init__(self): self._sharding_type = _get_sharding_type(partition_spec, xr.global_device_count()) self._group_assignment, self._replication_groups = _get_group_assignment( - self._sharding_type, mesh, partition_spec) + self._sharding_type, mesh, partition_spec, self._tile_assignment) def xla_spec(self, t: torch.Tensor) -> Union['XlaShardingSpec', None]: """