Skip to content

Commit

Permalink
Allow FSDP to use with torch.autocast for bfloat16 mixed precision (#…
Browse files Browse the repository at this point in the history
…2033)

* Ignore native_amp when FSDP is used

* Rollback condition

* Fix mixed precision of bfloat16 for FSDP
  • Loading branch information
brcps12 authored Oct 6, 2023
1 parent 230a5f5 commit 5ae6111
Showing 1 changed file with 1 addition and 0 deletions.
1 change: 1 addition & 0 deletions src/accelerate/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1458,6 +1458,7 @@ def get_mixed_precision_context_manager(native_amp: bool = False, autocast_kwarg
DistributedType.MULTI_GPU,
DistributedType.MULTI_NPU,
DistributedType.MULTI_XPU,
DistributedType.FSDP,
]:
return torch.autocast(device_type=state.device.type, dtype=torch.bfloat16, **autocast_kwargs)
else:
Expand Down

0 comments on commit 5ae6111

Please sign in to comment.