-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Does this work for training? #1
Comments
Hi, thank you for trying out our package. We have rigorously tested the back-propagation of Flash attention to ensure it produces correct gradients. Some aspects of large-scale training have not yet been extensively tested, and we invite you to experiment with these. To clarify, when you set attention_type="flash", Triton is still utilized internally even if you set use_triton=False. This setting specifically disables other Triton-written kernels, such as those used for RMS normalization, but not Flash attention. For training, we recommend using FP32. Using FP16 might lead to floating-point overflow issues. BF16 is currently unsupported for Flash attention, primarily because the gradient calculations for the attention bias weights do not accommodate BF16. It's likely that Triton disables BF16 for atomic_add to prevent overflow issues as well. |
Oh ok. BF16 seemed to work for inference though in my testing. BF16 is supported in flash attention (if you see the FA readme). I think training with flash in fp32 would be slower than without and bf16 depending on sequence lengths. |
Our current realization of flash attention is written in Triton, which unfortunately has some limitations on using bf16. We used Triton to accelerate development and prove the concept. So, after some testing, we plan to contribute to the original Flash Attention work by adding attention bias calculation into a kernel. Yes, training in bf16 should be faster, but our main intent with this project was to eliminate the quadratic memory requirement for the attention mechanism. However, we also achieved a significant speedup because of the IO-aware algorithm, which is flash attention. |
Please correct me if I am wrong, but does flash attention support FP32? I thought it only supported FP16 (or BF16 in the sdp pytorch kernel?) |
Just trying it out and I presume use_triton=False for training. Have you tried it with training/
With attention_type="flash", use_triton=Falsen with bf16 I get
atomic_add does not support bf16 with triton.
with fp16 it hangs after one forward pass. I know training can be different for flash.
But, I think T5 can only be trained with bf16 anyway.
What mode would you recommend for training long contexts on novel task, different from LM.
The text was updated successfully, but these errors were encountered: