diff --git a/test/test_pallas.py b/test/test_pallas.py index dce444a4d7c..0cba53ac8b9 100644 --- a/test/test_pallas.py +++ b/test/test_pallas.py @@ -10,6 +10,8 @@ from torch_xla import runtime as xr from torch_xla._internal import tpu +import numpy as np + if xr.device_type() == 'TPU': from torch_xla.experimental.custom_kernel import jax_import_guard jax_import_guard() @@ -52,33 +54,6 @@ def _pagedattention_generate_qkv( q = torch.randn(batch_size, num_heads, head_dim, dtype=dtype) return q, k_pages, v_pages, page_indices - def _pagedattention_reconstruct_kv(self, page_indices, pages): - batch_size = page_indices.shape[0] - num_heads, _, _, head_dim = pages.shape - - def per_sequence_page_gather(pages, page_indices): - page_indices_int64 = page_indices.to(torch.int64) - return torch.index_select(pages, dim=1, index=page_indices_int64) - - gathered = torch.vmap( - per_sequence_page_gather, in_dims=(None, 0))(pages, page_indices) - - return gathered.reshape(batch_size, num_heads, -1, head_dim) - - def _pagedattention_grouped_query_attention_reference(self, q, k, v, lengths): - batch_size, num_heads, head_dim = q.shape - _, num_kv_heads, max_seq_len, _ = k.shape - assert k.shape == v.shape - assert num_heads % num_kv_heads == 0 - q = q.reshape(batch_size, num_kv_heads, num_heads // num_kv_heads, head_dim) - logits = torch.einsum("bhgd,bhtd->bhgt", q.float(), k.float()) - mask = torch.arange(max_seq_len)[None, :] < lengths[:, None] - mask_value = -0.7 * float(torch.finfo(torch.float32).max) - logits = logits.masked_fill(mask[:, None, None, :], mask_value) - weights = torch.softmax(logits, dim=-1) - o = torch.einsum("bhgt,bhtd->bhgd", weights.to(v.dtype), v) - return o.reshape(batch_size, num_heads, head_dim) - @unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.") def test_tpu_custom_call_pallas_add(self): # This payload is generated by the following Pallas code: @@ -533,19 +508,18 @@ def test_flash_attention_backward(self): @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4, "This test only works on TPUv4+.") def test_paged_attention_wrapper(self): + jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGHEST) from torch_xla.experimental.custom_kernel import paged_attention + from jax.experimental.pallas.ops.tpu.paged_attention.paged_attention_kernel import paged_attention as jax_paged_attention - batch_size = 4 max_kv_len = 2048 block_size = 512 - page_size = 16 - num_kv_heads = 1 - q_kv_head_ratio = 1 - head_dim = 128 + page_size = 64 + num_kv_heads = 8 + q_kv_head_ratio = 8 + head_dim = 256 dtype = torch.float32 - seq_lens = torch.tensor( - [max_kv_len // batch_size * (i + 1) for i in range(batch_size)], - dtype=torch.int32) + seq_lens = torch.tensor([0, 3, 256, 513, 1023, 2048], dtype=torch.int32) q, k_pages, v_pages, page_indices = self._pagedattention_generate_qkv( seq_lens, @@ -554,7 +528,6 @@ def test_paged_attention_wrapper(self): num_kv_heads, num_kv_heads * q_kv_head_ratio, head_dim, - dtype=dtype, ) q_xla = q.to("xla") @@ -563,7 +536,7 @@ def test_paged_attention_wrapper(self): seq_lens_xla = seq_lens.to("xla") page_indices_xla = page_indices.to("xla") - o = paged_attention( + outputs, _, _ = paged_attention( q_xla, k_pages_xla, v_pages_xla, @@ -571,14 +544,29 @@ def test_paged_attention_wrapper(self): page_indices_xla, pages_per_compute_block=block_size // page_size, ) - - k = self._pagedattention_reconstruct_kv(page_indices, k_pages) - v = self._pagedattention_reconstruct_kv(page_indices, v_pages) - o_expected = self._pagedattention_grouped_query_attention_reference( - q, k, v, seq_lens) - - self.assertEqual(o[0].shape, o_expected.shape) - self.assertTrue(torch.allclose(o[0].cpu(), expected_output.cpu())) + batch_size, num_heads, head_dim = q_xla.shape + outputs = outputs.reshape(batch_size, num_heads, head_dim) + + q_jax = jnp.array(q.numpy(), dtype=jnp.float32) + k_pages_jax = jnp.array(k_pages.numpy(), dtype=jnp.float32) + v_pages_jax = jnp.array(v_pages.numpy(), dtype=jnp.float32) + seq_lens_jax = jnp.array(seq_lens.numpy(), dtype=jnp.int32) + page_indices_jax = jnp.array(page_indices.numpy(), dtype=jnp.int32) + expected_outputs = torch.from_numpy( + np.array( + jax_paged_attention( + q_jax, + k_pages_jax, + v_pages_jax, + seq_lens_jax, + page_indices_jax, + pages_per_compute_block=block_size // page_size, + ))) + + self.assertTrue( + torch.allclose( + outputs.cpu(), expected_outputs.cpu(), atol=1e+1, rtol=1e+1)) + jax.config.update('jax_default_matmul_precision', jax.lax.Precision.DEFAULT) if __name__ == '__main__': diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index b94cbb3a4ae..3bde331723e 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -357,18 +357,29 @@ def paged_attention(q, k_pages, v_pages, lengths, page_indices, static_argnames=["pages_per_compute_block"], ) + batch_size, num_heads, head_dim = q.shape + num_kv_heads, _, page_size, head_dim_k = k_pages.shape + batch_size_paged_indices, pages_per_sequence = page_indices.shape + + if (num_heads // num_kv_heads) % 8 != 0: + q = q.reshape(batch_size, num_heads, 1, head_dim) + page_indices_reshaped = page_indices.reshape(-1) buffer_index = torch.zeros((1,), dtype=torch.int32).to("xla") step = torch.zeros((1,), dtype=torch.int32).to("xla") - output = torch_xla._XLAC._xla_tpu_custom_call([ - lengths, - page_indices_reshaped, - buffer_index, - step, - q, - k_pages, - v_pages, - ], payload, [q.shape], [q.dtype]) + output_shape = torch.Size(list(q.shape[:-1]) + [1]) + + output = torch_xla._XLAC._xla_tpu_custom_call( + [ + lengths, + page_indices_reshaped, + buffer_index, + step, + q, + k_pages, + v_pages, + ], payload, [q.shape, output_shape, output_shape], + [q.dtype, torch.float32, torch.float32]) return output