Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
dudulightricks committed Nov 20, 2024
1 parent 5b29a99 commit 8d612bf
Showing 1 changed file with 42 additions and 1 deletion.
43 changes: 42 additions & 1 deletion test/test_pallas_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,14 @@

class PallasTest(unittest.TestCase):

def _attention(self, q, k, v):
def _attention(self, q, k, v, *, attn_mask=None, ab=None):
attn_weight = q @ k.transpose(-2, -1)
if attn_mask is not None:
# Masked out the unrelevant parts.
attn_weight = attn_weight.masked_fill(attn_mask,
torch.finfo(attn_weight.dtype).min)
if ab is not None:
attn_weight = attn_weight + ab
attn_weight = nn.functional.softmax(attn_weight, dim=-1)
attn_output = attn_weight @ v
return attn_output
Expand Down Expand Up @@ -98,6 +104,41 @@ def test_flash_attention_backward_spmd_data_parallel(self):
self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05))
jax.config.update('jax_default_matmul_precision', "default")

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
def test_flash_attention_spmd_data_parallel_with_segment_ids(self):
jax.config.update('jax_default_matmul_precision', "highest")
n_devices = xr.global_runtime_device_count()
xs.set_global_mesh(xs.Mesh(range(n_devices), (n_devices, 1, 1, 1)))

q = torch.randn(4, 2, 128, 4).to("xla")
k = torch.randn(4, 2, 128, 4).to("xla")
v = torch.randn(4, 2, 128, 4).to("xla")
q_segment_ids = torch.ones(4, 128, device=query.device, dtype=torch.float32)
kv_segment_ids = torch.rand(4, 128)

o = flash_attention(q, k, v, q_segment_ids, kv_segment_ids, partition_spec=range(4))
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(o),
f"{{devices=[{n_devices},1,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}}")

attention_mask = attention_mask.repeat_interleave(32, dim=0)
attention_mask = attention_mask.view(4, 32, -1, 128)
# head_size = self.heads
# current_length: int = attention_mask.shape[-1]
# if current_length != target_length:
# attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)

# if attention_mask.shape[0] < 4 * head_size:
# attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
#
# attention_mask = attention_mask.view(
# batch_size, attn.heads, -1, attention_mask.shape[-1]
# )

expected_o = self._attention(q, k, v, attn_mask=attention_mask)
self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05))
jax.config.update('jax_default_matmul_precision', "default")

if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
Expand Down

0 comments on commit 8d612bf

Please sign in to comment.