diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index d80060d6f6..c2c8cabf4d 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -30,7 +30,11 @@ from composer.utils.misc import create_interval_scheduler from mlflow.transformers import _fetch_model_card, _write_license_information from packaging import version -from transformers import PreTrainedModel, PreTrainedTokenizerBase +from transformers import ( + PretrainedConfig, + PreTrainedModel, + PreTrainedTokenizerBase, +) from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM from llmfoundry.models.utils import init_empty_weights @@ -338,6 +342,26 @@ def transform_model_and_tokenizer( """ return model, tokenizer + def transform_config( + self, + original_config: PretrainedConfig, + ) -> PretrainedConfig: + """Transform the model config before saving. + + Args: + original_config (Any): The original model config. + + Returns: + The transformed model config. + """ + copied_config = copy.deepcopy(original_config) + if copied_config.model_type == 'mpt': + copied_config.attn_config['attn_impl'] = 'torch' + copied_config.init_device = 'cpu' + if 'moe_world_size' in getattr(copied_config, 'ffn_config', {}): + copied_config.ffn_config['moe_world_size'] = 1 + return copied_config + def _save_checkpoint(self, state: State, logger: Logger): del logger # unused @@ -449,13 +473,10 @@ def dtensor_to_tensor_hook( if dist.get_global_rank() == 0: log.debug('Saving Hugging Face checkpoint in global rank 0') - # Edit HF config before building 2nd model copy - copied_config = copy.deepcopy(original_model.config) - if copied_config.model_type == 'mpt': - copied_config.attn_config['attn_impl'] = 'torch' - copied_config.init_device = 'cpu' - if 'moe_world_size' in getattr(copied_config, 'ffn_config', {}): - copied_config.ffn_config['moe_world_size'] = 1 + # Transform HF config before building 2nd model copy + new_config = self.transform_config( + original_config=original_model.config, + ) log.debug(f'Creating new model instance') @@ -464,7 +485,7 @@ def dtensor_to_tensor_hook( # model, only the adapter weights. active_adapter = original_model.active_adapter base_model = original_model.get_base_model() - new_base_model_instance = type(base_model)(copied_config) + new_base_model_instance = type(base_model)(new_config) new_model_instance = type(original_model)( new_base_model_instance, @@ -475,7 +496,7 @@ def dtensor_to_tensor_hook( # First create the model instance on meta device to avoid the # initialization cost. with init_empty_weights(): - new_model_instance = type(original_model)(copied_config) + new_model_instance = type(original_model)(new_config) new_model_instance.generation_config.update( **original_model.generation_config.to_dict(), )