diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 65bdcb3b6c..4365a5b2e5 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -588,9 +588,10 @@ def tensor_hook( del new_base_model_instance else: new_model_instance = type(original_model)(new_config) - new_model_instance.generation_config.update( - **original_model.generation_config.to_dict(), - ) + if new_model_instance.generation_config is not None: + new_model_instance.generation_config.update( + **original_model.generation_config.to_dict(), + ) # Then load the state dict in with "assign" so that the state dict # is loaded properly even though the model is initially on meta device. diff --git a/tests/a_scripts/inference/test_convert_composer_to_hf.py b/tests/a_scripts/inference/test_convert_composer_to_hf.py index 66ec739a65..bf5f2a970b 100644 --- a/tests/a_scripts/inference/test_convert_composer_to_hf.py +++ b/tests/a_scripts/inference/test_convert_composer_to_hf.py @@ -8,13 +8,14 @@ import pathlib import shutil from argparse import Namespace -from typing import Any, Callable, Optional, cast +from typing import Any, Callable, Optional, Union, cast from unittest import mock from unittest.mock import ANY, MagicMock, patch import catalogue import pytest import torch +import torch.nn as nn import transformers from composer import ComposerModel, Trainer from composer.loggers import MLFlowLogger @@ -23,7 +24,13 @@ from omegaconf import OmegaConf as om from torch.distributed._tensor.api import DTensor from torch.utils.data import DataLoader -from transformers import PreTrainedModel, PreTrainedTokenizerBase +from transformers import ( + AutoConfig, + GenerationConfig, + PretrainedConfig, + PreTrainedModel, + PreTrainedTokenizerBase, +) from llmfoundry.callbacks import HuggingFaceCheckpointer from llmfoundry.callbacks.hf_checkpointer import _maybe_get_license_filename @@ -1637,3 +1644,48 @@ def test_license_file_finder( found_path = _maybe_get_license_filename(str(tmp_path)) assert (found_path == license_file_name ) if license_file_name is not None else (found_path is None) + + +@pytest.mark.parametrize('generation_config', [None, {}, {'max_length': 200}]) +def test_generation_config_variants( + generation_config: Optional[Union[dict[str, Any], GenerationConfig]], +): + + class MockModel(nn.Module): + + def __init__(self, config: PretrainedConfig): + super().__init__() + self.config = config + # Ensure generation_config is always a GenerationConfig object + if isinstance(config.generation_config, dict): + self.generation_config = GenerationConfig( + **config.generation_config, + ) + else: + self.generation_config = config.generation_config + + config = AutoConfig.from_pretrained('gpt2') + # Convert dict to GenerationConfig if needed + if isinstance(generation_config, dict): + generation_config = GenerationConfig(**generation_config) + config.generation_config = generation_config + + mock_model = MockModel(config) + logger = MagicMock() + state = MagicMock() + state.timestamp.batch = 1 + state.is_model_ddp = False + state.model.model = mock_model + state.model.tokenizer = None + + checkpointer = HuggingFaceCheckpointer( + save_folder='test', + save_interval='1ba', + ) + + checkpointer._save_checkpoint( + state=state, + logger=logger, + upload_to_save_folder=False, + register_to_mlflow=False, + )