Skip to content

Commit

Permalink
ytyt
Browse files Browse the repository at this point in the history
  • Loading branch information
dudulightricks committed Nov 20, 2024
1 parent 5c13ea2 commit b231efa
Showing 1 changed file with 16 additions and 2 deletions.
18 changes: 16 additions & 2 deletions test/test_pallas_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch_xla.distributed.spmd as xs
from torch_xla import runtime as xr
from torch_xla._internal import tpu
import torch.nn.functional as F

if xr.device_type() == 'TPU':
from torch_xla.experimental.custom_kernel import flash_attention
Expand Down Expand Up @@ -122,8 +123,10 @@ def test_flash_attention_spmd_data_parallel_with_segment_ids(self):
torch_xla._XLAC._get_xla_sharding_spec(o),
f"{{devices=[{n_devices},1,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}}")

attention_mask = kv_segment_ids.repeat_interleave(32, dim=0)
attention_mask = attention_mask.view(4, 32, -1, 128)
attention_mask = F.pad(kv_segment_ids, (0, 16256), value=0.0)
attention_mask = attention_mask.repeat_interleave(2, dim=0)
attention_mask = attention_mask.view(4, 2, 128, 128)
attention_mask = torch.ones(4, 2, 128, 128).to("xla")
# head_size = self.heads
# current_length: int = attention_mask.shape[-1]
# if current_length != target_length:
Expand All @@ -137,6 +140,17 @@ def test_flash_attention_spmd_data_parallel_with_segment_ids(self):
# )

expected_o = self._attention(q, k, v, attn_mask=attention_mask)
diff = (expected_o - o).abs()
# z = torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05)

true_count = torch.sum(diff < 0.00001).item()
false_count = diff.numel() - true_count

print(f"Number of True: {true_count}")
print(f"Number of False: {false_count}")

print(diff.max().cpu())
print(diff.min().cpu())
self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05))
jax.config.update('jax_default_matmul_precision', "default")

Expand Down

0 comments on commit b231efa

Please sign in to comment.