diff --git a/test/test_pallas_spmd.py b/test/test_pallas_spmd.py index 93186e24633..b97f16f0c0d 100644 --- a/test/test_pallas_spmd.py +++ b/test/test_pallas_spmd.py @@ -10,6 +10,7 @@ import torch_xla.distributed.spmd as xs from torch_xla import runtime as xr from torch_xla._internal import tpu +import torch.nn.functional as F if xr.device_type() == 'TPU': from torch_xla.experimental.custom_kernel import flash_attention @@ -122,8 +123,10 @@ def test_flash_attention_spmd_data_parallel_with_segment_ids(self): torch_xla._XLAC._get_xla_sharding_spec(o), f"{{devices=[{n_devices},1,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}}") - attention_mask = kv_segment_ids.repeat_interleave(32, dim=0) - attention_mask = attention_mask.view(4, 32, -1, 128) + attention_mask = F.pad(kv_segment_ids, (0, 16256), value=0.0) + attention_mask = attention_mask.repeat_interleave(2, dim=0) + attention_mask = attention_mask.view(4, 2, 128, 128) + # attention_mask = torch.ones(4, 2, 128, 128).to("xla") # head_size = self.heads # current_length: int = attention_mask.shape[-1] # if current_length != target_length: @@ -137,6 +140,17 @@ def test_flash_attention_spmd_data_parallel_with_segment_ids(self): # ) expected_o = self._attention(q, k, v, attn_mask=attention_mask) + diff = (expected_o - o).abs() + # z = torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05) + + true_count = torch.sum(diff < 0.00001).item() + false_count = diff.numel() - true_count + + print(f"Number of True: {true_count}") + print(f"Number of False: {false_count}") + + print(diff.max().cpu()) + print(diff.min().cpu()) self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05)) jax.config.update('jax_default_matmul_precision', "default")