Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[misc] use out argument for flash attention #10822

Merged
merged 7 commits into from
Dec 2, 2024

Conversation

youkaichao
Copy link
Member

@youkaichao youkaichao commented Dec 2, 2024

replace #9740 since that is very old.

To make mypy happy, all attention forward signature needs to match. so i add output: Optional[torch.Tensor] = None, for all attention.

Signed-off-by: youkaichao <[email protected]>
Copy link

github-actions bot commented Dec 2, 2024

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we keep the current implementation? Basically, I'd like to keep some ops (e.g., query.view) in CUDA graph region, since it causes high CPU overheads.

You can test the perf impact of this PR by VLLM_USE_V1=1 python benchmarks/benchmark_latency.py.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed in 0899730

Copy link
Collaborator

@WoosukKwon WoosukKwon Dec 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the fix. However, I think we need a more general solution than just moving the view ops out. Fundamentally, the boundary of the attention forward method does not necessary align with the boundary of CUDA graphs. For example, we might want to additionally move the reshape_and_cache_flash op into the CUDA graph region, for better performance. I believe we need to preserve the flexibility in the current main branch, where I can freely customize the boundary of CUDA graphs within the attention layer. WDYT?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you can move reshape_and_cache_flash into cudagraph, it needs slot mapping as input, which is hidden from torch.compile .

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the boundary of piecewise cudagraph needs to be aligned with pytorch custom op. it is not about the boundary of the attention layer's forward method.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the boundary of piecewise cudagraph needs to be aligned with pytorch custom op. it is not about the boundary of the attention layer's forward method.

Yes, I wanted to say that the two boundaries should be decoupled. IIRC, the PR "essentially" enforce that the attention forward method be the pytorch custom op.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the attention metadata object needs to be hidden from torch.compile , this is a basic requirement of torch.compile integration.

the fewer arguments for the attention op, the lower overhead you get. this is the optimization. and in this pr, I already use the minimum arguments for the attention op.

if you have any operation that does not depend on attention metadata, but is specific to v1 attention, please let me know, and I can make a separate custom op for it. if no, i'd like to unify them to reduce the maintenance cost.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually, the only requirement, is that the caller of the custom op should not see anything from the attention metadata.

even if you have some operations specific to v1 attention, we can add it to the attention layer, like if self.use_v1: do_something

I want to limit the custom op call inside the attention layer, so that the code in attention backend becomes more intuitive. I get several complaints previously, that the attention implementation calls a torch.ops.vllm , and then it directs to another function below, which becomes very confusing for people who don't understand torch.compile .

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. Makes sense. Thanks for the explanation.

Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
@youkaichao youkaichao enabled auto-merge (squash) December 2, 2024 08:57
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Dec 2, 2024
Signed-off-by: youkaichao <[email protected]>
@youkaichao youkaichao merged commit a4c4daf into vllm-project:main Dec 2, 2024
47 checks passed
afeldman-nm pushed a commit to neuralmagic/vllm that referenced this pull request Dec 2, 2024
@youkaichao youkaichao deleted the use_out branch December 2, 2024 17:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants