Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix fp16 by converting outputs back to FP32 #134

Merged
merged 3 commits into from
Aug 5, 2021
Merged

Fix fp16 by converting outputs back to FP32 #134

merged 3 commits into from
Aug 5, 2021

Conversation

sgugger
Copy link
Collaborator

@sgugger sgugger commented Aug 5, 2021

As was pointed out in #101, there is currently a bug in mixed precision training: the outputs are properly computed in mixed precision but are returned in FP16, and the loss computation is not inside a torch.cuda.amp.autocast context manager so is executed in full FP16, which is generally unstable (especially for softmax).

This was not discovered with Transformers models as the loss is computed inside the model, which is generally a better idea if one wants to use mixed precision with Accelerate.

To fix the problem, this PR:

  • automatically converts the outputs of the model to FP32 so that the loss is executed in full precision (generally a better idea)
  • adds a context manager accelerator.autocast for more complex loss functions that should be executed in mixed precision (as with all things accelerate, this context manager always work, it just does nothing if FP16 is not activated).

@sgugger sgugger requested a review from LysandreJik August 5, 2021 08:17
Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Cool workaround on autocast. LGTM!

src/accelerate/accelerator.py Outdated Show resolved Hide resolved
@sgugger sgugger merged commit c8c9314 into main Aug 5, 2021
@sgugger sgugger deleted the fix_fp16 branch August 5, 2021 16:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants