Skip to content

Commit

Permalink
fix BasicShardingTest.test_2d_tensor_3d_mesh on cpu
Browse files Browse the repository at this point in the history
  • Loading branch information
vanbasten23 committed Feb 2, 2024
1 parent 43ef193 commit 960d609
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion test/spmd/test_xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,10 +728,14 @@ def test_2d_tensor_3d_mesh(self):
mesh = self._get_mesh((2, self.n_devices // 2, 1))
xs.mark_sharding(t1, mesh, partition_spec=(2, 1))
sharding_annotation = 'sharding={devices=[1,%d,2]' % (self.n_devices // 2)
else:
elif self.n_devices == 2:
mesh = self._get_mesh((2, 1, 1))
xs.mark_sharding(t1, mesh, partition_spec=(2, 1))
sharding_annotation = "sharding={replicated}"
else:
mesh = self._get_mesh((1, 1, 1))
xs.mark_sharding(t1, mesh, partition_spec=(2, 1))
sharding_annotation = "sharding={replicated}"
self.assertIn(sharding_annotation,
torch_xla._XLAC._get_xla_tensors_hlo([t1]))
actual = (t1 + t2).cpu()
Expand Down

0 comments on commit 960d609

Please sign in to comment.