From f32836ec6fb3863c6d96e8896fd2a3b94982bc51 Mon Sep 17 00:00:00 2001 From: Wonjoo Lee Date: Wed, 10 Apr 2024 21:09:43 +0000 Subject: [PATCH] Update unit tests and fix typos --- test/test_pallas.py | 131 +++++++++++++++++++----- torch_xla/experimental/custom_kernel.py | 39 +++++-- 2 files changed, 138 insertions(+), 32 deletions(-) diff --git a/test/test_pallas.py b/test/test_pallas.py index 1901cb721fb..4d7432b5068 100644 --- a/test/test_pallas.py +++ b/test/test_pallas.py @@ -27,8 +27,9 @@ def _attention(self, q, k, v): return attn_output # The following helper functions prefixed with _pagedattention are used for PagedAttention unit tests - # Reference: https://github.com/google/jax/blob/main/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py + # Reference: https://github.com/google/jax/blob/main/tests/pallas/paged_attention_kernel_test.py def _pagedattention_generate_qkv( + self, seq_lens, page_size, max_seq_len, @@ -37,45 +38,44 @@ def _pagedattention_generate_qkv( head_dim, dtype=torch.float32, ): - assert max_seq_len % page_size == 0 + # assert max_seq_len % page_size == 0 pages_per_sequence = max_seq_len // page_size batch_size = len(seq_lens) total_pages = batch_size * pages_per_sequence - k1, k2, k3, k4 = jax.random.split(prng_key, 4) k_pages = torch.randn( num_kv_heads, total_pages, page_size, head_dim, dtype=dtype) v_pages = torch.randn( num_kv_heads, total_pages, page_size, head_dim, dtype=dtype) page_indices = torch.randperm( - batch_size * pages_per_sequence, dtype=torch.int32) + batch_size * pages_per_sequence, dtype=torch.int64) page_indices = page_indices.reshape(batch_size, pages_per_sequence) q = torch.randn(batch_size, num_heads, head_dim, dtype=dtype) return q, k_pages, v_pages, page_indices - def _pagedattention_reconstruct_kv(page_indices, pages): + def _pagedattention_reconstruct_kv(self, page_indices, pages): batch_size = page_indices.shape[0] num_heads, _, _, head_dim = pages.shape def per_sequence_page_gather(pages, page_indices): - return torch.gather( - torch_pages, dim=1, index=torch_page_indices.unsqueeze(1)) + return torch.index_select(pages, dim=1, index=page_indices) gathered = torch.vmap( per_sequence_page_gather, in_dims=(None, 0))(pages, page_indices) + return gathered.reshape(batch_size, num_heads, -1, head_dim) - def _pagedattention_grouped_query_attention_reference(q, k, v, lengths): + def _pagedattention_grouped_query_attention_reference(self, q, k, v, lengths): batch_size, num_heads, head_dim = q.shape _, num_kv_heads, max_seq_len, _ = k.shape assert k.shape == v.shape assert num_heads % num_kv_heads == 0 q = q.reshape(batch_size, num_kv_heads, num_heads // num_kv_heads, head_dim) - logits = torch.einsum("bhgd, bhtd -> bhgt", q.float(), k.float()) + logits = torch.einsum("bhgd,bhtd->bhgt", q.float(), k.float()) mask = torch.arange(max_seq_len)[None, :] < lengths[:, None] - mask_value = -0.7 * torch.finfo(torch.float32).max - logits = logits.masked_fill(~mask, mask_value) + mask_value = -0.7 * float(torch.finfo(torch.float32).max) + logits = logits.masked_fill(mask[:, None, None, :], mask_value) weights = torch.softmax(logits, dim=-1) - o = torch.einsum("bhgt, bhtd -> bhgd", weights, v.to(weights.dtype)) + o = torch.einsum("bhgt,bhtd->bhgd", weights.to(v.dtype), v) return o.reshape(batch_size, num_heads, head_dim) @unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.") @@ -534,19 +534,104 @@ def test_flash_attention_backward(self): def test_tpu_custom_call_pallas_wrap_paged_attention(self): from jax.experimental.pallas.ops.tpu.paged_attention.paged_attention_kernel import paged_attention from torch_xla.experimental.custom_kernel import make_kernel_from_pallas - paged_attention_kernel = make_kernel_from_pallas( - paged_attention, lambda q, k, v: [(q.shape, q.dtype)]) - q_mini = torch.arange(128 * 4, dtype=torch.bfloat16).reshape(128, 4) / 13 - k_mini = torch.arange( - 1000, 1000 + 128 * 4, dtype=torch.bfloat16).reshape(128, 4) / 13 - q = q_mini.broadcast_to(3, 2, 128, 4).to("xla") - k = k_mini.broadcast_to(3, 2, 128, 4).to("xla") - v = torch.ones(3, 2, 128, 4, dtype=torch.bfloat16).to("xla") + def shape_dtype(q, *args): + return [(q.shape, q.dtype)] - o = paged_attention_kernel(q, k, v) - expected_o = self._attention(q, k, v) - self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu())) + paged_attention_kernel = make_kernel_from_pallas( + paged_attention, + shape_dtype, + static_argnames=['pages_per_compute_block']) + + batch_size = 4 + max_kv_len = 2048 + block_size = 512 + page_size = 16 + num_kv_heads = 1 + q_kv_head_ratio = 1 + head_dim = 128 + dtype = torch.float32 + seq_lens = torch.tensor( + [max_kv_len // batch_size * (i + 1) for i in range(batch_size)]) + + q, k_pages, v_pages, page_indices = self._pagedattention_generate_qkv( + seq_lens, + page_size, + max_kv_len, + num_kv_heads, + num_kv_heads * q_kv_head_ratio, + head_dim, + dtype=dtype, + ) + + q_xla = q.to("xla") + k_pages_xla = k_pages.to("xla") + v_pages_xla = v_pages.to("xla") + seq_lens_xla = seq_lens.to("xla") + page_indices_xla = page_indices.to("xla") + o = paged_attention_kernel( + q_xla, + k_pages_xla, + v_pages_xla, + seq_lens_xla, + page_indices_xla, + pages_per_compute_block=block_size // page_size, + ) + k = self._pagedattention_reconstruct_kv(page_indices, k_pages) + v = self._pagedattention_reconstruct_kv(page_indices, v_pages) + + o_expected = self._pagedattention_grouped_query_attention_reference( + q, k, v, seq_lens) + + self.assertEqual(o.shape, o_expected.shape) + + @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4, + "This test only works on TPUv4+.") + def test_paged_attention_wrapper(self): + from torch_xla.experimental.custom_kernel import paged_attention + + batch_size = 4 + max_kv_len = 2048 + block_size = 512 + page_size = 16 + num_kv_heads = 1 + q_kv_head_ratio = 1 + head_dim = 128 + dtype = torch.float32 + seq_lens = torch.tensor( + [max_kv_len // batch_size * (i + 1) for i in range(batch_size)]) + + q, k_pages, v_pages, page_indices = self._pagedattention_generate_qkv( + seq_lens, + page_size, + max_kv_len, + num_kv_heads, + num_kv_heads * q_kv_head_ratio, + head_dim, + dtype=dtype, + ) + + q_xla = q.to("xla") + k_pages_xla = k_pages.to("xla") + v_pages_xla = v_pages.to("xla") + seq_lens_xla = seq_lens.to("xla") + page_indices_xla = page_indices.to("xla") + + o = paged_attention( + q_xla, + k_pages_xla, + v_pages_xla, + seq_lens_xla, + page_indices_xla, + pages_per_compute_block=block_size // page_size, + ) + k = self._pagedattention_reconstruct_kv(page_indices, k_pages) + v = self._pagedattention_reconstruct_kv(page_indices, v_pages) + + o_expected = self._pagedattention_grouped_query_attention_reference( + q, k, v, seq_lens) + + self.assertEqual(o.shape, o_expected.shape) if __name__ == '__main__': diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index 6769b6401f5..6ae5c28c111 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -129,7 +129,10 @@ def convert_torch_dtype_to_jax(dtype: torch.dtype) -> jnp.dtype: return payload, tensor_args -def make_kernel_from_pallas(kernel: Callable, output_shape_dtype_fn: Callable): +def make_kernel_from_pallas(kernel: Callable, + output_shape_dtype_fn: Callable, + static_argnums: List[int] = None, + static_argnames: List[str] = None): # TODO: Maybe we can cache the payload for the same input. def wrapped_kernel(kernel: Callable, output_shape_dtype_fn: Callable, @@ -156,7 +159,12 @@ def wrapped_kernel(kernel: Callable, return outputs[0] return tuple(outputs) - return functools.partial(wrapped_kernel, kernel, output_shape_dtype_fn) + return functools.partial( + wrapped_kernel, + kernel, + output_shape_dtype_fn, + static_argnums=static_argnums, + static_argnames=static_argnames) class FlashAttention(torch.autograd.Function): @@ -339,19 +347,26 @@ def paged_attention(q, k_pages, v_pages, lengths, page_indices, from jax.experimental.pallas.ops.tpu.paged_attention.paged_attention_kernel import paged_attention # It returns the shape and type of o, l, m. - def shape_dtype(q, *arg): + def shape_dtype(q, *args): return [(q.shape, q.dtype)] - paged_attention_kernel = make_kernel_from_pallas(paged_attention, shape_dtype) - o = paged_attention_kernel(q, k_pages, v_pages, lengths, page_indices, - pages_per_compute_block) + paged_attention_kernel = make_kernel_from_pallas( + paged_attention, shape_dtype, static_argnames=['pages_per_compute_block']) + + o = paged_attention_kernel( + q, + k_pages, + v_pages, + lengths, + page_indices, + pages_per_compute_block=pages_per_compute_block, + ) return o XLA_LIB.define( "flash_attention(Tensor q, Tensor k, Tensor v, bool casual=False) -> Tensor", - "paged_attention(Tensor q, Tensor k_pages, Tensor v_pages, Tensor lengths, Tensor page_indices, int pages_per_compute_block) -> Tensor[]", ) @@ -381,12 +396,18 @@ def flash_attention_non_xla(q: torch.Tensor, return attn_output +XLA_LIB.define( + "paged_attention(Tensor q, Tensor k_pages, Tensor v_pages, Tensor lengths, Tensor page_indices, int pages_per_compute_block) -> Tensor", +) + + @impl(XLA_LIB, "paged_attention", "XLA") def paged_attention_xla(q: torch.Tensor, k_pages: torch.Tensor, v_pages: torch.Tensor, lengths: torch.Tensor, page_indices: torch.Tensor, pages_per_compute_block: int): - return flash_attention(q, k, v, causal=causal) + return paged_attention(q, k_pages, v_pages, lengths, page_indices, + pages_per_compute_block) @impl(XLA_LIB, "paged_attention", "CompositeExplicitAutograd") @@ -398,7 +419,7 @@ def paged_attention_non_xla(q: torch.Tensor, k_pages: torch.Tensor, # We need to make sure output tensor's shape is correct. if k.device != torch.device("meta"): warnings.warn( - 'XLA flash attention should only be applied to tensors on XLA device') + 'XLA paged attention should only be applied to tensors on XLA device') # perform a regular attention if input tensors are not on XLA device. attn_weight = q @ k.transpose(-2, -1)