diff --git a/test/test_pallas.py b/test/test_pallas.py index fd8dcffe6ac..6ae9adec021 100644 --- a/test/test_pallas.py +++ b/test/test_pallas.py @@ -541,18 +541,38 @@ def test_tpu_custom_call_pallas_wrap_paged_attention(self): from jax.experimental.pallas.ops.tpu.paged_attention.paged_attention_kernel import paged_attention from torch_xla.experimental.custom_kernel import make_kernel_from_pallas paged_attention_kernel = make_kernel_from_pallas( - paged_attention, lambda q, k, v: [(q.shape, q.dtype)]) - - q_mini = torch.arange(128 * 4, dtype=torch.bfloat16).reshape(128, 4) / 13 - k_mini = torch.arange( - 1000, 1000 + 128 * 4, dtype=torch.bfloat16).reshape(128, 4) / 13 - q = q_mini.broadcast_to(3, 2, 128, 4).to("xla") - k = k_mini.broadcast_to(3, 2, 128, 4).to("xla") - v = torch.ones(3, 2, 128, 4, dtype=torch.bfloat16).to("xla") - - o = paged_attention_kernel(q, k, v) - expected_o = self._attention(q, k, v) - self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu())) + paged_attention, lambda q, k_pages, v_pages, lengths, page_indices, + pages_per_compute_block: [(q.shape, q.dtype)]) + + batch_size = 4 + max_kv_len = 2048 + block_size = 512 + seq_lens = torch.tensor( + [max_kv_len // batch_size * (i + 1) for i in range(batch_size)]) + q, k_pages, v_pages, page_indices = _pagedattention_generate_qkv( + seq_lens, + page_size, + max_kv_len, + num_kv_heads, + num_kv_heads * q_kv_head_ratio, + head_dim, + dtype, + ) + o = paged_attention_kernel( + q, + k_pages, + v_pages, + seq_lens, + page_indices, + pages_per_compute_block=block_size // page_size, + ) + k = _pagedattention_reconstruct_kv(page_indices, k_pages) + v = _pagedattention_reconstruct_kv(page_indices, v_pages) + + o_expected = _pagedattention_grouped_query_attention_reference( + q, k, v, seq_lens) + + self.assertTrue(torch.allclose(o.cpu(), o_ref.cpu())) if __name__ == '__main__': diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index e236f48b8bb..7a783acf30a 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -392,7 +392,8 @@ def paged_attention_xla(q: torch.Tensor, k_pages: torch.Tensor, v_pages: torch.Tensor, lengths: torch.Tensor, page_indices: torch.Tensor, pages_per_compute_block: int): - return flash_attention(q, k, v, causal=causal) + return paged_attention(q, k_pages, v_pages, lengths, page_indices, + pages_per_compute_block) @impl(XLA_LIB, "paged_attention", "CompositeExplicitAutograd")