You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Indeed we should not always cast if the dtype is float32
Flash Attention supports only fp16 / bf16 as input dtype so we should always cast to half precision if the input gets silently casted to full precision (e.g. layer norm in Llama)
System Info
transformers
version: 4.33.1Who can help?
@younesbelkada
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
As we discussed in this thread: #25598 (comment)
The hidden states may be cast in float16 even if we are using bf16 mixed precision training.
transformers/src/transformers/models/llama/modeling_llama.py
Lines 485 to 487 in 78dd120
It may be difficult to figure out the correct data type if the model is loaded in 4/8-bit mode.
Expected behavior
The hidden states should be cast in Bfloat16 in bf16 training.
The text was updated successfully, but these errors were encountered: