From 10b1ea9d555bfb14ad124a1e23647fe8790273aa Mon Sep 17 00:00:00 2001 From: Saaketh Date: Mon, 29 Jul 2024 11:44:24 -0400 Subject: [PATCH] yo --- llmfoundry/callbacks/hf_checkpointer.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index b0286d81d1..303208ed6b 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -218,7 +218,8 @@ def __init__( self.mlflow_logging_config = mlflow_logging_config self.pretrained_model_name = self.mlflow_logging_config['metadata'].get( - 'pretrained_model_name', None + 'pretrained_model_name', + None, ) self.huggingface_folder_name_fstr = os.path.join( @@ -520,8 +521,6 @@ def tensor_hook( new_model_instance.generation_config.update( **original_model.generation_config.to_dict(), ) - if self.pretrained_model_name is not None: - new_model_instance.name_or_path = self.pretrained_model_name # Then load the state dict in with "assign" so that the state dict # is loaded properly even though the model is initially on meta device. @@ -534,6 +533,15 @@ def tensor_hook( original_tokenizer, ) + # Ensure that the pretrained model name is correctly set on the saved HF checkpoint. + if self.pretrained_model_name is not None: + new_model_instance.name_or_path = self.pretrained_model_name + if self.using_peft: + for k in new_model_instance.peft_config.keys(): + new_model_instance.peft_config[ + k + ].base_model_name_or_path = self.pretrained_model_name + log.debug('Saving Hugging Face checkpoint to disk') # This context manager casts the TE extra state in io.BytesIO format to tensor format # Needed for proper hf ckpt saving.