Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update non_xla attention to properly support paged_attention dynamo code path #7022

Merged
merged 2 commits into from
May 3, 2024

Conversation

wonjoolee95
Copy link
Collaborator

@wonjoolee95 wonjoolee95 commented May 2, 2024

  • Update non_xla attention to properly support paged_attention dynamo code path
  • Fix the original broken dynamo unit tests with paged_attention

Test plan:

root@1fdc3324aeef:/pytorch/xla# python test/test_pallas.py PallasTest.test_paged_attention_wrapper_with_dynamo
.
----------------------------------------------------------------------
Ran 1 test in 1.798s

OK

+ TPU CI

Copy link
Collaborator

@alanwaketan alanwaketan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

attn_output = attn_weight @ v
return attn_output
# Return orignal shape of q.
return torch.empty_like(q)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know if this actually initialize anything? I hope not.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea it is worth checking out, in above we have a warning about the q should be on meta device, I think running ops on meta_tensor will not allocate any device memory.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to https://pytorch.org/docs/stable/generated/torch.empty_like.html, it seems like emtpy_like returns an uninitialized tensor. Along with the meta tensor check above, I think we should be good.

@wonjoolee95
Copy link
Collaborator Author

wonjoolee95 commented May 3, 2024

Thanks for the reviews, I'll go ahead and merge this as the CIs (including TPU CI) are all green.

@wonjoolee95 wonjoolee95 merged commit 2bce3f8 into master May 3, 2024
21 of 22 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants