Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow FSDP to use with torch.autocast for bfloat16 mixed precision #2033

Merged
merged 3 commits into from
Oct 6, 2023

Conversation

brcps12
Copy link
Contributor

@brcps12 brcps12 commented Oct 5, 2023

What does this PR do?

FSDP supports mixed precision using MixedPrecision class, it does not need to wrap forward function with torch.autocast.

The code statement of ignoring this wrapping was added at accelerate v0.22.0, but now removed at v0.23.0

Related PRs are:

I can't find any information about why it is added or removed.

In fact, mixed precision works well even without torch.autocast, and even if it is needed, it does not work properly in the current version.

So, I think it need to apply one of the following two options:

  1. Add self.distributed_type != DistributedType.FSDP in condition not to use torch.autocast
  2. Add DistributedType.FSDP in this file

The reason for 2 is that when FSDP is used, the distributed_type field is replaced with DistribytedType.FSDP in this line, so I think it needs to be added to support FSDP as well.

As a related issue, the MPT posted on Huggingface Hub uses the LPNorm class, but when learning with FSDP + bfloat16, the dtype changes before and after norm. It is occurred in version v0.23.0.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@muellerzr muellerzr requested a review from pacman100 October 5, 2023 09:47
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@pacman100
Copy link
Contributor

FSDP supports mixed precision using MixedPrecision class, it does not need to wrap forward function with torch.autocast.

It is required else Mixed Precision FP16 fails with error RuntimeError: expected scalar type Half but found Float. See the failing tests because of which this was required: https://github.com/huggingface/accelerate/actions/runs/6079386828/job/16491799916

@pacman100
Copy link
Contributor

How about going with the option 2 you mentioned? Does that solve the issue with MPT?

@brcps12
Copy link
Contributor Author

brcps12 commented Oct 6, 2023

It is required else Mixed Precision FP16 fails with error RuntimeError: expected scalar type Half but found Float. See the failing tests because of which this was required: https://github.com/huggingface/accelerate/actions/runs/6079386828/job/16491799916

Looking at the problem, it seems that FSDP's MixedPrecision only supports between FSDP modules, not torch's operator (softmax -> matmul). Thank you for sharing!

How about going with the option 2 you mentioned? Does that solve the issue with MPT?

Yes. The issue has fixed. So, I'm gonna working with option 2 and changes PR title.

@brcps12 brcps12 changed the title Ignore torch.autocast for mixed precision when FSDP is used Fix mixed precision of bfloat16 when using FSDP Oct 6, 2023
@brcps12 brcps12 changed the title Fix mixed precision of bfloat16 when using FSDP Add FSDP allowed to wrap with torch.autocast for bfloat16 mixed precision Oct 6, 2023
@brcps12 brcps12 changed the title Add FSDP allowed to wrap with torch.autocast for bfloat16 mixed precision Allow FSDP to use with torch.autocast for bfloat16 mixed precision Oct 6, 2023
Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @brcps12 for fixing the bug wrt bf16 autocasting when using FSDP, LGTM! 🤗

@pacman100 pacman100 merged commit 5ae6111 into huggingface:main Oct 6, 2023
24 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants