You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The current flash attention 2 integration is sub-optimal in performance because it requires unpadding and padding the activations on each layer. For example in llama implementation:
These small kernels for unpad/pad keep gpu waiting for cpu, as shown in the visible gaps between kernels in cuda stream.
I'll suggest unpadding the activations at the very beginning (right after word embeddings) and padding it back at the end (maybe before lm_head), and the gap should disappear.
Motivation
To eliminate performance overhead of flash attention 2.
Your contribution
I can write the code when I'm not busy. Maybe not now.
The text was updated successfully, but these errors were encountered:
Hi @li-plus
Thanks a lot for the suggestion ! @fxmarty tried the approach of pad / unpadd at the beginning of the model forward call here: younesbelkada#5 but the implementation ended up bloating the modeling code, therefore it has been decided to not move forward for that approach maybe we can revisit this cc @ArthurZucker
I think this could be revisited given that we have more flexibility with the cache and the attention layer as well, not bandwidth on my side but ready to review a PR so will label it as a good difficult issue!
Feature request
The current flash attention 2 integration is sub-optimal in performance because it requires unpadding and padding the activations on each layer. For example in llama implementation:
transformers/src/transformers/models/llama/modeling_llama.py
Lines 591 to 612 in 769a954
These small kernels for unpad/pad keep gpu waiting for cpu, as shown in the visible gaps between kernels in cuda stream.
I'll suggest unpadding the activations at the very beginning (right after word embeddings) and padding it back at the end (maybe before lm_head), and the gap should disappear.
Motivation
To eliminate performance overhead of flash attention 2.
Your contribution
I can write the code when I'm not busy. Maybe not now.
The text was updated successfully, but these errors were encountered: