Skip to content

Commit

Permalink
Update unit tests and fix typos
Browse files Browse the repository at this point in the history
  • Loading branch information
wonjoolee95 committed Apr 11, 2024
1 parent b6822a3 commit 45d9fe3
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 18 deletions.
58 changes: 42 additions & 16 deletions test/test_pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def _attention(self, q, k, v):
# The following helper functions prefixed with _pagedattention are used for PagedAttention unit tests
# Reference: https://github.com/google/jax/blob/main/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py
def _pagedattention_generate_qkv(
self,
seq_lens,
page_size,
max_seq_len,
Expand All @@ -37,11 +38,10 @@ def _pagedattention_generate_qkv(
head_dim,
dtype=torch.float32,
):
assert max_seq_len % page_size == 0
# assert max_seq_len % page_size == 0
pages_per_sequence = max_seq_len // page_size
batch_size = len(seq_lens)
total_pages = batch_size * pages_per_sequence
k1, k2, k3, k4 = jax.random.split(prng_key, 4)
k_pages = torch.randn(
num_kv_heads, total_pages, page_size, head_dim, dtype=dtype)
v_pages = torch.randn(
Expand All @@ -52,7 +52,7 @@ def _pagedattention_generate_qkv(
q = torch.randn(batch_size, num_heads, head_dim, dtype=dtype)
return q, k_pages, v_pages, page_indices

def _pagedattention_reconstruct_kv(page_indices, pages):
def _pagedattention_reconstruct_kv(self, page_indices, pages):
batch_size = page_indices.shape[0]
num_heads, _, _, head_dim = pages.shape

Expand All @@ -64,7 +64,7 @@ def per_sequence_page_gather(pages, page_indices):
per_sequence_page_gather, in_dims=(None, 0))(pages, page_indices)
return gathered.reshape(batch_size, num_heads, -1, head_dim)

def _pagedattention_grouped_query_attention_reference(q, k, v, lengths):
def _pagedattention_grouped_query_attention_reference(self, q, k, v, lengths):
batch_size, num_heads, head_dim = q.shape
_, num_kv_heads, max_seq_len, _ = k.shape
assert k.shape == v.shape
Expand Down Expand Up @@ -541,18 +541,44 @@ def test_tpu_custom_call_pallas_wrap_paged_attention(self):
from jax.experimental.pallas.ops.tpu.paged_attention.paged_attention_kernel import paged_attention
from torch_xla.experimental.custom_kernel import make_kernel_from_pallas
paged_attention_kernel = make_kernel_from_pallas(
paged_attention, lambda q, k, v: [(q.shape, q.dtype)])

q_mini = torch.arange(128 * 4, dtype=torch.bfloat16).reshape(128, 4) / 13
k_mini = torch.arange(
1000, 1000 + 128 * 4, dtype=torch.bfloat16).reshape(128, 4) / 13
q = q_mini.broadcast_to(3, 2, 128, 4).to("xla")
k = k_mini.broadcast_to(3, 2, 128, 4).to("xla")
v = torch.ones(3, 2, 128, 4, dtype=torch.bfloat16).to("xla")

o = paged_attention_kernel(q, k, v)
expected_o = self._attention(q, k, v)
self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu()))
paged_attention, lambda q, k_pages, v_pages, lengths, page_indices,
pages_per_compute_block: [(q.shape, q.dtype)])

batch_size = 4
max_kv_len = 2048
block_size = 512
page_size = 16
num_kv_heads = 1
q_kv_head_ratio = 1
head_dim = 128
dtype = torch.float32
seq_lens = torch.tensor(
[max_kv_len // batch_size * (i + 1) for i in range(batch_size)])

q, k_pages, v_pages, page_indices = self._pagedattention_generate_qkv(
seq_lens,
page_size,
max_kv_len,
num_kv_heads,
num_kv_heads * q_kv_head_ratio,
head_dim,
dtype=dtype,
)
o = paged_attention_kernel(
q,
k_pages,
v_pages,
seq_lens,
page_indices,
pages_per_compute_block=block_size // page_size,
)
k = self._pagedattention_reconstruct_kv(page_indices, k_pages)
v = self._pagedattention_reconstruct_kv(page_indices, v_pages)

o_expected = self._pagedattention_grouped_query_attention_reference(
q, k, v, seq_lens)

self.assertTrue(torch.allclose(o.cpu(), o_expected.cpu()))


if __name__ == '__main__':
Expand Down
5 changes: 3 additions & 2 deletions torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,8 @@ def paged_attention_xla(q: torch.Tensor, k_pages: torch.Tensor,
v_pages: torch.Tensor, lengths: torch.Tensor,
page_indices: torch.Tensor,
pages_per_compute_block: int):
return flash_attention(q, k, v, causal=causal)
return paged_attention(q, k_pages, v_pages, lengths, page_indices,
pages_per_compute_block)


@impl(XLA_LIB, "paged_attention", "CompositeExplicitAutograd")
Expand All @@ -404,7 +405,7 @@ def paged_attention_non_xla(q: torch.Tensor, k_pages: torch.Tensor,
# We need to make sure output tensor's shape is correct.
if k.device != torch.device("meta"):
warnings.warn(
'XLA flash attention should only be applied to tensors on XLA device')
'XLA paged attention should only be applied to tensors on XLA device')

# perform a regular attention if input tensors are not on XLA device.
attn_weight = q @ k.transpose(-2, -1)
Expand Down

0 comments on commit 45d9fe3

Please sign in to comment.