You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Currently I'm using Wav2Vec 2.0 models. Digging into the code, I can see that it manually computes multi-head attention (and actually it copies from Bart). Using F.multi_head_attention_forward() would enjoy the benefits of any new improvements PyTorch brings (e.g. Flash attention) without installing extra libraries to do the hacking (i.e. optimum). The current solution is to use HF optimum to convert the model, which calls a private PyTorch's method.
To take advantage of Flash attention, optimum is required to convert the model. I was quite surprised that using Flash attention is not the default behavior of HF models. By using F.multi_head_attention_forward(), the users can enjoy the best attention speedup by default. For advanced users, who will be able to dig into the code to figure out why Flash attention is not used, and figure out to use optimum, it will save debugging time. For beginner users, this provides the best speed without any prior knowledge. It will also save the trouble of installing an extra library and perform the conversion.
Some considerations:
In terms of availability, F.multi_head_attention_forward() has existed for a long time (it goes back to at least PyTorch 1.8, I haven't checked before that). I see from latest main branch that minimum PyTorch version is 1.9
In terms of weight-compatibility, F.multi_head_attention_forward() supports passing in separate q, k, v projection weights, but input projection weight must be packed together. We can keep the nn.Linear modules, and pass the parameters directly to F.multi_head_attention_forward(). q, k, v biases will need to be packed into a single tensor (probably with torch.cat()).
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
Feature request
Currently I'm using Wav2Vec 2.0 models. Digging into the code, I can see that it manually computes multi-head attention (and actually it copies from Bart). Using
F.multi_head_attention_forward()
would enjoy the benefits of any new improvements PyTorch brings (e.g. Flash attention) without installing extra libraries to do the hacking (i.e. optimum). The current solution is to use HF optimum to convert the model, which calls a private PyTorch's method.https://github.com/huggingface/optimum/blob/05d20df3e6602e26d01cf3994a108de5b097a719/optimum/bettertransformer/models/encoder_models.py#L1415
Motivation
To take advantage of Flash attention, optimum is required to convert the model. I was quite surprised that using Flash attention is not the default behavior of HF models. By using
F.multi_head_attention_forward()
, the users can enjoy the best attention speedup by default. For advanced users, who will be able to dig into the code to figure out why Flash attention is not used, and figure out to use optimum, it will save debugging time. For beginner users, this provides the best speed without any prior knowledge. It will also save the trouble of installing an extra library and perform the conversion.Some considerations:
F.multi_head_attention_forward()
has existed for a long time (it goes back to at least PyTorch 1.8, I haven't checked before that). I see from latest main branch that minimum PyTorch version is 1.9transformers/setup.py
Line 176 in 4d40109
F.multi_head_attention_forward()
supports passing in separate q, k, v projection weights, but input projection weight must be packed together. We can keep thenn.Linear
modules, and pass the parameters directly toF.multi_head_attention_forward()
. q, k, v biases will need to be packed into a single tensor (probably withtorch.cat()
).For more details, check https://github.com/pytorch/pytorch/blob/v1.9.0/torch/nn/functional.py#L4836
Your contribution
This seems like a big change, but I think it should be straight-forward. I'm happy to submit a PR if there are people to help me land this.
Do let me know other considerations from HF side that I'm not aware of. Thank you!
The text was updated successfully, but these errors were encountered: