From 3fc2efed47dc371110733c8750a8f980ef37dd10 Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Mon, 29 Jan 2024 10:59:43 -0800 Subject: [PATCH] Integrate PEFT LoRA with HuggingFaceModel (#2829) --- composer/datasets/utils.py | 2 +- composer/models/huggingface.py | 271 ++++++++++++++++++++++++++----- setup.py | 4 + tests/common/models.py | 102 +++++++++--- tests/conftest.py | 7 +- tests/fixtures/fixtures.py | 56 +++++++ tests/models/test_hf_model.py | 281 ++++++++++++++++++++++++++++++++- 7 files changed, 658 insertions(+), 65 deletions(-) diff --git a/composer/datasets/utils.py b/composer/datasets/utils.py index b627ef85963..6bb376ff300 100644 --- a/composer/datasets/utils.py +++ b/composer/datasets/utils.py @@ -196,7 +196,7 @@ def __init__( self.stop_sequence_id_len = len(self.stop_sequence_ids) + 2 self.tokenizer = tokenizer - def __call__(self, input_ids: torch.Tensor, scores: Optional[torch.FloatTensor] = None, **kwargs) -> bool: + def __call__(self, input_ids: torch.LongTensor, scores: Optional[torch.FloatTensor] = None, **kwargs) -> bool: # For efficiency, we compare the last n tokens where n is the number of tokens in the stop_sequence lookback_ids_batch = input_ids[:, :][:, -self.stop_sequence_id_len:] diff --git a/composer/models/huggingface.py b/composer/models/huggingface.py index e633db9cb73..e635f0cec75 100644 --- a/composer/models/huggingface.py +++ b/composer/models/huggingface.py @@ -5,6 +5,7 @@ from __future__ import annotations +import copy import inspect import json import logging @@ -13,8 +14,9 @@ import string import tempfile import textwrap +import warnings from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Type, Union +from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set, Tuple, Type, Union import torch from torchmetrics import Metric @@ -23,14 +25,21 @@ from composer.models.base import ComposerModel from composer.utils import MissingConditionalImportError, dist, get_file, import_object, is_model_fsdp, safe_torch_load +try: + from peft import PeftModel, get_peft_model + peft_installed = True +except: + peft_installed = False + if TYPE_CHECKING: import transformers + from peft import PeftConfig, PeftModel from transformers import PretrainedConfig from transformers.models.auto.auto_factory import _BaseAutoModelClass log = logging.getLogger(__name__) -__all__ = ['HuggingFaceModel'] +__all__ = ['HuggingFaceModel', 'peft_installed'] class HuggingFaceModel(ComposerModel): @@ -38,7 +47,7 @@ class HuggingFaceModel(ComposerModel): A wrapper class that converts 🤗 Transformers models to composer models. Args: - model (transformers.PreTrainedModel): A 🤗 Transformers model. + model (Union[transformers.PreTrainedModel, peft.PeftModel)): A 🤗 Transformers model or a PEFT model. tokenizer (transformers.PreTrainedTokenizer, optional): The tokenizer used to prepare the dataset. Default ``None``. .. note:: If the tokenizer is provided, its config will be saved in the composer checkpoint, and it can be reloaded @@ -48,6 +57,8 @@ class HuggingFaceModel(ComposerModel): eval_metrics (list[Metric], optional): list of torchmetrics to compute on the eval_dataloader, or be accessible to :class:`Evaluator`s. Default: ``None``. shift_labels (bool, optional): If True, the batch's labels will be shifted before being used to calculate metrics. This should be set to true for CausalLM models and false otherwise. If not specified, `shift_labels` will be set automatically based on the model class name. Default: ``None``. allow_embedding_resizing (bool, optional): If True, the model's embeddings will be automatically resized when they are smaller than the tokenizer vocab size. Default: ``False``. + peft_config (PeftConfig, optional): Optional PEFT config to apply to the model. If provided, the model will be converted to a PEFT model. Only LoRA is currently supported. + should_save_peft_only (bool, optional): If True _and_ PEFT is active, the state dict will only contain the PEFT weights, not the frozen base model weights. .. note:: To ensure correct behavior, set `shift_labels` manually if using a custom model (i.e., if `model` is not an instance of a registered 🤗 Transformers class). @@ -66,14 +77,16 @@ class HuggingFaceModel(ComposerModel): """ def __init__(self, - model: transformers.PreTrainedModel, + model: Union[transformers.PreTrainedModel, 'PeftModel'], tokenizer: Optional[Union[transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerFast]] = None, use_logits: Optional[bool] = False, metrics: Optional[List[Metric]] = None, eval_metrics: Optional[List[Metric]] = None, shift_labels: Optional[bool] = None, - allow_embedding_resizing: bool = False) -> None: + allow_embedding_resizing: bool = False, + peft_config: Optional['PeftConfig'] = None, + should_save_peft_only: bool = True) -> None: try: import transformers del transformers # unused @@ -82,65 +95,111 @@ def __init__(self, conda_package='transformers', conda_channel='conda-forge') from e + if peft_config is not None: + if not peft_installed: + raise MissingConditionalImportError(extra_deps_group='peft', + conda_package='peft', + conda_channel='conda-forge') + + if peft_config is not None: + # Hugging Face requires the peft type and task type to be upper case, so we do that here + # https://github.com/huggingface/peft/blob/ebbff4023ad276cbcb2466fd7e99be7d3ae0ae11/src/peft/utils/peft_types.py#L22-L51 + if isinstance(peft_config.peft_type, str): + peft_config.peft_type = peft_config.peft_type.upper() + if isinstance(peft_config.task_type, str): + peft_config.task_type = peft_config.task_type.upper() + + if peft_config.peft_type != 'LORA': + raise ValueError( + f'PEFT type {peft_config.peft_type} is not supported by HuggingFaceModel. Only LORA is supported.') + super().__init__() self.model = model - self.config = model.config - self.model_forward_args = inspect.getfullargspec(self.model.forward).args + self.config: PretrainedConfig = model.config + self.model_forward_args = self._get_model_forward_args() self.tokenizer = tokenizer + self.should_save_peft_only = should_save_peft_only + self.use_logits = use_logits + self.labels: Optional[torch.Tensor] = None # set in eval_forward() if exists + self.dummy_forward_called = False # Used to make FSDP generate work, see generate function for more details + self.train_metrics: Optional[Dict] = self._get_metric_dict(metrics) if metrics is not None else None + self.val_metrics: Optional[Dict] = self._get_metric_dict( + eval_metrics) if eval_metrics is not None else copy.deepcopy(self.train_metrics) + + is_causal_lm = _is_registered_causal_lm(self.model) + self.shift_labels = is_causal_lm if shift_labels is None else shift_labels + + self._check_tokenizer_and_maybe_resize_embeddings(allow_embedding_resizing) + + if is_causal_lm and not self.shift_labels: + log.warning('The shift_labels argument was set to False but the model is an instance of a' + ' HuggingFace Causal LM. This may lead to incorrect behavior.') + # Note: No warning if shift_labels and not is_causal_lm, since the model may simply be a custom class. + + if peft_config is not None: + self.model = _maybe_get_peft_model(peft_config, self.model) + self.using_peft = isinstance(self.model, PeftModel) if peft_installed else False + + def _check_tokenizer_and_maybe_resize_embeddings(self, allow_embedding_resizing: bool) -> None: if self.tokenizer is None: log.warning( 'The tokenizer was not provided. This means the tokenizer config will not be saved in the checkpoint.') - if tokenizer is not None and self.config.vocab_size < len(tokenizer): + if self.tokenizer is not None and self.config.vocab_size < len(self.tokenizer): if allow_embedding_resizing: # when the embedding size is smaller than the tokenizer vocab size, # the embeddings should get resized to match the tokenizer vocab size log.warning(f'The number of tokens in the tokenizer is greater than the number of tokens in the model.' f' This would cause an error during training.' - f' Resizing the model embeddings to {len(tokenizer)} from {self.config.vocab_size}.') - self.model.resize_token_embeddings(len(tokenizer)) + f' Resizing the model embeddings to {len(self.tokenizer)} from {self.config.vocab_size}.') + self.model.resize_token_embeddings(len(self.tokenizer)) else: raise ValueError( f'The number of tokens in the tokenizer is greater than the number of tokens in the model.' f' This would cause an error during training.' - f' You can resize the model embeddings to {len(tokenizer)} from {self.config.vocab_size}' + f' You can resize the model embeddings to {len(self.tokenizer)} from {self.config.vocab_size}' f' by calling `model.resize_token_embeddings(len(tokenizer))` before calling the `HuggingFaceModel`' f' constructor, or pass `allow_embedding_resizing=True` to have it done automatically.') - elif tokenizer is not None and self.config.vocab_size > len(tokenizer): + elif self.tokenizer is not None and self.config.vocab_size > len(self.tokenizer): # when the embedding size is greater than the tokenizer vocab size, # the embeddings do not _need_ to be resized to match the tokenizer vocab size, # and should be done by the user if desired log.info( f'The number of tokens in the tokenizer is less than the number of tokens in the model.' - f' You may want to resize the model embeddings to {len(tokenizer)} from {self.config.vocab_size}' + f' You may want to resize the model embeddings to {len(self.tokenizer)} from {self.config.vocab_size}' f' by calling `model.resize_token_embeddings(len(tokenizer))` before calling the `HuggingFaceModel`' f' constructor. The vocab size is sometimes intentionally set to a multiple of 32 or 64 to improve' f' performance.') - self.use_logits = use_logits + def _get_metric_dict(self, metrics: List[Metric]) -> Dict[str, Metric]: + """Returns a dictionary of metrics keyed by their class name.""" + return {metric.__class__.__name__: metric for metric in metrics} - self.train_metrics: Optional[Dict] = None - self.val_metrics: Optional[Dict] = None + def _get_model_forward_args(self) -> Set[str]: + """Returns the arguments to the model's forward function.""" + model_forward_args = inspect.signature(maybe_get_underlying_model(self.model).forward).parameters.keys() - if eval_metrics is not None: - self.val_metrics = {metric.__class__.__name__: metric for metric in eval_metrics} - if metrics is not None: - self.train_metrics = {metric.__class__.__name__: metric for metric in metrics} - # if eval_metrics is None, use the same metrics as train_metrics - if eval_metrics is None: - self.val_metrics = {metric.__class__.__name__: metric for metric in metrics} + if not model_forward_args: + raise ValueError('Could not determine the forward arguments of the model. Please open a GitHub issue.') - self.labels: Optional[torch.Tensor] = None # set in eval_forward() if exists + model_forward_args = set(model_forward_args) - is_causal_lm = _is_registered_causal_lm(model) - self.shift_labels = is_causal_lm if shift_labels is None else shift_labels - if is_causal_lm and not self.shift_labels: - log.warning('The shift_labels argument was set to False but the model is an instance of a' - ' HuggingFace Causal LM. This may lead to incorrect behavior.') - # Note: No warning if shift_labels and not is_causal_lm, since the model may simply be a custom class. + return model_forward_args + + def state_dict(self, *args, **kwargs) -> Dict[str, Any]: + """Returns the state dict of the model.""" + full_state_dict = super().state_dict(*args, **kwargs) - self.dummy_forward_called = False + if self.using_peft and self.should_save_peft_only: + active_adapter = self.model.active_adapter + assert isinstance(active_adapter, str) + full_state_dict = filter_state_dict_peft(full_state_dict, + self.model.peft_config[active_adapter], + adapter_name='default', + remove_adapter_names=False) + + return full_state_dict @staticmethod def load_huggingface_tokenizer_from_saved_state( @@ -433,7 +492,7 @@ def eval_forward(self, batch, outputs: Optional[Any] = None): # HF encoder decoder models like T5 expect either decoder_input_ids or labels, # so we add decoder_input_ids to the batch if it is missing - if self.model.config.is_encoder_decoder and 'decoder_input_ids' not in batch: + if self.config.is_encoder_decoder and 'decoder_input_ids' not in batch: if hasattr(self.model, 'prepare_decoder_input_ids_from_labels'): batch['decoder_input_ids'] = self.model.prepare_decoder_input_ids_from_labels(labels=self.labels) else: @@ -489,7 +548,9 @@ def get_metadata(self): tmp_dir = Path(tmp_dir) model_dir = tmp_dir / 'model' tokenizer_dir = tmp_dir / 'tokenizer' - self.model.config.save_pretrained(model_dir) + + original_model_config: PretrainedConfig = self.config + original_model_config.save_pretrained(model_dir) if self.tokenizer is not None: self.tokenizer.save_pretrained(tokenizer_dir) @@ -502,6 +563,19 @@ def get_metadata(self): 'class': f'{self.model.__class__.__module__}.{self.model.__class__.__name__}' } + # Also save PEFT config if the model is a peft model + if self.using_peft: + active_adapter = self.model.active_adapter + assert isinstance(active_adapter, str) + self.model.peft_config[active_adapter].save_pretrained(str(model_dir)) + with open(model_dir / 'adapter_config.json') as _peft_config_file: + peft_config = json.load(_peft_config_file) + + model_output['peft_config'] = { + 'file_extension': '.json', + 'content': peft_config, + } + if self.tokenizer is not None: for tokenizer_file_name in tokenizer_dir.iterdir(): tokenizer_file_path = tokenizer_dir / tokenizer_file_name @@ -557,7 +631,7 @@ def generate(self, input_ids: torch.Tensor, **kwargs): if not using_torch_2() and not self.dummy_forward_called and is_model_fsdp(self.model): with torch.no_grad(): maybe_decoder_input_ids = {} - if self.model.config.is_encoder_decoder: + if self.config.is_encoder_decoder: maybe_decoder_input_ids['decoder_input_ids'] = torch.tensor([[0]], dtype=torch.long, device=input_ids.device) @@ -579,7 +653,49 @@ def generate(self, input_ids: torch.Tensor, **kwargs): return self.model.generate(input_ids=input_ids, pad_token_id=pad_token_id, **kwargs) -def _is_registered_causal_lm(model: transformers.PreTrainedModel) -> bool: +def _maybe_get_peft_model( + peft_config: 'PeftConfig', + model: Union[transformers.PreTrainedModel, 'PeftModel'], +) -> 'PeftModel': + """Creates a PEFT model if the model is not already a PEFT model. + + Args: + peft_config (Optional[peft.PeftConfig]): The PEFT config to use to create the PEFT model + model (Union[transformers.PreTrainedModel, 'PeftModel']): The model to create the PEFT model from + + Returns: + PeftModel: The PEFT model + """ + if not peft_installed: + raise MissingConditionalImportError(extra_deps_group='peft', conda_package='peft', conda_channel='conda-forge') + + if not isinstance(model, PeftModel): + log.info('Creating PEFT model') + peft_model = get_peft_model(model, peft_config) + assert isinstance(peft_model, PeftModel) + return peft_model + else: + warnings.warn('PEFT model was passed in directly. Ignoring the provided PEFT config.') + return model + + +def maybe_get_underlying_model( + model: Union[transformers.PreTrainedModel, 'PeftModel']) -> Union[transformers.PreTrainedModel, 'PeftModel']: + """Get the underlying PreTrainedModel from a model if it is a PEFT model + + Args: + model (Union[transformers.PreTrainedModel, 'PeftModel']): The model to get the underlying model from + + Returns: + Union[transformers.PreTrainedModel]: The underlying transformers model + """ + if peft_installed and isinstance(model, PeftModel): + return model.base_model.model + else: + return model + + +def _is_registered_causal_lm(model: Union[transformers.PreTrainedModel, 'PeftModel']) -> bool: """Return True if model class is either a registered 🤗 Causal LM or a subclass of one""" try: from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING @@ -588,6 +704,8 @@ def _is_registered_causal_lm(model: transformers.PreTrainedModel) -> bool: conda_package='transformers', conda_channel='conda-forge') from e + model_to_check = maybe_get_underlying_model(model) + # This try/except is needed until https://github.com/huggingface/transformers/issues/26778 # is resolved in a release. This means that this attempt to automatically detect causal LMs # does not currently work in an environment with flash attention <2 installed. @@ -599,7 +717,7 @@ def _is_registered_causal_lm(model: transformers.PreTrainedModel) -> bool: return False else: raise e - return any(isinstance(model, causal_lm_class) for causal_lm_class in causal_lm_classes) # type: ignore + return any(isinstance(model_to_check, causal_lm_class) for causal_lm_class in causal_lm_classes) # type: ignore def get_hf_config_from_composer_state_dict(state_dict: Dict[str, Any], @@ -642,6 +760,30 @@ def get_hf_config_from_composer_state_dict(state_dict: Dict[str, Any], f'config has a valid `_name_or_path`.') +def get_peft_config_from_composer_state_dict(state_dict: Dict[str, Any]) -> Optional['PeftConfig']: + """Get a PEFT config from a composer state dict + + Args: + state_dict (Dict[str, Any]): The state dict to get the config from + + Returns: + Optional[peft.PeftConfig]: The PEFT config. Will be ``None`` if the model is not a PEFT model. + """ + try: + import peft + except ImportError as e: + raise MissingConditionalImportError(extra_deps_group='nlp', conda_package='peft', + conda_channel='conda-forge') from e + + hf_model_dict = state_dict['state']['integrations']['huggingface']['model'] + if 'peft_config' not in hf_model_dict: + return None + + peft_config_dict = hf_model_dict['peft_config']['content'] + + return peft.get_peft_config(peft_config_dict) + + def write_huggingface_pretrained_from_composer_checkpoint( checkpoint_path: Union[Path, str], output_folder: Union[Path, str], @@ -718,6 +860,61 @@ def write_huggingface_pretrained_from_composer_checkpoint( config = get_hf_config_from_composer_state_dict(composer_state_dict) config.save_pretrained(output_folder) + peft_config = get_peft_config_from_composer_state_dict(composer_state_dict) + if peft_config is not None: + peft_config.save_pretrained(str(output_folder)) + weights_state_dict = composer_state_dict['state']['model'] torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(weights_state_dict, prefix='model.') - torch.save(weights_state_dict, Path(output_folder) / 'pytorch_model.bin') + + # NOTE: This only works for default adapter name, not multiple adapters + if peft_config is not None: + weights_state_dict = filter_state_dict_peft(weights_state_dict, peft_config, adapter_name='default') + + torch.save(weights_state_dict, Path(output_folder) / 'adapter_model.bin') + else: + torch.save(weights_state_dict, Path(output_folder) / 'pytorch_model.bin') + + +def filter_state_dict_peft(state_dict: Dict[str, Any], + peft_config: 'PeftConfig', + adapter_name: str = 'default', + remove_adapter_names: bool = True) -> Dict[str, Any]: + """Filter a state dict to only include the weights needed for a PEFT model + + Note: This function only works with LORA PEFT models right now. + + Args: + state_dict (Dict[str, Any]): The state dict to filter + peft_config (PeftConfig): The PEFT config to use to filter the state dict + adapter_name (str, optional): The name of the adapter to filter for. Defaults to 'default'. + remove_adapter_names (bool, optional): Whether to remove the adapter names from the state dict keys. Defaults to True. + + Returns: + Dict[str, Any]: The filtered state dict + """ + + if peft_config.peft_type != 'LORA': + raise NotImplementedError(f'Only LoRA PEFT is supported. Got {peft_config.peft_type}') + + # Filtering copied from https://github.com/huggingface/peft/blob/4186c9b104644fd247a4cc0dc2dfc1ede4665204/src/peft/utils/save_and_load.py#L68C1-L86C116 + bias = peft_config.bias # type: ignore + if bias == 'none': + to_return = {k: state_dict[k] for k in state_dict if 'lora_' in k} + elif bias == 'all': + to_return = {k: state_dict[k] for k in state_dict if 'lora_' in k or 'bias' in k} + elif bias == 'lora_only': + to_return = {} + for k in state_dict: + if 'lora_' in k: + to_return[k] = state_dict[k] + bias_name = k.split('lora_')[0] + 'bias' + if bias_name in state_dict: + to_return[bias_name] = state_dict[bias_name] + else: + raise NotImplementedError + to_return = {k: v for k, v in to_return.items() if (('lora_' in k and adapter_name in k) or ('bias' in k))} + + if remove_adapter_names: + to_return = {k.replace(f'.{adapter_name}', ''): v for k, v in to_return.items()} + return to_return diff --git a/setup.py b/setup.py index 1353df0522b..6754aa9387d 100644 --- a/setup.py +++ b/setup.py @@ -193,6 +193,10 @@ def package_files(prefix: str, directory: str, extension: str): 'datasets>=2.4,<3', ] +extra_deps['peft'] = [ + 'peft>=0.7.0,<0.8', +] + extra_deps['sentencepiece'] = [ 'protobuf<3.21', 'sentencepiece==0.1.99', diff --git a/tests/common/models.py b/tests/common/models.py index 41ea85ab049..a0b66d8929e 100644 --- a/tests/common/models.py +++ b/tests/common/models.py @@ -4,7 +4,7 @@ """Contains commonly used models that are shared across the test suite.""" import copy from functools import partial -from typing import Any, Dict, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union import pytest import torch @@ -14,6 +14,9 @@ from composer.metrics.nlp import LanguageCrossEntropy, MaskedAccuracy from composer.models import ComposerClassifier, HuggingFaceModel +if TYPE_CHECKING: + from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast + class EmptyModel(ComposerClassifier): """Always predict 0 with no parameters.""" @@ -443,105 +446,160 @@ def forward(self, batch: Tuple[torch.Tensor, Any]) -> torch.Tensor: # As a workaround, we inject objects into the PyTest namespace. Tests should not directly # use pytest.{var}, but instead should import and use these helper copy methods so the # objects in the PyTest namespace do not change. -def configure_tiny_bert_model(): +def configure_tiny_bert_model() -> 'PreTrainedModel': try: + from transformers import PreTrainedModel + assert isinstance(pytest.tiny_bert_model, PreTrainedModel) return copy.deepcopy(pytest.tiny_bert_model) except AttributeError: pytest.skip('Composer installed without NLP support') -def configure_tiny_bert_tokenizer(): +def configure_tiny_bert_tokenizer() -> Union['PreTrainedTokenizer', 'PreTrainedTokenizerFast']: try: + from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast + assert isinstance(pytest.tiny_bert_tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)) return copy.deepcopy(pytest.tiny_bert_tokenizer) except AttributeError: pytest.skip('Composer installed without NLP support') -def configure_tiny_bert_config(): +def configure_tiny_bert_config() -> 'PretrainedConfig': try: + from transformers import PretrainedConfig + assert isinstance(pytest.tiny_bert_config, PretrainedConfig) return copy.deepcopy(pytest.tiny_bert_config) except AttributeError: pytest.skip('Composer installed without NLP support') -def configure_tiny_bert_hf_model(use_logits=True): - return HuggingFaceModel(configure_tiny_bert_model(), configure_tiny_bert_tokenizer(), use_logits) # type: ignore +def configure_tiny_bert_hf_model(use_logits: bool = True) -> HuggingFaceModel: + return HuggingFaceModel(configure_tiny_bert_model(), configure_tiny_bert_tokenizer(), use_logits) -def configure_tiny_deberta_model(): +def configure_tiny_deberta_model() -> 'PreTrainedModel': try: + from transformers import PreTrainedModel + assert isinstance(pytest.tiny_deberta_model, PreTrainedModel) return copy.deepcopy(pytest.tiny_deberta_model) except AttributeError: pytest.skip('Composer installed without NLP support') -def configure_tiny_deberta_tokenizer(): +def configure_tiny_deberta_tokenizer() -> Union['PreTrainedTokenizer', 'PreTrainedTokenizerFast']: try: + from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast + assert isinstance(pytest.tiny_deberta_tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)) return copy.deepcopy(pytest.tiny_deberta_tokenizer) except AttributeError: pytest.skip('Composer installed without NLP support') -def configure_tiny_deberta_config(): +def configure_tiny_deberta_config() -> 'PretrainedConfig': try: + from transformers import PretrainedConfig + assert isinstance(pytest.tiny_deberta_config, PretrainedConfig) return copy.deepcopy(pytest.tiny_deberta_config) except AttributeError: pytest.skip('Composer installed without NLP support') -def configure_tiny_deberta_hf_model(use_logits=True): +def configure_tiny_deberta_hf_model(use_logits: bool = True) -> HuggingFaceModel: return HuggingFaceModel( - configure_tiny_deberta_model(), # type: ignore - configure_tiny_deberta_tokenizer(), # type: ignore + configure_tiny_deberta_model(), + configure_tiny_deberta_tokenizer(), use_logits, ) -def configure_tiny_gpt2_model(): +def configure_tiny_gpt2_model() -> 'PreTrainedModel': try: + from transformers import PreTrainedModel + assert isinstance(pytest.tiny_gpt2_model, PreTrainedModel) return copy.deepcopy(pytest.tiny_gpt2_model) except AttributeError: pytest.skip('Composer installed without NLP support') -def configure_tiny_gpt2_tokenizer(): +def configure_tiny_gpt2_tokenizer() -> Union['PreTrainedTokenizer', 'PreTrainedTokenizerFast']: try: + from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast + assert isinstance(pytest.tiny_gpt2_tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)) return copy.deepcopy(pytest.tiny_gpt2_tokenizer) except AttributeError: pytest.skip('Composer installed without NLP support') -def configure_tiny_gpt2_config(): +def configure_tiny_gpt2_config() -> 'PretrainedConfig': try: + from transformers import PretrainedConfig + assert isinstance(pytest.tiny_gpt2_config, PretrainedConfig) return copy.deepcopy(pytest.tiny_gpt2_config) except AttributeError: pytest.skip('Composer installed without NLP support') -def configure_tiny_gpt2_hf_model(use_logits=True): - return HuggingFaceModel(configure_tiny_gpt2_model(), configure_tiny_gpt2_tokenizer(), use_logits) # type: ignore +def configure_tiny_gpt2_hf_model(use_logits: bool = True) -> HuggingFaceModel: + return HuggingFaceModel(configure_tiny_gpt2_model(), configure_tiny_gpt2_tokenizer(), use_logits) -def configure_tiny_t5_model(): +def configure_tiny_t5_model() -> 'PreTrainedModel': try: + from transformers import PreTrainedModel + assert isinstance(pytest.tiny_t5_model, PreTrainedModel) return copy.deepcopy(pytest.tiny_t5_model) except AttributeError: pytest.skip('Composer installed without NLP support') -def configure_tiny_t5_tokenizer(): +def configure_tiny_t5_tokenizer() -> Union['PreTrainedTokenizer', 'PreTrainedTokenizerFast']: try: + from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast + assert isinstance(pytest.tiny_t5_tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)) return copy.deepcopy(pytest.tiny_t5_tokenizer) except AttributeError: pytest.skip('Composer installed without NLP support') -def configure_tiny_t5_config(): +def configure_tiny_t5_config() -> 'PretrainedConfig': try: + from transformers import PretrainedConfig + assert isinstance(pytest.tiny_t5_config, PretrainedConfig) return copy.deepcopy(pytest.tiny_t5_config) except AttributeError: pytest.skip('Composer installed without NLP support') -def configure_tiny_t5_hf_model(use_logits=True): - return HuggingFaceModel(configure_tiny_t5_model(), configure_tiny_t5_tokenizer(), use_logits) # type: ignore +def configure_tiny_t5_hf_model(use_logits: bool = True) -> HuggingFaceModel: + return HuggingFaceModel(configure_tiny_t5_model(), configure_tiny_t5_tokenizer(), use_logits) + + +def configure_tiny_mistral_model() -> 'PreTrainedModel': + try: + from transformers import PreTrainedModel + assert isinstance(pytest.tiny_mistral_model, PreTrainedModel) + return copy.deepcopy(pytest.tiny_mistral_model) + except AttributeError: + pytest.skip('Composer installed without NLP support') + + +def configure_tiny_mistral_tokenizer() -> Union['PreTrainedTokenizer', 'PreTrainedTokenizerFast']: + try: + from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast + assert isinstance(pytest.tiny_mistral_tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)) + return copy.deepcopy(pytest.tiny_mistral_tokenizer) + except AttributeError: + pytest.skip('Composer installed without NLP support') + + +def configure_tiny_mistral_config() -> 'PretrainedConfig': + try: + from transformers import PretrainedConfig + assert isinstance(pytest.tiny_mistral_config, PretrainedConfig) + return copy.deepcopy(pytest.tiny_mistral_config) + except AttributeError: + pytest.skip('Composer installed without NLP support') + + +def configure_tiny_mistral_hf_model(use_logits: bool = True) -> HuggingFaceModel: + return HuggingFaceModel(configure_tiny_mistral_model(), configure_tiny_mistral_tokenizer(), use_logits) diff --git a/tests/conftest.py b/tests/conftest.py index bcd063d9c74..bb923e88703 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -111,7 +111,9 @@ def pytest_configure(): if TRANSFORMERS_INSTALLED: from tests.fixtures.fixtures import (tiny_bert_config_helper, tiny_bert_model_helper, tiny_bert_tokenizer_helper, tiny_gpt2_config_helper, - tiny_gpt2_model_helper, tiny_gpt2_tokenizer_helper, tiny_opt_config_helper, + tiny_gpt2_model_helper, tiny_gpt2_tokenizer_helper, + tiny_mistral_config_helper, tiny_mistral_model_helper, + tiny_mistral_tokenizer_helper, tiny_opt_config_helper, tiny_opt_model_helper, tiny_opt_tokenizer_helper, tiny_t5_config_helper, tiny_t5_model_helper, tiny_t5_tokenizer_helper) pytest.tiny_bert_config = tiny_bert_config_helper() # type: ignore @@ -126,6 +128,9 @@ def pytest_configure(): pytest.tiny_t5_config = tiny_t5_config_helper() # type: ignore pytest.tiny_t5_model = tiny_t5_model_helper(pytest.tiny_t5_config) # type: ignore pytest.tiny_t5_tokenizer = tiny_t5_tokenizer_helper() # type: ignore + pytest.tiny_mistral_config = tiny_mistral_config_helper() # type: ignore + pytest.tiny_mistral_model = tiny_mistral_model_helper(pytest.tiny_mistral_config) # type: ignore + pytest.tiny_mistral_tokenizer = tiny_mistral_tokenizer_helper() # type: ignore def pytest_sessionfinish(session: pytest.Session, exitstatus: int): diff --git a/tests/fixtures/fixtures.py b/tests/fixtures/fixtures.py index cfd8674338f..17bc272b1ed 100644 --- a/tests/fixtures/fixtures.py +++ b/tests/fixtures/fixtures.py @@ -320,6 +320,47 @@ def _session_tiny_t5_model(_session_tiny_t5_config): # type: ignore return tiny_t5_model_helper(_session_tiny_t5_config) +def tiny_mistral_config_helper(): + transformers = pytest.importorskip('transformers') + + tiny_overrides = { + 'hidden_size': 128, + 'intermediate_size': 256, + 'num_attention_heads': 8, + 'num_hidden_layers': 2, + 'num_kv_heads': 4 + } + return transformers.AutoConfig.from_pretrained('mistralai/Mistral-7B-v0.1', **tiny_overrides) + + +@pytest.fixture(scope='session') +def _session_tiny_mistral_config(): # type: ignore + return tiny_mistral_config_helper() + + +def tiny_mistral_tokenizer_helper(): + transformers = pytest.importorskip('transformers') + + hf_tokenizer = transformers.AutoTokenizer.from_pretrained('mistralai/Mistral-7B-v0.1', model_max_length=512) + return hf_tokenizer + + +@pytest.fixture(scope='session') +def _session_tiny_mistral_tokenizer(): # type: ignore + return tiny_mistral_tokenizer_helper() + + +def tiny_mistral_model_helper(config): + transformers = pytest.importorskip('transformers') + + return transformers.AutoModelForCausalLM.from_config(config) + + +@pytest.fixture(scope='session') +def _session_tiny_t5_model(_session_tiny_t5_config): # type: ignore + return tiny_t5_model_helper(_session_tiny_t5_config) + + @pytest.fixture def tiny_bert_model(_session_tiny_bert_model): return copy.deepcopy(_session_tiny_bert_model) @@ -393,3 +434,18 @@ def tiny_t5_tokenizer(_session_tiny_t5_tokenizer): @pytest.fixture def tiny_t5_model(_session_tiny_t5_model): return copy.deepcopy(_session_tiny_t5_model) + + +@pytest.fixture +def tiny_mistral_config(_session_tiny_mistral_config): + return copy.deepcopy(_session_tiny_mistral_config) + + +@pytest.fixture +def tiny_mistral_tokenizer(_session_tiny_mistral_tokenizer): + return copy.deepcopy(_session_tiny_mistral_tokenizer) + + +@pytest.fixture +def tiny_mistral_model(_session_tiny_mistral_model): + return copy.deepcopy(_session_tiny_mistral_model) diff --git a/tests/models/test_hf_model.py b/tests/models/test_hf_model.py index 4092eb2c195..aeb42cd513d 100644 --- a/tests/models/test_hf_model.py +++ b/tests/models/test_hf_model.py @@ -6,7 +6,7 @@ import tempfile from contextlib import nullcontext from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional from unittest.mock import patch from urllib.parse import urlparse @@ -26,9 +26,48 @@ from tests.common.datasets import RandomTextClassificationDataset, RandomTextLMDataset, RandomTextRegressionDataset from tests.common.markers import device, world_size from tests.common.models import (configure_tiny_bert_model, configure_tiny_bert_tokenizer, configure_tiny_gpt2_model, - configure_tiny_gpt2_tokenizer, configure_tiny_t5_model, configure_tiny_t5_tokenizer) + configure_tiny_gpt2_tokenizer, configure_tiny_mistral_model, + configure_tiny_mistral_tokenizer, configure_tiny_t5_model, configure_tiny_t5_tokenizer) from tests.loggers.test_remote_uploader_downloader import DummyObjectStore +if TYPE_CHECKING: + from peft import PeftConfig + + +def _gpt2_peft_config(): + pytest.importorskip('peft') + from peft import get_peft_config + + peft_config = get_peft_config({ + 'peft_type': 'LORA', + 'task_type': 'CAUSAL_LM', + 'target_modules': ['c_attn'], + 'fan_in_fan_out': True, + }) + return peft_config + + +@pytest.fixture +def gpt2_peft_config(): + return _gpt2_peft_config() + + +def _mistral_peft_config(): + pytest.importorskip('peft') + from peft import get_peft_config + + peft_config = get_peft_config({ + 'peft_type': 'LORA', + 'task_type': 'CAUSAL_LM', + 'target_modules': ['up_proj'], + }) + return peft_config + + +@pytest.fixture +def mistral_peft_config(): + return _mistral_peft_config() + def test_hf_tokenizer_save(tmp_path: Path, tiny_bert_model, tiny_bert_tokenizer): transformers = pytest.importorskip('transformers') @@ -433,14 +472,23 @@ def get_lm_trainer(hf_model, device_train_microbatch_size: Optional[int] = None, batch_size: int = 4, sequence_length: int = 4, - size: int = 4): + size: int = 4, + peft_config: Optional['PeftConfig'] = None, + should_save_peft_only: bool = False): transformers = pytest.importorskip('transformers') metrics: List[Metric] = [LanguageCrossEntropy(ignore_index=-100)] if not is_conditional_generation: metrics.append(MaskedAccuracy(ignore_index=-100)) - model = HuggingFaceModel(hf_model, tokenizer=hf_tokenizer, metrics=metrics, use_logits=True) + model = HuggingFaceModel( + hf_model, + tokenizer=hf_tokenizer, + metrics=metrics, + use_logits=True, + peft_config=peft_config, + should_save_peft_only=should_save_peft_only, + ) vocab_size = hf_model.config.vocab_size sequence_length = 4 @@ -477,8 +525,13 @@ def get_lm_trainer(hf_model, collate_fn=collator, sampler=dist.get_sampler(train_dataset)) + from composer.optim import DecoupledAdamW + + optimizer = DecoupledAdamW(model.parameters(), lr=1e-3) + in_memory_logger = InMemoryLogger() trainer = Trainer(model=model, + optimizers=optimizer, train_dataloader=train_dataloader, eval_dataloader=eval_dataloader, max_duration='1ep', @@ -1150,3 +1203,223 @@ def test_eval_forward_generate(device, world_size, hf_model, hf_tokenizer, use_f assert len(generation1) == len(generation2) == 2 assert all(isinstance(decoded_generation, str) for decoded_generation in generation1) assert all(isinstance(decoded_generation, str) for decoded_generation in generation2) + + +@pytest.mark.parametrize('peft_type', ['LORA', 'loRa']) +@pytest.mark.parametrize('task_type', ['CAUSAL_LM', 'causal_lm']) +def test_peft_init(peft_type: str, task_type: str, tiny_gpt2_model, gpt2_peft_config): + pytest.importorskip('peft') + from peft import PeftModelForCausalLM + + peft_config = copy.deepcopy(gpt2_peft_config) + peft_config.peft_type = peft_type + peft_config.task_type = task_type + + original_model = copy.deepcopy(tiny_gpt2_model) + hf_model = HuggingFaceModel(tiny_gpt2_model, peft_config=peft_config) + assert isinstance(hf_model.model, PeftModelForCausalLM) + assert hf_model.model.peft_config['default'].peft_type == 'LORA' + assert hf_model.model.peft_config['default'].task_type == 'CAUSAL_LM' + assert hf_model.model.config == original_model.config + + +def test_peft_init_errors(tiny_gpt2_model, gpt2_peft_config): + pytest.importorskip('peft') + peft_config = copy.deepcopy(gpt2_peft_config) + peft_config.peft_type = 'NOT_LORA' + + with pytest.raises(ValueError): + _ = HuggingFaceModel(tiny_gpt2_model, peft_config=peft_config) + + +def test_peft_init_not_installed(tiny_gpt2_model, gpt2_peft_config): + pytest.importorskip('peft') + + with patch('composer.models.huggingface.peft_installed', False): + with pytest.raises(ImportError): + from composer.models import HuggingFaceModel + _ = HuggingFaceModel(tiny_gpt2_model, peft_config=gpt2_peft_config) + + +@pytest.mark.parametrize('should_save_peft_only', [True, False]) +def test_peft_trains_and_loads(tiny_gpt2_model, tiny_gpt2_tokenizer, gpt2_peft_config, tmp_path, should_save_peft_only): + pytest.importorskip('peft') + + trainer = get_lm_trainer( + tiny_gpt2_model, + tiny_gpt2_tokenizer, + str(tmp_path), + peft_config=gpt2_peft_config, + device_train_microbatch_size=1, + mlm=False, + should_save_peft_only=should_save_peft_only, + ) + trainer.fit() + + load_trainer = get_lm_trainer( + tiny_gpt2_model, + tiny_gpt2_tokenizer, + str(tmp_path), + peft_config=gpt2_peft_config, + device_train_microbatch_size=1, + mlm=False, + load_path=str(tmp_path / 'hf-checkpoint.pt'), + should_save_peft_only=should_save_peft_only, + ) + + for p1, p2 in zip(trainer.state.model.parameters(), load_trainer.state.model.parameters()): + torch.testing.assert_close(p1, p2) + + +@pytest.mark.parametrize('model,tokenizer,peft_config', [ + (configure_tiny_gpt2_model, configure_tiny_gpt2_tokenizer, _gpt2_peft_config()), + (configure_tiny_mistral_model, configure_tiny_mistral_tokenizer, _mistral_peft_config()), +]) +def test_peft_generate(model, tokenizer, peft_config): + pytest.importorskip('peft') + + model = model() + tokenizer = tokenizer() + + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + hf_model = HuggingFaceModel(model, tokenizer=tokenizer, peft_config=peft_config) + + input_dict = tokenizer(['hello', 'goodbyes'], return_tensors='pt', padding=True) + hf_model.generate(**input_dict, max_new_tokens=5, pad_token_id=tokenizer.pad_token_id) + + +def test_peft_metadata(tiny_gpt2_model, tiny_gpt2_tokenizer, gpt2_peft_config): + pytest.importorskip('peft') + + from peft import get_peft_config + + hf_model = HuggingFaceModel(tiny_gpt2_model, tokenizer=tiny_gpt2_tokenizer, peft_config=gpt2_peft_config) + metadata = hf_model.get_metadata() + loaded_peft_config = get_peft_config(metadata['model']['peft_config']['content']) + + assert loaded_peft_config == gpt2_peft_config + + +@pytest.mark.parametrize('should_save_peft_only', [True, False]) +def test_peft_write_hf_from_composer(tiny_gpt2_model, tiny_gpt2_tokenizer, gpt2_peft_config, tmp_path, + should_save_peft_only): + peft = pytest.importorskip('peft') + transformers = pytest.importorskip('transformers') + + # Simulate a local model instead of a hub model + tiny_gpt2_model.save_pretrained(tmp_path / 'hf-save-to-load') + tiny_gpt2_model = transformers.AutoModelForCausalLM.from_pretrained(tmp_path / 'hf-save-to-load') + + trainer = get_lm_trainer( + tiny_gpt2_model, + tiny_gpt2_tokenizer, + str(tmp_path), + peft_config=gpt2_peft_config, + device_train_microbatch_size=1, + mlm=False, + should_save_peft_only=should_save_peft_only, + ) + trainer.fit() + + from composer.models.huggingface import write_huggingface_pretrained_from_composer_checkpoint + write_huggingface_pretrained_from_composer_checkpoint(str(tmp_path / 'hf-checkpoint.pt'), + tmp_path / 'hf-save-pretrained') + + # Test we can load back in using transformers interface + loaded_hf_model = transformers.AutoModelForCausalLM.from_pretrained(str(tmp_path / 'hf-save-pretrained')) + for p1, p2 in zip(trainer.state.model.model.parameters(), loaded_hf_model.parameters()): + torch.testing.assert_close(p1, p2) + + # Test we can load back in using peft interface + loaded_peft_model = peft.PeftModelForCausalLM.from_pretrained(tiny_gpt2_model, str(tmp_path / 'hf-save-pretrained')) + for p1, p2 in zip(trainer.state.model.model.parameters(), loaded_peft_model.parameters()): + torch.testing.assert_close(p1, p2) + + +@pytest.mark.gpu +@world_size(2) +@pytest.mark.parametrize('should_save_peft_only', [True, False]) +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.13.0'), + reason='requires PyTorch 1.13 or higher') +def test_peft_fsdp_trains(tiny_gpt2_model, tiny_gpt2_tokenizer, gpt2_peft_config, tmp_path, world_size, + should_save_peft_only): + pytest.importorskip('peft') + + fsdp_config = { + 'sharding_strategy': 'FULL_SHARD', + 'cpu_offload': False, + 'mixed_precision': 'PURE', + 'backward_prefetch': 'BACKWARD_PRE', + 'activation_checkpointing': False, + 'activation_cpu_offload': False, + 'verbose': False + } + + stashed_model = copy.deepcopy(tiny_gpt2_model) + + trainer = get_lm_trainer( + tiny_gpt2_model, + tiny_gpt2_tokenizer, + str(tmp_path / 'trainer1'), + peft_config=gpt2_peft_config, + device_train_microbatch_size=1, + mlm=False, + fsdp_config=fsdp_config, + should_save_peft_only=should_save_peft_only, + ) + + for n, p in trainer.state.model.model.named_parameters(): + if 'lora' in n: + assert p.requires_grad + else: + assert not p.requires_grad + + trainer.fit() + trainer.close() + + load_trainer = get_lm_trainer( + stashed_model, + tiny_gpt2_tokenizer, + str(tmp_path / 'trainer2'), + peft_config=gpt2_peft_config, + device_train_microbatch_size=1, + mlm=False, + load_path=str(tmp_path / 'trainer1' / 'hf-checkpoint.pt'), + fsdp_config=fsdp_config, + should_save_peft_only=should_save_peft_only, + ) + + for n, p in load_trainer.state.model.model.named_parameters(): + if 'lora' in n: + assert p.requires_grad + else: + assert not p.requires_grad + + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + + with FSDP.summon_full_params(trainer.state.model), FSDP.summon_full_params(load_trainer.state.model): + for p1, p2 in zip(trainer.state.model.parameters(), load_trainer.state.model.parameters()): + torch.testing.assert_close(p1, p2) + + if dist.get_global_rank() == 0: + loaded_ckpt_1 = torch.load(str(tmp_path / 'trainer1' / 'hf-checkpoint.pt')) + + # Check that only the LoRA parameters were saved + if should_save_peft_only: + assert all('lora' in k for k in loaded_ckpt_1['state']['model'].keys()) + else: + assert not all('lora' in k for k in loaded_ckpt_1['state']['model'].keys()) + + +def test_filtered_state_dict(tiny_gpt2_model, tiny_gpt2_tokenizer, gpt2_peft_config, tmp_path): + pytest.importorskip('peft') + + hf_model = HuggingFaceModel(tiny_gpt2_model, + tokenizer=tiny_gpt2_tokenizer, + peft_config=gpt2_peft_config, + should_save_peft_only=True) + state_dict = hf_model.state_dict() + + assert len(state_dict.keys()) == 4