Skip to content

Commit

Permalink
Update unit tests and fix typos
Browse files Browse the repository at this point in the history
  • Loading branch information
wonjoolee95 committed Apr 10, 2024
1 parent b6822a3 commit 72fdd57
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 13 deletions.
44 changes: 32 additions & 12 deletions test/test_pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,18 +541,38 @@ def test_tpu_custom_call_pallas_wrap_paged_attention(self):
from jax.experimental.pallas.ops.tpu.paged_attention.paged_attention_kernel import paged_attention
from torch_xla.experimental.custom_kernel import make_kernel_from_pallas
paged_attention_kernel = make_kernel_from_pallas(
paged_attention, lambda q, k, v: [(q.shape, q.dtype)])

q_mini = torch.arange(128 * 4, dtype=torch.bfloat16).reshape(128, 4) / 13
k_mini = torch.arange(
1000, 1000 + 128 * 4, dtype=torch.bfloat16).reshape(128, 4) / 13
q = q_mini.broadcast_to(3, 2, 128, 4).to("xla")
k = k_mini.broadcast_to(3, 2, 128, 4).to("xla")
v = torch.ones(3, 2, 128, 4, dtype=torch.bfloat16).to("xla")

o = paged_attention_kernel(q, k, v)
expected_o = self._attention(q, k, v)
self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu()))
paged_attention, lambda q, k_pages, v_pages, lengths, page_indices,
pages_per_compute_block: [(q.shape, q.dtype)])

batch_size = 4
max_kv_len = 2048
block_size = 512
seq_lens = torch.tensor(
[max_kv_len // batch_size * (i + 1) for i in range(batch_size)])
q, k_pages, v_pages, page_indices = _pagedattention_generate_qkv(
seq_lens,
page_size,
max_kv_len,
num_kv_heads,
num_kv_heads * q_kv_head_ratio,
head_dim,
dtype,
)
o = paged_attention_kernel(
q,
k_pages,
v_pages,
seq_lens,
page_indices,
pages_per_compute_block=block_size // page_size,
)
k = _pagedattention_reconstruct_kv(page_indices, k_pages)
v = _pagedattention_reconstruct_kv(page_indices, v_pages)

o_expected = _pagedattention_grouped_query_attention_reference(
q, k, v, seq_lens)

self.assertTrue(torch.allclose(o.cpu(), o_ref.cpu()))


if __name__ == '__main__':
Expand Down
3 changes: 2 additions & 1 deletion torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,8 @@ def paged_attention_xla(q: torch.Tensor, k_pages: torch.Tensor,
v_pages: torch.Tensor, lengths: torch.Tensor,
page_indices: torch.Tensor,
pages_per_compute_block: int):
return flash_attention(q, k, v, causal=causal)
return paged_attention(q, k_pages, v_pages, lengths, page_indices,
pages_per_compute_block)


@impl(XLA_LIB, "paged_attention", "CompositeExplicitAutograd")
Expand Down

0 comments on commit 72fdd57

Please sign in to comment.