Skip to content

Commit

Permalink
Refactor hf checkpointer (#1318)
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea authored Jun 29, 2024
1 parent 0ebd7c9 commit 88511f7
Showing 1 changed file with 31 additions and 10 deletions.
41 changes: 31 additions & 10 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@
from composer.utils.misc import create_interval_scheduler
from mlflow.transformers import _fetch_model_card, _write_license_information
from packaging import version
from transformers import PreTrainedModel, PreTrainedTokenizerBase
from transformers import (
PretrainedConfig,
PreTrainedModel,
PreTrainedTokenizerBase,
)

from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM
from llmfoundry.models.utils import init_empty_weights
Expand Down Expand Up @@ -338,6 +342,26 @@ def transform_model_and_tokenizer(
"""
return model, tokenizer

def transform_config(
self,
original_config: PretrainedConfig,
) -> PretrainedConfig:
"""Transform the model config before saving.
Args:
original_config (Any): The original model config.
Returns:
The transformed model config.
"""
copied_config = copy.deepcopy(original_config)
if copied_config.model_type == 'mpt':
copied_config.attn_config['attn_impl'] = 'torch'
copied_config.init_device = 'cpu'
if 'moe_world_size' in getattr(copied_config, 'ffn_config', {}):
copied_config.ffn_config['moe_world_size'] = 1
return copied_config

def _save_checkpoint(self, state: State, logger: Logger):
del logger # unused

Expand Down Expand Up @@ -449,13 +473,10 @@ def dtensor_to_tensor_hook(
if dist.get_global_rank() == 0:
log.debug('Saving Hugging Face checkpoint in global rank 0')

# Edit HF config before building 2nd model copy
copied_config = copy.deepcopy(original_model.config)
if copied_config.model_type == 'mpt':
copied_config.attn_config['attn_impl'] = 'torch'
copied_config.init_device = 'cpu'
if 'moe_world_size' in getattr(copied_config, 'ffn_config', {}):
copied_config.ffn_config['moe_world_size'] = 1
# Transform HF config before building 2nd model copy
new_config = self.transform_config(
original_config=original_model.config,
)

log.debug(f'Creating new model instance')

Expand All @@ -464,7 +485,7 @@ def dtensor_to_tensor_hook(
# model, only the adapter weights.
active_adapter = original_model.active_adapter
base_model = original_model.get_base_model()
new_base_model_instance = type(base_model)(copied_config)
new_base_model_instance = type(base_model)(new_config)

new_model_instance = type(original_model)(
new_base_model_instance,
Expand All @@ -475,7 +496,7 @@ def dtensor_to_tensor_hook(
# First create the model instance on meta device to avoid the
# initialization cost.
with init_empty_weights():
new_model_instance = type(original_model)(copied_config)
new_model_instance = type(original_model)(new_config)
new_model_instance.generation_config.update(
**original_model.generation_config.to_dict(),
)
Expand Down

0 comments on commit 88511f7

Please sign in to comment.