Skip to content

Commit

Permalink
Added adapters to whisper
Browse files Browse the repository at this point in the history
  • Loading branch information
HenningBuhl committed Jul 23, 2023
1 parent accf70f commit 544ae05
Show file tree
Hide file tree
Showing 17 changed files with 501 additions and 26 deletions.
17 changes: 17 additions & 0 deletions adapter_docs/classes/models/whisper.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
Whisper
-----------------------------------------------------------------------------------------------------------------------

The Whisper model was presented in `Robust Speech Recognition via Large-Scale Weak Supervision
<https://arxiv.org/abs/2212.04356>`_ by Alec Radford, Jong Wook Kim, Tao Xu, Greg Brockman, Christine
McLeavey, Ilya Sutskever.

According to the abstract, Whisper is trained on 680,000 hours of multilingual and multitask data. This
scale was previously unseen. Whisper is able to approach the accuracy and robustness of humans.


WhisperAdapterModel
~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.adapters.WhisperAdapterModel
:members:
:inherited-members: WhisperPreTrainedModel
1 change: 1 addition & 0 deletions adapter_docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ Currently, we support the PyTorch versions of all models as listed on the `Model
classes/models/gpt2
classes/models/gptj
classes/models/mbart
classes/models/whisper
classes/models/roberta
classes/models/t5
classes/models/vit
Expand Down
1 change: 1 addition & 0 deletions adapter_docs/model_overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ The table below further shows which model architectures support which adaptation
| [GPT-2](classes/models/gpt2.html) ||||||||
| [GPT-J](classes/models/gptj.html) ||||||||
| [MBart](classes/models/mbart.html) ||||||||
| [Whisper](classes/models/whisper.html) ||||||||
| [RoBERTa](classes/models/roberta.html) ||||||||
| [T5](classes/models/t5.html) ||||||||
| [ViT](classes/models/vit.html) ||||||||
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2604,6 +2604,8 @@
"MAMConfig",
"MBartAdapterModel",
"MBartModelWithHeads",
"WhisperAdapterModel",
"WhisperModelWithHeads",
"ModelAdaptersConfig",
"ModelAdaptersMixin",
"ModelWithFlexibleHeadsAdaptersMixin",
Expand Down Expand Up @@ -5708,6 +5710,8 @@
MAMConfig,
MBartAdapterModel,
MBartModelWithHeads,
WhisperAdapterModel,
WhisperModelWithHeads,
ModelAdaptersConfig,
ModelAdaptersMixin,
ModelWithFlexibleHeadsAdaptersMixin,
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,10 @@
"MBartAdapterModel",
"MBartModelWithHeads",
],
"models.whisper": [
"WhisperAdapterModel",
"WhisperModelWithHeads",
],
"models.roberta": [
"RobertaAdapterModel",
"RobertaModelWithHeads",
Expand Down Expand Up @@ -219,6 +223,7 @@
from .models.gpt2 import GPT2AdapterModel, GPT2ModelWithHeads
from .models.gptj import GPTJAdapterModel
from .models.mbart import MBartAdapterModel, MBartModelWithHeads
from .models.whisper import WhisperAdapterModel, WhisperModelWithHeads
from .models.roberta import RobertaAdapterModel, RobertaModelWithHeads
from .models.t5 import T5AdapterModel, T5ModelWithHeads
from .models.vit import ViTAdapterModel
Expand Down
1 change: 1 addition & 0 deletions src/transformers/adapters/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def __init__(self, *split_adapters: List[Union[AdapterCompositionBlock, str]], b
"deberta",
"bart",
"mbart",
"whisper",
"gpt2",
"gptj",
"t5",
Expand Down
7 changes: 7 additions & 0 deletions src/transformers/adapters/head_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,13 @@
},
"layers": ["lm_head"],
},
# Whisper
"WhisperForConditionalGeneration": {
"config": {
"head_type": "seq2seq_lm",
},
"layers": ["proj_out"],
},
# DistilBERT
"DistilBertForSequenceClassification": {
"config": {
Expand Down
51 changes: 51 additions & 0 deletions src/transformers/adapters/mixins/whisper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from typing import Iterable, Tuple

import torch.nn as nn

from ..layer import AdapterLayer
from ..model_mixin import (
EmbeddingAdaptersMixin,
EmbeddingAdaptersWrapperMixin,
InvertibleAdaptersWrapperMixin,
ModelAdaptersMixin,
ModelWithHeadsAdaptersMixin,
)


class WhisperEncoderLayerAdaptersMixin:
"""Adds adapters to the WhisperEncoderLayer module of WHISPER."""

def _init_adapter_modules(self):
self.attention_adapters = AdapterLayer("mh_adapter", self.config)
self.output_adapters = AdapterLayer("output_adapter", self.config)
self.attention_adapters._init_adapter_modules()
self.output_adapters._init_adapter_modules()


class WhisperDecoderLayerAdaptersMixin(WhisperEncoderLayerAdaptersMixin):
"""Adds adapters to the WhisperDecoderLayer module of WHISPER."""

def _init_adapter_modules(self):
super()._init_adapter_modules()
self.cross_attention_adapters = AdapterLayer("cross_adapter", self.config)
self.cross_attention_adapters._init_adapter_modules()


class WhisperModelAdaptersMixin(EmbeddingAdaptersMixin, InvertibleAdaptersWrapperMixin, ModelAdaptersMixin):
"""Adds adapters to the WhisperModel class."""

invertible_adapters_base_name = "encoder"

def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]:
if hasattr(self, "encoder"):
for i, layer in enumerate(self.encoder.layers):
yield i, layer
for i, layer in enumerate(self.decoder.layers, start=len(self.encoder.layers)):
yield i, layer
else:
for i, layer in enumerate(self.decoder.layers):
yield i, layer


class WhisperModelWithHeadsAdaptersMixin(EmbeddingAdaptersWrapperMixin, ModelWithHeadsAdaptersMixin):
pass
2 changes: 2 additions & 0 deletions src/transformers/adapters/models/auto/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
("deberta", "DebertaAdapterModel"),
("bart", "BartAdapterModel"),
("mbart", "MBartAdapterModel"),
("whisper", "WhisperAdapterModel"),
("gpt2", "GPT2AdapterModel"),
("gptj", "GPTJAdapterModel"),
("t5", "T5AdapterModel"),
Expand All @@ -33,6 +34,7 @@
("distilbert", "DistilBertModelWithHeads"),
("bart", "BartModelWithHeads"),
("mbart", "MBartModelWithHeads"),
("whisper", "WhisperModelWithHeads"),
("gpt2", "GPT2ModelWithHeads"),
("t5", "T5ModelWithHeads"),
]
Expand Down
24 changes: 24 additions & 0 deletions src/transformers/adapters/models/whisper/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from typing import TYPE_CHECKING

from ....utils import _LazyModule


_import_structure = {
"adapter_model": [
"WhisperAdapterModel",
"WhisperModelWithHeads",
],
}


if TYPE_CHECKING:
from .adapter_model import WhisperAdapterModel, WhisperModelWithHeads

else:
import sys

sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
)
Loading

0 comments on commit 544ae05

Please sign in to comment.