From cefd616048d52ed714cd5e982a54eaaf9aa38707 Mon Sep 17 00:00:00 2001 From: Irene Dea Date: Tue, 23 Jul 2024 10:42:00 -0700 Subject: [PATCH] Add transformation hooks to hf_causal_lm (#1383) --- llmfoundry/models/hf/hf_causal_lm.py | 72 ++++++++------------------ llmfoundry/utils/config_utils.py | 42 +++++++++++++++ tests/models/hf/test_hf_transform.py | 76 ++++++++++++++++++++++++++++ 3 files changed, 139 insertions(+), 51 deletions(-) create mode 100644 tests/models/hf/test_hf_transform.py diff --git a/llmfoundry/models/hf/hf_causal_lm.py b/llmfoundry/models/hf/hf_causal_lm.py index a15429aa06..7c0baf0c58 100644 --- a/llmfoundry/models/hf/hf_causal_lm.py +++ b/llmfoundry/models/hf/hf_causal_lm.py @@ -11,7 +11,6 @@ Any, Dict, List, - Mapping, Optional, Tuple, Union, @@ -23,7 +22,6 @@ from transformers import ( AutoConfig, AutoModelForCausalLM, - PretrainedConfig, PreTrainedModel, PreTrainedTokenizerBase, ) @@ -36,7 +34,7 @@ from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithFSDP from llmfoundry.models.layers.attention import is_flash_v2_installed from llmfoundry.models.utils import init_empty_weights -from llmfoundry.utils.config_utils import get_hf_config_value +from llmfoundry.utils.config_utils import set_config_overrides if TYPE_CHECKING: from peft import PeftConfig, PeftModel @@ -105,9 +103,13 @@ def __init__( config_overrides=config_overrides, load_in_8bit=load_in_8bit, pretrained=pretrained, - prepare_for_fsdp=True, + prepare_for_fsdp=False, ) + model = self.transform_model(model) + + ComposerHFCausalLM.prepare_inner_model(model, init_device) + train_metrics, eval_metrics = ComposerHFCausalLM.build_metrics( use_train_metrics=use_train_metrics, additional_train_metrics=additional_train_metrics, @@ -121,7 +123,7 @@ def __init__( peft_config_object = None if peft_config is not None: - peft_config_object = self._get_peft_config(peft_config) + peft_config_object = self.get_peft_config(peft_config) # Set up config args for the model construction and base classes super().__init__( @@ -135,6 +137,17 @@ def __init__( should_save_peft_only=should_save_peft_only, ) + def transform_model(self, model: PreTrainedModel) -> PreTrainedModel: + """Transforms the model after initialization. + + Args: + model (PreTrainedModel): The model to transform. + + Returns: + PreTrainedModel: The transformed model. + """ + return model + @staticmethod def build_metrics( use_train_metrics: bool, @@ -259,50 +272,7 @@ def _autoset_attn_implementation_monkeypatch( _autoset_attn_implementation_monkeypatch, ) - # set config overrides - for k, v in config_overrides.items(): - if not hasattr(config, k): - raise ValueError( - f'config does not have attribute "{k}" to override ({k}: {v}).', - ) - - attr = getattr(config, k) - # attempt to disallow typos in nested configs - if isinstance(attr, Mapping): - extra_keys = [_k for _k in v.keys() if _k not in attr.keys()] - if extra_keys: - raise ValueError( - f'Config dict override got unknown keys. ' + - f'Extra keys: {extra_keys}. ' + - f'Expected (a subset of) keys: {list(attr.keys())}.', - ) - getattr(config, k).update(v) - # necessary case to allow for rope_scaling to be overriden in llama config - elif attr is None and isinstance(v, Mapping): - setattr(config, k, {}) - getattr(config, k).update(v) - elif isinstance(attr, PretrainedConfig): - if not isinstance(v, Mapping): - raise ValueError( - f'Expected a dictionary for config override {k}, but got {v}.', - ) - - for _k, _v in v.items(): - if not hasattr(attr, _k): - raise ValueError( - f'config does not have attribute "{_k}" to override ({k}: {_k}: {_v}).', - ) - setattr(attr, _k, _v) - else: - setattr(config, k, v) - - if hasattr(config, 'attn_config') and get_hf_config_value( - config.attn_config, - 'seq_parallel_world_size', - ) is not None: - raise NotImplementedError( - 'Sequence Parallelism is not supported for HuggingFace models.', - ) + set_config_overrides(config, config_overrides) # We need to have all non-zero local ranks be not-pretrained # Rank 0 will still be pretrained, and distribute the weights appropriately @@ -395,10 +365,10 @@ def _autoset_attn_implementation_monkeypatch( if prepare_for_fsdp: ComposerHFCausalLM.prepare_inner_model(model, init_device) + return model - @staticmethod - def _get_peft_config(peft_config_dict: Dict[str, Any]) -> 'PeftConfig': + def get_peft_config(self, peft_config_dict: Dict[str, Any]) -> 'PeftConfig': if peft_installed: from peft import LoraConfig peft_type = peft_config_dict.get('peft_type', '') diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index 4b86de99b8..48290bd7c5 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -812,3 +812,45 @@ def _verify_uc_path(path: str) -> bool: f'but your `UCVolumeDatasetSource` might be invalid.', ) return False + + +def set_config_overrides( + config: PretrainedConfig, + config_overrides: Dict[str, Any], +): + # set config overrides + for k, v in config_overrides.items(): + if not hasattr(config, k): + raise ValueError( + f'config does not have attribute "{k}" to override ({k}: {v}).', + ) + + attr = getattr(config, k) + # attempt to disallow typos in nested configs + if isinstance(attr, Mapping): + extra_keys = [_k for _k in v.keys() if _k not in attr.keys()] + if extra_keys: + raise ValueError( + f'Config dict override got unknown keys. ' + + f'Extra keys: {extra_keys}. ' + + f'Expected (a subset of) keys: {list(attr.keys())}.', + ) + getattr(config, k).update(v) + # necessary case to allow for rope_scaling to be overriden in llama config + elif attr is None and isinstance(v, Mapping): + setattr(config, k, {}) + getattr(config, k).update(v) + elif isinstance(attr, PretrainedConfig): + if not isinstance(v, Mapping): + raise ValueError( + f'Expected a dictionary for config override {k}, but got {v}.', + ) + + for _k, _v in v.items(): + if not hasattr(attr, _k): + raise ValueError( + f'config does not have attribute "{_k}" to override ({k}: {_k}: {_v}).', + ) + setattr(attr, _k, _v) + else: + setattr(config, k, v) diff --git a/tests/models/hf/test_hf_transform.py b/tests/models/hf/test_hf_transform.py new file mode 100644 index 0000000000..f479b50f73 --- /dev/null +++ b/tests/models/hf/test_hf_transform.py @@ -0,0 +1,76 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Dict, Optional + +import pytest +from composer.models.huggingface import maybe_get_underlying_model +from peft import PeftConfig, PeftModel +from transformers import LlamaForCausalLM, PreTrainedModel + +from llmfoundry.models.hf.hf_causal_lm import ComposerHFCausalLM +from llmfoundry.models.utils import init_empty_weights + + +@pytest.mark.gpu +@pytest.mark.parametrize( + 'peft_config', + [ + None, + { + 'peft_type': 'LORA', + 'task_type': 'CAUSAL_LM', + 'lora_alpha': 32, + 'r': 2, + 'target_modules': [ + 'q_proj', + 'k_proj', + 'v_proj', + ], + }, + ], +) +def test_hf_transform(peft_config: Optional[dict]): + model_cfg = { + 'pretrained_model_name_or_path': 'codellama/CodeLlama-7b-hf', + 'config_overrides': { + 'num_hidden_layers': 2, + 'hidden_size': 32, + 'intermediate_size': 64, + }, + 'pretrained': False, + 'peft_config': peft_config, + 'init_device': 'meta', + 'tokenizer': 'codellama/CodeLlama-7b-hf', + } + + class TransformedHFCausalLM(ComposerHFCausalLM): + + def transform_model(self, model: PreTrainedModel) -> PreTrainedModel: + assert isinstance(model, LlamaForCausalLM) + with init_empty_weights(): + model.config.num_hidden_layers = 1 + new_model = type(model)(model.config) + return new_model + + def get_peft_config( + self, + peft_config_dict: Dict[str, Any], + ) -> PeftConfig: + peft_config_dict['target_modules'] = ['o_proj'] + return super().get_peft_config(peft_config_dict) + + composer_model = TransformedHFCausalLM(**model_cfg) + model = composer_model.model + inner_model = maybe_get_underlying_model(model) + + if peft_config: + peft_model = composer_model.model + assert isinstance(peft_model, PeftModel) + + target_modules = peft_model.peft_config[peft_model.active_adapter + ].target_modules + assert list(target_modules) == ['o_proj'] + + assert isinstance(inner_model, LlamaForCausalLM) + assert inner_model.config.num_hidden_layers == 1