Skip to content

Commit

Permalink
Refactor adapter composition implementation (#591)
Browse files Browse the repository at this point in the history
Refactors the implementation of composition blocks in the model forward
pass such that more of the logic is shared between all adapter methods.

### Changes

- Move adapter methods into new `methods` folder
- Introduce new `ComposableAdapterLayerBase` as subclass of
`AdapterLayerbase` as shared base class of all adapter methods that
support composition.
- This class provides default implementations for a couple of
composition blocks (currently `Stack`, `Parallel`, `BatchSplit`,
`Average`) which can be used by all subclasses.
- To enable these composition blocks for deriving methods, a couple of
helper methods defined in `ComposableAdapterLayerBase` must be
implemented. See
https://github.com/calpt/adapter-transformers/blob/55fdc0cbe2f695914108a9c0e208127b13bc617e/src/adapters/methods/adapter_layer_base.py#L132-L222.
- Different adapter methods require passing different inputs to each
composition block. Thus, the input states are abstracted as a
`NamedTuple` in the base class. Deriving classes should define concrete
`NamedTuple`-derived state classes. E.g., see
https://github.com/calpt/adapter-transformers/blob/55fdc0cbe2f695914108a9c0e208127b13bc617e/src/adapters/methods/bottleneck.py#L22
- Update `Split` composition block to support more than two child
blocks. Splits are defined as a list of split indices, ie. `Split("a",
"b", "c", splits=[64, 64, 64])`. **Breaking change**
- Renamings: `AdapterLayer` -> `BottleneckLayer`; `PrefixTuningShim` ->
`PrefixTuningLayer`

---------

Co-authored-by: Leon Engländer <[email protected]>
  • Loading branch information
calpt and lenglaender authored Oct 29, 2023
1 parent 6da07d1 commit dfe17e9
Show file tree
Hide file tree
Showing 42 changed files with 1,150 additions and 1,200 deletions.
10 changes: 5 additions & 5 deletions docs/adapter_composition.md
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,10 @@ In the example, `attention_scores` holds a dictionary of the following form:
Splitting the input between two adapters using the 'Split' block.
```

The `Split` block can be used to split an input sequence between two adapters.
This is done by specifying a split index, at which the sequences should be divided.
The `Split` block can be used to split an input sequence between multiple adapters.
This is done by specifying split indices at which the sequences should be divided.
In the following example, we split each input sequence between adapters `g` and `h`.
For each sequence, all tokens from 0 up to 63 are forwarded through `g` while all tokens beginning at index 64 are forwarded through `h`:
For each sequence, all tokens from 0 up to 63 are forwarded through `g` while the next 64 tokens are forwarded through `h`:

```python
import adapters.composition as ac
Expand All @@ -173,7 +173,7 @@ import adapters.composition as ac
model.add_adapter("g")
model.add_adapter("h")

model.active_adapters = ac.Split("g", "h", split_index=64)
model.active_adapters = ac.Split("g", "h", splits=[64, 64])
```

## `BatchSplit`
Expand Down Expand Up @@ -286,7 +286,7 @@ E.g., we can nest a `Split` block within a `Stack` of adapters:
```python
import adapters.composition as ac

model.active_adapters = ac.Stack("a", ac.Split("b", "c", split_index=60))
model.active_adapters = ac.Stack("a", ac.Split("b", "c", splits=60))
```

However, combinations of adapter composition blocks cannot be arbitrarily deep. All currently supported possibilities are visualized in the table below.
Expand Down
11 changes: 9 additions & 2 deletions docs/classes/adapter_layer.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
AdapterLayer
Adapter Implementation
=======================

.. autoclass:: adapters.AdapterLayer
The following classes define the common interfaces for all adapter methods.
They further hold logic shared by all adapter implementations.
All newly added adapter methods should inherit from either one of these classes.

.. autoclass:: adapters.AdapterLayerBase
:members:

.. autoclass:: adapters.ComposableAdapterLayerBase
:members:
7 changes: 0 additions & 7 deletions docs/classes/adapter_modules.rst

This file was deleted.

53 changes: 37 additions & 16 deletions docs/contributing/adding_adapter_methods.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,28 +20,49 @@ Thus, each adapter method implementation at least should provide two classes:

- a configuration class deriving from `AdapterConfigBase` that provides attributes for all configuration options of the method
- a module class deriving from the abstract `AdapterLayerBase` that provides the method parameters and a set of standard adapter management functions
- modules supporting [adapter composition](https://docs.adapterhub.ml/adapter_composition.html) should instead derive from `ComposableAdapterLayerBase`

**📝 Steps**
### Configuration

- All configuration classes reside in `src/transformers/adapters/configuration.py`.
To add a new configuration class for a new method, create a new subclass of `AdapterConfigBase`.
All configuration classes reside in `src/adapters/configuration/adapter_config.py`.
- To add a new configuration class for a new method, create a new subclass of [`AdapterConfigBase`](adapters.AdapterConfigBase).
Make sure to set the `architecture` attribute in your class.
- Finally, also make sure the config class is added to the `__init__.py` files in `src/transformers/adapters` and `src/transformers`.
- The `AdapterLayerBase` class from which any new adapter modules should derive resides in `src/transformers/adapters/layer.py`.
- This abstract base class defines a set of methods that should be implemented by each deriving class,
including methods for adding, enabling and deleting adapter weights.
- Most importantly, the module classes deriving from this base class should implement the forward pass through an adaptation component.
- The concrete implementation of these classes heavily depends on the specifics of the adapter method.
For a reference implementation, have a look at `AdapterLayer` for bottleneck adapters.
- To actually make use of the newly implemented classes, it's finally necessary to integrate the forward calls to the modules in the actual model implementations.
- This, again, is highly dependent on how the adapter method interacts with the base model classes. Typically, module classes can be integrated either via mixins (see `src/transformers/adapters/mixins`) or directly as submodules of the respective model components.
- The model class integration has to be repeated for each supported Transformer model, as they typically don't share a codebase. At this point it is often important to consider where the adapters need to be added to the transformer model and whether there is an implementation that does not require more copying of classes than the current implementation.
Please try to integrate any new adapter method into every model class when it's reasonable.
You can find all currently supported model classes at https://docs.adapterhub.ml/model_overview.html.
- Finally, also make sure the config class is added to the `__init__.py` files in `src/adapters`.

### Modeling

All adapter method implementations reside in `src/adapters/methods`.

#### For methods **without** composition support

The [`AdapterLayerBase`](adapters.AdapterLayerBase) class from which any new adapter modules should derive resides in `src/adapters/methods/adapter_layer_base.py`.
- This abstract base class defines a set of methods that should be implemented by each deriving class,
including methods for adding, enabling and deleting adapter weights. These methods are marked as abstract in the base class. See [`AdapterLayerBase`](adapters.AdapterLayerBase) for details.
- Most importantly however, the module classes deriving from this base class should implement the forward pass through an adaptation component.
- The concrete implementation of these classes heavily depends on the specifics of the adapter method.

#### For methods **with** composition support

The [`ComposableAdapterLayerBase`](adapters.ComposableAdapterLayerBase) class (as subclass of [`AdapterLayerBase`](adapters.AdapterLayerBase)), which resides in `src/adapters/methods/adapter_layer_base.py` provides the basic skeleton for implementing adapter composition.
- Your deriving module class firstly should implement all methods required by [`AdapterLayerBase`](adapters.AdapterLayerBase). See section above for details.
- For adapter composition, the pre-implemented `compose()` method constitutes the main entry-point. This method should be called during the forward pass of your adapter module.
- `compose()` expects a `state` object, which is a generic named tuple object defined by your adapter method. This state object should hold all tensors (such as hidden states, attention masks etc.) and state attributes required for your adapter implementation. See `BottleneckState` for an example.
- Implementations for specific composition blocks are given in methods starting with `compose_`. Some composition blocks provide generic default implementations, some must be implemented by the deriving class if they should be supported. Make sure to list all supported composition blocks in the `supported_compositions` class attribute of your deriving module.
- In any case, a small set of helper methods should be implemented by any deriving module to support basic composition logic. These are marked as abstract methods in [`ComposableAdapterLayerBase`](adapters.ComposableAdapterLayerBase) and currently consist of the following: vslice(), pad_and_concat(), repeat(), mean(), compose_single(). See [`ComposableAdapterLayerBase`](adapters.ComposableAdapterLayerBase) for details.

For a reference implementation, have a look at `BottleneckLayer` for bottleneck adapters.

#### For all methods

To actually make use of the newly implemented classes, it's finally necessary to integrate the forward calls to the modules in the actual model implementations.
- This, again, is highly dependent on how the adapter method interacts with the base model classes. Typically, module classes can be integrated either via mixins (see modules starting with "mixin" in `src/adapters/models`) or directly as submodules of the respective model components.
- The model class integration has to be repeated for each supported Transformer model, as they typically don't share a codebase. At this point it is often important to consider where the adapters need to be added to the transformer model and whether there is an implementation that does not require more copying of classes than the current implementation.
Please try to integrate any new adapter method into every model class when it's reasonable.
You can find all currently supported model classes at https://docs.adapterhub.ml/model_overview.html.

**Additional things to consider**

- New adapter methods typically also require some changes in the `AdapterLoader` class in `src/transformers/adapters/loading.py` (also see [here](https://docs.adapterhub.ml/extending.html#loading-custom-module-weights)).
- New adapter methods typically also require some changes in the `AdapterLoader` class in `src/adapters/loading.py` (also see [here](https://docs.adapterhub.ml/extending.html#loading-custom-module-weights)).
- Depending on the method to be integrated, further changes in other classes might be necessary.

## Testing
Expand Down
4 changes: 2 additions & 2 deletions docs/contributing/adding_adapters_to_a_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ Now that we have discussed the purpose of every file in `src/adapters/models/<mo
- To figure out which classes to change, think about where to insert LoRA, Prefix Tuning, and bottleneck adapters.
- You can use similar model implementations for guidance.
- Often, existing mixins of another class can be reused. E.g. `BertLayer`, `RobertaLayer`, `XLMRobertaLayer`, `DebertaLayer`, `DebertaV2Layer` and `BertGenerationLayer` (all models derived from BERT) use the `BertLayerAdaptersMixin`.
- To additionally support Prefix Tuning, it's necessary to apply the forward call to the `PrefixTuningShim` module in the respective attention layer (see step 3 for how to modify the code of an Hugging Face class).
- Make sure the calls to `adapter_layer_forward()` are added in the right places.
- To additionally support Prefix Tuning, it's necessary to apply the forward call to the `PrefixTuningLayer` module in the respective attention layer (see step 3 for how to modify the code of an Hugging Face class).
- Make sure the calls to `bottleneck_layer_forward()` are added in the right places.
- The mixin for the whole base model class (e.g., `BertModel`) should derive from `ModelBaseAdaptersMixin` and (if possible) `EmbeddingAdaptersMixin` and/or `InvertibleAdaptersMixin`. This mixin should at least implement the `iter_layers()` method but might require additional modifications depending on the architecture.
- If the model is a combination of different models, such as the EncoderDecoderModel, use `ModelUsingSubmodelsAdaptersMixin` instead of `ModelBaseAdaptersMixin`.
3. **Copied functions:**
Expand Down
4 changes: 2 additions & 2 deletions src/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@
"Seq2SeqLMHead",
"TaggingHead",
],
"layer": ["AdapterLayer", "AdapterLayerBase"],
"methods.adapter_layer_base": ["AdapterLayerBase", "ComposableAdapterLayerBase"],
"model_mixin": [
"EmbeddingAdaptersMixin",
"InvertibleAdaptersMixin",
Expand Down Expand Up @@ -182,7 +182,7 @@
Seq2SeqLMHead,
TaggingHead,
)
from .layer import AdapterLayer, AdapterLayerBase
from .methods.adapter_layer_base import AdapterLayerBase, ComposableAdapterLayerBase
from .model_mixin import (
EmbeddingAdaptersMixin,
InvertibleAdaptersMixin,
Expand Down
9 changes: 3 additions & 6 deletions src/adapters/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,9 @@ def name(self):


class Split(AdapterCompositionBlock):
def __init__(self, left: str, right: str, split_index: int):
super().__init__(left, right)
assert split_index > 0
self.left = left
self.right = right
self.split_index = split_index
def __init__(self, *split_adapters: List[Union[AdapterCompositionBlock, str]], splits: Union[List[int], int]):
super().__init__(*split_adapters)
self.splits = splits if isinstance(splits, list) else [splits] * len(split_adapters)


class BatchSplit(AdapterCompositionBlock):
Expand Down
2 changes: 1 addition & 1 deletion src/adapters/heads/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@

from ..composition import AdapterCompositionBlock, BatchSplit, Parallel, parse_heads_from_composition
from ..context import AdapterSetup, ForwardContext
from ..methods.modeling import Activation_Function_Class
from ..model_mixin import ModelWithHeadsAdaptersMixin
from ..modeling import Activation_Function_Class


logger = logging.getLogger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion src/adapters/heads/language_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from transformers.modeling_outputs import CausalLMOutput, CausalLMOutputWithPast, MaskedLMOutput, Seq2SeqLMOutput

from ..modeling import Activation_Function_Class
from ..methods.modeling import Activation_Function_Class
from .base import PredictionHead


Expand Down
Loading

0 comments on commit dfe17e9

Please sign in to comment.