Skip to content

Commit

Permalink
precommit
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed Jan 18, 2024
1 parent 4da1d94 commit e9d4c4c
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 13 deletions.
6 changes: 3 additions & 3 deletions composer/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ class MultiTokenEOSCriteria(transformers.StoppingCriteria):
def __init__(
self,
stop_sequence: str,
tokenizer: transformers.PreTrainedTokenizer,
tokenizer: transformers.PreTrainedTokenizerBase,
batch_size: int,
) -> None:
self.done_tracker = [False] * batch_size
Expand All @@ -196,7 +196,7 @@ def __init__(
self.stop_sequence_id_len = len(self.stop_sequence_ids) + 2
self.tokenizer = tokenizer

def __call__(self, input_ids, scores: Optional[torch.FloatTensor] = None, **kwargs) -> bool:
def __call__(self, input_ids: torch.LongTensor, scores: Optional[torch.FloatTensor] = None, **kwargs) -> bool:
# For efficiency, we compare the last n tokens where n is the number of tokens in the stop_sequence
lookback_ids_batch = input_ids[:, :][:, -self.stop_sequence_id_len:]

Expand All @@ -213,7 +213,7 @@ def __call__(self, input_ids, scores: Optional[torch.FloatTensor] = None, **kwar
return False not in self.done_tracker

def stop_sequences_criteria(
tokenizer: transformers.PreTrainedTokenizer,
tokenizer: transformers.PreTrainedTokenizerBase,
stop_sequences: List[str],
batch_size: int,
) -> transformers.StoppingCriteriaList:
Expand Down
34 changes: 26 additions & 8 deletions composer/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import string
import tempfile
import textwrap
import warnings
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Type, Union

Expand Down Expand Up @@ -159,7 +160,8 @@ def __init__(self,

self.labels: Optional[torch.Tensor] = None # set in eval_forward() if exists

is_causal_lm = _is_registered_causal_lm(model)
is_causal_lm = _is_registered_causal_lm(self.model)

self.shift_labels = is_causal_lm if shift_labels is None else shift_labels
if is_causal_lm and not self.shift_labels:
log.warning('The shift_labels argument was set to False but the model is an instance of a'
Expand All @@ -169,16 +171,21 @@ def __init__(self,
self.dummy_forward_called = False

if peft_config is not None:
self.model = get_peft_model(self.model, peft_config)
log.info(f'PEFT model created. {self.model}')
from peft import PeftModel
if isinstance(self.model, PeftModel):
warnings.warn('PEFT model was passed in directly. Ignoring the provided PEFT config.')
else:
self.model = get_peft_model(self.model, peft_config)
log.info(f'PEFT model created. {self.model}')

def state_dict(self, *args, **kwargs) -> Dict[str, Any]:
"""Returns the state dict of the model."""
full_state_dict = super().state_dict(*args, **kwargs)

if self.peft_filter_state_dict_trainable:
full_state_dict = filter_state_dict_peft(full_state_dict, self.model.peft_config[self.model.active_adapter],
False)
active_adapter = self.model.active_adapter
assert isinstance(active_adapter, str)
full_state_dict = filter_state_dict_peft(full_state_dict, self.model.peft_config[active_adapter], False)

return full_state_dict

Expand Down Expand Up @@ -474,6 +481,7 @@ def eval_forward(self, batch, outputs: Optional[Any] = None):

# HF encoder decoder models like T5 expect either decoder_input_ids or labels,
# so we add decoder_input_ids to the batch if it is missing
assert isinstance(self.model.config, PretrainedConfig)
model_config: PretrainedConfig = self.model.config
if model_config.is_encoder_decoder and 'decoder_input_ids' not in batch:
if hasattr(self.model, 'prepare_decoder_input_ids_from_labels'):
Expand Down Expand Up @@ -532,6 +540,7 @@ def get_metadata(self):
model_dir = tmp_dir / 'model'
tokenizer_dir = tmp_dir / 'tokenizer'

assert isinstance(self.model.config, PretrainedConfig)
original_model_config: PretrainedConfig = self.model.config
original_model_config.save_pretrained(model_dir)
if self.tokenizer is not None:
Expand Down Expand Up @@ -615,6 +624,7 @@ def generate(self, input_ids: torch.Tensor, **kwargs):
if not using_torch_2() and not self.dummy_forward_called and is_model_fsdp(self.model):
with torch.no_grad():
maybe_decoder_input_ids = {}
assert isinstance(self.model.config, PretrainedConfig)
model_config: PretrainedConfig = self.model.config
if model_config.is_encoder_decoder:
maybe_decoder_input_ids['decoder_input_ids'] = torch.tensor([[0]],
Expand All @@ -638,7 +648,7 @@ def generate(self, input_ids: torch.Tensor, **kwargs):
return self.model.generate(input_ids=input_ids, pad_token_id=pad_token_id, **kwargs)


def _is_registered_causal_lm(model: transformers.PreTrainedModel) -> bool:
def _is_registered_causal_lm(model: Union[transformers.PreTrainedModel, 'PeftModel']) -> bool:
"""Return True if model class is either a registered 🤗 Causal LM or a subclass of one"""
try:
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING
Expand All @@ -647,6 +657,11 @@ def _is_registered_causal_lm(model: transformers.PreTrainedModel) -> bool:
conda_package='transformers',
conda_channel='conda-forge') from e

if _peft_installed and isinstance(model, PeftModel):
model_to_check = model.base_model.model
else:
model_to_check = model

# This try/except is needed until https://github.com/huggingface/transformers/issues/26778
# is resolved in a release. This means that this attempt to automatically detect causal LMs
# does not currently work in an environment with flash attention <2 installed.
Expand All @@ -658,7 +673,7 @@ def _is_registered_causal_lm(model: transformers.PreTrainedModel) -> bool:
return False
else:
raise e
return any(isinstance(model, causal_lm_class) for causal_lm_class in causal_lm_classes) # type: ignore
return any(isinstance(model_to_check, causal_lm_class) for causal_lm_class in causal_lm_classes) # type: ignore


def get_hf_config_from_composer_state_dict(state_dict: Dict[str, Any],
Expand Down Expand Up @@ -819,8 +834,11 @@ def write_huggingface_pretrained_from_composer_checkpoint(
def filter_state_dict_peft(state_dict: Dict[str, Any],
peft_config: 'PeftConfig',
remove_adapter_names: bool = True) -> Dict[str, Any]:
if peft_config.peft_type != 'LORA':
raise NotImplementedError(f'Only LoRA PEFT is supported. Got {peft_config.peft_type}')

# Filtering copied from https://github.com/huggingface/peft/blob/4186c9b104644fd247a4cc0dc2dfc1ede4665204/src/peft/utils/save_and_load.py#L68C1-L86C116
bias = peft_config.bias
bias = peft_config.bias # type: ignore
if bias == 'none':
to_return = {k: state_dict[k] for k in state_dict if 'lora_' in k}
elif bias == 'all':
Expand Down
4 changes: 2 additions & 2 deletions tests/datasets/test_in_context_learning_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,13 @@ def test_stop_sequences_criteria(tiny_gpt2_tokenizer):
seq1 = tiny_gpt2_tokenizer('Dogs are furry')['input_ids']
seq2 = tiny_gpt2_tokenizer('Dogs are furry\n\n')['input_ids']
seq1 = [50257] * (len(seq2) - len(seq1)) + seq1
input_ids = torch.tensor([seq1, seq2])
input_ids = torch.LongTensor([seq1, seq2])
assert not eos_criteria(input_ids, None)

eos_criteria = MultiTokenEOSCriteria('\n\n', tiny_gpt2_tokenizer, 2)
seq1 = tiny_gpt2_tokenizer('Dogs are furry\n\n')['input_ids']
seq2 = tiny_gpt2_tokenizer('Dogs are furry\n\n')['input_ids']
input_ids = torch.tensor([seq1, seq2])
input_ids = torch.LongTensor([seq1, seq2])
assert eos_criteria(input_ids, None)


Expand Down

0 comments on commit e9d4c4c

Please sign in to comment.