Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update non_xla attention to properly support paged_attention dynamo code path #7022

Merged
merged 2 commits into from
May 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions test/test_pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,18 +590,19 @@ def test_paged_attention_wrapper_with_dynamo(self):

def paged_attention_wrapper(q, k, v, seq_lens, page_indices,
pages_per_compute_block):
return paged_attention(
q_xla,
k_pages_xla,
v_pages_xla,
seq_lens_xla,
page_indices_xla,
pages_per_compute_block=block_size // page_size,
return torch.ops.xla.paged_attention(
q,
k,
v,
seq_lens,
page_indices,
pages_per_compute_block=pages_per_compute_block,
)

compiled_paged_attention = torch.compile(
paged_attention_wrapper, backend="openxla")
output = paged_attention_wrapper(

output = compiled_paged_attention(
q_xla,
k_pages_xla,
v_pages_xla,
Expand Down
9 changes: 3 additions & 6 deletions torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,11 +490,8 @@ def non_xla_attetion(q, k, v, attention_type):
f'XLA {attention_type} attention should only be applied to tensors on XLA device'
)

# perform a regular attention if input tensors are not on XLA device.
attn_weight = q @ k.transpose(-2, -1)
attn_weight = torch.nn.functional.softmax(attn_weight, dim=-1)
attn_output = attn_weight @ v
return attn_output
# Return orignal shape of q.
return torch.empty_like(q)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know if this actually initialize anything? I hope not.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea it is worth checking out, in above we have a warning about the q should be on meta device, I think running ops on meta_tensor will not allocate any device memory.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to https://pytorch.org/docs/stable/generated/torch.empty_like.html, it seems like emtpy_like returns an uninitialized tensor. Along with the meta tensor check above, I think we should be good.



XLA_LIB.define(
Expand Down Expand Up @@ -537,4 +534,4 @@ def paged_attention_non_xla(q: torch.Tensor, k_pages: torch.Tensor,
v_pages: torch.Tensor, lengths: torch.Tensor,
page_indices: torch.Tensor,
pages_per_compute_block: int):
return non_xla_attetion(q, k, v, "paged")
return non_xla_attetion(q, k_pages, v_pages, "paged")
Loading