From 1cd9adb948bf7433acaa8b7f7a40a258e09f3405 Mon Sep 17 00:00:00 2001 From: Wonjoo Lee Date: Thu, 2 May 2024 23:23:37 +0000 Subject: [PATCH 1/2] Update non_xla attention to properly support paged_attention dynamo code path --- test/test_pallas.py | 17 +++++++++-------- torch_xla/experimental/custom_kernel.py | 10 +++------- 2 files changed, 12 insertions(+), 15 deletions(-) diff --git a/test/test_pallas.py b/test/test_pallas.py index 7b8755fc71e..8901a84c80a 100644 --- a/test/test_pallas.py +++ b/test/test_pallas.py @@ -590,18 +590,19 @@ def test_paged_attention_wrapper_with_dynamo(self): def paged_attention_wrapper(q, k, v, seq_lens, page_indices, pages_per_compute_block): - return paged_attention( - q_xla, - k_pages_xla, - v_pages_xla, - seq_lens_xla, - page_indices_xla, - pages_per_compute_block=block_size // page_size, + return torch.ops.xla.paged_attention( + q, + k, + v, + seq_lens, + page_indices, + pages_per_compute_block=pages_per_compute_block, ) compiled_paged_attention = torch.compile( paged_attention_wrapper, backend="openxla") - output = paged_attention_wrapper( + + output = compiled_paged_attention( q_xla, k_pages_xla, v_pages_xla, diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index 9bd050efc29..6bdea95f3c2 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -482,7 +482,7 @@ def paged_attention(q, k_pages, v_pages, lengths, page_indices, return output.reshape(batch_size, num_heads, head_dim).to(q.dtype) -def non_xla_attetion(q, k, v, attention_type): +def non_xla_attetion(q, k, v): # This will be called when dynamo use fake tensor to construct the fake output. # We need to make sure output tensor's shape is correct. if k.device != torch.device("meta"): @@ -490,11 +490,7 @@ def non_xla_attetion(q, k, v, attention_type): f'XLA {attention_type} attention should only be applied to tensors on XLA device' ) - # perform a regular attention if input tensors are not on XLA device. - attn_weight = q @ k.transpose(-2, -1) - attn_weight = torch.nn.functional.softmax(attn_weight, dim=-1) - attn_output = attn_weight @ v - return attn_output + return torch.empty_like(q) XLA_LIB.define( @@ -537,4 +533,4 @@ def paged_attention_non_xla(q: torch.Tensor, k_pages: torch.Tensor, v_pages: torch.Tensor, lengths: torch.Tensor, page_indices: torch.Tensor, pages_per_compute_block: int): - return non_xla_attetion(q, k, v, "paged") + return non_xla_attetion(q, k_pages, v_pages) From 9d75282a43014439410dafd0f29cf0ddd72abfb3 Mon Sep 17 00:00:00 2001 From: Wonjoo Lee Date: Thu, 2 May 2024 23:27:40 +0000 Subject: [PATCH 2/2] Run linter --- torch_xla/experimental/custom_kernel.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index 6bdea95f3c2..484002124d5 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -482,7 +482,7 @@ def paged_attention(q, k_pages, v_pages, lengths, page_indices, return output.reshape(batch_size, num_heads, head_dim).to(q.dtype) -def non_xla_attetion(q, k, v): +def non_xla_attetion(q, k, v, attention_type): # This will be called when dynamo use fake tensor to construct the fake output. # We need to make sure output tensor's shape is correct. if k.device != torch.device("meta"): @@ -490,6 +490,7 @@ def non_xla_attetion(q, k, v): f'XLA {attention_type} attention should only be applied to tensors on XLA device' ) + # Return orignal shape of q. return torch.empty_like(q) @@ -533,4 +534,4 @@ def paged_attention_non_xla(q: torch.Tensor, k_pages: torch.Tensor, v_pages: torch.Tensor, lengths: torch.Tensor, page_indices: torch.Tensor, pages_per_compute_block: int): - return non_xla_attetion(q, k_pages, v_pages) + return non_xla_attetion(q, k_pages, v_pages, "paged")