Skip to content

Commit

Permalink
yo
Browse files Browse the repository at this point in the history
  • Loading branch information
snarayan21 committed Jul 29, 2024
1 parent c61a14e commit 10b1ea9
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,8 @@ def __init__(

self.mlflow_logging_config = mlflow_logging_config
self.pretrained_model_name = self.mlflow_logging_config['metadata'].get(
'pretrained_model_name', None
'pretrained_model_name',
None,
)

self.huggingface_folder_name_fstr = os.path.join(
Expand Down Expand Up @@ -520,8 +521,6 @@ def tensor_hook(
new_model_instance.generation_config.update(
**original_model.generation_config.to_dict(),
)
if self.pretrained_model_name is not None:
new_model_instance.name_or_path = self.pretrained_model_name

# Then load the state dict in with "assign" so that the state dict
# is loaded properly even though the model is initially on meta device.
Expand All @@ -534,6 +533,15 @@ def tensor_hook(
original_tokenizer,
)

# Ensure that the pretrained model name is correctly set on the saved HF checkpoint.
if self.pretrained_model_name is not None:
new_model_instance.name_or_path = self.pretrained_model_name
if self.using_peft:
for k in new_model_instance.peft_config.keys():
new_model_instance.peft_config[
k
].base_model_name_or_path = self.pretrained_model_name

log.debug('Saving Hugging Face checkpoint to disk')
# This context manager casts the TE extra state in io.BytesIO format to tensor format
# Needed for proper hf ckpt saving.
Expand Down

0 comments on commit 10b1ea9

Please sign in to comment.