diff --git a/test/test_pallas_spmd.py b/test/test_pallas_spmd.py index e88b8b2caff..f470368243b 100644 --- a/test/test_pallas_spmd.py +++ b/test/test_pallas_spmd.py @@ -39,10 +39,10 @@ def test_flash_attention_spmd_data_parallel(self): k = torch.randn(4, 2, 128, 4).to("xla") v = torch.randn(4, 2, 128, 4).to("xla") - o = flash_attention(q, k, v, partition_spec=range(n_devices)) + o = flash_attention(q, k, v, partition_spec=range(4)) self.assertEqual( torch_xla._XLAC._get_xla_sharding_spec(o), - f"{{devices=[{n_devices},1,1,1]0,1,2,3}}") + f"{{devices=[{n_devices},1,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}}") expected_o = self._attention(q, k, v) self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05)) @@ -63,7 +63,7 @@ def test_flash_attention_backward_spmd_data_parallel(self): k.retain_grad() v.retain_grad() - o = flash_attention(q, k, v, partition_spec=range(n_devices)) + o = flash_attention(q, k, v, partition_spec=range(4)) loss = o.sum() loss.backward() xm.mark_step() @@ -73,13 +73,13 @@ def test_flash_attention_backward_spmd_data_parallel(self): v_grad = v.grad self.assertEqual( torch_xla._XLAC._get_xla_sharding_spec(q_grad), - f"{{devices=[{n_devices},1,1,1]0,1,2,3}}") + f"{{devices=[{n_devices},1,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}}") self.assertEqual( torch_xla._XLAC._get_xla_sharding_spec(k_grad), - f"{{devices=[{n_devices},1,1,1]0,1,2,3}}") + f"{{devices=[{n_devices},1,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}}") self.assertEqual( torch_xla._XLAC._get_xla_sharding_spec(v_grad), - f"{{devices=[{n_devices},1,1,1]0,1,2,3}}") + f"{{devices=[{n_devices},1,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}}") torch.manual_seed(42) q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")