Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Flash Attention 2] Performance improvement #28160

Open
li-plus opened this issue Dec 20, 2023 · 3 comments
Open

[Flash Attention 2] Performance improvement #28160

li-plus opened this issue Dec 20, 2023 · 3 comments

Comments

@li-plus
Copy link
Contributor

li-plus commented Dec 20, 2023

Feature request

The current flash attention 2 integration is sub-optimal in performance because it requires unpadding and padding the activations on each layer. For example in llama implementation:

batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
query_states, key_states, value_states, attention_mask, query_length
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
)
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)

These small kernels for unpad/pad keep gpu waiting for cpu, as shown in the visible gaps between kernels in cuda stream.

image

I'll suggest unpadding the activations at the very beginning (right after word embeddings) and padding it back at the end (maybe before lm_head), and the gap should disappear.

Motivation

To eliminate performance overhead of flash attention 2.

Your contribution

I can write the code when I'm not busy. Maybe not now.

@amyeroberts
Copy link
Collaborator

cc @ArthurZucker @younesbelkada

@younesbelkada
Copy link
Contributor

Hi @li-plus
Thanks a lot for the suggestion !
@fxmarty tried the approach of pad / unpadd at the beginning of the model forward call here: younesbelkada#5 but the implementation ended up bloating the modeling code, therefore it has been decided to not move forward for that approach maybe we can revisit this cc @ArthurZucker

@ArthurZucker
Copy link
Collaborator

I think this could be revisited given that we have more flexibility with the cache and the attention layer as well, not bandwidth on my side but ready to review a PR so will label it as a good difficult issue!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants