diff --git a/test/test_pallas.py b/test/test_pallas.py index f23bf4cc9c2..1901cb721fb 100644 --- a/test/test_pallas.py +++ b/test/test_pallas.py @@ -26,25 +26,28 @@ def _attention(self, q, k, v): attn_output = attn_weight @ v return attn_output - # The following helper functions prefixed with _pagedattention are used to help test PagedAttention + # The following helper functions prefixed with _pagedattention are used for PagedAttention unit tests # Reference: https://github.com/google/jax/blob/main/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py def _pagedattention_generate_qkv( - seq_lens, - page_size, - max_seq_len, - num_kv_heads, - num_heads, - head_dim, - dtype = torch.float32, + seq_lens, + page_size, + max_seq_len, + num_kv_heads, + num_heads, + head_dim, + dtype=torch.float32, ): assert max_seq_len % page_size == 0 pages_per_sequence = max_seq_len // page_size batch_size = len(seq_lens) total_pages = batch_size * pages_per_sequence k1, k2, k3, k4 = jax.random.split(prng_key, 4) - k_pages = torch.randn(num_kv_heads, total_pages, page_size, head_dim, dtype=dtype) - v_pages = torch.randn(num_kv_heads, total_pages, page_size, head_dim, dtype=dtype) - page_indices = torch.randperm(batch_size * pages_per_sequence, dtype=torch.int32) + k_pages = torch.randn( + num_kv_heads, total_pages, page_size, head_dim, dtype=dtype) + v_pages = torch.randn( + num_kv_heads, total_pages, page_size, head_dim, dtype=dtype) + page_indices = torch.randperm( + batch_size * pages_per_sequence, dtype=torch.int32) page_indices = page_indices.reshape(batch_size, pages_per_sequence) q = torch.randn(batch_size, num_heads, head_dim, dtype=dtype) return q, k_pages, v_pages, page_indices @@ -54,11 +57,11 @@ def _pagedattention_reconstruct_kv(page_indices, pages): num_heads, _, _, head_dim = pages.shape def per_sequence_page_gather(pages, page_indices): - return torch.gather(torch_pages, dim=1, index=torch_page_indices.unsqueeze(1)) + return torch.gather( + torch_pages, dim=1, index=torch_page_indices.unsqueeze(1)) - gathered = torch.vmap(per_sequence_page_gather, in_dims=(None, 0))( - pages, page_indices - ) + gathered = torch.vmap( + per_sequence_page_gather, in_dims=(None, 0))(pages, page_indices) return gathered.reshape(batch_size, num_heads, -1, head_dim) def _pagedattention_grouped_query_attention_reference(q, k, v, lengths): @@ -67,10 +70,10 @@ def _pagedattention_grouped_query_attention_reference(q, k, v, lengths): assert k.shape == v.shape assert num_heads % num_kv_heads == 0 q = q.reshape(batch_size, num_kv_heads, num_heads // num_kv_heads, head_dim) - logits = torch.einsum("bhgd, bhtd -> bhgt", q.float(), k.float()) + logits = torch.einsum("bhgd, bhtd -> bhgt", q.float(), k.float()) mask = torch.arange(max_seq_len)[None, :] < lengths[:, None] mask_value = -0.7 * torch.finfo(torch.float32).max - logits = logits.masked_fill(~mask, mask_value) + logits = logits.masked_fill(~mask, mask_value) weights = torch.softmax(logits, dim=-1) o = torch.einsum("bhgt, bhtd -> bhgd", weights, v.to(weights.dtype)) return o.reshape(batch_size, num_heads, head_dim) diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index 97edd049fbb..6769b6401f5 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -331,20 +331,22 @@ def flash_attention( return FlashAttention.apply(q, k, v, causal) -def paged_attention(Tensor q, Tensor k_pages, Tensor v_pages, Tensor lengths, Tensor page_indices, int pages_per_compute_block): - # Import JAX within the function such that we don't need to call the jax_import_guard() - # in the global scope which could cause problems for xmp.spawn. - jax_import_guard() - from jax.experimental.pallas.ops.tpu.paged_attention.paged_attention_kernel import paged_attention +def paged_attention(q, k_pages, v_pages, lengths, page_indices, + pages_per_compute_block): + # Import JAX within the function such that we don't need to call the jax_import_guard() + # in the global scope which could cause problems for xmp.spawn. + jax_import_guard() + from jax.experimental.pallas.ops.tpu.paged_attention.paged_attention_kernel import paged_attention - # It returns the shape and type of o, l, m. - def shape_dtype(q, *arg): - return [(q.shape, q.dtype)] + # It returns the shape and type of o, l, m. + def shape_dtype(q, *arg): + return [(q.shape, q.dtype)] - paged_attention_kernel = make_kernel_from_pallas(paged_attention, shape_dtype) - o = paged_attention_kernel(q, k_pages, v_pages, lengths, page_indices, pages_per_compute_block) + paged_attention_kernel = make_kernel_from_pallas(paged_attention, shape_dtype) + o = paged_attention_kernel(q, k_pages, v_pages, lengths, page_indices, + pages_per_compute_block) - return o + return o XLA_LIB.define( @@ -380,12 +382,18 @@ def flash_attention_non_xla(q: torch.Tensor, @impl(XLA_LIB, "paged_attention", "XLA") -def paged_attention_xla(Tensor q, Tensor k_pages, Tensor v_pages, Tensor lengths, Tensor page_indices, int pages_per_compute_block): +def paged_attention_xla(q: torch.Tensor, k_pages: torch.Tensor, + v_pages: torch.Tensor, lengths: torch.Tensor, + page_indices: torch.Tensor, + pages_per_compute_block: int): return flash_attention(q, k, v, causal=causal) @impl(XLA_LIB, "paged_attention", "CompositeExplicitAutograd") -def paged_attention_non_xla(Tensor q, Tensor k_pages, Tensor v_pages, Tensor lengths, Tensor page_indices, int pages_per_compute_block): +def paged_attention_non_xla(q: torch.Tensor, k_pages: torch.Tensor, + v_pages: torch.Tensor, lengths: torch.Tensor, + page_indices: torch.Tensor, + pages_per_compute_block: int): # This will be called when dynamo use fake tensor to construct the fake output. # We need to make sure output tensor's shape is correct. if k.device != torch.device("meta"):