From 96b27c594b31c588420734ba2337ce61385c56bf Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Thu, 4 Apr 2024 13:09:02 -0700 Subject: [PATCH] Allow overrides for nested PretrainedConfig (#1089) --- llmfoundry/models/hf/hf_causal_lm.py | 16 ++++++++++++-- tests/models/hf/test_hf_config.py | 33 +++++++++++++++++++++++++++- 2 files changed, 46 insertions(+), 3 deletions(-) diff --git a/llmfoundry/models/hf/hf_causal_lm.py b/llmfoundry/models/hf/hf_causal_lm.py index 38ed7a7e70..5bca5cb21a 100644 --- a/llmfoundry/models/hf/hf_causal_lm.py +++ b/llmfoundry/models/hf/hf_causal_lm.py @@ -11,8 +11,8 @@ from composer.models.huggingface import peft_installed from composer.utils import dist from omegaconf import DictConfig -from transformers import (AutoConfig, AutoModelForCausalLM, PreTrainedModel, - PreTrainedTokenizerBase) +from transformers import (AutoConfig, AutoModelForCausalLM, PretrainedConfig, + PreTrainedModel, PreTrainedTokenizerBase) from llmfoundry.metrics import (DEFAULT_CAUSAL_LM_EVAL_METRICS, DEFAULT_CAUSAL_LM_TRAIN_METRICS) @@ -161,6 +161,18 @@ def _autoset_attn_implementation_monkeypatch( elif attr is None and isinstance(v, Mapping): setattr(config, k, {}) getattr(config, k).update(v) + elif isinstance(attr, PretrainedConfig): + if not isinstance(v, Mapping): + raise ValueError( + f'Expected a dictionary for config override {k}, but got {v}.' + ) + + for _k, _v in v.items(): + if not hasattr(attr, _k): + raise ValueError( + f'config does not have attribute "{_k}" to override ({k}: {_k}: {_v}).' + ) + setattr(attr, _k, _v) else: setattr(config, k, v) diff --git a/tests/models/hf/test_hf_config.py b/tests/models/hf/test_hf_config.py index d5de596199..e79756aba3 100644 --- a/tests/models/hf/test_hf_config.py +++ b/tests/models/hf/test_hf_config.py @@ -12,7 +12,7 @@ import torch from omegaconf import DictConfig from omegaconf import OmegaConf as om -from transformers import AutoModelForCausalLM +from transformers import AutoModelForCausalLM, PretrainedConfig from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM from llmfoundry.utils import build_tokenizer @@ -205,3 +205,34 @@ def test_rope_scaling_override(): # This would error if the config isn't parsed into a proper dictionary model.get_metadata() assert model.config.rope_scaling == {'type': 'dynamic', 'factor': 0.5} + + +@pytest.mark.skipif('HUGGING_FACE_HUB_TOKEN' not in os.environ, + reason='CI does not have access to Dbrx') +def test_nested_override(): + model_cfg = { + 'name': 'hf_causal_lm', + 'pretrained_model_name_or_path': 'databricks/dbrx-instruct', + 'config_overrides': { + 'ffn_config': { + 'ffn_hidden_size': 500, + } + }, + 'use_auth_token': True, + 'pretrained': False, + 'init_device': 'meta', + } + model_cfg = om.create(model_cfg) + + model = build_composer_model( + name=model_cfg.name, + cfg=model_cfg, + tokenizer=None, # type: ignore + ) + + # The value we changed + assert model.config.ffn_config.ffn_hidden_size == 500 + # Ensure we still have a config, and haven't replaced it with a dictionary + assert isinstance(model.config.ffn_config, PretrainedConfig) + # Ensure the other values still exist and are not set back to their defaults + assert model.config.ffn_config.moe_num_experts == 16