diff --git a/test/test_pallas_spmd.py b/test/test_pallas_spmd.py index 58f620070a9..f705fa976e4 100644 --- a/test/test_pallas_spmd.py +++ b/test/test_pallas_spmd.py @@ -115,7 +115,7 @@ def test_flash_attention_spmd_data_parallel_with_segment_ids(self): k = torch.randn(4, 2, 128, 4).to("xla") v = torch.randn(4, 2, 128, 4).to("xla") q_segment_ids = torch.ones(4, 128).to("xla") - kv_segment_ids = mask = torch.randn(4, 128).to("xla") + kv_segment_ids = torch.randn(4, 2, 128, 128).to("xla") o = flash_attention(q, k, v, q_segment_ids, kv_segment_ids, partition_spec=range(4)) self.assertEqual(