Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FSDP grad accum fix #34645

Merged
merged 2 commits into from
Nov 15, 2024
Merged

Conversation

winglian
Copy link
Contributor

@winglian winglian commented Nov 7, 2024

What does this PR do?

8c62a92#diff-ed55888e6665791fe92cc8fc0c499da54f4ace6738551cd9a2591881cda076deL2473-R2480 seems to have broken FSDP training as we need to wrap the train step in no_sync on every step except the last step in the gradient accumulation batch. Without this, during the optimizer.step(), I get the error:

RuntimeError: Tensors of the same index must be on the same device and the same dtype except `step` tensors that can be CPU and float32/64 notwithstanding

this wasn't caught by the fsdp tests since gradient accumulation wasn't part of those tests, so I also added tests for that as well.

see https://muellerzr.github.io/blog/gradient_accumulation.html#what-is-the-right-way-then

Screenshot 2024-11-07 at 12 17 36 PM

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@winglian
Copy link
Contributor Author

winglian commented Nov 7, 2024

@ArthurZucker @muellerzr for review

Copy link
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching this!

@winglian
Copy link
Contributor Author

winglian commented Nov 7, 2024

is it possible to also trigger the slow tests for this PR too to make sure the new test passes?

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sense ! Thanks for fixing !

@SunMarc SunMarc requested a review from ydshieh November 15, 2024 16:04
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@ydshieh ydshieh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you.

Newly added tests are passing (in this PR) and failed on main. So it does its job.

The change in trainer seems reasonable to me too.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for being late on this, thanks 🤗 will patch

@ArthurZucker ArthurZucker merged commit b0c0ba7 into huggingface:main Nov 15, 2024
26 checks passed
2015aroras pushed a commit to 2015aroras/transformers that referenced this pull request Nov 15, 2024
* add gradient accumulation steps tests for fsdp

* invert no_sync context to fix training for fsdp
ArthurZucker pushed a commit that referenced this pull request Nov 18, 2024
* add gradient accumulation steps tests for fsdp

* invert no_sync context to fix training for fsdp
BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
* add gradient accumulation steps tests for fsdp

* invert no_sync context to fix training for fsdp
BernardZach pushed a commit to innovationcore/transformers that referenced this pull request Dec 6, 2024
* add gradient accumulation steps tests for fsdp

* invert no_sync context to fix training for fsdp
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants