Skip to content

Commit

Permalink
k
Browse files Browse the repository at this point in the history
  • Loading branch information
dudulightricks committed Nov 21, 2024
1 parent 45c5201 commit ac0ff16
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions test/test_pallas_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ def test_flash_attention_spmd_data_parallel(self):
k = torch.randn(4, 2, 128, 4).to("xla")
v = torch.randn(4, 2, 128, 4).to("xla")

o = flash_attention(q, k, v, partition_spec=range(n_devices))
o = flash_attention(q, k, v, partition_spec=range(4))
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(o),
f"{{devices=[{n_devices},1,1,1]0,1,2,3}}")
f"{{devices=[{n_devices},1,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}}")

expected_o = self._attention(q, k, v)
self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05))
Expand All @@ -69,7 +69,7 @@ def test_flash_attention_backward_spmd_data_parallel(self):
k.retain_grad()
v.retain_grad()

o = flash_attention(q, k, v, partition_spec=range(n_devices))
o = flash_attention(q, k, v, partition_spec=range(4))
loss = o.sum()
loss.backward()
xm.mark_step()
Expand All @@ -79,13 +79,13 @@ def test_flash_attention_backward_spmd_data_parallel(self):
v_grad = v.grad
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(q_grad),
f"{{devices=[{n_devices},1,1,1]0,1,2,3}}")
f"{{devices=[{n_devices},1,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}}")
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(k_grad),
f"{{devices=[{n_devices},1,1,1]0,1,2,3}}")
f"{{devices=[{n_devices},1,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}}")
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(v_grad),
f"{{devices=[{n_devices},1,1,1]0,1,2,3}}")
f"{{devices=[{n_devices},1,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}}")

torch.manual_seed(42)
q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
Expand Down Expand Up @@ -122,10 +122,10 @@ def test_flash_attention_spmd_data_parallel_with_segment_ids(self):
kv_segment_ids[8:, :, 60:] = -10000.0

o = flash_attention(
q, k, v, q_segment_ids, kv_segment_ids, partition_spec=range(n_devices))
q, k, v, q_segment_ids, kv_segment_ids, partition_spec=range(4))
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(o),
f"{{devices=[{n_devices},1,1,1]0,1,2,3}}")
f"{{devices=[{n_devices},1,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}}")

attention_mask = kv_segment_ids.repeat_interleave(32, dim=0)
attention_mask = attention_mask.view(16, 32, 1, 128)
Expand Down

0 comments on commit ac0ff16

Please sign in to comment.