Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The hidden states in LlamaFlashAttention2 are cast in fp16 unexpectedly #26451

Closed
2 of 4 tasks
hiyouga opened this issue Sep 27, 2023 · 2 comments · Fixed by #26846
Closed
2 of 4 tasks

The hidden states in LlamaFlashAttention2 are cast in fp16 unexpectedly #26451

hiyouga opened this issue Sep 27, 2023 · 2 comments · Fixed by #26846
Assignees

Comments

@hiyouga
Copy link
Contributor

hiyouga commented Sep 27, 2023

System Info

  • transformers version: 4.33.1
  • Platform: Linux-5.4.0-147-generic-x86_64-with-glibc2.31
  • Python version: 3.10.13
  • Huggingface_hub version: 0.17.1
  • Safetensors version: 0.3.3
  • Accelerate version: 0.23.0
  • PyTorch version (GPU?): 2.0.1+cu118 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: A100 40GB
  • Using distributed or parallel set-up in script?: No

Who can help?

@younesbelkada

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

As we discussed in this thread: #25598 (comment)

The hidden states may be cast in float16 even if we are using bf16 mixed precision training.

query_states = query_states.to(torch.float16)
key_states = key_states.to(torch.float16)
value_states = value_states.to(torch.float16)

It may be difficult to figure out the correct data type if the model is loaded in 4/8-bit mode.

Expected behavior

The hidden states should be cast in Bfloat16 in bf16 training.

@ArthurZucker
Copy link
Collaborator

Indeed we should not always cast if the dtype is float32

FYI @younesbelkada

@younesbelkada
Copy link
Contributor

Thanks @hiyouga this makes sense

Indeed we should not always cast if the dtype is float32

Flash Attention supports only fp16 / bf16 as input dtype so we should always cast to half precision if the input gets silently casted to full precision (e.g. layer norm in Llama)

I will work on it and let you know !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
3 participants