From 8d612bfa88a7daa039213df4623debafad2b09af Mon Sep 17 00:00:00 2001 From: Dudu Moshe Date: Wed, 20 Nov 2024 17:23:36 +0200 Subject: [PATCH] wip --- test/test_pallas_spmd.py | 43 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 42 insertions(+), 1 deletion(-) diff --git a/test/test_pallas_spmd.py b/test/test_pallas_spmd.py index f470368243b..9ed8674f216 100644 --- a/test/test_pallas_spmd.py +++ b/test/test_pallas_spmd.py @@ -22,8 +22,14 @@ class PallasTest(unittest.TestCase): - def _attention(self, q, k, v): + def _attention(self, q, k, v, *, attn_mask=None, ab=None): attn_weight = q @ k.transpose(-2, -1) + if attn_mask is not None: + # Masked out the unrelevant parts. + attn_weight = attn_weight.masked_fill(attn_mask, + torch.finfo(attn_weight.dtype).min) + if ab is not None: + attn_weight = attn_weight + ab attn_weight = nn.functional.softmax(attn_weight, dim=-1) attn_output = attn_weight @ v return attn_output @@ -98,6 +104,41 @@ def test_flash_attention_backward_spmd_data_parallel(self): self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05)) jax.config.update('jax_default_matmul_precision', "default") + @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3, + "This test only works on TPUv3+.") + def test_flash_attention_spmd_data_parallel_with_segment_ids(self): + jax.config.update('jax_default_matmul_precision', "highest") + n_devices = xr.global_runtime_device_count() + xs.set_global_mesh(xs.Mesh(range(n_devices), (n_devices, 1, 1, 1))) + + q = torch.randn(4, 2, 128, 4).to("xla") + k = torch.randn(4, 2, 128, 4).to("xla") + v = torch.randn(4, 2, 128, 4).to("xla") + q_segment_ids = torch.ones(4, 128, device=query.device, dtype=torch.float32) + kv_segment_ids = torch.rand(4, 128) + + o = flash_attention(q, k, v, q_segment_ids, kv_segment_ids, partition_spec=range(4)) + self.assertEqual( + torch_xla._XLAC._get_xla_sharding_spec(o), + f"{{devices=[{n_devices},1,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}}") + + attention_mask = attention_mask.repeat_interleave(32, dim=0) + attention_mask = attention_mask.view(4, 32, -1, 128) + # head_size = self.heads + # current_length: int = attention_mask.shape[-1] + # if current_length != target_length: + # attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + + # if attention_mask.shape[0] < 4 * head_size: + # attention_mask = attention_mask.repeat_interleave(head_size, dim=0) + # + # attention_mask = attention_mask.view( + # batch_size, attn.heads, -1, attention_mask.shape[-1] + # ) + + expected_o = self._attention(q, k, v, attn_mask=attention_mask) + self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05)) + jax.config.update('jax_default_matmul_precision', "default") if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO)