diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index be312190df..bf14dafc90 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -430,7 +430,7 @@ def dtensor_to_tensor_hook( if dist.get_global_rank() == 0: if cpu_offload: tensor = tensor.cpu() - state_dict[fqn] = tensor + state_dict[fqn] = tensor.to(dtype=self.dtype) if dist.get_global_rank() != 0: for fqn in dtensor_fqns: del state_dict[fqn] @@ -463,10 +463,10 @@ def dtensor_to_tensor_hook( with state_dict_context: state_dict = state_dict_model.state_dict() - # Convert the state dict to the requested precis - for k, v in state_dict.items(): - if isinstance(v, torch.Tensor): - state_dict[k] = v.to(dtype=self.dtype) + # # Convert the state dict to the requested precis + # for k, v in state_dict.items(): + # if isinstance(v, torch.Tensor): + # state_dict[k] = v.to(dtype=self.dtype) new_model_instance = None # Need this for pyright because variable could be unbound