Skip to content

Commit

Permalink
[SPMD] Fix reduce_scatter for TPUv2 (#7248)
Browse files Browse the repository at this point in the history
Summary:
Fix reduce_scatter test for TPUv2

Test Plan:
python test/spmd/test_xla_sharding.py -v -k test_spmd_reduce_scatter
  • Loading branch information
alanwaketan authored Jun 14, 2024
1 parent c216d26 commit 0025ca7
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions test/spmd/test_xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -1213,7 +1213,8 @@ def test_manual_sharding_api_e2e(self):
self.assertEqual(xxx.shape, (8, 8))
self.assertTrue(torch.allclose(x.cpu() + 1, xxx.cpu()))

@unittest.skipIf(xr.device_type() != 'TPU', "Skip non-TPU device")
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
"Only runs on TPUv4")
def test_spmd_reduce_scatter(self):
xs.set_global_mesh(self._get_mesh((1, self.n_devices)))
x = torch.ones(8, 8).to(xm.xla_device())
Expand All @@ -1230,10 +1231,11 @@ def test_spmd_reduce_scatter(self):
f"reduce-scatter(f32[8,8]{{1,0}} %custom-call.2), channel_id=1, replica_groups={{{{{','.join([str(x) for x in self.device_ids])}}}}}, use_global_device_ids=true, dimensions={{0}}, to_apply=%AddComputation.3",
hlo)

expected_x = torch.ones(2, 8) * 4
expected_x = torch.ones(8 // self.n_devices, 8) * 4
self.assertTrue(torch.allclose(x.cpu(), expected_x))

@unittest.skipIf(xr.device_type() != 'TPU', "Skip non-TPU device")
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
"Only runs on TPUv4")
def test_spmd_reduce_scatter_canonical_index(self):
xs.set_global_mesh(self._get_mesh((1, self.n_devices)))
x = torch.ones(8, 8).to(xm.xla_device())
Expand All @@ -1250,7 +1252,7 @@ def test_spmd_reduce_scatter_canonical_index(self):
f"reduce-scatter(f32[8,8]{{1,0}} %custom-call.2), channel_id=1, replica_groups={{{{{','.join([str(x) for x in self.device_ids])}}}}}, use_global_device_ids=true, dimensions={{1}}, to_apply=%AddComputation.3",
hlo)

expected_x = torch.ones(8, 2) * 4
expected_x = torch.ones(8, 8 // self.n_devices) * 4
self.assertTrue(torch.allclose(x.cpu(), expected_x))


Expand Down

0 comments on commit 0025ca7

Please sign in to comment.