diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 531d6942b7..1ece1bff75 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -251,15 +251,13 @@ def _save_checkpoint(self, state: State, logger: Logger): new_model_instance = type(original_model)( new_base_model_instance, original_model.peft_config[active_adapter]) + new_model_instance.to(dtype=self.dtype) else: # First create the model instance on meta device to avoid the # initialization cost. with init_empty_weights(): new_model_instance = type(original_model)(copied_config) - new_model_instance.to(dtype=self.dtype) - new_model_instance.load_state_dict(state_dict) - # Then load the state dict in with "assign" so that the state dict # is loaded properly even though the model is initially on meta device. new_model_instance.load_state_dict(state_dict, assign=True)