diff --git a/test/test_pallas_spmd.py b/test/test_pallas_spmd.py index 3dacdb573bc..89a798a9406 100644 --- a/test/test_pallas_spmd.py +++ b/test/test_pallas_spmd.py @@ -45,10 +45,10 @@ def test_flash_attention_spmd_data_parallel(self): k = torch.randn(4, 2, 128, 4).to("xla") v = torch.randn(4, 2, 128, 4).to("xla") - o = flash_attention(q, k, v, partition_spec=range(n_devices)) + o = flash_attention(q, k, v, partition_spec=range(4)) self.assertEqual( torch_xla._XLAC._get_xla_sharding_spec(o), - f"{{devices=[{n_devices},1,1,1]0,1,2,3}}") + f"{{devices=[{n_devices},1,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}}") expected_o = self._attention(q, k, v) self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05)) @@ -69,7 +69,7 @@ def test_flash_attention_backward_spmd_data_parallel(self): k.retain_grad() v.retain_grad() - o = flash_attention(q, k, v, partition_spec=range(n_devices)) + o = flash_attention(q, k, v, partition_spec=range(4)) loss = o.sum() loss.backward() xm.mark_step() @@ -79,13 +79,13 @@ def test_flash_attention_backward_spmd_data_parallel(self): v_grad = v.grad self.assertEqual( torch_xla._XLAC._get_xla_sharding_spec(q_grad), - f"{{devices=[{n_devices},1,1,1]0,1,2,3}}") + f"{{devices=[{n_devices},1,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}}") self.assertEqual( torch_xla._XLAC._get_xla_sharding_spec(k_grad), - f"{{devices=[{n_devices},1,1,1]0,1,2,3}}") + f"{{devices=[{n_devices},1,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}}") self.assertEqual( torch_xla._XLAC._get_xla_sharding_spec(v_grad), - f"{{devices=[{n_devices},1,1,1]0,1,2,3}}") + f"{{devices=[{n_devices},1,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}}") torch.manual_seed(42) q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") @@ -122,10 +122,10 @@ def test_flash_attention_spmd_data_parallel_with_segment_ids(self): kv_segment_ids[8:, :, 60:] = -10000.0 o = flash_attention( - q, k, v, q_segment_ids, kv_segment_ids, partition_spec=range(n_devices)) + q, k, v, q_segment_ids, kv_segment_ids, partition_spec=range(4)) self.assertEqual( torch_xla._XLAC._get_xla_sharding_spec(o), - f"{{devices=[{n_devices},1,1,1]0,1,2,3}}") + f"{{devices=[{n_devices},1,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}}") attention_mask = kv_segment_ids.repeat_interleave(32, dim=0) attention_mask = attention_mask.view(16, 32, 1, 128)