From 0025ca7662f3a04e9855f965b5061f41089029a5 Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Thu, 13 Jun 2024 21:43:27 -0700 Subject: [PATCH] [SPMD] Fix reduce_scatter for TPUv2 (#7248) Summary: Fix reduce_scatter test for TPUv2 Test Plan: python test/spmd/test_xla_sharding.py -v -k test_spmd_reduce_scatter --- test/spmd/test_xla_sharding.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/test/spmd/test_xla_sharding.py b/test/spmd/test_xla_sharding.py index 2d710a7c7c1..40d3304e6f0 100644 --- a/test/spmd/test_xla_sharding.py +++ b/test/spmd/test_xla_sharding.py @@ -1213,7 +1213,8 @@ def test_manual_sharding_api_e2e(self): self.assertEqual(xxx.shape, (8, 8)) self.assertTrue(torch.allclose(x.cpu() + 1, xxx.cpu())) - @unittest.skipIf(xr.device_type() != 'TPU', "Skip non-TPU device") + @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4, + "Only runs on TPUv4") def test_spmd_reduce_scatter(self): xs.set_global_mesh(self._get_mesh((1, self.n_devices))) x = torch.ones(8, 8).to(xm.xla_device()) @@ -1230,10 +1231,11 @@ def test_spmd_reduce_scatter(self): f"reduce-scatter(f32[8,8]{{1,0}} %custom-call.2), channel_id=1, replica_groups={{{{{','.join([str(x) for x in self.device_ids])}}}}}, use_global_device_ids=true, dimensions={{0}}, to_apply=%AddComputation.3", hlo) - expected_x = torch.ones(2, 8) * 4 + expected_x = torch.ones(8 // self.n_devices, 8) * 4 self.assertTrue(torch.allclose(x.cpu(), expected_x)) - @unittest.skipIf(xr.device_type() != 'TPU', "Skip non-TPU device") + @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4, + "Only runs on TPUv4") def test_spmd_reduce_scatter_canonical_index(self): xs.set_global_mesh(self._get_mesh((1, self.n_devices))) x = torch.ones(8, 8).to(xm.xla_device()) @@ -1250,7 +1252,7 @@ def test_spmd_reduce_scatter_canonical_index(self): f"reduce-scatter(f32[8,8]{{1,0}} %custom-call.2), channel_id=1, replica_groups={{{{{','.join([str(x) for x in self.device_ids])}}}}}, use_global_device_ids=true, dimensions={{1}}, to_apply=%AddComputation.3", hlo) - expected_x = torch.ones(8, 2) * 4 + expected_x = torch.ones(8, 8 // self.n_devices) * 4 self.assertTrue(torch.allclose(x.cpu(), expected_x))