-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add flash attention for gpt_bigcode
#26479
Conversation
All tests are all passing except the Error Traceback========================================================== FAILURES =========================================================== __________________________________ GPTBigCodeModelTest.test_flash_attn_2_generate_use_cache ___________________________________self = <tests.models.gpt_bigcode.test_modeling_gpt_bigcode.GPTBigCodeModelTest testMethod=test_flash_attn_2_generate_use_cache>
tests/test_modeling_common.py:2936: ../../anaconda3/envs/transformers/lib/python3.9/site-packages/torch/autograd/grad_mode.py:27: in decorate_context ctx = <torch.autograd.function.IndexPutFirstAxisBackward object at 0x7fea8a486400>
E IndexError: tensors used as indices must be long, byte or bool tensors ../../anaconda3/envs/transformers/lib/python3.9/site-packages/flash_attn/bert_padding.py:51: IndexError self = <tests.models.gpt_bigcode.test_modeling_gpt_bigcode.GPTBigCodeMHAModelTest testMethod=test_flash_attn_2_generate_use_cache>
tests/test_modeling_common.py:2936: ../../anaconda3/envs/transformers/lib/python3.9/site-packages/torch/autograd/grad_mode.py:27: in decorate_context ctx = <torch.autograd.function.IndexPutFirstAxisBackward object at 0x7fea872f76d0>
E IndexError: tensors used as indices must be long, byte or bool tensors ../../anaconda3/envs/transformers/lib/python3.9/site-packages/flash_attn/bert_padding.py:51: IndexError At this point, I think the failing of the tests might be caused due to some unknown error on my machine rather than the code since I have this test failing for all If you don't mind @younesbelkada , could you please checkout this branch and run the tests on your end and let me know the results? |
Again, thanks a lot for your amazing contribution @susnato ! |
Thanks a lot @younesbelkada! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM I can confirm FA tests pass and I have also tried a generation locally and it seems to work great! Thanks for your huge work ! 🙏
docs/source/en/perf_infer_gpu_one.md
Outdated
@@ -33,6 +33,7 @@ We natively support Flash Attention 2 for the following models: | |||
|
|||
- Llama | |||
- Falcon | |||
- GPTBigCode |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- GPTBigCode | |
- GPTBigCode (Starcoder) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
input_dtype = query.dtype | ||
if input_dtype == torch.float32: | ||
logger.warning_once( | ||
"The input hidden states seems to be silently casted in float32, this might be related to" | ||
" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" | ||
" float16." | ||
) | ||
query = query.to(torch.float16) | ||
key = key.to(torch.float16) | ||
value = value.to(torch.float16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be fixed in the global fix I want to apply in #26451 as a follow up PR that I will take care
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should I then remove this block? or are we keeping this block for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would say we can keep it for now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's wait until your PR is merge @younesbelkada 😉
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey! Great work here, would be great going forward to have an estimatate of what kind of speedup we are expecting (and maybe add it to the readme here) or just benchmark that we indeed have a speedup!
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
I would love to do it but I don't really have access to a high quality GPU to show the full performance of |
@susnato I will run benchmarks for you :) |
Thanks @younesbelkada! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! If you can update the Readme following this would be great!
docs/source/en/perf_infer_gpu_one.md
Outdated
@@ -33,6 +33,7 @@ We natively support Flash Attention 2 for the following models: | |||
|
|||
- Llama | |||
- Falcon | |||
- GPTBigCode (Starcoder) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment as for Llama, let's add a section in the readme of starcoder about expected gains!
Thanks for updating the readme of |
Hi @susnato for me the changes look great, I will let @ArthurZucker give a final pass and merge! |
input_dtype = query.dtype | ||
if input_dtype == torch.float32: | ||
logger.warning_once( | ||
"The input hidden states seems to be silently casted in float32, this might be related to" | ||
" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" | ||
" float16." | ||
) | ||
query = query.to(torch.float16) | ||
key = key.to(torch.float16) | ||
value = value.to(torch.float16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's wait until your PR is merge @younesbelkada 😉
Co-authored-by: Arthur <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry to block again!
padding_mask: Optional[torch.LongTensor] = None, | ||
encoder_padding_mask: Optional[torch.LongTensor] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should wait again for the padding mask refactor and not pass padding mask! #26792
@younesbelkada @ArthurZucker i was trying to run the blog tutorial code here: https://github.com/pacman100/DHS-LLM-Workshop/blob/main/personal_copilot/training/train.py I just pulled in the main line and getting this error. This was working back in August, but now I am getting errors related to the Flash Attention 2.0 error message above in this issue. Why was this working back in August now it is broken? Is there a pinned version if transformers I should be using? |
encoder_hidden_states: Optional[torch.Tensor] = None, | ||
encoder_attention_mask: Optional[torch.Tensor] = None, | ||
use_cache: Optional[bool] = False, | ||
padding_mask: Optional[torch.LongTensor] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There should be no more padding_mask
now since #26792 has been merged, let me know if you want to handle this, otherwise I can have a look!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hey I am currently working to fix this.
Hello @younesbelkada, I have pushed the commit to remove Please let me know if any more changes are needed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks very much for your great work! LGTM with only 1 nit!
input_dtype = query.dtype | ||
if input_dtype == torch.float32: | ||
# Handle the case where the model is quantized | ||
target_dtype = getattr(self.config, "_pre_quantization_dtype", self.c_attn.weight.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you replace that with a hasattr
logic ? the reason behind it is presented here: #26916
Hi @younesbelkada, done! |
@susnato @younesbelkada , sorry to keep bugging you guys here in this PR, but can you please tell me how folks were able to run with Flash Attention back in August using this tutorial from @pacman100 When I run with the flash attention flag, I keep getting:
|
Hi @cmosguy, could you please try and checkout this branch ( Right now, StarCoder does not support |
I just checked out your branch and ran
|
Hi @cmosguy, could you please run Or check the flash attention version on your system? |
Hey @susnato - I had no idea that I had to install that additonal package. Shouldn't that have been installed with the Anyways, I repeated the commands you mentioned: pip freeze | grep flash
flash-attn==2.3.3 So it looks like it is loading:
I am assuming at this point the But then it does this:
|
Hey @cmosguy, I will checkout the script Wednesday and get back to you if I find a solution :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for all of the work adding this!
Just a few small nits
@@ -34,6 +34,7 @@ We natively support Flash Attention 2 for the following models: | |||
- Llama | |||
- Mistral | |||
- Falcon | |||
- [GPTBigCode (Starcoder)](model_doc/gpt_bigcode#) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we link to the model page here but not for e.g. falcon or llama?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey, if I remember it correctly, it was asked by @ArthurZucker , here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK - it doesn't really matter so I'm happy to leave it and find out when Arthur's back what that reasoning is :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok then leaving it as it is.
Hi @amyeroberts, I have pushed the changes. |
Hello @cmosguy, the dropout error is fixed now on the main branch. Could you please install again from the main branch and re-run your script? It should work now. |
Hey @susnato Thanks for your efforts here, yes I was able to see the training start with the lines:
Thanks for your help here. If you do not mind me asking, what happened back in August? The tutorial I mentioned with script had things working before, but you made a lot of edits that indicates this is just now being added. Was it there before then removed? I guess I am trying to understand, because I may be interested in swapping in other models and I cannot fully comprehend when flash attention can be used or not on which model. |
Hello @cmosguy, I am glad that you were able to start training :).
As far as I know Flash Attention is just added for this model. But then Flash Attention were added to Also please note that the flash attention code from the tutorial didn't handle I hope this explanation helps, otherwise feel free to tag me if you need any help. :)
For every model that supports from transformers import AutoModel
model = AutoModel.from_pretrained("bigcode/starcoderbase-1b")
model._supports_flash_attn_2 # will give output as True since it has support for Flash Attention from transformers import AutoModel
model = AutoModel.from_pretrained("gpt2")
model._supports_flash_attn_2 # will give output as False since it does not has support for Flash Attention (For now). |
@susnato wow I am learning so much hanging out here. Thank you for walking me through what happened this really explains a lot. I appreciate you taking the time to investigate the issue and writing a coherent explanation. I totally did not see the commit from before, thank you for bringing this to my attention (pun intended). OK, so I am off to the races here with the training in the meantime, cheers! |
Hey @cmosguy, it was fun for me too! cheers! |
* added flash attention of gpt_bigcode * changed docs * Update src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py * add FA-2 docs * oops * Update docs/source/en/perf_infer_gpu_one.md Last Nit Co-authored-by: Arthur <[email protected]> * fix * oops * remove padding_mask * change getattr->hasattr logic * changed .md file --------- Co-authored-by: Younes Belkada <[email protected]> Co-authored-by: younesbelkada <[email protected]> Co-authored-by: Arthur <[email protected]>
What does this PR do?
Adds Flash Attention 2 for
GPTBigCode (Starcoder)
as discussed in in this issue - #26350 .Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
cc : @younesbelkada