Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Does this work for training? #1

Open
jamesharrisivi opened this issue Apr 28, 2024 · 4 comments
Open

Does this work for training? #1

jamesharrisivi opened this issue Apr 28, 2024 · 4 comments

Comments

@jamesharrisivi
Copy link

Just trying it out and I presume use_triton=False for training. Have you tried it with training/

With attention_type="flash", use_triton=Falsen with bf16 I get atomic_add does not support bf16 with triton.

with fp16 it hangs after one forward pass. I know training can be different for flash.
But, I think T5 can only be trained with bf16 anyway.

What mode would you recommend for training long contexts on novel task, different from LM.

@Ingvarstep
Copy link
Contributor

Hi, thank you for trying out our package. We have rigorously tested the back-propagation of Flash attention to ensure it produces correct gradients. Some aspects of large-scale training have not yet been extensively tested, and we invite you to experiment with these.

To clarify, when you set attention_type="flash", Triton is still utilized internally even if you set use_triton=False. This setting specifically disables other Triton-written kernels, such as those used for RMS normalization, but not Flash attention.

For training, we recommend using FP32. Using FP16 might lead to floating-point overflow issues. BF16 is currently unsupported for Flash attention, primarily because the gradient calculations for the attention bias weights do not accommodate BF16. It's likely that Triton disables BF16 for atomic_add to prevent overflow issues as well.

@jamesharrisivi
Copy link
Author

jamesharrisivi commented Apr 29, 2024

Oh ok. BF16 seemed to work for inference though in my testing.

BF16 is supported in flash attention (if you see the FA readme).

I think training with flash in fp32 would be slower than without and bf16 depending on sequence lengths.

@Ingvarstep
Copy link
Contributor

Our current realization of flash attention is written in Triton, which unfortunately has some limitations on using bf16. We used Triton to accelerate development and prove the concept. So, after some testing, we plan to contribute to the original Flash Attention work by adding attention bias calculation into a kernel.

Yes, training in bf16 should be faster, but our main intent with this project was to eliminate the quadratic memory requirement for the attention mechanism. However, we also achieved a significant speedup because of the IO-aware algorithm, which is flash attention.

@karioth
Copy link

karioth commented Jan 10, 2025

For training, we recommend using FP32. Using FP16 might lead to floating-point overflow issues. BF16 is currently unsupported for Flash attention, primarily because the gradient calculations for the attention bias weights do not accommodate BF16. It's likely that Triton disables BF16 for atomic_add to prevent overflow issues as well.

Please correct me if I am wrong, but does flash attention support FP32? I thought it only supported FP16 (or BF16 in the sdp pytorch kernel?)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants