From e9d4c4c9ba6b02484a5055139a1f631d4eaf80ce Mon Sep 17 00:00:00 2001 From: Daniel King Date: Thu, 18 Jan 2024 01:05:50 -0800 Subject: [PATCH] precommit --- composer/datasets/utils.py | 6 ++-- composer/models/huggingface.py | 34 ++++++++++++++----- .../test_in_context_learning_datasets.py | 4 +-- 3 files changed, 31 insertions(+), 13 deletions(-) diff --git a/composer/datasets/utils.py b/composer/datasets/utils.py index 431a860900..6bb376ff30 100644 --- a/composer/datasets/utils.py +++ b/composer/datasets/utils.py @@ -179,7 +179,7 @@ class MultiTokenEOSCriteria(transformers.StoppingCriteria): def __init__( self, stop_sequence: str, - tokenizer: transformers.PreTrainedTokenizer, + tokenizer: transformers.PreTrainedTokenizerBase, batch_size: int, ) -> None: self.done_tracker = [False] * batch_size @@ -196,7 +196,7 @@ def __init__( self.stop_sequence_id_len = len(self.stop_sequence_ids) + 2 self.tokenizer = tokenizer - def __call__(self, input_ids, scores: Optional[torch.FloatTensor] = None, **kwargs) -> bool: + def __call__(self, input_ids: torch.LongTensor, scores: Optional[torch.FloatTensor] = None, **kwargs) -> bool: # For efficiency, we compare the last n tokens where n is the number of tokens in the stop_sequence lookback_ids_batch = input_ids[:, :][:, -self.stop_sequence_id_len:] @@ -213,7 +213,7 @@ def __call__(self, input_ids, scores: Optional[torch.FloatTensor] = None, **kwar return False not in self.done_tracker def stop_sequences_criteria( - tokenizer: transformers.PreTrainedTokenizer, + tokenizer: transformers.PreTrainedTokenizerBase, stop_sequences: List[str], batch_size: int, ) -> transformers.StoppingCriteriaList: diff --git a/composer/models/huggingface.py b/composer/models/huggingface.py index dbeda8a6a6..2dbeabbe34 100644 --- a/composer/models/huggingface.py +++ b/composer/models/huggingface.py @@ -13,6 +13,7 @@ import string import tempfile import textwrap +import warnings from pathlib import Path from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Type, Union @@ -159,7 +160,8 @@ def __init__(self, self.labels: Optional[torch.Tensor] = None # set in eval_forward() if exists - is_causal_lm = _is_registered_causal_lm(model) + is_causal_lm = _is_registered_causal_lm(self.model) + self.shift_labels = is_causal_lm if shift_labels is None else shift_labels if is_causal_lm and not self.shift_labels: log.warning('The shift_labels argument was set to False but the model is an instance of a' @@ -169,16 +171,21 @@ def __init__(self, self.dummy_forward_called = False if peft_config is not None: - self.model = get_peft_model(self.model, peft_config) - log.info(f'PEFT model created. {self.model}') + from peft import PeftModel + if isinstance(self.model, PeftModel): + warnings.warn('PEFT model was passed in directly. Ignoring the provided PEFT config.') + else: + self.model = get_peft_model(self.model, peft_config) + log.info(f'PEFT model created. {self.model}') def state_dict(self, *args, **kwargs) -> Dict[str, Any]: """Returns the state dict of the model.""" full_state_dict = super().state_dict(*args, **kwargs) if self.peft_filter_state_dict_trainable: - full_state_dict = filter_state_dict_peft(full_state_dict, self.model.peft_config[self.model.active_adapter], - False) + active_adapter = self.model.active_adapter + assert isinstance(active_adapter, str) + full_state_dict = filter_state_dict_peft(full_state_dict, self.model.peft_config[active_adapter], False) return full_state_dict @@ -474,6 +481,7 @@ def eval_forward(self, batch, outputs: Optional[Any] = None): # HF encoder decoder models like T5 expect either decoder_input_ids or labels, # so we add decoder_input_ids to the batch if it is missing + assert isinstance(self.model.config, PretrainedConfig) model_config: PretrainedConfig = self.model.config if model_config.is_encoder_decoder and 'decoder_input_ids' not in batch: if hasattr(self.model, 'prepare_decoder_input_ids_from_labels'): @@ -532,6 +540,7 @@ def get_metadata(self): model_dir = tmp_dir / 'model' tokenizer_dir = tmp_dir / 'tokenizer' + assert isinstance(self.model.config, PretrainedConfig) original_model_config: PretrainedConfig = self.model.config original_model_config.save_pretrained(model_dir) if self.tokenizer is not None: @@ -615,6 +624,7 @@ def generate(self, input_ids: torch.Tensor, **kwargs): if not using_torch_2() and not self.dummy_forward_called and is_model_fsdp(self.model): with torch.no_grad(): maybe_decoder_input_ids = {} + assert isinstance(self.model.config, PretrainedConfig) model_config: PretrainedConfig = self.model.config if model_config.is_encoder_decoder: maybe_decoder_input_ids['decoder_input_ids'] = torch.tensor([[0]], @@ -638,7 +648,7 @@ def generate(self, input_ids: torch.Tensor, **kwargs): return self.model.generate(input_ids=input_ids, pad_token_id=pad_token_id, **kwargs) -def _is_registered_causal_lm(model: transformers.PreTrainedModel) -> bool: +def _is_registered_causal_lm(model: Union[transformers.PreTrainedModel, 'PeftModel']) -> bool: """Return True if model class is either a registered 🤗 Causal LM or a subclass of one""" try: from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING @@ -647,6 +657,11 @@ def _is_registered_causal_lm(model: transformers.PreTrainedModel) -> bool: conda_package='transformers', conda_channel='conda-forge') from e + if _peft_installed and isinstance(model, PeftModel): + model_to_check = model.base_model.model + else: + model_to_check = model + # This try/except is needed until https://github.com/huggingface/transformers/issues/26778 # is resolved in a release. This means that this attempt to automatically detect causal LMs # does not currently work in an environment with flash attention <2 installed. @@ -658,7 +673,7 @@ def _is_registered_causal_lm(model: transformers.PreTrainedModel) -> bool: return False else: raise e - return any(isinstance(model, causal_lm_class) for causal_lm_class in causal_lm_classes) # type: ignore + return any(isinstance(model_to_check, causal_lm_class) for causal_lm_class in causal_lm_classes) # type: ignore def get_hf_config_from_composer_state_dict(state_dict: Dict[str, Any], @@ -819,8 +834,11 @@ def write_huggingface_pretrained_from_composer_checkpoint( def filter_state_dict_peft(state_dict: Dict[str, Any], peft_config: 'PeftConfig', remove_adapter_names: bool = True) -> Dict[str, Any]: + if peft_config.peft_type != 'LORA': + raise NotImplementedError(f'Only LoRA PEFT is supported. Got {peft_config.peft_type}') + # Filtering copied from https://github.com/huggingface/peft/blob/4186c9b104644fd247a4cc0dc2dfc1ede4665204/src/peft/utils/save_and_load.py#L68C1-L86C116 - bias = peft_config.bias + bias = peft_config.bias # type: ignore if bias == 'none': to_return = {k: state_dict[k] for k in state_dict if 'lora_' in k} elif bias == 'all': diff --git a/tests/datasets/test_in_context_learning_datasets.py b/tests/datasets/test_in_context_learning_datasets.py index ec7df306d6..2a3ff87884 100644 --- a/tests/datasets/test_in_context_learning_datasets.py +++ b/tests/datasets/test_in_context_learning_datasets.py @@ -73,13 +73,13 @@ def test_stop_sequences_criteria(tiny_gpt2_tokenizer): seq1 = tiny_gpt2_tokenizer('Dogs are furry')['input_ids'] seq2 = tiny_gpt2_tokenizer('Dogs are furry\n\n')['input_ids'] seq1 = [50257] * (len(seq2) - len(seq1)) + seq1 - input_ids = torch.tensor([seq1, seq2]) + input_ids = torch.LongTensor([seq1, seq2]) assert not eos_criteria(input_ids, None) eos_criteria = MultiTokenEOSCriteria('\n\n', tiny_gpt2_tokenizer, 2) seq1 = tiny_gpt2_tokenizer('Dogs are furry\n\n')['input_ids'] seq2 = tiny_gpt2_tokenizer('Dogs are furry\n\n')['input_ids'] - input_ids = torch.tensor([seq1, seq2]) + input_ids = torch.LongTensor([seq1, seq2]) assert eos_criteria(input_ids, None)