From 8e69c26e220733e46feef622ace3b97f4df655cd Mon Sep 17 00:00:00 2001 From: Dudu Moshe Date: Thu, 21 Nov 2024 16:17:52 +0200 Subject: [PATCH] fdf --- test/test_pallas_spmd.py | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/test/test_pallas_spmd.py b/test/test_pallas_spmd.py index b97f16f0c0d..411f056dec1 100644 --- a/test/test_pallas_spmd.py +++ b/test/test_pallas_spmd.py @@ -112,20 +112,20 @@ def test_flash_attention_spmd_data_parallel_with_segment_ids(self): n_devices = xr.global_runtime_device_count() xs.set_global_mesh(xs.Mesh(range(n_devices), (n_devices, 1, 1, 1))) - q = torch.randn(4, 2, 128, 4).to("xla") - k = torch.randn(4, 2, 128, 4).to("xla") - v = torch.randn(4, 2, 128, 4).to("xla") - q_segment_ids = torch.ones(4, 128, device=q.device, dtype=torch.float32).to("xla") - kv_segment_ids = torch.rand(4, 128).to("xla") + q = torch.randn(16, 32, 2048, 64).to("xla") + k = torch.randn(16, 32, 2048, 64).to("xla") + v = torch.randn(16, 32, 2048, 64).to("xla") + q_segment_ids = torch.ones(16, 2048, device=q.device, dtype=torch.float32).to("xla") + kv_segment_ids = torch.rand(16, 1, 128).to("xla") o = flash_attention(q, k, v, q_segment_ids, kv_segment_ids, partition_spec=range(4)) self.assertEqual( torch_xla._XLAC._get_xla_sharding_spec(o), f"{{devices=[{n_devices},1,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}}") - attention_mask = F.pad(kv_segment_ids, (0, 16256), value=0.0) - attention_mask = attention_mask.repeat_interleave(2, dim=0) - attention_mask = attention_mask.view(4, 2, 128, 128) + # attention_mask = F.pad(kv_segment_ids, (0, 16256), value=0.0) + attention_mask = kv_segment_ids.repeat_interleave(32, dim=0) + attention_mask = attention_mask.view(16, 32, -1, 128) # attention_mask = torch.ones(4, 2, 128, 128).to("xla") # head_size = self.heads # current_length: int = attention_mask.shape[-1] @@ -139,7 +139,15 @@ def test_flash_attention_spmd_data_parallel_with_segment_ids(self): # batch_size, attn.heads, -1, attention_mask.shape[-1] # ) - expected_o = self._attention(q, k, v, attn_mask=attention_mask) + # expected_o = self._attention(q, k, v, attn_mask=attention_mask) + expected_o = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + ) diff = (expected_o - o).abs() # z = torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05)