Skip to content

Commit

Permalink
to remove
Browse files Browse the repository at this point in the history
  • Loading branch information
dudulightricks committed Nov 20, 2024
1 parent 0c79164 commit 5b29a99
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions test/test_pallas_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ def test_flash_attention_spmd_data_parallel(self):
k = torch.randn(4, 2, 128, 4).to("xla")
v = torch.randn(4, 2, 128, 4).to("xla")

o = flash_attention(q, k, v, partition_spec=range(n_devices))
o = flash_attention(q, k, v, partition_spec=range(4))
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(o),
f"{{devices=[{n_devices},1,1,1]0,1,2,3}}")
f"{{devices=[{n_devices},1,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}}")

expected_o = self._attention(q, k, v)
self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05))
Expand All @@ -63,7 +63,7 @@ def test_flash_attention_backward_spmd_data_parallel(self):
k.retain_grad()
v.retain_grad()

o = flash_attention(q, k, v, partition_spec=range(n_devices))
o = flash_attention(q, k, v, partition_spec=range(4))
loss = o.sum()
loss.backward()
xm.mark_step()
Expand All @@ -73,13 +73,13 @@ def test_flash_attention_backward_spmd_data_parallel(self):
v_grad = v.grad
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(q_grad),
f"{{devices=[{n_devices},1,1,1]0,1,2,3}}")
f"{{devices=[{n_devices},1,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}}")
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(k_grad),
f"{{devices=[{n_devices},1,1,1]0,1,2,3}}")
f"{{devices=[{n_devices},1,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}}")
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(v_grad),
f"{{devices=[{n_devices},1,1,1]0,1,2,3}}")
f"{{devices=[{n_devices},1,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}}")

torch.manual_seed(42)
q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
Expand Down

0 comments on commit 5b29a99

Please sign in to comment.