From a97cf662e39cf34c5c81d7be8777faf0cd33b893 Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Mon, 5 Feb 2024 07:46:17 -0800 Subject: [PATCH] Remove extra call to .to and load_state_dict in hf checkpointer (#939) --- llmfoundry/callbacks/hf_checkpointer.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 531d6942b7..1ece1bff75 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -251,15 +251,13 @@ def _save_checkpoint(self, state: State, logger: Logger): new_model_instance = type(original_model)( new_base_model_instance, original_model.peft_config[active_adapter]) + new_model_instance.to(dtype=self.dtype) else: # First create the model instance on meta device to avoid the # initialization cost. with init_empty_weights(): new_model_instance = type(original_model)(copied_config) - new_model_instance.to(dtype=self.dtype) - new_model_instance.load_state_dict(state_dict) - # Then load the state dict in with "assign" so that the state dict # is loaded properly even though the model is initially on meta device. new_model_instance.load_state_dict(state_dict, assign=True)