Skip to content

Commit

Permalink
Hfcheckpointer optional generation config (#1543)
Browse files Browse the repository at this point in the history
Co-authored-by: v-chen_data <[email protected]>
  • Loading branch information
KuuCi and v-chen_data authored Sep 24, 2024
1 parent f377090 commit d85c83b
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 5 deletions.
7 changes: 4 additions & 3 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,9 +588,10 @@ def tensor_hook(
del new_base_model_instance
else:
new_model_instance = type(original_model)(new_config)
new_model_instance.generation_config.update(
**original_model.generation_config.to_dict(),
)
if new_model_instance.generation_config is not None:
new_model_instance.generation_config.update(
**original_model.generation_config.to_dict(),
)

# Then load the state dict in with "assign" so that the state dict
# is loaded properly even though the model is initially on meta device.
Expand Down
56 changes: 54 additions & 2 deletions tests/a_scripts/inference/test_convert_composer_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@
import pathlib
import shutil
from argparse import Namespace
from typing import Any, Callable, Optional, cast
from typing import Any, Callable, Optional, Union, cast
from unittest import mock
from unittest.mock import ANY, MagicMock, patch

import catalogue
import pytest
import torch
import torch.nn as nn
import transformers
from composer import ComposerModel, Trainer
from composer.loggers import MLFlowLogger
Expand All @@ -23,7 +24,13 @@
from omegaconf import OmegaConf as om
from torch.distributed._tensor.api import DTensor
from torch.utils.data import DataLoader
from transformers import PreTrainedModel, PreTrainedTokenizerBase
from transformers import (
AutoConfig,
GenerationConfig,
PretrainedConfig,
PreTrainedModel,
PreTrainedTokenizerBase,
)

from llmfoundry.callbacks import HuggingFaceCheckpointer
from llmfoundry.callbacks.hf_checkpointer import _maybe_get_license_filename
Expand Down Expand Up @@ -1637,3 +1644,48 @@ def test_license_file_finder(
found_path = _maybe_get_license_filename(str(tmp_path))
assert (found_path == license_file_name
) if license_file_name is not None else (found_path is None)


@pytest.mark.parametrize('generation_config', [None, {}, {'max_length': 200}])
def test_generation_config_variants(
generation_config: Optional[Union[dict[str, Any], GenerationConfig]],
):

class MockModel(nn.Module):

def __init__(self, config: PretrainedConfig):
super().__init__()
self.config = config
# Ensure generation_config is always a GenerationConfig object
if isinstance(config.generation_config, dict):
self.generation_config = GenerationConfig(
**config.generation_config,
)
else:
self.generation_config = config.generation_config

config = AutoConfig.from_pretrained('gpt2')
# Convert dict to GenerationConfig if needed
if isinstance(generation_config, dict):
generation_config = GenerationConfig(**generation_config)
config.generation_config = generation_config

mock_model = MockModel(config)
logger = MagicMock()
state = MagicMock()
state.timestamp.batch = 1
state.is_model_ddp = False
state.model.model = mock_model
state.model.tokenizer = None

checkpointer = HuggingFaceCheckpointer(
save_folder='test',
save_interval='1ba',
)

checkpointer._save_checkpoint(
state=state,
logger=logger,
upload_to_save_folder=False,
register_to_mlflow=False,
)

0 comments on commit d85c83b

Please sign in to comment.