-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix FA2 tests #29909
Fix FA2 tests #29909
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AH. That's a great catch. Thanks for it!
model = model_class.from_pretrained( | ||
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" | ||
) | ||
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's update the name to test_flash_attn_2_inference_equivalence
or something like that!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will do!
On a side note, how to make sure that every model using FA2 still passes ? The tests are slow, so I'm not actually sure the CI is totally green ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You'll need to run the tests manually. You can select just the flash attention tests by doing something like:
RUN_SLOW=1 pytest tests/models -k "flash_attn"
on a GPU setup
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good spot - thanks for fixing!
model = model_class.from_pretrained( | ||
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" | ||
) | ||
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You'll need to run the tests manually. You can select just the flash attention tests by doing something like:
RUN_SLOW=1 pytest tests/models -k "flash_attn"
on a GPU setup
I've ran I'll open an issue to keep trace of the different failures. Should I still merge the PR in the meantime?
|
@ylacombe Thanks for running and sharing the results! Merging depends on whether the same tests are failing on main, if they are, then merging is fine; if not, the tests will need to be fixed :) |
Testing this right now then ! |
Well, the same tests fail except qwen2 and stablelm that are introduced by this PR, but this makes sense since the FA2 tests were'nt actually testing FA2 |
Feel free to mege! |
😨😨😨😨😨 |
Thanks a lot ❤️ for the fix and great catch! One nit: It would be really nice 🙏 if you can mention, in the PR description, a bit why the previous testing is done improperly. Something as simple as
This way, it's super clear what the PR is doing even before diving into the changes. |
afaik many FA2 tests were already failing (they are not in the CI) due to diffs in logits |
@fxmarty I think we or you (?) have run those tests before merging. Do you know why we have many failing FA2 tests? Or those many failing tests are only for newly added (many) models ..? |
Oh, they are not run on T4 GPUs. |
@ydshieh When I used to run these tests locally (some months ago), it was because the diff tolerance was too low between eager/fa2. Some models (as whisper) somehow require a large diff tolerance |
* fix FA2 tests * refactor inference test name
* fix FA2 tests * refactor inference test name
What does this PR do?
#26572 introduced an artifact that avoid properly testing inference with Flash Attention 2, the model supposed to be loaded without Flash Attention 2 (as a reference to compare) was in fact using Flash Attention 2!
cc @fxmarty @ArthurZucker @amyeroberts