From 880507166819e1a8a4bb81a127e18c3586af08b1 Mon Sep 17 00:00:00 2001 From: Saaketh Date: Mon, 29 Jul 2024 12:45:37 -0400 Subject: [PATCH] yo --- llmfoundry/callbacks/hf_checkpointer.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index e369eaa55f..79dc73de98 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -218,10 +218,11 @@ def __init__( self.mlflow_logging_config = mlflow_logging_config if 'metadata' in self.mlflow_logging_config: - self.pretrained_model_name = self.mlflow_logging_config['metadata'].get( - 'pretrained_model_name', - None, - ) + self.pretrained_model_name = self.mlflow_logging_config[ + 'metadata'].get( + 'pretrained_model_name', + None, + ) else: self.pretrained_model_name = None @@ -540,14 +541,11 @@ def tensor_hook( if self.pretrained_model_name is not None: new_model_instance.name_or_path = self.pretrained_model_name if self.using_peft: + new_model_instance.base_model.name_or_path = self.pretrained_model_name for k in new_model_instance.peft_config.keys(): - new_model_instance.base_model.name_or_path = self.pretrained_model_name new_model_instance.peft_config[ k ].base_model_name_or_path = self.pretrained_model_name - print("PEFT CONFIG IS:") - for k,v in new_model_instance.peft_config.items(): - print("key:", k, "value:", v) log.debug('Saving Hugging Face checkpoint to disk') # This context manager casts the TE extra state in io.BytesIO format to tensor format