-
Notifications
You must be signed in to change notification settings - Fork 27.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Native support of torch.nn.functionnal.scaled_dot_product_attention
#26557
Comments
@younesbelkada, is there a reason why |
Hi @SimJeg ! pip install transformers optimum Then once you load the model call: model = model.to_bettertransformer() The goal in the future, as mentioned in the issue is to add a native support of SDPA |
Here is a WIP PR #26572 @SimJeg I think it is mostly about Transformers handling padding with a padding mask, which PyTorch SDPA used to not support (until recently) for the optimized paths. Having the code offloaded at first was probably a way to showcase that SDPA indeed works well and that a native integration is worth it! |
@younesbelkada @patrickvonplaten - Hi team, I was looking at the attention implementation in transformers for the various LLMs vs. the attention implementation in diffusers and am a bit confused by the use (or lack of use) with PyTorch SDPA. Is it correct that the transformers is not using PyTorch SDPA because it cannot not handle padded inputs? If so, how are we able to use Pytorch SDPA in diffusers without running into the same issues? My understanding is that padding isn't necessary for the self-attention layers of common text-to-image models like Stable Diffusion, but is likely being used in the cross-attention layers, since text prompts are of differing lengths. |
Is SDPA inference only, or could it be used during training as an alternative to something like Flash Attention or xformers for the folks who use ROCm? The FA2-ROCm is still a WIP and CDNA2 only. |
@xzuyn SDPA is a wrapper around xformers and Flash Attention kernels, so yes, it can be used for training as well (and is probably even more interesting there). Unfortunately, as far as my knowledge goes, FA is not upstreamed in PyTorch on RoCm systems as of PyTorch 2.1. I believe AMD folks are working towards that though, feel free to open an issue in PyTorch repo to track the progress. |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
not stale |
Feature request
PyTorch has released
torch.nn.functionnal.scaled_dot_product_attention
since its 2.0 version that supports more memory efficient attention computationOfficial documentation here. Currently three implementations are available in that method, making it possible to dispatch the SDPA kernel to
In addition to that, in the next versions, PyTorch will add support for Flash Attention 2: pytorch/pytorch#105602 that is already available in the PyTorch nightlies.
SDPA makes model inference faster and more memory efficient, and supports multiple hardwares (CPU, GPU, CUDA, AMD...)
Users can already benefit from SDPA through the
BetterTransformer
API of optimumAs SDPA is already quite stable and performant, we should migrate the
BetterTransformer
API to the native transformers codebase to support OTB model acceleration and memory efficiency.cc @LysandreJik @fxmarty
Motivation
Make LLMs faster, out of the box by just updating PyTorch version
Your contribution
Help implementing this in the next versions
The text was updated successfully, but these errors were encountered: