Skip to content

Commit

Permalink
Use the pretrained generation config if it exists for HF models (#1440)
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea authored Aug 9, 2024
1 parent 44b09f0 commit 9cdfd6d
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 3 deletions.
21 changes: 20 additions & 1 deletion llmfoundry/models/hf/hf_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from transformers import (
AutoConfig,
AutoModelForCausalLM,
GenerationConfig,
PreTrainedModel,
PreTrainedTokenizerBase,
)
Expand Down Expand Up @@ -232,7 +233,7 @@ def build_inner_model(

# Hugging Face copies the modules into the
# transformers modules cache. On particular systems, this operation seems to cause contention between
# the different processes. To avoid this contention, we first create the config on local rank
# the different processes. To avoid this contention, we first create the config and generation config on local rank
# zero. This will set up the transformers module cache and avoid the future contention.
if dist.get_local_rank() == 0:
AutoConfig.from_pretrained(
Expand All @@ -243,6 +244,13 @@ def build_inner_model(
use_cache=
False, # Necessary due to https://github.com/huggingface/transformers/issues/28056
)
try:
GenerationConfig.from_pretrained(
pretrained_model_name_or_path,
use_auth_token=use_auth_token,
)
except OSError:
pass

dist.barrier()

Expand Down Expand Up @@ -337,6 +345,17 @@ def build_inner_model(
if dist.get_local_rank() == 0:
os.remove(signal_file_path)

# Use the pretrained generation config for the model if it exists.
try:
model.generation_config = GenerationConfig.from_pretrained(
pretrained_model_name_or_path,
use_auth_token=use_auth_token,
)
except OSError:
log.warning(
f'No existing generation config found for the model with name or path {pretrained_model_name_or_path}. Using default generation config.',
)

# Hugging Face's weight tying does not succeed if the model is inited on meta device
# so we manually apply the weight tying here
if model.config.tie_word_embeddings and resolved_init_device == 'meta':
Expand Down
67 changes: 65 additions & 2 deletions tests/models/hf/test_hf_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,28 @@

import os
from copy import deepcopy
from pathlib import Path
from typing import Any, Dict, Mapping
from unittest.mock import Mock, patch

import pytest
import torch
from omegaconf import OmegaConf as om
from transformers import PretrainedConfig
from transformers import (
AutoConfig,
AutoModelForCausalLM,
PretrainedConfig,
PreTrainedModel,
)

from llmfoundry.models.hf.hf_fsdp import rgetattr
from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM
from llmfoundry.utils import build_tokenizer
from llmfoundry.utils.builders import build_composer_model
from llmfoundry.utils.config_utils import to_dict_container
from llmfoundry.utils.config_utils import (
set_config_overrides,
to_dict_container,
)


def test_remote_code_false_mpt(
Expand Down Expand Up @@ -279,3 +288,57 @@ def test_use_flash():

# Make sure that HF has not cast the parameters to bf16
assert next(model.parameters()).dtype == torch.float32


def test_generation_config(tmp_path: Path):
# Create a small llama model to edit and save.
config = AutoConfig.from_pretrained('codellama/CodeLlama-7b-hf')
set_config_overrides(
config,
config_overrides={
'num_hidden_layers': 2,
'hidden_size': 32,
'intermediate_size': 64,
},
)
model = AutoModelForCausalLM.from_config(config)

assert isinstance(model, PreTrainedModel)
assert model.generation_config is not None

new_bos_token_id = 100

# Set the bos_token_id to something else
model.generation_config.bos_token_id = new_bos_token_id

# Generation config and model config no longer match
assert model.generation_config.bos_token_id != model.config.bos_token_id

save_dir = tmp_path / 'model'

# Save the model.
model.save_pretrained(save_dir)

# Now load the model from the save directory and check that the bos_token_id is the same as what we set.
model_cfg = {
'name': 'hf_causal_lm',
'pretrained_model_name_or_path': str(save_dir),
'use_auth_token': True,
'pretrained': False,
'init_device': 'cpu',
}

name = model_cfg.pop('name')
model = build_composer_model(
name=name,
cfg=model_cfg,
tokenizer=None, # type: ignore
)

inner_model = model.model

assert isinstance(inner_model, PreTrainedModel)
assert inner_model.generation_config is not None

# save_pretrained and reloading with hf_causal_lm should use the bos_token_id we set from earlier.
assert inner_model.generation_config.bos_token_id == new_bos_token_id

0 comments on commit 9cdfd6d

Please sign in to comment.