From 3b316c43bbb56eafcc301e59d9947bee58c8dc4a Mon Sep 17 00:00:00 2001 From: Daniel King Date: Tue, 24 Sep 2024 21:21:27 -0700 Subject: [PATCH] Revert "deprecations" This reverts commit 6858db9d3988e35c54275f5bab44404ca95744a3. --- llmfoundry/command_utils/eval.py | 2 +- llmfoundry/models/hf/__init__.py | 2 + llmfoundry/models/hf/model_wrapper.py | 103 ++++++++++++++++++++++++++ tests/models/test_model.py | 6 +- 4 files changed, 109 insertions(+), 4 deletions(-) create mode 100644 llmfoundry/models/hf/model_wrapper.py diff --git a/llmfoundry/command_utils/eval.py b/llmfoundry/command_utils/eval.py index 73127e8a07..70c4319ea8 100644 --- a/llmfoundry/command_utils/eval.py +++ b/llmfoundry/command_utils/eval.py @@ -82,7 +82,7 @@ def evaluate_model( warnings.warn( VersionedDeprecationWarning( 'The argument fsdp_config is deprecated. Please use parallelism_config instead.', - remove_version='0.14.0', + remove_version='0.13.0', ), ) if fsdp_config and parallelism_config: diff --git a/llmfoundry/models/hf/__init__.py b/llmfoundry/models/hf/__init__.py index 03df90e8cd..2f25f92940 100644 --- a/llmfoundry/models/hf/__init__.py +++ b/llmfoundry/models/hf/__init__.py @@ -9,6 +9,7 @@ prepare_hf_model_for_fsdp, ) from llmfoundry.models.hf.hf_t5 import ComposerHFT5 +from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithFSDP __all__ = [ 'BaseHuggingFaceModel', @@ -17,4 +18,5 @@ 'prepare_hf_causal_lm_model_for_fsdp', 'prepare_hf_enc_dec_model_for_fsdp', 'prepare_hf_model_for_fsdp', + 'HuggingFaceModelWithFSDP', ] diff --git a/llmfoundry/models/hf/model_wrapper.py b/llmfoundry/models/hf/model_wrapper.py new file mode 100644 index 0000000000..c8805e5d6d --- /dev/null +++ b/llmfoundry/models/hf/model_wrapper.py @@ -0,0 +1,103 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +"""Re-usable :class:`.ComposerModel` for LLM HF Models.""" + +from __future__ import annotations + +import warnings +from collections import UserDict +from typing import TYPE_CHECKING, Mapping, Optional, Union + +import transformers +from composer.models.huggingface import HuggingFaceModel +from torchmetrics import Metric +from transformers import PreTrainedTokenizerBase +from transformers.utils.generic import ModelOutput + +from llmfoundry.models.hf.hf_fsdp import prepare_hf_model_for_fsdp +from llmfoundry.utils.warnings import VersionedDeprecationWarning + +if TYPE_CHECKING: + from peft import PeftConfig, PeftModel + +__all__ = ['HuggingFaceModelWithFSDP'] + +# HuggingFace hardcodes the ignore index to -100 +_HF_IGNORE_INDEX = -100 + + +class HuggingFaceModelWithFSDP(HuggingFaceModel): + """Wrapper around HuggingFaceModel. + + Handles preparation for FSDP wrapping. + """ + + def __init__( + self, + model: Union[transformers.PreTrainedModel, 'PeftModel'], + tokenizer: Optional[PreTrainedTokenizerBase] = None, + metrics: Optional[list[Metric]] = None, + eval_metrics: Optional[list[Metric]] = None, + shift_labels: bool = False, + allow_embedding_resizing: bool = False, + init_device: Optional[str] = None, + peft_config: Optional['PeftConfig'] = None, + should_save_peft_only: bool = True, + ): + warnings.warn( + VersionedDeprecationWarning( + '`HuggingFaceModelWithFSDP` is deprecated. In the future please use `BaseHuggingFaceModel`.', + remove_version='0.13.0', + ), + ) + super().__init__( + model, + tokenizer, + use_logits=True, + metrics=metrics, + eval_metrics=eval_metrics, + shift_labels=shift_labels, + allow_embedding_resizing=allow_embedding_resizing, + peft_config=peft_config, + should_save_peft_only=should_save_peft_only, + ) + + self.prepare_inner_model(self.model, init_device) + + def forward(self, batch: Mapping): + if isinstance(batch, dict) or isinstance(batch, UserDict): + # Further input validation is left to the huggingface forward call + batch = { + k: v for k, v in batch.items() if k in self.model_forward_args + } + output = self.model(**batch) # type: ignore (thirdparty) + else: + raise ValueError( + 'Unexpected batch type. Expected a dictionary with keys corresponding to the inputs to the forward function of the Huggingface model', + ) + return output + + def loss(self, outputs: ModelOutput, batch: Mapping): + if self.config.use_return_dict: + return outputs['loss'] + # loss is at index 0 in the output tuple, logits are at index 1 + return outputs[:2] + + @staticmethod + def prepare_inner_model( + model: Union[transformers.PreTrainedModel, 'PeftModel'], + init_device: Optional[str] = None, + ): + """Prepare the inner model for FSDP wrapping. + + Args: + model: The model to prepare. + init_device: The device to initialize the model on. + """ + # Note: We need to add the FSDP related attributes to the model AFTER the super init, + # so that the (possible) embedding resizing doesn't destroy them + prepare_hf_model_for_fsdp(model, init_device) + + # This provides support for meta initialization when using FSDP + model.param_init_fn = lambda module: model._init_weights(module) diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 92effffdd8..eeb6bf0d90 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -39,7 +39,7 @@ from llmfoundry import ComposerHFCausalLM from llmfoundry.layers_registry import norms -from llmfoundry.models.hf import BaseHuggingFaceModel +from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithFSDP from llmfoundry.models.layers import build_alibi_bias from llmfoundry.models.layers.attention import ( check_alibi_support, @@ -2560,7 +2560,7 @@ def test_hf_init( False, ) - model = BaseHuggingFaceModel(model, tokenizer) + model = HuggingFaceModelWithFSDP(model, tokenizer) batch = gen_random_batch(batch_size, test_cfg) @@ -2609,7 +2609,7 @@ def test_head_dim_8_flash_mqa_attn(batch_size: int = 2): mpt = MPTForCausalLM(hf_config) - model = BaseHuggingFaceModel(mpt, tokenizer, shift_labels=True) + model = HuggingFaceModelWithFSDP(mpt, tokenizer, shift_labels=True) model = model.to(test_cfg.device) batch = gen_random_batch(batch_size, test_cfg)