Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FA-2] Final fix for FA2 dtype #26846

Merged
merged 6 commits into from
Oct 18, 2023

Conversation

younesbelkada
Copy link
Contributor

@younesbelkada younesbelkada commented Oct 16, 2023

What does this PR do?

Replaces #26560
Fixes #26451

Proposes a simpler fix for dealing with FA-2 + PEFT + quantization fine-tuning where users usually cast all other modules (e.g. LayerNorms) in fp32 for training stability.

With #26761 being introduced, it is now much simpler to retrieve model's original dtype, note also that self.config._pre_quantization_dtype remains the single source of truth as to is not supported for quantized models

cc @ArthurZucker @pacman100

Added also a nice test

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, think we can simplify a bit and remove the warning ?

Comment on lines 417 to 421
logger.warning_once(
"The input hidden states seems to be silently casted in float32, this might be related to"
" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
" float16."
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can remove this now no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I think we need to keep it to inform users about that

src/transformers/models/falcon/modeling_falcon.py Outdated Show resolved Hide resolved
@younesbelkada younesbelkada merged commit 5a73316 into huggingface:main Oct 18, 2023
@younesbelkada younesbelkada deleted the fa-2-final-fix branch October 18, 2023 21:13
EduardoPach pushed a commit to EduardoPach/transformers that referenced this pull request Nov 19, 2023
* final fix for FA2 dtype

* try

* oops

* Update src/transformers/models/falcon/modeling_falcon.py

Co-authored-by: Arthur <[email protected]>

* apply fix everywhere

---------

Co-authored-by: Arthur <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

The hidden states in LlamaFlashAttention2 are cast in fp16 unexpectedly
3 participants