Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use F.multi_head_attention_forward() to take advantage of PyTorch's Flash attention #25704

Closed
gau-nernst opened this issue Aug 24, 2023 · 2 comments

Comments

@gau-nernst
Copy link
Contributor

Feature request

Currently I'm using Wav2Vec 2.0 models. Digging into the code, I can see that it manually computes multi-head attention (and actually it copies from Bart). Using F.multi_head_attention_forward() would enjoy the benefits of any new improvements PyTorch brings (e.g. Flash attention) without installing extra libraries to do the hacking (i.e. optimum). The current solution is to use HF optimum to convert the model, which calls a private PyTorch's method.

https://github.com/huggingface/optimum/blob/05d20df3e6602e26d01cf3994a108de5b097a719/optimum/bettertransformer/models/encoder_models.py#L1415

Motivation

To take advantage of Flash attention, optimum is required to convert the model. I was quite surprised that using Flash attention is not the default behavior of HF models. By using F.multi_head_attention_forward(), the users can enjoy the best attention speedup by default. For advanced users, who will be able to dig into the code to figure out why Flash attention is not used, and figure out to use optimum, it will save debugging time. For beginner users, this provides the best speed without any prior knowledge. It will also save the trouble of installing an extra library and perform the conversion.

Some considerations:

  • In terms of availability, F.multi_head_attention_forward() has existed for a long time (it goes back to at least PyTorch 1.8, I haven't checked before that). I see from latest main branch that minimum PyTorch version is 1.9
    "torch>=1.9,!=1.12.0",
    , so this function is definitely available.
  • In terms of weight-compatibility, F.multi_head_attention_forward() supports passing in separate q, k, v projection weights, but input projection weight must be packed together. We can keep the nn.Linear modules, and pass the parameters directly to F.multi_head_attention_forward(). q, k, v biases will need to be packed into a single tensor (probably with torch.cat()).

For more details, check https://github.com/pytorch/pytorch/blob/v1.9.0/torch/nn/functional.py#L4836

Your contribution

This seems like a big change, but I think it should be straight-forward. I'm happy to submit a PR if there are people to help me land this.

Do let me know other considerations from HF side that I'm not aware of. Thank you!

@ArthurZucker
Copy link
Collaborator

Something is being cooked up in #25598 😉

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@gau-nernst gau-nernst closed this as not planned Won't fix, can't repro, duplicate, stale Sep 30, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants