diff --git a/internlm/core/naive_amp.py b/internlm/core/naive_amp.py index ffd413b0..bc1a697a 100644 --- a/internlm/core/naive_amp.py +++ b/internlm/core/naive_amp.py @@ -81,7 +81,7 @@ def _convert_to_fp16(self, input_: Any): def _convert_to_fp32(self, input_: Any): """Converts the input to fp32 if it is a Tensor of dtype float16.""" - if isinstance(input_, Tensor) and input_.dtype in (torch.float16, torch.bfloat16): + if isinstance(input_, Tensor) and input_.dtype in (torch.float16,): input_ = input_.float() return input_