diff --git a/test/test_pallas.py b/test/test_pallas.py index 0cba53ac8b9..38ceeaf520e 100644 --- a/test/test_pallas.py +++ b/test/test_pallas.py @@ -536,7 +536,7 @@ def test_paged_attention_wrapper(self): seq_lens_xla = seq_lens.to("xla") page_indices_xla = page_indices.to("xla") - outputs, _, _ = paged_attention( + output = paged_attention( q_xla, k_pages_xla, v_pages_xla, @@ -544,15 +544,13 @@ def test_paged_attention_wrapper(self): page_indices_xla, pages_per_compute_block=block_size // page_size, ) - batch_size, num_heads, head_dim = q_xla.shape - outputs = outputs.reshape(batch_size, num_heads, head_dim) q_jax = jnp.array(q.numpy(), dtype=jnp.float32) k_pages_jax = jnp.array(k_pages.numpy(), dtype=jnp.float32) v_pages_jax = jnp.array(v_pages.numpy(), dtype=jnp.float32) seq_lens_jax = jnp.array(seq_lens.numpy(), dtype=jnp.int32) page_indices_jax = jnp.array(page_indices.numpy(), dtype=jnp.int32) - expected_outputs = torch.from_numpy( + expected_output = torch.from_numpy( np.array( jax_paged_attention( q_jax, @@ -565,7 +563,87 @@ def test_paged_attention_wrapper(self): self.assertTrue( torch.allclose( - outputs.cpu(), expected_outputs.cpu(), atol=1e+1, rtol=1e+1)) + output.cpu()[seq_lens > 0], + expected_output.cpu()[seq_lens > 0], + atol=1e-1, + rtol=1e-1)) + jax.config.update('jax_default_matmul_precision', jax.lax.Precision.DEFAULT) + + @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4, + "This test only works on TPUv4+.") + def test_paged_attention_wrapper_with_dynamo(self): + jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGHEST) + from torch_xla.experimental.custom_kernel import paged_attention + from jax.experimental.pallas.ops.tpu.paged_attention.paged_attention_kernel import paged_attention as jax_paged_attention + + max_kv_len = 2048 + block_size = 512 + page_size = 64 + num_kv_heads = 8 + q_kv_head_ratio = 8 + head_dim = 256 + dtype = torch.float32 + seq_lens = torch.tensor([0, 3, 256, 513, 1023, 2048], dtype=torch.int32) + + q, k_pages, v_pages, page_indices = self._pagedattention_generate_qkv( + seq_lens, + page_size, + max_kv_len, + num_kv_heads, + num_kv_heads * q_kv_head_ratio, + head_dim, + ) + + q_xla = q.to("xla") + k_pages_xla = k_pages.to("xla") + v_pages_xla = v_pages.to("xla") + seq_lens_xla = seq_lens.to("xla") + page_indices_xla = page_indices.to("xla") + + def paged_attention_wrapper(q, k, v, seq_lens, page_indices, + pages_per_compute_block): + return paged_attention( + q_xla, + k_pages_xla, + v_pages_xla, + seq_lens_xla, + page_indices_xla, + pages_per_compute_block=block_size // page_size, + ) + + compiled_paged_attention = torch.compile( + paged_attention_wrapper, backend="openxla") + output = paged_attention_wrapper( + q_xla, + k_pages_xla, + v_pages_xla, + seq_lens_xla, + page_indices_xla, + pages_per_compute_block=block_size // page_size, + ) + + q_jax = jnp.array(q.numpy(), dtype=jnp.float32) + k_pages_jax = jnp.array(k_pages.numpy(), dtype=jnp.float32) + v_pages_jax = jnp.array(v_pages.numpy(), dtype=jnp.float32) + seq_lens_jax = jnp.array(seq_lens.numpy(), dtype=jnp.int32) + page_indices_jax = jnp.array(page_indices.numpy(), dtype=jnp.int32) + expected_output = torch.from_numpy( + np.array( + jax_paged_attention( + q_jax, + k_pages_jax, + v_pages_jax, + seq_lens_jax, + page_indices_jax, + pages_per_compute_block=block_size // page_size, + ))) + + self.assertTrue( + torch.allclose( + output.cpu()[seq_lens > 0], + expected_output.cpu()[seq_lens > 0], + atol=1e-1, + rtol=1e-1)) jax.config.update('jax_default_matmul_precision', jax.lax.Precision.DEFAULT) diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index 3bde331723e..b4bce9b76d2 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -369,7 +369,7 @@ def paged_attention(q, k_pages, v_pages, lengths, page_indices, step = torch.zeros((1,), dtype=torch.int32).to("xla") output_shape = torch.Size(list(q.shape[:-1]) + [1]) - output = torch_xla._XLAC._xla_tpu_custom_call( + output, _, _ = torch_xla._XLAC._xla_tpu_custom_call( [ lengths, page_indices_reshaped, @@ -381,7 +381,7 @@ def paged_attention(q, k_pages, v_pages, lengths, page_indices, ], payload, [q.shape, output_shape, output_shape], [q.dtype, torch.float32, torch.float32]) - return output + return output.reshape(batch_size, num_heads, head_dim) XLA_LIB.define(