From 544ae05bfc918f8b95a8864446d22882d6df1737 Mon Sep 17 00:00:00 2001 From: Henning Buhl Date: Sat, 22 Jul 2023 20:34:11 +0200 Subject: [PATCH] Added adapters to whisper --- adapter_docs/classes/models/whisper.rst | 17 ++ adapter_docs/index.rst | 1 + adapter_docs/model_overview.md | 1 + src/transformers/__init__.py | 4 + src/transformers/adapters/__init__.py | 5 + src/transformers/adapters/composition.py | 1 + src/transformers/adapters/head_utils.py | 7 + src/transformers/adapters/mixins/whisper.py | 51 ++++ .../adapters/models/auto/adapter_model.py | 2 + .../adapters/models/whisper/__init__.py | 24 ++ .../adapters/models/whisper/adapter_model.py | 227 ++++++++++++++++++ src/transformers/adapters/prefix_tuning.py | 2 +- .../adapters/wrappers/configuration.py | 4 + .../models/whisper/modeling_whisper.py | 107 +++++++-- tests_adapters/models/test_whisper.py | 11 + tests_adapters/test_whisper.py | 60 +++++ utils/check_adapters.py | 3 + 17 files changed, 501 insertions(+), 26 deletions(-) create mode 100644 adapter_docs/classes/models/whisper.rst create mode 100644 src/transformers/adapters/mixins/whisper.py create mode 100644 src/transformers/adapters/models/whisper/__init__.py create mode 100644 src/transformers/adapters/models/whisper/adapter_model.py create mode 100644 tests_adapters/models/test_whisper.py create mode 100644 tests_adapters/test_whisper.py diff --git a/adapter_docs/classes/models/whisper.rst b/adapter_docs/classes/models/whisper.rst new file mode 100644 index 0000000000..3d78b15184 --- /dev/null +++ b/adapter_docs/classes/models/whisper.rst @@ -0,0 +1,17 @@ +Whisper +----------------------------------------------------------------------------------------------------------------------- + +The Whisper model was presented in `Robust Speech Recognition via Large-Scale Weak Supervision +`_ by Alec Radford, Jong Wook Kim, Tao Xu, Greg Brockman, Christine +McLeavey, Ilya Sutskever. + +According to the abstract, Whisper is trained on 680,000 hours of multilingual and multitask data. This +scale was previously unseen. Whisper is able to approach the accuracy and robustness of humans. + + +WhisperAdapterModel +~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.adapters.WhisperAdapterModel + :members: + :inherited-members: WhisperPreTrainedModel diff --git a/adapter_docs/index.rst b/adapter_docs/index.rst index bbf0de28c3..b06a23e9d4 100644 --- a/adapter_docs/index.rst +++ b/adapter_docs/index.rst @@ -70,6 +70,7 @@ Currently, we support the PyTorch versions of all models as listed on the `Model classes/models/gpt2 classes/models/gptj classes/models/mbart + classes/models/whisper classes/models/roberta classes/models/t5 classes/models/vit diff --git a/adapter_docs/model_overview.md b/adapter_docs/model_overview.md index bb54cca7e0..fdab565866 100644 --- a/adapter_docs/model_overview.md +++ b/adapter_docs/model_overview.md @@ -25,6 +25,7 @@ The table below further shows which model architectures support which adaptation | [GPT-2](classes/models/gpt2.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | [GPT-J](classes/models/gptj.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | [MBart](classes/models/mbart.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| [Whisper](classes/models/whisper.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | [RoBERTa](classes/models/roberta.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | [T5](classes/models/t5.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | [ViT](classes/models/vit.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index a507a51e20..a6d5ac2c60 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -2604,6 +2604,8 @@ "MAMConfig", "MBartAdapterModel", "MBartModelWithHeads", + "WhisperAdapterModel", + "WhisperModelWithHeads", "ModelAdaptersConfig", "ModelAdaptersMixin", "ModelWithFlexibleHeadsAdaptersMixin", @@ -5708,6 +5710,8 @@ MAMConfig, MBartAdapterModel, MBartModelWithHeads, + WhisperAdapterModel, + WhisperModelWithHeads, ModelAdaptersConfig, ModelAdaptersMixin, ModelWithFlexibleHeadsAdaptersMixin, diff --git a/src/transformers/adapters/__init__.py b/src/transformers/adapters/__init__.py index fea1f63968..6417361046 100644 --- a/src/transformers/adapters/__init__.py +++ b/src/transformers/adapters/__init__.py @@ -118,6 +118,10 @@ "MBartAdapterModel", "MBartModelWithHeads", ], + "models.whisper": [ + "WhisperAdapterModel", + "WhisperModelWithHeads", + ], "models.roberta": [ "RobertaAdapterModel", "RobertaModelWithHeads", @@ -219,6 +223,7 @@ from .models.gpt2 import GPT2AdapterModel, GPT2ModelWithHeads from .models.gptj import GPTJAdapterModel from .models.mbart import MBartAdapterModel, MBartModelWithHeads + from .models.whisper import WhisperAdapterModel, WhisperModelWithHeads from .models.roberta import RobertaAdapterModel, RobertaModelWithHeads from .models.t5 import T5AdapterModel, T5ModelWithHeads from .models.vit import ViTAdapterModel diff --git a/src/transformers/adapters/composition.py b/src/transformers/adapters/composition.py index 32ab776dc0..5cd9972501 100644 --- a/src/transformers/adapters/composition.py +++ b/src/transformers/adapters/composition.py @@ -108,6 +108,7 @@ def __init__(self, *split_adapters: List[Union[AdapterCompositionBlock, str]], b "deberta", "bart", "mbart", + "whisper", "gpt2", "gptj", "t5", diff --git a/src/transformers/adapters/head_utils.py b/src/transformers/adapters/head_utils.py index 05f3a54566..6e1b77e582 100644 --- a/src/transformers/adapters/head_utils.py +++ b/src/transformers/adapters/head_utils.py @@ -314,6 +314,13 @@ }, "layers": ["lm_head"], }, + # Whisper + "WhisperForConditionalGeneration": { + "config": { + "head_type": "seq2seq_lm", + }, + "layers": ["proj_out"], + }, # DistilBERT "DistilBertForSequenceClassification": { "config": { diff --git a/src/transformers/adapters/mixins/whisper.py b/src/transformers/adapters/mixins/whisper.py new file mode 100644 index 0000000000..3bdf23208e --- /dev/null +++ b/src/transformers/adapters/mixins/whisper.py @@ -0,0 +1,51 @@ +from typing import Iterable, Tuple + +import torch.nn as nn + +from ..layer import AdapterLayer +from ..model_mixin import ( + EmbeddingAdaptersMixin, + EmbeddingAdaptersWrapperMixin, + InvertibleAdaptersWrapperMixin, + ModelAdaptersMixin, + ModelWithHeadsAdaptersMixin, +) + + +class WhisperEncoderLayerAdaptersMixin: + """Adds adapters to the WhisperEncoderLayer module of WHISPER.""" + + def _init_adapter_modules(self): + self.attention_adapters = AdapterLayer("mh_adapter", self.config) + self.output_adapters = AdapterLayer("output_adapter", self.config) + self.attention_adapters._init_adapter_modules() + self.output_adapters._init_adapter_modules() + + +class WhisperDecoderLayerAdaptersMixin(WhisperEncoderLayerAdaptersMixin): + """Adds adapters to the WhisperDecoderLayer module of WHISPER.""" + + def _init_adapter_modules(self): + super()._init_adapter_modules() + self.cross_attention_adapters = AdapterLayer("cross_adapter", self.config) + self.cross_attention_adapters._init_adapter_modules() + + +class WhisperModelAdaptersMixin(EmbeddingAdaptersMixin, InvertibleAdaptersWrapperMixin, ModelAdaptersMixin): + """Adds adapters to the WhisperModel class.""" + + invertible_adapters_base_name = "encoder" + + def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: + if hasattr(self, "encoder"): + for i, layer in enumerate(self.encoder.layers): + yield i, layer + for i, layer in enumerate(self.decoder.layers, start=len(self.encoder.layers)): + yield i, layer + else: + for i, layer in enumerate(self.decoder.layers): + yield i, layer + + +class WhisperModelWithHeadsAdaptersMixin(EmbeddingAdaptersWrapperMixin, ModelWithHeadsAdaptersMixin): + pass diff --git a/src/transformers/adapters/models/auto/adapter_model.py b/src/transformers/adapters/models/auto/adapter_model.py index cfd159bad6..5fcf2755c6 100644 --- a/src/transformers/adapters/models/auto/adapter_model.py +++ b/src/transformers/adapters/models/auto/adapter_model.py @@ -19,6 +19,7 @@ ("deberta", "DebertaAdapterModel"), ("bart", "BartAdapterModel"), ("mbart", "MBartAdapterModel"), + ("whisper", "WhisperAdapterModel"), ("gpt2", "GPT2AdapterModel"), ("gptj", "GPTJAdapterModel"), ("t5", "T5AdapterModel"), @@ -33,6 +34,7 @@ ("distilbert", "DistilBertModelWithHeads"), ("bart", "BartModelWithHeads"), ("mbart", "MBartModelWithHeads"), + ("whisper", "WhisperModelWithHeads"), ("gpt2", "GPT2ModelWithHeads"), ("t5", "T5ModelWithHeads"), ] diff --git a/src/transformers/adapters/models/whisper/__init__.py b/src/transformers/adapters/models/whisper/__init__.py new file mode 100644 index 0000000000..8ac73c1e36 --- /dev/null +++ b/src/transformers/adapters/models/whisper/__init__.py @@ -0,0 +1,24 @@ +from typing import TYPE_CHECKING + +from ....utils import _LazyModule + + +_import_structure = { + "adapter_model": [ + "WhisperAdapterModel", + "WhisperModelWithHeads", + ], +} + + +if TYPE_CHECKING: + from .adapter_model import WhisperAdapterModel, WhisperModelWithHeads + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + ) diff --git a/src/transformers/adapters/models/whisper/adapter_model.py b/src/transformers/adapters/models/whisper/adapter_model.py new file mode 100644 index 0000000000..458f04c224 --- /dev/null +++ b/src/transformers/adapters/models/whisper/adapter_model.py @@ -0,0 +1,227 @@ +import warnings + +import torch + +from ....models.whisper.modeling_whisper import ( + WHISPER_INPUTS_DOCSTRING, + WHISPER_START_DOCSTRING, + WhisperConfig, + WhisperModel, + WhisperPreTrainedModel, + shift_tokens_right, +) +from ....utils import add_start_docstrings, add_start_docstrings_to_model_forward +from ...composition import adjust_tensors_for_parallel +from ...heads import ( + ClassificationHead, + ModelWithFlexibleHeadsAdaptersMixin, + MultiLabelClassificationHead, + QuestionAnsweringHead, + Seq2SeqLMHead, +) +from ...model_mixin import EmbeddingAdaptersWrapperMixin + + +@add_start_docstrings( + "WHISPER Model with the option to add multiple flexible prediction heads on top.", WHISPER_START_DOCSTRING +) +class WhisperAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, WhisperPreTrainedModel): + _keys_to_ignore_on_load_missing = [r"proj_out.weight"] + + def __init__(self, config: WhisperConfig, **kwargs): + super().__init__(config, **kwargs) + self.model = WhisperModel(config) + + self._init_head_modules() + + def get_encoder(self): + return self.model.get_encoder() + + def get_decoder(self): + return self.model.get_decoder() + + @add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING) + def forward( + self, + input_ids=None, + attention_mask=None, + decoder_input_ids=None, + decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + encoder_outputs=None, + inputs_embeds=None, + decoder_inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + past_key_values=None, + head=None, + output_adapter_gating_scores=False, + output_adapter_fusion_attentions=False, + **kwargs + ): + # TODO What should be done before the original model.forward call? + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + past_key_values=past_key_values, + output_adapter_gating_scores=output_adapter_gating_scores, + output_adapter_fusion_attentions=output_adapter_fusion_attentions, + adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), + ) + + # TODO What should be done after the original model.forward call? + # TODO Language detection? + + head_outputs = self.forward_head( + outputs, + head_name=head, + attention_mask=attention_mask, + return_dict=return_dict, + **kwargs, + ) + + return head_outputs + + # Copied from WhisperForConditionalGeneration + def prepare_inputs_for_generation( + self, + decoder_input_ids, + past_key_values=None, + use_cache=None, + encoder_outputs=None, + attention_mask=None, + **kwargs + ): + # cut decoder_input_ids if past is used + if past_key_values is not None: + decoder_input_ids = decoder_input_ids[:, -1:] + + return { + "encoder_outputs": encoder_outputs, + "past_key_values": past_key_values, + "decoder_input_ids": decoder_input_ids, + "use_cache": use_cache, + "decoder_attention_mask": None, + } + + # Copied from WhisperForConditionalGeneration + @staticmethod + def _reorder_cache(past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + return reordered_past + + + head_types = { + "classification": ClassificationHead, + "multilabel_classification": MultiLabelClassificationHead, + "question_answering": QuestionAnsweringHead, + "seq2seq_lm": Seq2SeqLMHead, + } + + def add_classification_head( + self, + head_name, + num_labels=2, + layers=2, + activation_function="tanh", + overwrite_ok=False, + multilabel=False, + id2label=None, + ): + """ + Adds a sequence classification head on top of the model. + + Args: + head_name (str): The name of the head. + num_labels (int, optional): Number of classification labels. Defaults to 2. + layers (int, optional): Number of layers. Defaults to 2. + activation_function (str, optional): Activation function. Defaults to 'tanh'. + overwrite_ok (bool, optional): Force overwrite if a head with the same name exists. Defaults to False. + multilabel (bool, optional): Enable multilabel classification setup. Defaults to False. + """ + + if multilabel: + head = MultiLabelClassificationHead(self, head_name, num_labels, layers, activation_function, id2label) + else: + head = ClassificationHead(self, head_name, num_labels, layers, activation_function, id2label) + self.add_prediction_head(head, overwrite_ok) + + def add_qa_head( + self, + head_name, + num_labels=2, + layers=1, + activation_function="tanh", + overwrite_ok=False, + id2label=None, + ): + head = QuestionAnsweringHead(self, head_name, num_labels, layers, activation_function, id2label) + self.add_prediction_head(head, overwrite_ok) + + def add_seq2seq_lm_head( + self, + head_name, + overwrite_ok=False, + ): + """ + Adds a sequence-to-sequence language modeling head on top of the model. + + Args: + head_name (str): The name of the head. + overwrite_ok (bool, optional): Force overwrite if a head with the same name exists. Defaults to False. + """ + head = Seq2SeqLMHead(self, head_name) + self.add_prediction_head(head, overwrite_ok=overwrite_ok) + + +class WhisperModelWithHeads(WhisperAdapterModel): + def __init__(self, *args, **kwargs): + warnings.warn( + "This class has been renamed to `{}` in v3. " + "Please use the new class instead as this class might be removed in a future version.".format( + self.__class__.__bases__[0].__name__ + ), + FutureWarning, + ) + super().__init__(*args, **kwargs) + + @classmethod + def from_config(cls, config): + warnings.warn( + "This class has been renamed to `{}` in v3. " + "Please use the new class instead as this class might be removed in a future version.".format( + cls.__bases__[0].__name__ + ), + FutureWarning, + ) + return super().from_config(config) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + warnings.warn( + "This class has been renamed to `{}` in v3. " + "Please use the new class instead as this class might be removed in a future version.".format( + cls.__bases__[0].__name__ + ), + FutureWarning, + ) + return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) diff --git a/src/transformers/adapters/prefix_tuning.py b/src/transformers/adapters/prefix_tuning.py index d3967f634e..580b82e9cc 100644 --- a/src/transformers/adapters/prefix_tuning.py +++ b/src/transformers/adapters/prefix_tuning.py @@ -202,7 +202,7 @@ def forward(self, *args, **kwargs): prefix_states = {} if adapter_setup is not None: # Infer batch size - input_tensor_names = ["input_ids", "decoder_input_ids", "attention_mask", "inputs_embeds", "pixel_values"] + input_tensor_names = ["input_ids", "decoder_input_ids", "attention_mask", "inputs_embeds", "pixel_values", "input_features"] batch_size = None for name in input_tensor_names: if kwargs.get(name, None) is not None: diff --git a/src/transformers/adapters/wrappers/configuration.py b/src/transformers/adapters/wrappers/configuration.py index 3506d93f70..724532fa8f 100644 --- a/src/transformers/adapters/wrappers/configuration.py +++ b/src/transformers/adapters/wrappers/configuration.py @@ -45,6 +45,10 @@ "hidden_dropout_prob": "dropout", "attention_probs_dropout_prob": "attention_dropout", }, + "whisper": { + "hidden_dropout_prob": "dropout", + "attention_probs_dropout_prob": "attention_dropout", + }, "roberta": {}, "t5": { "hidden_size": "d_model", diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index fb03b63de1..75f4c9976e 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -13,8 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """ PyTorch Whisper model.""" - - +import copy import math import random from typing import Optional, Tuple, Union @@ -25,14 +24,35 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...adapters.composition import adjust_tensors_for_parallel +from ...adapters.context import ForwardContext +from ...adapters.lora import Linear as LoRALinear +from ...adapters.mixins.whisper import ( + WhisperDecoderLayerAdaptersMixin, + WhisperEncoderLayerAdaptersMixin, + WhisperModelAdaptersMixin, + WhisperModelWithHeadsAdaptersMixin, +) +from ...adapters.model_mixin import InvertibleAdaptersMixin +from ...adapters.prefix_tuning import PrefixTuningShim from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, Seq2SeqLMOutput, Seq2SeqModelOutput, + Seq2SeqQuestionAnsweringModelOutput, + Seq2SeqSequenceClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from ...utils import ( + add_code_sample_docstrings, + add_end_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) from .configuration_whisper import WhisperConfig @@ -110,11 +130,13 @@ class WhisperAttention(nn.Module): def __init__( self, + config: WhisperConfig, embed_dim: int, num_heads: int, dropout: float = 0.0, is_decoder: bool = False, bias: bool = True, + location_key: Optional[str] = None, ): super().__init__() self.embed_dim = embed_dim @@ -130,10 +152,12 @@ def __init__( self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder - self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False) - self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.k_proj = LoRALinear(embed_dim, embed_dim, "selfattn", config, attn_key="k", bias=False) + self.v_proj = LoRALinear(embed_dim, embed_dim, "selfattn", config, attn_key="v", bias=bias) + self.q_proj = LoRALinear(embed_dim, embed_dim, "selfattn", config, attn_key="q", bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + self.prefix_tuning = PrefixTuningShim(location_key + "_prefix" if location_key else None, config) # Copied from transformers.models.bart.modeling_bart.BartAttention._shape with BART->whisper def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): @@ -197,6 +221,12 @@ def forward( past_key_value = (key_states, value_states) proj_shape = (bsz * self.num_heads, -1, self.head_dim) + + key_states, value_states, attention_mask = self.prefix_tuning( + key_states, value_states, hidden_states, attention_mask + ) + (query_states,) = adjust_tensors_for_parallel(key_states, query_states) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) key_states = key_states.view(*proj_shape) value_states = value_states.view(*proj_shape) @@ -253,7 +283,7 @@ def forward( attn_output = attn_output.transpose(1, 2) # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned aross GPUs when using tensor-parallelism. + # partitioned across GPUs when using tensor-parallelism. attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) attn_output = self.out_proj(attn_output) @@ -262,23 +292,29 @@ def forward( # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Whisper -class WhisperEncoderLayer(nn.Module): +class WhisperEncoderLayer(WhisperEncoderLayerAdaptersMixin, nn.Module): def __init__(self, config: WhisperConfig): super().__init__() + self.config = config + self.embed_dim = config.d_model self.self_attn = WhisperAttention( + config, embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, + location_key="encoder", ) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] self.activation_dropout = config.activation_dropout - self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) - self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.fc1 = LoRALinear(self.embed_dim, config.encoder_ffn_dim, "intermediate", config) + self.fc2 = LoRALinear(config.encoder_ffn_dim, self.embed_dim, "output", config) self.final_layer_norm = nn.LayerNorm(self.embed_dim) + self._init_adapter_modules() + def forward( self, hidden_states: torch.Tensor, @@ -306,7 +342,7 @@ def forward( output_attentions=output_attentions, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - hidden_states = residual + hidden_states + hidden_states = self.attention_adapters(hidden_states, residual, None) residual = hidden_states hidden_states = self.final_layer_norm(hidden_states) @@ -314,7 +350,7 @@ def forward( hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) hidden_states = self.fc2(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - hidden_states = residual + hidden_states + hidden_states = self.attention_adapters(hidden_states, residual, None) if hidden_states.dtype == torch.float16 and ( torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() @@ -331,16 +367,20 @@ def forward( # Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Whisper -class WhisperDecoderLayer(nn.Module): +class WhisperDecoderLayer(WhisperDecoderLayerAdaptersMixin, nn.Module): def __init__(self, config: WhisperConfig): super().__init__() + self.config = config + self.embed_dim = config.d_model self.self_attn = WhisperAttention( + config, embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, is_decoder=True, + location_key="self", ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] @@ -348,16 +388,20 @@ def __init__(self, config: WhisperConfig): self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.encoder_attn = WhisperAttention( + config, self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, is_decoder=True, + location_key="cross", ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) - self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) + self.fc1 = LoRALinear(self.embed_dim, config.encoder_ffn_dim, "intermediate", config) + self.fc2 = LoRALinear(config.encoder_ffn_dim, self.embed_dim, "output", config) self.final_layer_norm = nn.LayerNorm(self.embed_dim) + self._init_adapter_modules() + def forward( self, hidden_states: torch.Tensor, @@ -403,7 +447,7 @@ def forward( output_attentions=output_attentions, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - hidden_states = residual + hidden_states + hidden_states = self.attention_adapters(hidden_states, residual, None) # Cross-Attention Block cross_attn_present_key_value = None @@ -423,7 +467,7 @@ def forward( output_attentions=output_attentions, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - hidden_states = residual + hidden_states + hidden_states = self.cross_attention_adapters(hidden_states, residual, None) # add cross-attn to positions 3,4 of present_key_value tuple present_key_value = present_key_value + cross_attn_present_key_value @@ -435,7 +479,7 @@ def forward( hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) hidden_states = self.fc2(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - hidden_states = residual + hidden_states + hidden_states = self.output_adapters(hidden_states, residual, None) outputs = (hidden_states,) @@ -453,7 +497,6 @@ class WhisperPreTrainedModel(PreTrainedModel): base_model_prefix = "model" main_input_name = "input_features" supports_gradient_checkpointing = True - _no_split_modules = ["WhisperEncoderLayer"] def _init_weights(self, module): std = self.config.init_std @@ -573,7 +616,7 @@ def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): """ -class WhisperEncoder(WhisperPreTrainedModel): +class WhisperEncoder(InvertibleAdaptersMixin, WhisperPreTrainedModel): """ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a [`WhisperEncoderLayer`]. @@ -585,6 +628,8 @@ class WhisperEncoder(WhisperPreTrainedModel): def __init__(self, config: WhisperConfig): super().__init__(config) + self.config = config + self.dropout = config.dropout self.layerdrop = config.encoder_layerdrop @@ -659,6 +704,8 @@ def forward( hidden_states = inputs_embeds + embed_pos hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = self.invertible_adapters_forward(hidden_states) + encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None @@ -668,6 +715,7 @@ def forward( len(self.layers) ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." + for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) @@ -699,11 +747,13 @@ def custom_forward(*inputs): ) hidden_states = layer_outputs[0] + (attention_mask,) = adjust_tensors_for_parallel(hidden_states, attention_mask) if output_attentions: all_attentions = all_attentions + (layer_outputs[1],) hidden_states = self.layer_norm(hidden_states) + if output_hidden_states: encoder_states = encoder_states + (hidden_states,) @@ -752,7 +802,6 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_em # create causal mask # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] combined_attention_mask = None - if input_shape[-1] > 1: combined_attention_mask = _make_causal_mask( input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length @@ -760,7 +809,9 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_em if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( + inputs_embeds.device + ) combined_attention_mask = ( expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask ) @@ -936,6 +987,7 @@ def custom_forward(*inputs): use_cache=use_cache, ) hidden_states = layer_outputs[0] + (attention_mask,) = adjust_tensors_for_parallel(hidden_states, attention_mask) if use_cache: next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) @@ -947,6 +999,7 @@ def custom_forward(*inputs): all_cross_attentions += (layer_outputs[2],) hidden_states = self.layer_norm(hidden_states) + # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) @@ -971,7 +1024,7 @@ def custom_forward(*inputs): "The bare Whisper Model outputting raw hidden-states without any specific head on top.", WHISPER_START_DOCSTRING, ) -class WhisperModel(WhisperPreTrainedModel): +class WhisperModel(WhisperModelAdaptersMixin, WhisperPreTrainedModel): _keys_to_ignore_on_load_missing = [r"proj_out.weight"] def __init__(self, config: WhisperConfig): @@ -979,7 +1032,9 @@ def __init__(self, config: WhisperConfig): self.encoder = WhisperEncoder(config) self.decoder = WhisperDecoder(config) - # Initialize weights and apply final processing + + self._init_adapter_modules() + self.post_init() def get_input_embeddings(self): @@ -1003,6 +1058,7 @@ def freeze_encoder(self): @add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC) + @ForwardContext.wrap def forward( self, input_features: Optional[torch.LongTensor] = None, @@ -1095,7 +1151,7 @@ def forward( "The Whisper Model with a language modeling head. Can be used for automatic speech recognition.", WHISPER_START_DOCSTRING, ) -class WhisperForConditionalGeneration(WhisperPreTrainedModel): +class WhisperForConditionalGeneration(WhisperModelWithHeadsAdaptersMixin, WhisperPreTrainedModel): base_model_prefix = "model" _keys_to_ignore_on_load_missing = [ r"encoder.version", @@ -1208,6 +1264,7 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, ) + lm_logits = self.model.encoder.invertible_adapters_forward(outputs[0], rev=True) lm_logits = self.proj_out(outputs[0]) loss = None diff --git a/tests_adapters/models/test_whisper.py b/tests_adapters/models/test_whisper.py new file mode 100644 index 0000000000..a992f3ecd7 --- /dev/null +++ b/tests_adapters/models/test_whisper.py @@ -0,0 +1,11 @@ +from tests.models.whisper.test_modeling_whisper import * +from transformers import WhisperAdapterModel +from transformers.testing_utils import require_torch + +from .base import AdapterModelTesterMixin + + +@require_torch +class WhisperAdapterModelTest(AdapterModelTesterMixin, WhisperModelTest): + all_model_classes = (WhisperAdapterModel,) + fx_compatible = False diff --git a/tests_adapters/test_whisper.py b/tests_adapters/test_whisper.py new file mode 100644 index 0000000000..8d2d2fdbf9 --- /dev/null +++ b/tests_adapters/test_whisper.py @@ -0,0 +1,60 @@ +import unittest + +from transformers import WhisperConfig +from transformers.testing_utils import require_torch + +from .methods import ( + BottleneckAdapterTestMixin, + UniPELTTestMixin, + CompacterTestMixin, + IA3TestMixin, + LoRATestMixin, + PrefixTuningTestMixin, +) +from .test_adapter import AdapterTestBase, make_config +from .composition.test_parallel import ParallelAdapterInferenceTestMixin +from .test_adapter_conversion import ModelClassConversionTestMixin +from .test_adapter_fusion_common import AdapterFusionModelTestMixin +from .test_adapter_heads import PredictionHeadModelTestMixin + + +class WhisperAdapterTestBase(AdapterTestBase): + config_class = WhisperConfig + config = make_config( + WhisperConfig, + d_model=16, + encoder_layers=2, + decoder_layers=2, + encoder_attention_heads=4, + decoder_attention_heads=4, + encoder_ffn_dim=4, + decoder_ffn_dim=4, + vocab_size=51865, + ) + tokenizer_name = "openai/whisper-tiny" + + +@require_torch +class WhisperAdapterTest( + BottleneckAdapterTestMixin, + CompacterTestMixin, + IA3TestMixin, + LoRATestMixin, + PrefixTuningTestMixin, + UniPELTTestMixin, + AdapterFusionModelTestMixin, + PredictionHeadModelTestMixin, + ParallelAdapterInferenceTestMixin, + WhisperAdapterTestBase, + unittest.TestCase, +): + pass + + +@require_torch +class WhisperClassConversionTest( + ModelClassConversionTestMixin, + WhisperAdapterTestBase, + unittest.TestCase, +): + pass diff --git a/utils/check_adapters.py b/utils/check_adapters.py index fe8c902c45..f8b07ba89b 100644 --- a/utils/check_adapters.py +++ b/utils/check_adapters.py @@ -12,6 +12,7 @@ "distilbert", "bart", "mbart", + "whisper", "gpt2", "gptj", "encoder_decoder", @@ -27,6 +28,8 @@ "BartDecoder", "MBartEncoder", "MBartDecoder", + "WhisperEncoder", + "WhisperDecoder", "T5Stack", ]