Skip to content

Commit

Permalink
Update unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
wonjoolee95 committed Apr 22, 2024
1 parent d3e2d2e commit 312bef1
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 7 deletions.
88 changes: 83 additions & 5 deletions test/test_pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,23 +536,21 @@ def test_paged_attention_wrapper(self):
seq_lens_xla = seq_lens.to("xla")
page_indices_xla = page_indices.to("xla")

outputs, _, _ = paged_attention(
output = paged_attention(
q_xla,
k_pages_xla,
v_pages_xla,
seq_lens_xla,
page_indices_xla,
pages_per_compute_block=block_size // page_size,
)
batch_size, num_heads, head_dim = q_xla.shape
outputs = outputs.reshape(batch_size, num_heads, head_dim)

q_jax = jnp.array(q.numpy(), dtype=jnp.float32)
k_pages_jax = jnp.array(k_pages.numpy(), dtype=jnp.float32)
v_pages_jax = jnp.array(v_pages.numpy(), dtype=jnp.float32)
seq_lens_jax = jnp.array(seq_lens.numpy(), dtype=jnp.int32)
page_indices_jax = jnp.array(page_indices.numpy(), dtype=jnp.int32)
expected_outputs = torch.from_numpy(
expected_output = torch.from_numpy(
np.array(
jax_paged_attention(
q_jax,
Expand All @@ -565,7 +563,87 @@ def test_paged_attention_wrapper(self):

self.assertTrue(
torch.allclose(
outputs.cpu(), expected_outputs.cpu(), atol=1e+1, rtol=1e+1))
output.cpu()[seq_lens > 0],
expected_output.cpu()[seq_lens > 0],
atol=1e-1,
rtol=1e-1))
jax.config.update('jax_default_matmul_precision', jax.lax.Precision.DEFAULT)

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
"This test only works on TPUv4+.")
def test_paged_attention_wrapper_with_dynamo(self):
jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGHEST)
from torch_xla.experimental.custom_kernel import paged_attention
from jax.experimental.pallas.ops.tpu.paged_attention.paged_attention_kernel import paged_attention as jax_paged_attention

max_kv_len = 2048
block_size = 512
page_size = 64
num_kv_heads = 8
q_kv_head_ratio = 8
head_dim = 256
dtype = torch.float32
seq_lens = torch.tensor([0, 3, 256, 513, 1023, 2048], dtype=torch.int32)

q, k_pages, v_pages, page_indices = self._pagedattention_generate_qkv(
seq_lens,
page_size,
max_kv_len,
num_kv_heads,
num_kv_heads * q_kv_head_ratio,
head_dim,
)

q_xla = q.to("xla")
k_pages_xla = k_pages.to("xla")
v_pages_xla = v_pages.to("xla")
seq_lens_xla = seq_lens.to("xla")
page_indices_xla = page_indices.to("xla")

def paged_attention_wrapper(q, k, v, seq_lens, page_indices,
pages_per_compute_block):
return paged_attention(
q_xla,
k_pages_xla,
v_pages_xla,
seq_lens_xla,
page_indices_xla,
pages_per_compute_block=block_size // page_size,
)

compiled_paged_attention = torch.compile(
paged_attention_wrapper, backend="openxla")
output = paged_attention_wrapper(
q_xla,
k_pages_xla,
v_pages_xla,
seq_lens_xla,
page_indices_xla,
pages_per_compute_block=block_size // page_size,
)

q_jax = jnp.array(q.numpy(), dtype=jnp.float32)
k_pages_jax = jnp.array(k_pages.numpy(), dtype=jnp.float32)
v_pages_jax = jnp.array(v_pages.numpy(), dtype=jnp.float32)
seq_lens_jax = jnp.array(seq_lens.numpy(), dtype=jnp.int32)
page_indices_jax = jnp.array(page_indices.numpy(), dtype=jnp.int32)
expected_output = torch.from_numpy(
np.array(
jax_paged_attention(
q_jax,
k_pages_jax,
v_pages_jax,
seq_lens_jax,
page_indices_jax,
pages_per_compute_block=block_size // page_size,
)))

self.assertTrue(
torch.allclose(
output.cpu()[seq_lens > 0],
expected_output.cpu()[seq_lens > 0],
atol=1e-1,
rtol=1e-1))
jax.config.update('jax_default_matmul_precision', jax.lax.Precision.DEFAULT)


Expand Down
4 changes: 2 additions & 2 deletions torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def paged_attention(q, k_pages, v_pages, lengths, page_indices,
step = torch.zeros((1,), dtype=torch.int32).to("xla")
output_shape = torch.Size(list(q.shape[:-1]) + [1])

output = torch_xla._XLAC._xla_tpu_custom_call(
output, _, _ = torch_xla._XLAC._xla_tpu_custom_call(
[
lengths,
page_indices_reshaped,
Expand All @@ -381,7 +381,7 @@ def paged_attention(q, k_pages, v_pages, lengths, page_indices,
], payload, [q.shape, output_shape, output_shape],
[q.dtype, torch.float32, torch.float32])

return output
return output.reshape(batch_size, num_heads, head_dim)


XLA_LIB.define(
Expand Down

0 comments on commit 312bef1

Please sign in to comment.