Skip to content

Commit

Permalink
Add reshape in paged_attention
Browse files Browse the repository at this point in the history
  • Loading branch information
wonjoolee95 committed Apr 22, 2024
1 parent 9c1aa79 commit d3e2d2e
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 54 deletions.
78 changes: 33 additions & 45 deletions test/test_pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from torch_xla import runtime as xr
from torch_xla._internal import tpu

import numpy as np

if xr.device_type() == 'TPU':
from torch_xla.experimental.custom_kernel import jax_import_guard
jax_import_guard()
Expand Down Expand Up @@ -52,33 +54,6 @@ def _pagedattention_generate_qkv(
q = torch.randn(batch_size, num_heads, head_dim, dtype=dtype)
return q, k_pages, v_pages, page_indices

def _pagedattention_reconstruct_kv(self, page_indices, pages):
batch_size = page_indices.shape[0]
num_heads, _, _, head_dim = pages.shape

def per_sequence_page_gather(pages, page_indices):
page_indices_int64 = page_indices.to(torch.int64)
return torch.index_select(pages, dim=1, index=page_indices_int64)

gathered = torch.vmap(
per_sequence_page_gather, in_dims=(None, 0))(pages, page_indices)

return gathered.reshape(batch_size, num_heads, -1, head_dim)

def _pagedattention_grouped_query_attention_reference(self, q, k, v, lengths):
batch_size, num_heads, head_dim = q.shape
_, num_kv_heads, max_seq_len, _ = k.shape
assert k.shape == v.shape
assert num_heads % num_kv_heads == 0
q = q.reshape(batch_size, num_kv_heads, num_heads // num_kv_heads, head_dim)
logits = torch.einsum("bhgd,bhtd->bhgt", q.float(), k.float())
mask = torch.arange(max_seq_len)[None, :] < lengths[:, None]
mask_value = -0.7 * float(torch.finfo(torch.float32).max)
logits = logits.masked_fill(mask[:, None, None, :], mask_value)
weights = torch.softmax(logits, dim=-1)
o = torch.einsum("bhgt,bhtd->bhgd", weights.to(v.dtype), v)
return o.reshape(batch_size, num_heads, head_dim)

@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
def test_tpu_custom_call_pallas_add(self):
# This payload is generated by the following Pallas code:
Expand Down Expand Up @@ -533,19 +508,18 @@ def test_flash_attention_backward(self):
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
"This test only works on TPUv4+.")
def test_paged_attention_wrapper(self):
jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGHEST)
from torch_xla.experimental.custom_kernel import paged_attention
from jax.experimental.pallas.ops.tpu.paged_attention.paged_attention_kernel import paged_attention as jax_paged_attention

batch_size = 4
max_kv_len = 2048
block_size = 512
page_size = 16
num_kv_heads = 1
q_kv_head_ratio = 1
head_dim = 128
page_size = 64
num_kv_heads = 8
q_kv_head_ratio = 8
head_dim = 256
dtype = torch.float32
seq_lens = torch.tensor(
[max_kv_len // batch_size * (i + 1) for i in range(batch_size)],
dtype=torch.int32)
seq_lens = torch.tensor([0, 3, 256, 513, 1023, 2048], dtype=torch.int32)

q, k_pages, v_pages, page_indices = self._pagedattention_generate_qkv(
seq_lens,
Expand All @@ -554,7 +528,6 @@ def test_paged_attention_wrapper(self):
num_kv_heads,
num_kv_heads * q_kv_head_ratio,
head_dim,
dtype=dtype,
)

q_xla = q.to("xla")
Expand All @@ -563,22 +536,37 @@ def test_paged_attention_wrapper(self):
seq_lens_xla = seq_lens.to("xla")
page_indices_xla = page_indices.to("xla")

o = paged_attention(
outputs, _, _ = paged_attention(
q_xla,
k_pages_xla,
v_pages_xla,
seq_lens_xla,
page_indices_xla,
pages_per_compute_block=block_size // page_size,
)

k = self._pagedattention_reconstruct_kv(page_indices, k_pages)
v = self._pagedattention_reconstruct_kv(page_indices, v_pages)
o_expected = self._pagedattention_grouped_query_attention_reference(
q, k, v, seq_lens)

self.assertEqual(o[0].shape, o_expected.shape)
self.assertTrue(torch.allclose(o[0].cpu(), expected_output.cpu()))
batch_size, num_heads, head_dim = q_xla.shape
outputs = outputs.reshape(batch_size, num_heads, head_dim)

q_jax = jnp.array(q.numpy(), dtype=jnp.float32)
k_pages_jax = jnp.array(k_pages.numpy(), dtype=jnp.float32)
v_pages_jax = jnp.array(v_pages.numpy(), dtype=jnp.float32)
seq_lens_jax = jnp.array(seq_lens.numpy(), dtype=jnp.int32)
page_indices_jax = jnp.array(page_indices.numpy(), dtype=jnp.int32)
expected_outputs = torch.from_numpy(
np.array(
jax_paged_attention(
q_jax,
k_pages_jax,
v_pages_jax,
seq_lens_jax,
page_indices_jax,
pages_per_compute_block=block_size // page_size,
)))

self.assertTrue(
torch.allclose(
outputs.cpu(), expected_outputs.cpu(), atol=1e+1, rtol=1e+1))
jax.config.update('jax_default_matmul_precision', jax.lax.Precision.DEFAULT)


if __name__ == '__main__':
Expand Down
29 changes: 20 additions & 9 deletions torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,18 +357,29 @@ def paged_attention(q, k_pages, v_pages, lengths, page_indices,
static_argnames=["pages_per_compute_block"],
)

batch_size, num_heads, head_dim = q.shape
num_kv_heads, _, page_size, head_dim_k = k_pages.shape
batch_size_paged_indices, pages_per_sequence = page_indices.shape

if (num_heads // num_kv_heads) % 8 != 0:
q = q.reshape(batch_size, num_heads, 1, head_dim)

page_indices_reshaped = page_indices.reshape(-1)
buffer_index = torch.zeros((1,), dtype=torch.int32).to("xla")
step = torch.zeros((1,), dtype=torch.int32).to("xla")
output = torch_xla._XLAC._xla_tpu_custom_call([
lengths,
page_indices_reshaped,
buffer_index,
step,
q,
k_pages,
v_pages,
], payload, [q.shape], [q.dtype])
output_shape = torch.Size(list(q.shape[:-1]) + [1])

output = torch_xla._XLAC._xla_tpu_custom_call(
[
lengths,
page_indices_reshaped,
buffer_index,
step,
q,
k_pages,
v_pages,
], payload, [q.shape, output_shape, output_shape],
[q.dtype, torch.float32, torch.float32])

return output

Expand Down

0 comments on commit d3e2d2e

Please sign in to comment.