Skip to content

Commit

Permalink
feat(naive_amp.py): not use fp32 output
Browse files Browse the repository at this point in the history
  • Loading branch information
huangting4201 committed Dec 25, 2023
1 parent ec262ce commit 053a13c
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion internlm/core/naive_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def _convert_to_fp16(self, input_: Any):

def _convert_to_fp32(self, input_: Any):
"""Converts the input to fp32 if it is a Tensor of dtype float16."""
if isinstance(input_, Tensor) and input_.dtype in (torch.float16, torch.bfloat16):
if isinstance(input_, Tensor) and input_.dtype in (torch.float16,):
input_ = input_.float()
return input_

Expand Down

0 comments on commit 053a13c

Please sign in to comment.