From dfe17e9560e1dcec98503ff5ccbd8f0d82c4458f Mon Sep 17 00:00:00 2001 From: calpt Date: Sun, 29 Oct 2023 12:00:32 +0100 Subject: [PATCH] Refactor adapter composition implementation (#591) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Refactors the implementation of composition blocks in the model forward pass such that more of the logic is shared between all adapter methods. ### Changes - Move adapter methods into new `methods` folder - Introduce new `ComposableAdapterLayerBase` as subclass of `AdapterLayerbase` as shared base class of all adapter methods that support composition. - This class provides default implementations for a couple of composition blocks (currently `Stack`, `Parallel`, `BatchSplit`, `Average`) which can be used by all subclasses. - To enable these composition blocks for deriving methods, a couple of helper methods defined in `ComposableAdapterLayerBase` must be implemented. See https://github.com/calpt/adapter-transformers/blob/55fdc0cbe2f695914108a9c0e208127b13bc617e/src/adapters/methods/adapter_layer_base.py#L132-L222. - Different adapter methods require passing different inputs to each composition block. Thus, the input states are abstracted as a `NamedTuple` in the base class. Deriving classes should define concrete `NamedTuple`-derived state classes. E.g., see https://github.com/calpt/adapter-transformers/blob/55fdc0cbe2f695914108a9c0e208127b13bc617e/src/adapters/methods/bottleneck.py#L22 - Update `Split` composition block to support more than two child blocks. Splits are defined as a list of split indices, ie. `Split("a", "b", "c", splits=[64, 64, 64])`. **Breaking change** - Renamings: `AdapterLayer` -> `BottleneckLayer`; `PrefixTuningShim` -> `PrefixTuningLayer` --------- Co-authored-by: Leon Engländer --- docs/adapter_composition.md | 10 +- docs/classes/adapter_layer.rst | 11 +- docs/classes/adapter_modules.rst | 7 - docs/contributing/adding_adapter_methods.md | 53 +- .../adding_adapters_to_a_model.md | 4 +- src/adapters/__init__.py | 4 +- src/adapters/composition.py | 9 +- src/adapters/heads/base.py | 2 +- src/adapters/heads/language_modeling.py | 2 +- src/adapters/layer.py | 702 ------------------ src/adapters/methods/__init__.py | 0 src/adapters/methods/adapter_layer_base.py | 471 ++++++++++++ src/adapters/methods/bottleneck.py | 372 ++++++++++ src/adapters/{ => methods}/lora.py | 12 +- src/adapters/{ => methods}/modeling.py | 4 +- src/adapters/{ => methods}/prefix_tuning.py | 460 +++--------- src/adapters/model_mixin.py | 13 +- src/adapters/models/albert/mixin_albert.py | 12 +- src/adapters/models/bart/mixin_bart.py | 14 +- src/adapters/models/beit/mixin_beit.py | 12 +- src/adapters/models/beit/modeling_beit.py | 4 +- src/adapters/models/bert/mixin_bert.py | 16 +- src/adapters/models/bert/modeling_bert.py | 4 +- .../modeling_bert_generation.py | 4 +- src/adapters/models/clip/mixin_clip.py | 14 +- src/adapters/models/deberta/mixin_deberta.py | 6 +- .../models/deberta/modeling_deberta.py | 4 +- .../models/deberta_v2/mixin_deberta_v2.py | 6 +- .../models/deberta_v2/modeling_deberta_v2.py | 4 +- .../models/distilbert/mixin_distilbert.py | 12 +- .../models/electra/modeling_electra.py | 4 +- src/adapters/models/gpt2/mixin_gpt2.py | 14 +- src/adapters/models/gptj/mixin_gptj.py | 12 +- src/adapters/models/llama/mixin_llama.py | 12 +- .../models/roberta/modeling_roberta.py | 4 +- src/adapters/models/t5/mixin_t5.py | 14 +- src/adapters/models/t5/modeling_t5.py | 6 +- src/adapters/models/vit/mixin_vit.py | 12 +- src/adapters/models/vit/modeling_vit.py | 4 +- .../xlm_roberta/modeling_xlm_roberta.py | 4 +- src/adapters/models/xmod/modeling_xmod.py | 4 +- .../composition/test_adapter_composition.py | 12 +- 42 files changed, 1150 insertions(+), 1200 deletions(-) delete mode 100644 docs/classes/adapter_modules.rst delete mode 100644 src/adapters/layer.py create mode 100644 src/adapters/methods/__init__.py create mode 100644 src/adapters/methods/adapter_layer_base.py create mode 100644 src/adapters/methods/bottleneck.py rename src/adapters/{ => methods}/lora.py (98%) rename src/adapters/{ => methods}/modeling.py (99%) rename src/adapters/{ => methods}/prefix_tuning.py (53%) diff --git a/docs/adapter_composition.md b/docs/adapter_composition.md index 1ee26806ea..05f85f3fdb 100644 --- a/docs/adapter_composition.md +++ b/docs/adapter_composition.md @@ -160,10 +160,10 @@ In the example, `attention_scores` holds a dictionary of the following form: Splitting the input between two adapters using the 'Split' block. ``` -The `Split` block can be used to split an input sequence between two adapters. -This is done by specifying a split index, at which the sequences should be divided. +The `Split` block can be used to split an input sequence between multiple adapters. +This is done by specifying split indices at which the sequences should be divided. In the following example, we split each input sequence between adapters `g` and `h`. -For each sequence, all tokens from 0 up to 63 are forwarded through `g` while all tokens beginning at index 64 are forwarded through `h`: +For each sequence, all tokens from 0 up to 63 are forwarded through `g` while the next 64 tokens are forwarded through `h`: ```python import adapters.composition as ac @@ -173,7 +173,7 @@ import adapters.composition as ac model.add_adapter("g") model.add_adapter("h") -model.active_adapters = ac.Split("g", "h", split_index=64) +model.active_adapters = ac.Split("g", "h", splits=[64, 64]) ``` ## `BatchSplit` @@ -286,7 +286,7 @@ E.g., we can nest a `Split` block within a `Stack` of adapters: ```python import adapters.composition as ac -model.active_adapters = ac.Stack("a", ac.Split("b", "c", split_index=60)) +model.active_adapters = ac.Stack("a", ac.Split("b", "c", splits=60)) ``` However, combinations of adapter composition blocks cannot be arbitrarily deep. All currently supported possibilities are visualized in the table below. diff --git a/docs/classes/adapter_layer.rst b/docs/classes/adapter_layer.rst index 2b54475994..01233d6328 100644 --- a/docs/classes/adapter_layer.rst +++ b/docs/classes/adapter_layer.rst @@ -1,5 +1,12 @@ -AdapterLayer +Adapter Implementation ======================= -.. autoclass:: adapters.AdapterLayer +The following classes define the common interfaces for all adapter methods. +They further hold logic shared by all adapter implementations. +All newly added adapter methods should inherit from either one of these classes. + +.. autoclass:: adapters.AdapterLayerBase + :members: + +.. autoclass:: adapters.ComposableAdapterLayerBase :members: diff --git a/docs/classes/adapter_modules.rst b/docs/classes/adapter_modules.rst deleted file mode 100644 index 46056142bd..0000000000 --- a/docs/classes/adapter_modules.rst +++ /dev/null @@ -1,7 +0,0 @@ -Adapter Modules -=============== - -Classes implementing task and language adapters. - -.. automodule:: adapters.modeling - :members: diff --git a/docs/contributing/adding_adapter_methods.md b/docs/contributing/adding_adapter_methods.md index e968c90b7c..29a7801579 100644 --- a/docs/contributing/adding_adapter_methods.md +++ b/docs/contributing/adding_adapter_methods.md @@ -20,28 +20,49 @@ Thus, each adapter method implementation at least should provide two classes: - a configuration class deriving from `AdapterConfigBase` that provides attributes for all configuration options of the method - a module class deriving from the abstract `AdapterLayerBase` that provides the method parameters and a set of standard adapter management functions + - modules supporting [adapter composition](https://docs.adapterhub.ml/adapter_composition.html) should instead derive from `ComposableAdapterLayerBase` -**📝 Steps** +### Configuration -- All configuration classes reside in `src/transformers/adapters/configuration.py`. - To add a new configuration class for a new method, create a new subclass of `AdapterConfigBase`. +All configuration classes reside in `src/adapters/configuration/adapter_config.py`. +- To add a new configuration class for a new method, create a new subclass of [`AdapterConfigBase`](adapters.AdapterConfigBase). Make sure to set the `architecture` attribute in your class. - - Finally, also make sure the config class is added to the `__init__.py` files in `src/transformers/adapters` and `src/transformers`. -- The `AdapterLayerBase` class from which any new adapter modules should derive resides in `src/transformers/adapters/layer.py`. - - This abstract base class defines a set of methods that should be implemented by each deriving class, - including methods for adding, enabling and deleting adapter weights. - - Most importantly, the module classes deriving from this base class should implement the forward pass through an adaptation component. - - The concrete implementation of these classes heavily depends on the specifics of the adapter method. - For a reference implementation, have a look at `AdapterLayer` for bottleneck adapters. -- To actually make use of the newly implemented classes, it's finally necessary to integrate the forward calls to the modules in the actual model implementations. - - This, again, is highly dependent on how the adapter method interacts with the base model classes. Typically, module classes can be integrated either via mixins (see `src/transformers/adapters/mixins`) or directly as submodules of the respective model components. - - The model class integration has to be repeated for each supported Transformer model, as they typically don't share a codebase. At this point it is often important to consider where the adapters need to be added to the transformer model and whether there is an implementation that does not require more copying of classes than the current implementation. - Please try to integrate any new adapter method into every model class when it's reasonable. - You can find all currently supported model classes at https://docs.adapterhub.ml/model_overview.html. +- Finally, also make sure the config class is added to the `__init__.py` files in `src/adapters`. + +### Modeling + +All adapter method implementations reside in `src/adapters/methods`. + +#### For methods **without** composition support + +The [`AdapterLayerBase`](adapters.AdapterLayerBase) class from which any new adapter modules should derive resides in `src/adapters/methods/adapter_layer_base.py`. +- This abstract base class defines a set of methods that should be implemented by each deriving class, +including methods for adding, enabling and deleting adapter weights. These methods are marked as abstract in the base class. See [`AdapterLayerBase`](adapters.AdapterLayerBase) for details. +- Most importantly however, the module classes deriving from this base class should implement the forward pass through an adaptation component. +- The concrete implementation of these classes heavily depends on the specifics of the adapter method. + +#### For methods **with** composition support + +The [`ComposableAdapterLayerBase`](adapters.ComposableAdapterLayerBase) class (as subclass of [`AdapterLayerBase`](adapters.AdapterLayerBase)), which resides in `src/adapters/methods/adapter_layer_base.py` provides the basic skeleton for implementing adapter composition. +- Your deriving module class firstly should implement all methods required by [`AdapterLayerBase`](adapters.AdapterLayerBase). See section above for details. +- For adapter composition, the pre-implemented `compose()` method constitutes the main entry-point. This method should be called during the forward pass of your adapter module. +- `compose()` expects a `state` object, which is a generic named tuple object defined by your adapter method. This state object should hold all tensors (such as hidden states, attention masks etc.) and state attributes required for your adapter implementation. See `BottleneckState` for an example. +- Implementations for specific composition blocks are given in methods starting with `compose_`. Some composition blocks provide generic default implementations, some must be implemented by the deriving class if they should be supported. Make sure to list all supported composition blocks in the `supported_compositions` class attribute of your deriving module. +- In any case, a small set of helper methods should be implemented by any deriving module to support basic composition logic. These are marked as abstract methods in [`ComposableAdapterLayerBase`](adapters.ComposableAdapterLayerBase) and currently consist of the following: vslice(), pad_and_concat(), repeat(), mean(), compose_single(). See [`ComposableAdapterLayerBase`](adapters.ComposableAdapterLayerBase) for details. + +For a reference implementation, have a look at `BottleneckLayer` for bottleneck adapters. + +#### For all methods + +To actually make use of the newly implemented classes, it's finally necessary to integrate the forward calls to the modules in the actual model implementations. +- This, again, is highly dependent on how the adapter method interacts with the base model classes. Typically, module classes can be integrated either via mixins (see modules starting with "mixin" in `src/adapters/models`) or directly as submodules of the respective model components. +- The model class integration has to be repeated for each supported Transformer model, as they typically don't share a codebase. At this point it is often important to consider where the adapters need to be added to the transformer model and whether there is an implementation that does not require more copying of classes than the current implementation. +Please try to integrate any new adapter method into every model class when it's reasonable. +You can find all currently supported model classes at https://docs.adapterhub.ml/model_overview.html. **Additional things to consider** -- New adapter methods typically also require some changes in the `AdapterLoader` class in `src/transformers/adapters/loading.py` (also see [here](https://docs.adapterhub.ml/extending.html#loading-custom-module-weights)). +- New adapter methods typically also require some changes in the `AdapterLoader` class in `src/adapters/loading.py` (also see [here](https://docs.adapterhub.ml/extending.html#loading-custom-module-weights)). - Depending on the method to be integrated, further changes in other classes might be necessary. ## Testing diff --git a/docs/contributing/adding_adapters_to_a_model.md b/docs/contributing/adding_adapters_to_a_model.md index f574bb806a..9306e8d92f 100644 --- a/docs/contributing/adding_adapters_to_a_model.md +++ b/docs/contributing/adding_adapters_to_a_model.md @@ -27,8 +27,8 @@ Now that we have discussed the purpose of every file in `src/adapters/models/ 0 - self.left = left - self.right = right - self.split_index = split_index + def __init__(self, *split_adapters: List[Union[AdapterCompositionBlock, str]], splits: Union[List[int], int]): + super().__init__(*split_adapters) + self.splits = splits if isinstance(splits, list) else [splits] * len(split_adapters) class BatchSplit(AdapterCompositionBlock): diff --git a/src/adapters/heads/base.py b/src/adapters/heads/base.py index dd43a4e658..d9c7386fec 100644 --- a/src/adapters/heads/base.py +++ b/src/adapters/heads/base.py @@ -20,8 +20,8 @@ from ..composition import AdapterCompositionBlock, BatchSplit, Parallel, parse_heads_from_composition from ..context import AdapterSetup, ForwardContext +from ..methods.modeling import Activation_Function_Class from ..model_mixin import ModelWithHeadsAdaptersMixin -from ..modeling import Activation_Function_Class logger = logging.getLogger(__name__) diff --git a/src/adapters/heads/language_modeling.py b/src/adapters/heads/language_modeling.py index 7e6ec95ccb..3e0cda610a 100644 --- a/src/adapters/heads/language_modeling.py +++ b/src/adapters/heads/language_modeling.py @@ -2,7 +2,7 @@ from transformers.modeling_outputs import CausalLMOutput, CausalLMOutputWithPast, MaskedLMOutput, Seq2SeqLMOutput -from ..modeling import Activation_Function_Class +from ..methods.modeling import Activation_Function_Class from .base import PredictionHead diff --git a/src/adapters/layer.py b/src/adapters/layer.py deleted file mode 100644 index 99b2e151ac..0000000000 --- a/src/adapters/layer.py +++ /dev/null @@ -1,702 +0,0 @@ -from abc import ABCMeta, abstractmethod -from typing import Dict, List, Mapping, Union - -import numpy as np -import torch -from torch import nn - -from .composition import ( - AdapterCompositionBlock, - Average, - BatchSplit, - Fuse, - Parallel, - Split, - Stack, - adjust_tensors_for_parallel, -) -from .configuration import BnConfig -from .context import AdapterSetup, ForwardContext -from .modeling import Adapter, BertFusion, ParallelAdapter - - -# We don't inherit from ABC because __slots__ changes object layout -class AdapterLayerBase(metaclass=ABCMeta): - """ - Base class for all adaptation methods that require per-layer modules. - """ - - @property - def layer_idx(self): - return getattr(self, "_layer_idx", -1) - - @layer_idx.setter - def layer_idx(self, layer_idx): - idx = getattr(self, "_layer_idx", layer_idx) - assert idx == layer_idx - setattr(self, "_layer_idx", idx) - - def get_active_setup(self, module_dict): - if hasattr(self, "adapters_config"): - # First check current context before falling back to defined setup - context = AdapterSetup.get_context() - if context is not None: - adapter_setup = context.adapter_setup - else: - adapter_setup = self.adapters_config.active_setup - else: - adapter_setup = None - skip_adapters = adapter_setup is None or ( - self.adapters_config.skip_layers is not None and self.layer_idx in self.adapters_config.skip_layers - ) - if not skip_adapters and (len(set(module_dict.keys()) & adapter_setup.flatten()) > 0): - return adapter_setup - else: - return None - - def _store_gating_score(self, adapter_name, gating_score): - context = ForwardContext.get_context() - if context.output_adapter_gating_scores: - gating_cache = context.adapter_gating_scores - if self.layer_idx not in gating_cache[adapter_name]: - gating_cache[adapter_name][self.layer_idx] = {} - gating_score = gating_score.detach().squeeze().cpu().numpy() - if len(gating_score.shape) == 0: - gating_score = np.expand_dims(gating_score, axis=0) - cache_score = gating_cache[adapter_name][self.layer_idx].get(self.location_key, None) - if cache_score is not None: - gating_cache[adapter_name][self.layer_idx][self.location_key] = np.column_stack( - (cache_score, gating_score) - ) - else: - gating_cache[adapter_name][self.layer_idx][self.location_key] = gating_score - - def _store_fusion_attentions(self, fusion_name, attentions): - context = ForwardContext.get_context() - if context.output_adapter_fusion_attentions: - attention_cache = context.adapter_fusion_attentions - if self.layer_idx not in attention_cache[fusion_name]: - attention_cache[fusion_name][self.layer_idx] = {} - attention_cache[fusion_name][self.layer_idx][self.location_key] = attentions - - @abstractmethod - def add_adapter(self, adapter_name: str, layer_idx: int) -> bool: - raise NotImplementedError() - - @abstractmethod - def average_adapter(self, adapter_name: str, input_adapters: Dict[str, float]) -> bool: - raise NotImplementedError() - - @abstractmethod - def delete_adapter(self, adapter_name: str): - raise NotImplementedError() - - @abstractmethod - def add_fusion_layer(self, adapter_names: Union[List, str]): - raise NotImplementedError() - - @abstractmethod - def delete_fusion_layer(self, adapter_names: Union[List, str]): - raise NotImplementedError() - - @abstractmethod - def enable_adapters(self, adapter_setup: AdapterCompositionBlock, unfreeze_adapters: bool, unfreeze_fusion: bool): - raise NotImplementedError() - - @abstractmethod - def get_adapter(self, adapter_name: str) -> nn.Module: - raise NotImplementedError() - - -class AdapterLayer(AdapterLayerBase, nn.Module): - def __init__(self, location_key: str): - super().__init__() - self.location_key = location_key - - def init_adapters(self, model_config, adapters_config): - self.model_config = model_config - self.adapters_config = adapters_config - self.adapters = nn.ModuleDict(dict()) - self.adapter_fusion_layer = nn.ModuleDict(dict()) - - def add_adapter(self, adapter_name: str, layer_idx: int) -> bool: - self.layer_idx = layer_idx - adapter_config = self.adapters_config.match( - adapter_name, - config_type=BnConfig, - layer_idx=self.layer_idx, - location_key=self.location_key, - ) - if adapter_config is not None: - reduction_factor = adapter_config["reduction_factor"] - if isinstance(reduction_factor, Mapping): - if str(self.layer_idx) in reduction_factor: - reduction_factor = reduction_factor[str(self.layer_idx)] - elif "default" in reduction_factor: - reduction_factor = reduction_factor["default"] - else: - raise KeyError( - "The given reduction factor mapping does not give a default value and does not specify each " - "reduction factor individually. You need to provide a default value like this: " - '{"1": 16, "default": 16}' - ) - - if adapter_config.is_parallel: - adapter_class = ParallelAdapter - else: - adapter_class = Adapter - adapter = adapter_class( - adapter_name=adapter_name, - input_size=self.model_config.hidden_size, - down_sample=int(self.model_config.hidden_size // reduction_factor), - config=adapter_config, - ) - adapter.train(self.training) # make sure training mode is consistent - self.adapters[adapter_name] = adapter - return True - - return False - - def average_adapter(self, adapter_name: str, input_adapters: Dict[str, float]) -> bool: - # add new adapter - if self.add_adapter(adapter_name, self.layer_idx): - # average weights - avg_state_dict = {} - for name, weight in input_adapters.items(): - if name in self.adapters: - module = self.adapters[name] - for k, v in module.state_dict().items(): - if k in avg_state_dict: - avg_state_dict[k] += weight * v - else: - avg_state_dict[k] = weight * v - else: - self.delete_adapter(adapter_name) # clean up before raising error - raise ValueError("Adapter {} not found.".format(name)) - # load averaged weights - self.adapters[adapter_name].load_state_dict(avg_state_dict) - return True - - return False - - def delete_adapter(self, adapter_name: str): - if adapter_name in self.adapters: - del self.adapters[adapter_name] - - def add_fusion_layer(self, adapter_names: Union[List, str]): - """See BertModel.add_fusion_layer""" - adapter_names = adapter_names if isinstance(adapter_names, list) else adapter_names.split(",") - if self.adapters_config.common_config_value(adapter_names, self.location_key): - fusion_config = self.adapters_config.get_fusion(adapter_names) - dropout_prob = fusion_config.dropout_prob or getattr(self.model_config, "attention_probs_dropout_prob", 0) - fusion = BertFusion( - fusion_config, - self.model_config.hidden_size, - dropout_prob, - ) - fusion.train(self.training) # make sure training mode is consistent - self.adapter_fusion_layer[",".join(adapter_names)] = fusion - - def delete_fusion_layer(self, adapter_names: Union[List, str]): - adapter_names = adapter_names if isinstance(adapter_names, str) else ",".join(adapter_names) - if adapter_names in self.adapter_fusion_layer: - del self.adapter_fusion_layer[adapter_names] - - def enable_adapters(self, adapter_setup: AdapterCompositionBlock, unfreeze_adapters: bool, unfreeze_fusion: bool): - """ - Unfreezes a given list of adapters, the adapter fusion layer, or both - - Args: - adapter_names: names of adapters to unfreeze (or names of adapters part of the fusion layer to unfreeze) - unfreeze_adapters: whether the adapter weights should be activated - unfreeze_fusion: whether the adapter fusion layer for the given adapters should be activated - """ - if unfreeze_adapters: - for adapter_name in adapter_setup.flatten(): - if adapter_name in self.adapters: - for param in self.adapters[adapter_name].parameters(): - param.requires_grad = True - if unfreeze_fusion: - if isinstance(adapter_setup, Fuse): - if adapter_setup.name in self.adapter_fusion_layer: - for param in self.adapter_fusion_layer[adapter_setup.name].parameters(): - param.requires_grad = True - for sub_setup in adapter_setup: - if isinstance(sub_setup, Fuse): - if sub_setup.name in self.adapter_fusion_layer: - for param in self.adapter_fusion_layer[sub_setup.name].parameters(): - param.requires_grad = True - - def freeze_adapter(self, adapter_name: str, freeze: bool = True): - if adapter_name in self.adapters: - self.adapters[adapter_name].train(not freeze) - for param in self.adapters[adapter_name].parameters(): - param.requires_grad = not freeze - - def get_adapter(self, adapter_name: str): - if adapter_name in self.adapters: - return self.adapters[adapter_name] - else: - return None - - def adapter_stack(self, adapter_setup: Stack, hidden_states, input_tensor, layer_norm, lvl=0): - """ - Forwards the given input through the given stack of adapters. - """ - for i, adapter_stack_layer in enumerate(adapter_setup): - # Break if setup is too deep - if isinstance(adapter_stack_layer, AdapterCompositionBlock) and lvl >= 1: - raise ValueError( - "Specified adapter setup is too deep. Cannot have {} at level {}".format( - adapter_stack_layer.__class__.__name__, lvl - ) - ) - # Case 1: We have a nested fusion layer -> call fusion method - if isinstance(adapter_stack_layer, Fuse): - hidden_states = self.adapter_fusion( - adapter_stack_layer, hidden_states, input_tensor, layer_norm, lvl=lvl + 1 - ) - # Case 2: We have a nested split layer -> call split method - elif isinstance(adapter_stack_layer, Split): - hidden_states = self.adapter_split( - adapter_stack_layer, hidden_states, input_tensor, layer_norm, lvl=lvl + 1 - ) - # Case 3: We have a nested parallel layer -> call parallel method - elif isinstance(adapter_stack_layer, Parallel): - hidden_states, input_tensor = self.adapter_parallel( - adapter_stack_layer, hidden_states, input_tensor, layer_norm, lvl=lvl + 1 - ) - # Case 4: We have a nested batch split block -> call batchsplit method - elif isinstance(adapter_stack_layer, BatchSplit): - hidden_states = self.adapter_batchsplit( - adapter_stack_layer, hidden_states, input_tensor, layer_norm, lvl=lvl + 1 - ) - # Case 5: We have a nested average block -> call average method - elif isinstance(adapter_stack_layer, Average): - hidden_states = self.adapter_average_output( - adapter_stack_layer, hidden_states, input_tensor, layer_norm, lvl=lvl + 1 - ) - # Case 6: We have a single adapter which is part of this module -> forward pass - elif adapter_stack_layer in self.adapters: - adapter_layer = self.adapters[adapter_stack_layer] - hidden_states, _, residual = adapter_layer.pre_forward(hidden_states, input_tensor, layer_norm) - context = ForwardContext.get_context() - layer_output = adapter_layer( - hidden_states, residual_input=residual, output_gating=context.output_adapter_gating_scores - ) - hidden_states, up = layer_output[0], layer_output[2] - self._store_gating_score(adapter_stack_layer, layer_output[-1]) - # as this stack might be part of a fusion block, return the adapter up-projection output here - # together with the final output (with potential residuals & norms) if we reached the last block of the stack - if i == len(adapter_setup) - 1: - return hidden_states, up, input_tensor - # Case X: No adapter which is part of this module -> ignore - - # If we got here, we either had another nested composition block - # or no adapter was found. In both cases, we don't need to set the second return value for fusion - return hidden_states, None, input_tensor - - def adapter_fusion(self, adapter_setup: Fuse, hidden_states, input_tensor, layer_norm, lvl=0): - """ - Performs adapter fusion with the given adapters for the given input. - """ - context = ForwardContext.get_context() - - # config of _last_ fused adapter is significant - fusion_config = self.adapters_config.get_fusion(adapter_setup.name) - last_adapter = self.adapters[adapter_setup.last()] - hidden_states, query, residual = last_adapter.pre_forward( - hidden_states, input_tensor, layer_norm, fusion_config=fusion_config - ) - - up_list = [] - - for adapter_block in adapter_setup: - # Case 1: We have a nested stack -> call stack method - if isinstance(adapter_block, Stack): - _, up, _ = self.adapter_stack(adapter_block, hidden_states, input_tensor, layer_norm, lvl=lvl + 1) - if up is not None: # could be none if stack is empty - up_list.append(up) - # Case 2: We have a single adapter which is part of this module -> forward pass - elif adapter_block in self.adapters: - adapter_layer = self.adapters[adapter_block] - layer_output = adapter_layer( - hidden_states, residual_input=residual, output_gating=context.output_adapter_gating_scores - ) - up = layer_output[2] - self._store_gating_score(adapter_block, layer_output[-1]) - up_list.append(up) - # Case 3: nesting other composition blocks is invalid - elif isinstance(adapter_block, AdapterCompositionBlock): - raise ValueError( - "Invalid adapter setup. Cannot nest {} in {}".format( - adapter_block.__class__.__name__, adapter_setup.__class__.__name__ - ) - ) - # Case X: No adapter which is part of this module -> ignore - - if len(up_list) > 0: - up_list = torch.stack(up_list) - up_list = up_list.permute(1, 2, 0, 3) - - fusion_output = self.adapter_fusion_layer[adapter_setup.name]( - query, - up_list, - up_list, - residual, - output_attentions=context.output_adapter_fusion_attentions, - ) - if context.output_adapter_fusion_attentions: - hidden_states = fusion_output[0] - self._store_fusion_attentions(adapter_setup.name, fusion_output[-1]) - else: - hidden_states = fusion_output - - return hidden_states - - def adapter_split(self, adapter_setup: Split, hidden_states, input_tensor, layer_norm, lvl=0): - """ - Splits the given input between the given adapters. - """ - # config of _first_ of splitted adapters is significant - first_adapter = self.adapters[adapter_setup.first()] - hidden_states, query, residual = first_adapter.pre_forward(hidden_states, input_tensor, layer_norm) - - # split hidden representations and residuals at split index - split_hidden_states = [ - hidden_states[:, : adapter_setup.split_index, :], - hidden_states[:, adapter_setup.split_index :, :], - ] - split_input_tensor = [ - input_tensor[:, : adapter_setup.split_index, :], - input_tensor[:, adapter_setup.split_index :, :], - ] - split_residual = [ - residual[:, : adapter_setup.split_index, :], - residual[:, adapter_setup.split_index :, :], - ] - - for i, adapter_block in enumerate(adapter_setup): - # Case 1: We have a nested stack -> call stack method - if isinstance(adapter_block, Stack): - split_hidden_states[i], _, _ = self.adapter_stack( - adapter_block, split_hidden_states[i], split_input_tensor[i], layer_norm, lvl=lvl + 1 - ) - # Case 2: We have a nested split -> recursively call split - elif isinstance(adapter_block, Split): - split_hidden_states[i] = self.adapter_split( - adapter_block, split_hidden_states[i], split_input_tensor[i], layer_norm, lvl=lvl + 1 - ) - # Case 3: We have a nested batch split -> call batch split method - elif isinstance(adapter_block, BatchSplit): - split_hidden_states[i] = self.adapter_batchsplit( - adapter_block, split_hidden_states[i], split_input_tensor[i], layer_norm, lvl=lvl + 1 - ) - # Case 4: We have a nested average -> call average method - elif isinstance(adapter_block, Average): - split_hidden_states[i] = self.adapter_average_output( - adapter_block, split_hidden_states[i], split_input_tensor[i], layer_norm, lvl=lvl + 1 - ) - # Case 5: We have a single adapter which is part of this module -> forward pass - elif adapter_block in self.adapters: - adapter_layer = self.adapters[adapter_block] - context = ForwardContext.get_context() - layer_output = adapter_layer( - split_hidden_states[i], - residual_input=split_residual[i], - output_gating=context.output_adapter_gating_scores, - ) - split_hidden_states[i] = layer_output[0] - self._store_gating_score(adapter_block, layer_output[-1]) - # Case 6: nesting other composition blocks is invalid - elif isinstance(adapter_block, AdapterCompositionBlock): - raise ValueError( - "Invalid adapter setup. Cannot nest {} in {}".format( - adapter_block.__class__.__name__, adapter_setup.__class__.__name__ - ) - ) - # Case X: No adapter which is part of this module -> ignore - - hidden_states = torch.cat(split_hidden_states, dim=1) - return hidden_states - - def adapter_parallel(self, adapter_setup: Parallel, hidden_states, input_tensor, layer_norm, lvl=0): - """ - For parallel execution of the adapters on the same input. This means that the input is repeated N times before - feeding it to the adapters (where N is the number of adapters). - """ - - context = ForwardContext.get_context() - if not context.adapters_parallelized: - orig_batch_size = input_tensor.shape[0] - input_tensor = input_tensor.repeat(self.adapters_config.active_setup.parallel_channels, 1, 1) - hidden_states = hidden_states.repeat(self.adapters_config.active_setup.parallel_channels, 1, 1) - context.adapters_parallelized = True - else: - # The base model should handle replication of input. - # Therefore, we assume the (replicated) input batch to be divisible by the number of parallel channels. - if hidden_states.shape[0] % adapter_setup.parallel_channels != 0: - raise ValueError( - "The total input batch size in a Parallel adapter block must be divisible by the number of" - " parallel channels." - ) - orig_batch_size = hidden_states.shape[0] // adapter_setup.parallel_channels - - # We assume all adapters have the same config - first_adapter = self.adapters[adapter_setup.first()] - hidden_states, _, residual = first_adapter.pre_forward(hidden_states, input_tensor, layer_norm) - - # sequentially feed different parts of the blown-up batch into different adapters - children_hidden = [] - for i, child in enumerate(adapter_setup): - # Case 1: We have a nested stack -> call stack method - if isinstance(child, Stack): - child_hidden_states, _, _ = self.adapter_stack( - child, - hidden_states[i * orig_batch_size : (i + 1) * orig_batch_size], - input_tensor[i * orig_batch_size : (i + 1) * orig_batch_size], - layer_norm, - lvl=lvl + 1, - ) - children_hidden.append(child_hidden_states) - # Case 2: We have a nested batchsplit block -> call batchsplit method - elif isinstance(child, BatchSplit): - child_hidden_states = self.adapter_batchsplit( - child, - hidden_states[i * orig_batch_size : (i + 1) * orig_batch_size], - input_tensor[i * orig_batch_size : (i + 1) * orig_batch_size], - layer_norm, - lvl=lvl + 1, - ) - children_hidden.append(child_hidden_states) - # Case 3: We have a nested average block -> call average method - elif isinstance(child, Average): - child_hidden_states = self.adapter_average_output( - child, - hidden_states[i * orig_batch_size : (i + 1) * orig_batch_size], - input_tensor[i * orig_batch_size : (i + 1) * orig_batch_size], - layer_norm, - lvl=lvl + 1, - ) - children_hidden.append(child_hidden_states) - # Case 4: We have a single adapter which is part of this module -> forward pass - elif child in self.adapters: - adapter_layer = self.adapters[child] - context = ForwardContext.get_context() - layer_output = adapter_layer( - hidden_states[i * orig_batch_size : (i + 1) * orig_batch_size], - residual_input=residual[i * orig_batch_size : (i + 1) * orig_batch_size], - output_gating=context.output_adapter_gating_scores, - ) - child_hidden_states = layer_output[0] - self._store_gating_score(child, layer_output[-1]) - children_hidden.append(child_hidden_states) - # Case 5: nesting other composition blocks is invalid - elif isinstance(child, AdapterCompositionBlock): - raise ValueError( - "Invalid adapter setup. Cannot nest {} in {}".format( - child.__class__.__name__, adapter_setup.__class__.__name__ - ) - ) - # Case X: No adapter which is part of this module -> ignore - else: - children_hidden.append(hidden_states[i * orig_batch_size : (i + 1) * orig_batch_size]) - - # concatenate all outputs and return - hidden_states = torch.cat(children_hidden, 0) - return hidden_states, input_tensor - - def adapter_batchsplit(self, adapter_setup: BatchSplit, hidden_states, input_tensor, layer_norm, lvl=0): - if not sum(adapter_setup.batch_sizes) == hidden_states.shape[0]: - raise IndexError( - "The given batch has a size of {} which is not compatible with batch_sizes {}".format( - hidden_states.shape[0], adapter_setup.batch_sizes - ) - ) - - first_adapter = self.adapters[adapter_setup.first()] - hidden_states, _, residual = first_adapter.pre_forward(hidden_states, input_tensor, layer_norm) - children_hidden = [] - for i, adapter_block in enumerate(adapter_setup): - # compute ids of sequences thet should be passed to the ith adapter - batch_idx = ( - sum(adapter_setup.batch_sizes[:i]), - sum(adapter_setup.batch_sizes[: i + 1]), - ) - # Case 1: We have a nested stack -> call stack method - if isinstance(adapter_block, Stack): - child, _, _ = self.adapter_stack( - adapter_block, - hidden_states[batch_idx[0] : batch_idx[1]], - input_tensor[batch_idx[0] : batch_idx[1]], - layer_norm, - lvl=lvl + 1, - ) - children_hidden.append(child) - # Case 2: We have a nested split -> recursively call split - elif isinstance(adapter_block, Split): - child = self.adapter_split( - adapter_block, - hidden_states[batch_idx[0] : batch_idx[1]], - input_tensor[batch_idx[0] : batch_idx[1]], - layer_norm, - lvl=lvl + 1, - ) - children_hidden.append(child) - # Case 3: We have a nested batch split block -> call batchsplit method - elif isinstance(adapter_block, BatchSplit): - child = self.adapter_batchsplit( - adapter_block, - hidden_states[batch_idx[0] : batch_idx[1]], - input_tensor[batch_idx[0] : batch_idx[1]], - layer_norm, - lvl=lvl + 1, - ) - children_hidden.append(child) - # Case 4: We have a nested average block -> call average method - elif isinstance(adapter_block, Average): - child = self.adapter_average_output( - adapter_block, - hidden_states[batch_idx[0] : batch_idx[1]], - input_tensor[batch_idx[0] : batch_idx[1]], - layer_norm, - lvl=lvl + 1, - ) - children_hidden.append(child) - # Case 5: We have a single adapter which is part of this module -> forward pass - elif adapter_block in self.adapters: - - adapter_layer = self.adapters[adapter_block] - context = ForwardContext.get_context() - layer_output = adapter_layer( - hidden_states[batch_idx[0] : batch_idx[1]], - residual_input=residual[batch_idx[0] : batch_idx[1]], - output_gating=context.output_adapter_gating_scores, - ) - children_hidden.append(layer_output[0]) - self._store_gating_score(adapter_block, layer_output[-1]) - # Case 6: nesting other composition blocks is invalid - elif isinstance(adapter_block, AdapterCompositionBlock): - raise ValueError( - "Invalid adapter setup. Cannot nest {} in {}".format( - adapter_block.__class__.__name__, adapter_setup.__class__.__name__ - ) - ) - # Case X: No adapter which is part of this module -> ignore - else: - children_hidden.append(hidden_states[batch_idx]) - - hidden_states = torch.cat(children_hidden, 0) - return hidden_states - - def adapter_average_output(self, adapter_setup: Average, hidden_states, input_tensor, layer_norm, lvl=0): - """ - For averaging the output representations of multiple adapters. - """ - context = ForwardContext.get_context() - - # We assume all adapters have the same config - first_adapter = self.adapters[adapter_setup.first()] - hidden_states, _, residual = first_adapter.pre_forward(hidden_states, input_tensor, layer_norm) - - children_hidden = [] - - for adapter_block in adapter_setup: - # Case 1: We have a nested stack -> call stack method - if isinstance(adapter_block, Stack): - child, _, _ = self.adapter_stack(adapter_block, hidden_states, input_tensor, layer_norm, lvl=lvl + 1) - children_hidden.append(child) - # Case 2: We have a nested split block -> call split method - elif isinstance(adapter_block, Split): - child = self.adapter_split(adapter_block, hidden_states, input_tensor, layer_norm, lvl=lvl + 1) - children_hidden.append(child) - # Case 3: We have a nested batch split block -> call batchsplit method - elif isinstance(adapter_block, BatchSplit): - child = self.adapter_batchsplit(adapter_block, hidden_states, input_tensor, layer_norm, lvl=lvl + 1) - children_hidden.append(child) - # Case 4: We have a single adapter which is part of this module -> forward pass - elif adapter_block in self.adapters: - adapter_layer = self.adapters[adapter_block] - layer_output = adapter_layer( - hidden_states, residual_input=residual, output_gating=context.output_adapter_gating_scores - ) - children_hidden.append(layer_output[0]) - self._store_gating_score(adapter_block, layer_output[-1]) - # Case 5: nesting other composition blocks is invalid - elif isinstance(adapter_block, AdapterCompositionBlock): - raise ValueError( - "Invalid adapter setup. Cannot nest {} in {}".format( - adapter_block.__class__.__name__, adapter_setup.__class__.__name__ - ) - ) - # Case X: No adapter which is part of this module -> ignore - - weights = torch.tensor(adapter_setup.weights)[:, None, None, None].to(hidden_states.device) - hidden_states = torch.mean(torch.stack(children_hidden, 0) * weights, 0) - - return hidden_states - - def adapter_layer_forward(self, hidden_states, residual_input, layer_norm): - """Forward pass through the adapter layer. - NOTE: This method should only be called if the calling module directly inherits from AdapterLayer. Otherwise, - call the regular forward() method. - - Args: - hidden_states (torch.Tensor): Input hidden states to the adapter layer. - residual_input (torch.Tensor): Residual input to the adapter layer. - layer_norm (torch.nn.Module): Transformer layer normalization module to be used by the adapter layer. - - Returns: - torch.Tensor: Output hidden states of the adapter layer. - """ - # Batch sizes might be different due to prefix tuning w. Parallel block - (residual_input,) = adjust_tensors_for_parallel(hidden_states, residual_input) - # Replicate in both directions as residual might be larger (e.g. GPT-J) - (hidden_states,) = adjust_tensors_for_parallel(residual_input, hidden_states) - adapter_setup = self.get_active_setup(self.adapters) - if adapter_setup is not None: - input_hidden_states = hidden_states - - if isinstance(adapter_setup, Stack): - hidden_states, _, residual_input = self.adapter_stack( - adapter_setup, hidden_states, residual_input, layer_norm - ) - elif isinstance(adapter_setup, Fuse): - hidden_states = self.adapter_fusion(adapter_setup, hidden_states, residual_input, layer_norm) - elif isinstance(adapter_setup, Split): - hidden_states = self.adapter_split(adapter_setup, hidden_states, residual_input, layer_norm) - elif isinstance(adapter_setup, Parallel): - # notice that we are overriding input tensor here to keep the same dim as hidden_states for the residual - # in case we were blowing up the batch for parallel processing of multiple adapters for the same input - hidden_states, residual_input = self.adapter_parallel( - adapter_setup, hidden_states, residual_input, layer_norm - ) - elif isinstance(adapter_setup, BatchSplit): - hidden_states = self.adapter_batchsplit(adapter_setup, hidden_states, residual_input, layer_norm) - elif isinstance(adapter_setup, Average): - hidden_states = self.adapter_average_output(adapter_setup, hidden_states, residual_input, layer_norm) - else: - raise ValueError(f"Invalid adapter setup {adapter_setup}") - - last_adapter = self.adapters[adapter_setup.last()] - hidden_states = last_adapter.post_forward(hidden_states, input_hidden_states, residual_input, layer_norm) - - elif layer_norm: - hidden_states = layer_norm(hidden_states + residual_input) - else: - hidden_states = hidden_states + residual_input - - return hidden_states - - def forward(self, hidden_states, residual_input, layer_norm): - """Forward pass through the adapter layer. - - Args: - hidden_states (torch.Tensor): Input hidden states to the adapter layer. - residual_input (torch.Tensor): Residual input to the adapter layer. - layer_norm (torch.nn.Module): Transformer layer normalization module to be used by the adapter layer. - - Returns: - torch.Tensor: Output hidden states of the adapter layer. - """ - return self.adapter_layer_forward(hidden_states, residual_input, layer_norm) diff --git a/src/adapters/methods/__init__.py b/src/adapters/methods/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/adapters/methods/adapter_layer_base.py b/src/adapters/methods/adapter_layer_base.py new file mode 100644 index 0000000000..b89b75cb14 --- /dev/null +++ b/src/adapters/methods/adapter_layer_base.py @@ -0,0 +1,471 @@ +from abc import ABCMeta, abstractmethod +from typing import Collection, Dict, List, NamedTuple, Union + +import numpy as np +import torch +from torch import nn + +from ..composition import ALLOWED_NESTINGS, AdapterCompositionBlock, Average, BatchSplit, Fuse, Parallel, Split, Stack +from ..context import AdapterSetup, ForwardContext + + +# We don't inherit from ABC because __slots__ changes object layout +class AdapterLayerBase(metaclass=ABCMeta): + """ + Base class for all adaptation methods that require per-layer modules. + + Make sure the 'adapter_modules_name' attribute is overriden in derived classes. + """ + + adapter_modules_name = "" + + @property + def adapter_modules(self) -> Collection: + return getattr(self, self.adapter_modules_name) + + @property + def layer_idx(self): + return getattr(self, "_layer_idx", -1) + + @layer_idx.setter + def layer_idx(self, layer_idx): + idx = getattr(self, "_layer_idx", layer_idx) + assert idx == layer_idx + setattr(self, "_layer_idx", idx) + + def get_active_setup(self): + if hasattr(self, "adapters_config"): + # First check current context before falling back to defined setup + context = AdapterSetup.get_context() + if context is not None: + adapter_setup = context.adapter_setup + else: + adapter_setup = self.adapters_config.active_setup + else: + adapter_setup = None + skip_adapters = adapter_setup is None or ( + self.adapters_config.skip_layers is not None and self.layer_idx in self.adapters_config.skip_layers + ) + if not skip_adapters and (len(set(self.adapter_modules.keys()) & adapter_setup.flatten()) > 0): + return adapter_setup + else: + return None + + def _store_gating_score(self, adapter_name, gating_score): + context = ForwardContext.get_context() + if context.output_adapter_gating_scores: + gating_cache = context.adapter_gating_scores + if self.layer_idx not in gating_cache[adapter_name]: + gating_cache[adapter_name][self.layer_idx] = {} + gating_score = gating_score.detach().squeeze().cpu().numpy() + if len(gating_score.shape) == 0: + gating_score = np.expand_dims(gating_score, axis=0) + cache_score = gating_cache[adapter_name][self.layer_idx].get(self.location_key, None) + if cache_score is not None: + gating_cache[adapter_name][self.layer_idx][self.location_key] = np.column_stack( + (cache_score, gating_score) + ) + else: + gating_cache[adapter_name][self.layer_idx][self.location_key] = gating_score + + def _store_fusion_attentions(self, fusion_name, attentions): + context = ForwardContext.get_context() + if context.output_adapter_fusion_attentions: + attention_cache = context.adapter_fusion_attentions + if self.layer_idx not in attention_cache[fusion_name]: + attention_cache[fusion_name][self.layer_idx] = {} + attention_cache[fusion_name][self.layer_idx][self.location_key] = attentions + + @abstractmethod + def add_adapter(self, adapter_name: str, layer_idx: int) -> bool: + """Adds a new adapter module to the layer. + + Args: + adapter_name (str): The name of the new adapter to add. + layer_idx (int): + The index of the adapters layer (this should be set once by the first added adapter and the kept fix). + + Returns: + bool: True if the adapter was added, False otherwise. + """ + raise NotImplementedError() + + @abstractmethod + def average_adapter(self, adapter_name: str, input_adapters: Dict[str, float]) -> bool: + """Averages a set of adapter modules into a new adapter module. + + Args: + adapter_name (str): The name of the new (averaged) adapter module to add. + input_adapters (Dict[str, float]): Either: + - a list of adapter names (with equal weighting). + - a dictionary of adapter names and their corresponding weights. + + Returns: + bool: True if the adapter was added, False otherwise. + """ + raise NotImplementedError() + + @abstractmethod + def delete_adapter(self, adapter_name: str): + """Deletes an adapter module from the layer. + + Args: + adapter_name (str): The name of the adapter to delete. + """ + raise NotImplementedError() + + @abstractmethod + def add_fusion_layer(self, adapter_names: Union[List, str]): + # TODO remove this method from the base class + raise NotImplementedError() + + @abstractmethod + def delete_fusion_layer(self, adapter_names: Union[List, str]): + # TODO remove this method from the base class + raise NotImplementedError() + + @abstractmethod + def enable_adapters(self, adapter_setup: AdapterCompositionBlock, unfreeze_adapters: bool, unfreeze_fusion: bool): + """Enables/ disables a set of adapter modules within the layer. + + Args: + adapter_setup (AdapterCompositionBlock): The adapter setup to enable/ disable. + unfreeze_adapters (bool): Whether to unfreeze the adapters. + unfreeze_fusion (bool): Whether to unfreeze the fusion layers. + """ + raise NotImplementedError() + + @abstractmethod + def get_adapter(self, adapter_name: str) -> nn.Module: + """Returns the adapter module with the given name. + + Args: + adapter_name (str): The name of the adapter module. + """ + raise NotImplementedError() + + +class ComposableAdapterLayerBase(AdapterLayerBase): + """ + Base class for all adapter methods that support composition. + + Make sure the 'adapter_modules_name' and 'supported_compositions' attributes as well as all abstract methods are + overriden in derived classes. + """ + + supported_compositions = [] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._init_mapping() + + def _init_mapping(self): + self.composition_to_func_map = { + Stack: self.compose_stack, + Fuse: self.compose_fuse, + Split: self.compose_split, + BatchSplit: self.compose_batch_split, + Parallel: self.compose_parallel, + Average: self.compose_average, + } + + # START CUSTOMIZABLE METHODS # + # The following methods should be implemented in derived classes. + + def _bsz(self, state: NamedTuple) -> int: + """ + Returns the batch size of the given state. + """ + return state[0].shape[0] + + def pre_block(self, adapter_setup: Union[AdapterCompositionBlock, str], state: NamedTuple) -> NamedTuple: + """ + Optional state pre-processing method which is invoked before passing the state to the first child block of a + composition. By default, this method does not contain any logic. E.g. used for bottleneck adapters to implement + residuals and LNs. + + Args: + adapter_setup (Union[AdapterCompositionBlock, str]): The current composition or single adapter. + state (NamedTuple): The current state. + + Returns: + NamedTuple: The pre-processed state. + """ + return state + + @abstractmethod + def vslice(self, state: NamedTuple, slice_obj: slice) -> NamedTuple: + """Slices the given state along the batch size (vertical) dimension. + This is e.g. used by the BatchSplit and Parallel composition blocks. IMPORTANT: Has to be implemented by all + derived classes. + + Args: + state (NamedTuple): The state to be sliced. + slice_obj (slice): The slice object. + + Returns: + NamedTuple: The sliced state. + """ + raise NotImplementedError() + + @abstractmethod + def pad_and_concat(self, states: List[NamedTuple]) -> NamedTuple: + """Concatenates the given states along the batch size dimension. + Pads the states before concatenation if necessary. This is e.g. used by the BatchSplit and Parallel composition + blocks. IMPORTANT: Has to be implemented by all derived classes. + + Args: + states (List[NamedTuple]): The states to be concatenated. + + Returns: + NamedTuple: The concatenated state. + """ + raise NotImplementedError() + + @abstractmethod + def repeat(self, state: NamedTuple, channels: int) -> NamedTuple: + """Repeats the given state along the batch size dimension for the given number of times. + This is e.g. used by the Parallel composition block. IMPORTANT: Has to be implemented by all derived classes. + + Args: + state (NamedTuple): The state to be repeated. + channels (int): The number of times the state should be repeated. + + Returns: + NamedTuple: The repeated state. + """ + raise NotImplementedError() + + @abstractmethod + def mean(self, states: List[NamedTuple], weights: torch.Tensor) -> NamedTuple: + """Averages the given states along the batch size dimension by the given weights. + This is e.g. used by the Average composition block. IMPORTANT: Has to be implemented by all derived classes. + + Args: + states (List[NamedTuple]): The states to be averaged. + weights (torch.Tensor): The averaging weights. + + Returns: + NamedTuple: The averaged state. + """ + raise NotImplementedError() + + @abstractmethod + def compose_single(self, adapter_setup: str, state: NamedTuple, lvl: int = 0) -> NamedTuple: + """Forwards the given state through the given single adapter. + + Args: + adapter_setup (str): The name of the adapter. + state (NamedTuple): The state to be forwarded. + lvl (int, optional): The composition depth. Defaults to 0. + + Returns: + NamedTuple: The state after forwarding through the adapter. + """ + raise NotImplementedError() + + # END CUSTOMIZABLE METHODS # + + def check_composition_valid(self, parent: AdapterCompositionBlock, child: AdapterCompositionBlock, lvl: int): + """Checks whether the given composition is valid. + + Args: + parent (AdapterCompositionBlock): The parent composition block. + child (AdapterCompositionBlock): The child composition block. + lvl (int): The composition depth. + + Raises: + ValueError: If the composition is invalid. + """ + # Break if setup is too deep + if isinstance(parent, Stack) and lvl >= 1: + raise ValueError( + "Specified adapter setup is too deep. Cannot have {} at level {}".format(child.__class__.__name__, lvl) + ) + elif type(child) not in ALLOWED_NESTINGS[type(parent)]: + raise ValueError( + "Cannot nest {} inside {}. Only the following nestings are allowed: {}".format( + child.__class__.__name__, + parent.__class__.__name__, + ", ".join([t.__name__ for t in ALLOWED_NESTINGS[type(parent)]]), + ) + ) + + def compose_stack(self, adapter_setup: Stack, state: NamedTuple, lvl: int = 0) -> NamedTuple: + """ + For sequentially stacking multiple adapters. + """ + for i, adapter_stack_layer in enumerate(adapter_setup): + if isinstance(adapter_stack_layer, AdapterCompositionBlock): + self.check_composition_valid(adapter_setup, adapter_stack_layer, lvl) + composition_func = self.composition_to_func_map[type(adapter_stack_layer)] + state = composition_func(adapter_stack_layer, state, lvl=lvl + 1) + elif adapter_stack_layer in self.adapter_modules: + state = self.pre_block(adapter_stack_layer, state) + state = self.compose_single(adapter_stack_layer, state, lvl=lvl + 1) + else: + raise ValueError( + "Invalid adapter setup: {} is not a valid adapter name or composition block.".format( + adapter_stack_layer.__class__.__name__ + ) + ) + + return state + + def compose_fuse(self, adapter_setup: Fuse, state: NamedTuple, lvl: int = 0): + """ + For fusing multiple adapters using adapter fusion. NOTE: This method has no default implementation. + """ + # Fuse is currently only applicable to bottleneck adapters, thus don't provide a default implementation + raise NotImplementedError() + + def compose_split(self, adapter_setup: Split, state: NamedTuple, lvl: int = 0): + """ + For splitting to multiple adapters along the sequence length dimension. NOTE: This method has no default + implementation. + """ + # Split is currently only applicable to bottleneck adapters, thus don't provide a default implementation + raise NotImplementedError() + + def compose_batch_split(self, adapter_setup: BatchSplit, state: NamedTuple, lvl: int = 0): + """ + For splitting to multiple adapters along the batch size dimension. + """ + if sum(adapter_setup.batch_sizes) != self._bsz(state): + raise IndexError( + "The given batch has a size of {} which is not equal to the sum of batch_sizes {}".format( + self._bsz(state), adapter_setup.batch_sizes + ) + ) + + state = self.pre_block(adapter_setup, state) + + # sequentially feed different parts of the blown-up batch into different adapters + children_states = [] + for i, child in enumerate(adapter_setup): + # compute ids of sequences thet should be passed to the ith adapter + batch_idx = ( + sum(adapter_setup.batch_sizes[:i]), + sum(adapter_setup.batch_sizes[: i + 1]), + ) + if isinstance(child, AdapterCompositionBlock): + self.check_composition_valid(adapter_setup, child, lvl) + composition_func = self.composition_to_func_map[type(child)] + child_state = composition_func( + child, + self.vslice(state, slice(*batch_idx)), + lvl=lvl + 1, + ) + children_states.append(child_state) + elif child in self.adapter_modules: + child_state = self.compose_single( + child, + self.vslice(state, slice(*batch_idx)), + lvl=lvl + 1, + ) + children_states.append(child_state) + else: + children_states.append(self.vslice(state, slice(*batch_idx))) + + # concatenate all outputs and return + state = self.pad_and_concat(children_states) + return state + + def compose_parallel(self, adapter_setup: Parallel, state: NamedTuple, lvl: int = 0): + """ + For parallel execution of the adapters on the same input. This means that the input is repeated N times before + feeding it to the adapters (where N is the number of adapters). + """ + + context = ForwardContext.get_context() + if not context.adapters_parallelized: + orig_batch_size = self._bsz(state) + state = self.repeat(state, adapter_setup.parallel_channels) + context.adapters_parallelized = True + else: + # The base model should handle replication of input. + # Therefore, we assume the (replicated) input batch to be divisible by the number of parallel channels. + if self._bsz(state) % adapter_setup.parallel_channels != 0: + raise ValueError( + "The total input batch size in a Parallel adapter block must be divisible by the number of" + " parallel channels." + ) + orig_batch_size = self._bsz(state) // adapter_setup.parallel_channels + + state = self.pre_block(adapter_setup, state) + + # sequentially feed different parts of the blown-up batch into different adapters + children_states = [] + for i, child in enumerate(adapter_setup): + if isinstance(child, AdapterCompositionBlock): + self.check_composition_valid(adapter_setup, child, lvl) + composition_func = self.composition_to_func_map[type(child)] + child_state = composition_func( + child, + self.vslice(state, slice(i * orig_batch_size, (i + 1) * orig_batch_size)), + lvl=lvl + 1, + ) + children_states.append(child_state) + elif child in self.adapter_modules: + child_state = self.compose_single( + child, + self.vslice(state, slice(i * orig_batch_size, (i + 1) * orig_batch_size)), + lvl=lvl + 1, + ) + children_states.append(child_state) + else: + children_states.append(self.vslice(state, slice(i * orig_batch_size, (i + 1) * orig_batch_size))) + + # concatenate all outputs and return + state = self.pad_and_concat(children_states) + return state + + def compose_average(self, adapter_setup: Average, state: NamedTuple, lvl: int = 0): + """ + For averaging the output representations of multiple adapters. + """ + + state = self.pre_block(adapter_setup, state) + + children_states = [] + for i, child in enumerate(adapter_setup): + if isinstance(child, AdapterCompositionBlock): + self.check_composition_valid(adapter_setup, child, lvl) + composition_func = self.composition_to_func_map[type(child)] + child_state = composition_func(child, state, lvl=lvl + 1) + children_states.append(child_state) + elif child in self.adapter_modules: + child_state = self.compose_single(child, state, lvl=lvl + 1) + children_states.append(child_state) + else: + pass + + weights = torch.tensor(adapter_setup.weights)[:, None, None, None].to(state[0].device) + state = self.mean(children_states, weights) + + return state + + def compose(self, adapter_setup: Union[AdapterCompositionBlock, str], state: NamedTuple) -> NamedTuple: + """The main composition forward method which recursively calls the composition blocks forward methods. + This method should be called by the forward method of the derived class. + + Args: + adapter_setup (Union[AdapterCompositionBlock, str]): The adapter setup to be used. + state (NamedTuple): The current state. + + Returns: + NamedTuple: The state after forwarding through the adapter setup. + """ + if isinstance(adapter_setup, AdapterCompositionBlock): + composition_func = self.composition_to_func_map[type(adapter_setup)] + state = composition_func(adapter_setup, state, lvl=0) + elif adapter_setup in self.adapter_modules: + state = self.compose_single(adapter_setup, state, lvl=0) + else: + raise ValueError( + "Invalid adapter setup: {} is not a valid adapter name or composition block.".format( + adapter_setup.__class__.__name__ + ) + ) + + return state diff --git a/src/adapters/methods/bottleneck.py b/src/adapters/methods/bottleneck.py new file mode 100644 index 0000000000..c150191695 --- /dev/null +++ b/src/adapters/methods/bottleneck.py @@ -0,0 +1,372 @@ +from typing import Dict, List, Mapping, NamedTuple, Optional, Union + +import torch +from torch import nn + +from ..composition import ( + AdapterCompositionBlock, + Average, + BatchSplit, + Fuse, + Parallel, + Split, + Stack, + adjust_tensors_for_parallel, +) +from ..configuration import BnConfig +from ..context import ForwardContext +from .adapter_layer_base import ComposableAdapterLayerBase +from .modeling import Adapter, BertFusion, ParallelAdapter + + +class BottleneckState(NamedTuple): + """ + Models the input and output states of a bottleneck adapter layer. + + Args: + hidden_states (torch.Tensor): The layer input/ output hidden states. + input_tensor (torch.Tensor): The Transformer sub-block residual connection inputs. + adapter_residual (torch.Tensor): The adapter residual connection inputs. + layer_norm (torch.nn.Module, optional): The Transformer layer norm module. + bottleneck_up (torch.Tensor, optional): + The up-projected bottleneck MLP output. This is only for Fuse compositions. + """ + + hidden_states: torch.Tensor + input_tensor: torch.Tensor + adapter_residual: torch.Tensor + layer_norm: Optional[torch.nn.Module] + bottleneck_up: Optional[torch.Tensor] = None + + +class BottleneckLayer(ComposableAdapterLayerBase, nn.Module): + adapter_modules_name = "adapters" + supported_compositions = [Stack, Fuse, Split, Parallel, BatchSplit, Average] + + def __init__(self, location_key: str): + super().__init__() + self.location_key = location_key + + def init_adapters(self, model_config, adapters_config): + self._init_mapping() + self.model_config = model_config + self.adapters_config = adapters_config + self.adapters = nn.ModuleDict(dict()) + self.adapter_fusion_layer = nn.ModuleDict(dict()) + + def add_adapter(self, adapter_name: str, layer_idx: int) -> bool: + self.layer_idx = layer_idx + adapter_config = self.adapters_config.match( + adapter_name, + config_type=BnConfig, + layer_idx=self.layer_idx, + location_key=self.location_key, + ) + if adapter_config is not None: + reduction_factor = adapter_config["reduction_factor"] + if isinstance(reduction_factor, Mapping): + if str(self.layer_idx) in reduction_factor: + reduction_factor = reduction_factor[str(self.layer_idx)] + elif "default" in reduction_factor: + reduction_factor = reduction_factor["default"] + else: + raise KeyError( + "The given reduction factor mapping does not give a default value and does not specify each " + "reduction factor individually. You need to provide a default value like this: " + '{"1": 16, "default": 16}' + ) + + if adapter_config.is_parallel: + adapter_class = ParallelAdapter + else: + adapter_class = Adapter + adapter = adapter_class( + adapter_name=adapter_name, + input_size=self.model_config.hidden_size, + down_sample=int(self.model_config.hidden_size // reduction_factor), + config=adapter_config, + ) + adapter.train(self.training) # make sure training mode is consistent + self.adapters[adapter_name] = adapter + return True + + return False + + def average_adapter(self, adapter_name: str, input_adapters: Dict[str, float]) -> bool: + # add new adapter + if self.add_adapter(adapter_name, self.layer_idx): + # average weights + avg_state_dict = {} + for name, weight in input_adapters.items(): + if name in self.adapters: + module = self.adapters[name] + for k, v in module.state_dict().items(): + if k in avg_state_dict: + avg_state_dict[k] += weight * v + else: + avg_state_dict[k] = weight * v + else: + self.delete_adapter(adapter_name) # clean up before raising error + raise ValueError("Adapter {} not found.".format(name)) + # load averaged weights + self.adapters[adapter_name].load_state_dict(avg_state_dict) + return True + + return False + + def delete_adapter(self, adapter_name: str): + if adapter_name in self.adapters: + del self.adapters[adapter_name] + + def add_fusion_layer(self, adapter_names: Union[List, str]): + """See BertModel.add_fusion_layer""" + adapter_names = adapter_names if isinstance(adapter_names, list) else adapter_names.split(",") + if self.adapters_config.common_config_value(adapter_names, self.location_key): + fusion_config = self.adapters_config.get_fusion(adapter_names) + dropout_prob = fusion_config.dropout_prob or getattr(self.model_config, "attention_probs_dropout_prob", 0) + fusion = BertFusion( + fusion_config, + self.model_config.hidden_size, + dropout_prob, + ) + fusion.train(self.training) # make sure training mode is consistent + self.adapter_fusion_layer[",".join(adapter_names)] = fusion + + def delete_fusion_layer(self, adapter_names: Union[List, str]): + adapter_names = adapter_names if isinstance(adapter_names, str) else ",".join(adapter_names) + if adapter_names in self.adapter_fusion_layer: + del self.adapter_fusion_layer[adapter_names] + + def enable_adapters(self, adapter_setup: AdapterCompositionBlock, unfreeze_adapters: bool, unfreeze_fusion: bool): + """ + Unfreezes a given list of adapters, the adapter fusion layer, or both + + Args: + adapter_names: names of adapters to unfreeze (or names of adapters part of the fusion layer to unfreeze) + unfreeze_adapters: whether the adapter weights should be activated + unfreeze_fusion: whether the adapter fusion layer for the given adapters should be activated + """ + if unfreeze_adapters: + for adapter_name in adapter_setup.flatten(): + if adapter_name in self.adapters: + for param in self.adapters[adapter_name].parameters(): + param.requires_grad = True + if unfreeze_fusion: + if isinstance(adapter_setup, Fuse): + if adapter_setup.name in self.adapter_fusion_layer: + for param in self.adapter_fusion_layer[adapter_setup.name].parameters(): + param.requires_grad = True + for sub_setup in adapter_setup: + if isinstance(sub_setup, Fuse): + if sub_setup.name in self.adapter_fusion_layer: + for param in self.adapter_fusion_layer[sub_setup.name].parameters(): + param.requires_grad = True + + def freeze_adapter(self, adapter_name: str, freeze: bool = True): + if adapter_name in self.adapters: + self.adapters[adapter_name].train(not freeze) + for param in self.adapters[adapter_name].parameters(): + param.requires_grad = not freeze + + def get_adapter(self, adapter_name: str): + if adapter_name in self.adapters: + return self.adapters[adapter_name] + else: + return None + + def pre_block(self, adapter_setup: Union[AdapterCompositionBlock, str], state: BottleneckState) -> BottleneckState: + if isinstance(adapter_setup, AdapterCompositionBlock): + adapter_name = adapter_setup.first() + else: + adapter_name = adapter_setup + first_adapter = self.adapters[adapter_name] + hidden_states, _, residual = first_adapter.pre_forward( + state.hidden_states, state.input_tensor, state.layer_norm + ) + + return state._replace(hidden_states=hidden_states, adapter_residual=residual) + + def vslice(self, state: BottleneckState, slice_obj: slice) -> BottleneckState: + return BottleneckState( + state.hidden_states[slice_obj], + state.input_tensor[slice_obj], + state.adapter_residual[slice_obj], + state.layer_norm, + state.bottleneck_up[slice_obj] if state.bottleneck_up is not None else None, + ) + + def pad_and_concat(self, states: List[BottleneckState]) -> BottleneckState: + return BottleneckState( + torch.cat([state.hidden_states for state in states], dim=0), + torch.cat([state.input_tensor for state in states], dim=0), + torch.cat([state.adapter_residual for state in states], dim=0), + states[0].layer_norm, + torch.cat([state.bottleneck_up for state in states], dim=0) + if states[0].bottleneck_up is not None + else None, + ) + + def repeat(self, state: BottleneckState, channels: int) -> BottleneckState: + return BottleneckState( + state.hidden_states.repeat(channels, 1, 1), + state.input_tensor.repeat(channels, 1, 1), + state.adapter_residual.repeat(channels, 1, 1), + state.layer_norm, + state.bottleneck_up.repeat(channels, 1, 1) if state.bottleneck_up is not None else None, + ) + + def mean(self, states: List[BottleneckState], weights: torch.Tensor) -> BottleneckState: + return BottleneckState( + torch.mean(torch.stack([s.hidden_states for s in states], 0) * weights, dim=0), + states[0].input_tensor, + states[0].adapter_residual, + states[0].layer_norm, + states[0].bottleneck_up, + ) + + def compose_single(self, adapter_setup: str, state: BottleneckState, lvl: int = 0) -> BottleneckState: + adapter_layer = self.adapters[adapter_setup] + context = ForwardContext.get_context() + layer_output = adapter_layer( + state.hidden_states, + residual_input=state.adapter_residual, + output_gating=context.output_adapter_gating_scores, + ) + hidden_states, up = layer_output[0], layer_output[2] + self._store_gating_score(adapter_setup, layer_output[-1]) + + return BottleneckState(hidden_states, state.input_tensor, state.adapter_residual, state.layer_norm, up) + + def compose_fuse(self, adapter_setup: Fuse, state: BottleneckState, lvl: int = 0): + """ + Performs adapter fusion with the given adapters for the given input. + """ + context = ForwardContext.get_context() + + # config of _last_ fused adapter is significant + fusion_config = self.adapters_config.get_fusion(adapter_setup.name) + last_adapter = self.adapters[adapter_setup.last()] + hidden_states, query, residual = last_adapter.pre_forward( + state.hidden_states, state.input_tensor, state.layer_norm, fusion_config=fusion_config + ) + state = state._replace(hidden_states=hidden_states, adapter_residual=residual) + + children_states = [] + for child in adapter_setup: + if isinstance(child, AdapterCompositionBlock): + self.check_composition_valid(adapter_setup, child, lvl) + composition_func = self.composition_to_func_map[type(child)] + child_state = composition_func(child, state, lvl=lvl + 1) + children_states.append(child_state) + elif child in self.adapter_modules: + child_state = self.compose_single(child, state, lvl=lvl + 1) + children_states.append(child_state) + else: + pass + + if len(children_states) > 0: + up_list = torch.stack([state.bottleneck_up for state in children_states]) + up_list = up_list.permute(1, 2, 0, 3) + + fusion_output = self.adapter_fusion_layer[adapter_setup.name]( + query, + up_list, + up_list, + state.adapter_residual, + output_attentions=context.output_adapter_fusion_attentions, + ) + if context.output_adapter_fusion_attentions: + hidden_states = fusion_output[0] + self._store_fusion_attentions(adapter_setup.name, fusion_output[-1]) + else: + hidden_states = fusion_output + + return state._replace(hidden_states=hidden_states) + + def compose_split(self, adapter_setup: Split, state: BottleneckState, lvl: int = 0): + """ + Splits the given input between the given adapters. + """ + if sum(adapter_setup.splits) != state.hidden_states.shape[1]: + raise IndexError( + "The given input has sequence length {} which is not equal to the sum of splits {}".format( + state.hidden_states.shape[1], adapter_setup.splits + ) + ) + + state = self.pre_block(adapter_setup, state) + + children_states = [] + for i, child in enumerate(adapter_setup): + batch_idx = ( + sum(adapter_setup.splits[:i]), + sum(adapter_setup.splits[: i + 1]), + ) + child_state = BottleneckState( + state.hidden_states[:, batch_idx[0] : batch_idx[1], :], + state.input_tensor[:, batch_idx[0] : batch_idx[1], :], + state.adapter_residual[:, batch_idx[0] : batch_idx[1], :], + state.layer_norm, + state.bottleneck_up[:, batch_idx[0] : batch_idx[1], :] if state.bottleneck_up is not None else None, + ) + if isinstance(child, AdapterCompositionBlock): + self.check_composition_valid(adapter_setup, child, lvl) + composition_func = self.composition_to_func_map[type(child)] + child_state = composition_func(child, child_state, lvl=lvl + 1) + children_states.append(child_state) + elif child in self.adapter_modules: + child_state = self.compose_single(child, child_state, lvl=lvl + 1) + children_states.append(child_state) + else: + pass + + hidden_states = torch.cat([child.hidden_states for child in children_states], dim=1) + return state._replace(hidden_states=hidden_states) + + def bottleneck_layer_forward(self, hidden_states, residual_input, layer_norm): + """Forward pass through the adapter layer. + NOTE: This method should only be called if the calling module directly inherits from BottleneckLayer. + Otherwise, call the regular forward() method. + + Args: + hidden_states (torch.Tensor): Input hidden states to the adapter layer. + residual_input (torch.Tensor): Residual input to the adapter layer. + layer_norm (torch.nn.Module): Transformer layer normalization module to be used by the adapter layer. + + Returns: + torch.Tensor: Output hidden states of the adapter layer. + """ + # Batch sizes might be different due to prefix tuning w. Parallel block + (residual_input,) = adjust_tensors_for_parallel(hidden_states, residual_input) + # Replicate in both directions as residual might be larger (e.g. GPT-J) + (hidden_states,) = adjust_tensors_for_parallel(residual_input, hidden_states) + adapter_setup = self.get_active_setup() + if adapter_setup is not None: + input_hidden_states = hidden_states + + state = BottleneckState(hidden_states, residual_input, residual_input, layer_norm) + state = self.compose(adapter_setup, state) + hidden_states, residual_input, _, _, _ = state + + last_adapter = self.adapters[adapter_setup.last()] + hidden_states = last_adapter.post_forward(hidden_states, input_hidden_states, residual_input, layer_norm) + + elif layer_norm: + hidden_states = layer_norm(hidden_states + residual_input) + else: + hidden_states = hidden_states + residual_input + + return hidden_states + + def forward(self, hidden_states, residual_input, layer_norm): + """Forward pass through the adapter layer. + + Args: + hidden_states (torch.Tensor): Input hidden states to the adapter layer. + residual_input (torch.Tensor): Residual input to the adapter layer. + layer_norm (torch.nn.Module): Transformer layer normalization module to be used by the adapter layer. + + Returns: + torch.Tensor: Output hidden states of the adapter layer. + """ + return self.bottleneck_layer_forward(hidden_states, residual_input, layer_norm) diff --git a/src/adapters/lora.py b/src/adapters/methods/lora.py similarity index 98% rename from src/adapters/lora.py rename to src/adapters/methods/lora.py index 3549e7a8fc..a4c66c830c 100644 --- a/src/adapters/lora.py +++ b/src/adapters/methods/lora.py @@ -13,9 +13,9 @@ from transformers.configuration_utils import PretrainedConfig from transformers.pytorch_utils import Conv1D -from .composition import AdapterCompositionBlock -from .configuration import LoRAConfig, ModelAdaptersConfig -from .layer import AdapterLayerBase +from ..composition import AdapterCompositionBlock +from ..configuration import LoRAConfig, ModelAdaptersConfig +from .adapter_layer_base import AdapterLayerBase class LoRA(nn.Module): @@ -94,6 +94,8 @@ def com_inv(self, weights: torch.Tensor, added: torch.Tensor) -> torch.Tensor: class LoRALayer(AdapterLayerBase): + adapter_modules_name = "loras" + def __init__( self, location_key: str, model_config: PretrainedConfig, adapters_config: ModelAdaptersConfig, *args, **kwargs ): @@ -313,7 +315,7 @@ def T(w): return torch.transpose(w, -2, -1) if self.fan_in_fan_out else w if not self.merged: - adapter_setup = self.get_active_setup(self.loras) + adapter_setup = self.get_active_setup() if adapter_setup is not None: if len(adapter_setup) == 1: lora = self.loras[adapter_setup[0]] @@ -496,7 +498,7 @@ def T(w): return torch.t(w) if self.fan_in_fan_out else w if not self.merged: - adapter_setup = self.get_active_setup(self.loras) + adapter_setup = self.get_active_setup() if adapter_setup is not None: if len(adapter_setup) == 1: result = F.linear(x, T(self.weight), bias=self.bias) diff --git a/src/adapters/modeling.py b/src/adapters/methods/modeling.py similarity index 99% rename from src/adapters/modeling.py rename to src/adapters/methods/modeling.py index b61419069e..6b265e21f2 100644 --- a/src/adapters/modeling.py +++ b/src/adapters/methods/modeling.py @@ -5,8 +5,8 @@ from transformers.activations import get_activation -from .configuration import AdapterFusionConfig, BnConfig -from .context import ForwardContext +from ..configuration import AdapterFusionConfig, BnConfig +from ..context import ForwardContext class Activation_Function_Class(nn.Module): diff --git a/src/adapters/prefix_tuning.py b/src/adapters/methods/prefix_tuning.py similarity index 53% rename from src/adapters/prefix_tuning.py rename to src/adapters/methods/prefix_tuning.py index af9c57e03f..3a8743a3f2 100644 --- a/src/adapters/prefix_tuning.py +++ b/src/adapters/methods/prefix_tuning.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Union +from typing import Dict, List, NamedTuple, Optional, Union import torch import torch.nn.functional as F @@ -7,10 +7,10 @@ from transformers import PretrainedConfig from transformers.modeling_utils import ModuleUtilsMixin -from .composition import AdapterCompositionBlock, BatchSplit, Parallel, Stack, adjust_tensors_for_parallel -from .configuration import ModelAdaptersConfig, PrefixTuningConfig -from .context import AdapterSetup, ForwardContext -from .layer import AdapterLayerBase +from ..composition import AdapterCompositionBlock, BatchSplit, Parallel, Stack, adjust_tensors_for_parallel +from ..configuration import ModelAdaptersConfig, PrefixTuningConfig +from ..context import AdapterSetup, ForwardContext +from .adapter_layer_base import ComposableAdapterLayerBase from .modeling import Activation_Function_Class @@ -126,7 +126,7 @@ class PrefixTuningPool(nn.Module): How it works: - 1. A `PrefixTuningShim` module that sets this module as pool module is added to each layer. + 1. A `PrefixTuningLayer` module that sets this module as pool module is added to each layer. 2. On adding a prefix, each shim module where a prefix should be added increments a counter in `prefix_counts`. 3. Finally, the base model class confirms adding a new prefix by calling `confirm_prefix()`. 4. This module adds a prefix layer that produces outputs corresponding to the indicated number of layers. @@ -135,7 +135,7 @@ class PrefixTuningPool(nn.Module): - The forward call to this layer is executed in the ForwardContext of each model pass. - All other methods of this class (except for `confirm_prefix()`) should be called exclusively by - `PrefixTuningShim`. + `PrefixTuningLayer`. Args: config (:class:`~transformers.PretrainedConfig`): The model config. @@ -244,7 +244,29 @@ def forward(self, *args, **kwargs): return prefix_states -class PrefixTuningShim(AdapterLayerBase, nn.Module): +class PrefixTuningState(NamedTuple): + """ + Models the input and output states of a prefix tuning layer. + + Args: + key_states (torch.Tensor): The key states of the attention layer. + value_states (torch.Tensor): The value states of the attention layer. + residual_input (torch.Tensor): The residual input of the attention layer. + attention_mask (torch.Tensor, optional): The attention mask of the attention layer. + invert_mask (bool): Whether the attention mask is inverted (ie. using '1' for padding). + idx_slice (slice, optional): Id slice for slicing prefix states along the batch size dimension. + + """ + + key_states: torch.Tensor + value_states: torch.Tensor + residual_input: torch.Tensor + attention_mask: Optional[torch.Tensor] + invert_mask: bool + idx_slice: Optional[slice] = None + + +class PrefixTuningLayer(ComposableAdapterLayerBase, nn.Module): """ Representation of a Prefix Tuning layer within one Transformer layer. This class implements `AdapterLayerBase` for compatibility with adapters. It uses `PrefixTuningPool` in the background and `set_pool()` must be called after @@ -256,6 +278,9 @@ class PrefixTuningShim(AdapterLayerBase, nn.Module): config (:class:`~transformers.PretrainedConfig`): The model config. """ + adapter_modules_name = "prefixes" + supported_compositions = [Stack, Parallel, BatchSplit] + def __init__( self, location_key: str, @@ -373,63 +398,31 @@ def get_adapter(self, adapter_name): return None - def single_forward( - self, - adapter_name: str, - key_states, - value_states, - residual_input, - attention_mask=None, - invert_mask=True, - idx_range=None, - ): - prefix_id = self.prefixes[adapter_name] - batch_size = key_states.size(0) - - # Retrieve pre-computed prefix states from context - context = ForwardContext.get_context() - # batch_size x n_heads x prefix_length x n_embd_per_head - prefix_keys, prefix_values = context.prefix_states[adapter_name][self.location_key][prefix_id] - - # select index range for batch split - if idx_range is not None: - prefix_keys = prefix_keys[idx_range] - prefix_values = prefix_values[idx_range] - - if adapter_name in self.prefix_gates: - gate = self.prefix_gates[adapter_name] - gate_output = torch.mean(torch.sigmoid(gate(residual_input)), dim=1) - self._store_gating_score(adapter_name, gate_output) - gate_output_key = gate_output[:, 0].view(-1, 1, 1, 1) - gate_output_value = gate_output[:, -1].view(-1, 1, 1, 1) - prefix_keys = prefix_keys * gate_output_key - prefix_values = prefix_values * gate_output_value - - # replicate for Parallel block - prefix_keys, prefix_values = adjust_tensors_for_parallel(key_states, prefix_keys, prefix_values) - - key_states = torch.cat([prefix_keys, key_states], dim=2) - value_states = torch.cat([prefix_values, value_states], dim=2) - if attention_mask is not None: - if attention_mask.dim() == 2: # e.g. for DistilBERT, attention_mask has shape (batch_size, seq_len) - prefix_mask = torch.ones(batch_size, prefix_keys.size(2)).to(attention_mask.device) - else: - prefix_mask = torch.ones(batch_size, 1, attention_mask.size(2), prefix_keys.size(2)).to( - attention_mask.device - ) - if invert_mask: - prefix_mask = 1.0 - prefix_mask - (prefix_mask,) = adjust_tensors_for_parallel(attention_mask, prefix_mask) - attention_mask = torch.cat([prefix_mask, attention_mask], dim=-1) - - return key_states, value_states, residual_input, attention_mask + def vslice(self, state: PrefixTuningState, slice_obj: slice) -> PrefixTuningState: + if state.idx_slice is None: + split_idx_slice = slice_obj + else: + split_idx_slice = slice( + state.idx_slice.start + slice_obj.start, + state.idx_slice.start + slice_obj.stop, + ) + return PrefixTuningState( + key_states=state.key_states[slice_obj], + value_states=state.value_states[slice_obj], + residual_input=state.residual_input[slice_obj], + attention_mask=state.attention_mask[slice_obj] if state.attention_mask is not None else None, + invert_mask=state.invert_mask, + idx_slice=split_idx_slice, + ) - def _pad_and_concat(self, max_prefix_length, outputs, invert_mask=True): - """Pads all key & value states to the lFongest prefix length in the current batch. + def pad_and_concat(self, states: List[PrefixTuningState]) -> PrefixTuningState: + """Pads all key & value states to the longest prefix length in the current batch. This is required e.g. for stacked prefix tunings. """ + max_prefix_length = max([state.key_states.shape[-2] for state in states]) all_key_states, all_value_states, all_residual_input, all_attention_mask = [], [], [], [] - for key_states, value_states, residual_input, attention_mask in outputs: + for state in states: + key_states, value_states, residual_input, attention_mask = state[:4] # pad sizes pad_length = max_prefix_length - key_states.shape[-2] pad_size = (0, 0, pad_length, 0) @@ -445,7 +438,7 @@ def _pad_and_concat(self, max_prefix_length, outputs, invert_mask=True): attention_mask, (max_prefix_length - attention_mask.shape[-1], 0), "constant", - 1.0 if invert_mask else 0.0, + 1.0 if state.invert_mask else 0.0, ) all_key_states.append(key_states) @@ -458,294 +451,87 @@ def _pad_and_concat(self, max_prefix_length, outputs, invert_mask=True): all_residual_input = torch.cat(all_residual_input, dim=0) all_attention_mask = torch.cat(all_attention_mask, dim=0) if attention_mask is not None else None - return all_key_states, all_value_states, all_residual_input, all_attention_mask + return PrefixTuningState( + key_states=all_key_states, + value_states=all_value_states, + residual_input=all_residual_input, + attention_mask=all_attention_mask, + invert_mask=states[0].invert_mask, + idx_slice=states[0].idx_slice, + ) - def adapter_stack( - self, - adapter_setup: Stack, - key_states, - value_states, - residual_input, - attention_mask=None, - invert_mask=True, - idx_range=None, - lvl=0, - ): - for adapter_stack_layer in adapter_setup: - # Break if setup is too deep - if isinstance(adapter_stack_layer, AdapterCompositionBlock) and lvl >= 1: - raise ValueError( - "Specified adapter setup is too deep. Cannot have {} at level {}".format( - adapter_stack_layer.__class__.__name__, lvl - ) - ) - # We have a nested parallel layer -> call parallel method - elif isinstance(adapter_stack_layer, Parallel): - key_states, value_states, residual_input, attention_mask = self.adapter_parallel( - adapter_stack_layer, - key_states, - value_states, - residual_input, - attention_mask, - invert_mask=invert_mask, - idx_range=idx_range, - lvl=lvl + 1, - ) - # We have a nested batch split block -> call batchsplit method - elif isinstance(adapter_stack_layer, BatchSplit): - key_states, value_states, residual_input, attention_mask = self.adapter_batchsplit( - adapter_stack_layer, - key_states, - value_states, - residual_input, - attention_mask, - invert_mask=invert_mask, - idx_range=idx_range, - lvl=lvl + 1, - ) - # We have a single prefix tuning module part of this model -> forward pass - elif adapter_stack_layer in self.prefixes: - key_states, value_states, _, attention_mask = self.single_forward( - adapter_stack_layer, - key_states, - value_states, - residual_input, - attention_mask, - invert_mask, - idx_range=idx_range, - ) - # Nesting other composition blocks is invalid - elif isinstance(adapter_stack_layer, AdapterCompositionBlock): - raise ValueError( - "Invalid adapter setup. Cannot nest {} in {}".format( - adapter_stack_layer.__class__.__name__, adapter_setup.__class__.__name__ - ) - ) - # As all prefix tuning modules are centrally stored, fail if not found. + def repeat(self, state: PrefixTuningState, channels: int) -> PrefixTuningState: + if state.attention_mask is not None: + if state.attention_mask.dim() == 2: # e.g. for DistilBERT, attention_mask has shape (batch_size, seq_len) + attention_mask = state.attention_mask.repeat(channels, 1) else: - raise ValueError(f"Unknown prefix tuning name '{adapter_stack_layer}'.") - - return key_states, value_states, residual_input, attention_mask - - def adapter_parallel( - self, - adapter_setup: Parallel, - key_states, - value_states, - residual_input, - attention_mask=None, - invert_mask=True, - idx_range=None, - lvl=0, - ): - """ - For parallel execution of the adapters on the same input. This means that the input is repeated N times before - feeding it to the adapters (where N is the number of adapters). - """ - - context = ForwardContext.get_context() - if not context.adapters_parallelized: - orig_batch_size = residual_input.shape[0] - residual_input = residual_input.repeat(self.adapters_config.active_setup.parallel_channels, 1, 1, 1) - key_states = key_states.repeat(self.adapters_config.active_setup.parallel_channels, 1, 1, 1) - value_states = value_states.repeat(self.adapters_config.active_setup.parallel_channels, 1, 1, 1) - if attention_mask is not None: - if attention_mask.dim() == 2: # e.g. for DistilBERT, attention_mask has shape (batch_size, seq_len) - attention_mask = attention_mask.repeat(self.adapters_config.active_setup.parallel_channels, 1) - else: - attention_mask = attention_mask.repeat( - self.adapters_config.active_setup.parallel_channels, 1, 1, 1 - ) - context.adapters_parallelized = True + attention_mask = state.attention_mask.repeat(channels, 1, 1, 1) else: - # The base model should handle replication of input. - # Therefore, we assume the (replicated) input batch to be divisible by the number of parallel channels. - if residual_input.shape[0] % adapter_setup.parallel_channels != 0: - raise ValueError( - "The total input batch size in a Parallel adapter block must be divisible by the number of" - " parallel channels." - ) - orig_batch_size = residual_input.shape[0] // adapter_setup.parallel_channels - - # sequentially feed different parts of the blown-up batch into different adapters - children_outputs = [] - # track which prefix is longest for padding in the end - max_prefix_length = 0 - for i, child in enumerate(adapter_setup): - # construct inputs to child modules - inputs = { - "key_states": key_states[i * orig_batch_size : (i + 1) * orig_batch_size], - "value_states": value_states[i * orig_batch_size : (i + 1) * orig_batch_size], - "residual_input": residual_input[i * orig_batch_size : (i + 1) * orig_batch_size], - "attention_mask": attention_mask[i * orig_batch_size : (i + 1) * orig_batch_size] - if attention_mask is not None - else None, - "invert_mask": invert_mask, - "idx_range": idx_range, - } + attention_mask = None + return PrefixTuningState( + key_states=state.key_states.repeat(channels, 1, 1, 1), + value_states=state.value_states.repeat(channels, 1, 1, 1), + residual_input=state.residual_input.repeat(channels, 1, 1), + attention_mask=attention_mask, + invert_mask=state.invert_mask, + idx_slice=state.idx_slice, + ) - # Case 1: We have a nested stack -> call stack method - if isinstance(child, Stack): - child_outputs = self.adapter_stack( - child, - **inputs, - lvl=lvl + 1, - ) - children_outputs.append(child_outputs) - # Case 2. We have a nested batchsplit block -> call batchsplit method - elif isinstance(child, BatchSplit): - child_outputs = self.adapter_batchsplit( - child, - **inputs, - lvl=lvl + 1, - ) - children_outputs.append(child_outputs) - # Case 3: We have a single adapter which is part of this module -> forward pass - elif child in self.prefixes: - child_outputs = self.single_forward( - child, - **inputs, - ) - children_outputs.append(child_outputs) - # Case 4: nesting other composition blocks is invalid - elif isinstance(child, AdapterCompositionBlock): - raise ValueError( - "Invalid adapter setup. Cannot nest {} in {}".format( - child.__class__.__name__, adapter_setup.__class__.__name__ - ) - ) - # As all prefix tuning modules are centrally stored, fail if not found. - else: - raise ValueError(f"Unknown prefix tuning name '{child}'.") + def mean(self, states: List[PrefixTuningState], weights: torch.Tensor) -> PrefixTuningState: + # TODO implement average composition + raise NotImplementedError() - # update max prefix length - current_prefix_length = child_outputs[0].shape[-2] - if current_prefix_length > max_prefix_length: - max_prefix_length = current_prefix_length + def compose_single(self, adapter_setup: str, state: PrefixTuningState, lvl: int = 0) -> PrefixTuningState: + prefix_id = self.prefixes[adapter_setup] + batch_size = state.key_states.size(0) - # concatenate all outputs and return - key_states, value_states, residual_input, attention_mask = self._pad_and_concat( - max_prefix_length, children_outputs, invert_mask=invert_mask - ) - return key_states, value_states, residual_input, attention_mask + # Retrieve pre-computed prefix states from context + context = ForwardContext.get_context() + # batch_size x n_heads x prefix_length x n_embd_per_head + prefix_keys, prefix_values = context.prefix_states[adapter_setup][self.location_key][prefix_id] + + # Select index range for batch split + # Ignore slices that go beyond the prefix states bsz + # (this is the case for slices produced by Parallel blocks which operate on replicated kv states) + if state.idx_slice is not None and state.idx_slice.start < prefix_keys.size(0): + prefix_keys = prefix_keys[state.idx_slice] + prefix_values = prefix_values[state.idx_slice] + + if adapter_setup in self.prefix_gates: + gate = self.prefix_gates[adapter_setup] + gate_output = torch.mean(torch.sigmoid(gate(state.residual_input)), dim=1) + self._store_gating_score(adapter_setup, gate_output) + gate_output_key = gate_output[:, 0].view(-1, 1, 1, 1) + gate_output_value = gate_output[:, -1].view(-1, 1, 1, 1) + prefix_keys = prefix_keys * gate_output_key + prefix_values = prefix_values * gate_output_value - def adapter_batchsplit( - self, - adapter_setup: BatchSplit, - key_states, - value_states, - residual_input, - attention_mask=None, - invert_mask=True, - idx_range=None, - lvl=0, - ): - if not sum(adapter_setup.batch_sizes) == key_states.shape[0]: - raise IndexError( - "The given batch has a size of {} which is not compatible with batch_sizes {}".format( - key_states.shape[0], adapter_setup.batch_sizes - ) - ) + # Replicate for Parallel block + prefix_keys, prefix_values = adjust_tensors_for_parallel(state.key_states, prefix_keys, prefix_values) - children_outputs = [] - # track which prefix is longest for padding in the end - max_prefix_length = 0 - for i, adapter_block in enumerate(adapter_setup): - # compute ids of sequences that should be passed to the ith adapter - if idx_range is None: - split_idx_range = range( - sum(adapter_setup.batch_sizes[:i]), - sum(adapter_setup.batch_sizes[: i + 1]), - ) + key_states = torch.cat([prefix_keys, state.key_states], dim=2) + value_states = torch.cat([prefix_values, state.value_states], dim=2) + if state.attention_mask is not None: + if state.attention_mask.dim() == 2: # e.g. for DistilBERT, attention_mask has shape (batch_size, seq_len) + prefix_mask = torch.ones(batch_size, prefix_keys.size(2)).to(state.attention_mask.device) else: - split_idx_range = range( - idx_range.start + sum(adapter_setup.batch_sizes[:i]), - idx_range.start + sum(adapter_setup.batch_sizes[: i + 1]), - ) - inputs = { - "key_states": key_states[split_idx_range], - "value_states": value_states[split_idx_range], - "residual_input": residual_input[split_idx_range], - "attention_mask": attention_mask[split_idx_range] if attention_mask is not None else None, - "invert_mask": invert_mask, - "idx_range": split_idx_range, - } - # Case 1: We have a nested stack -> call stack method - if isinstance(adapter_block, Stack): - child_outputs = self.adapter_stack( - adapter_block, - **inputs, - lvl=lvl + 1, - ) - children_outputs.append(child_outputs) - # Case 2: We have a nested batch split block -> call batchsplit method - elif isinstance(adapter_block, BatchSplit): - child_outputs = self.adapter_batchsplit( - adapter_block, - **inputs, - lvl=lvl + 1, - ) - children_outputs.append(child_outputs) - # Case 4: We have a single adapter which is part of this module -> forward pass - elif adapter_block in self.prefixes: - child_outputs = self.single_forward( - adapter_block, - **inputs, + prefix_mask = torch.ones(batch_size, 1, state.attention_mask.size(2), prefix_keys.size(2)).to( + state.attention_mask.device ) - children_outputs.append(child_outputs) - # Case 5: nesting other composition blocks is invalid - elif isinstance(adapter_block, AdapterCompositionBlock): - raise ValueError( - "Invalid adapter setup. Cannot nest {} in {}".format( - adapter_block.__class__.__name__, adapter_setup.__class__.__name__ - ) - ) - # As all prefix tuning modules are centrally stored, fail if not found. - else: - raise ValueError(f"Unknown prefix tuning name '{adapter_block}'.") - - # update max prefix length - current_prefix_length = child_outputs[0].shape[-2] - if current_prefix_length > max_prefix_length: - max_prefix_length = current_prefix_length + if state.invert_mask: + prefix_mask = 1.0 - prefix_mask + (prefix_mask,) = adjust_tensors_for_parallel(state.attention_mask, prefix_mask) + attention_mask = torch.cat([prefix_mask, state.attention_mask], dim=-1) + else: + attention_mask = None - # concatenate all outputs and return - key_states, value_states, residual_input, attention_mask = self._pad_and_concat( - max_prefix_length, children_outputs, invert_mask=invert_mask - ) - return key_states, value_states, residual_input, attention_mask + return state._replace(key_states=key_states, value_states=value_states, attention_mask=attention_mask) def forward(self, key_states, value_states, residual_input, attention_mask=None, invert_mask=True): - adapter_setup = self.get_active_setup(self.prefixes) + adapter_setup = self.get_active_setup() if adapter_setup is not None: - if isinstance(adapter_setup, Stack): - key_states, value_states, _, attention_mask = self.adapter_stack( - adapter_setup, - key_states, - value_states, - residual_input, - attention_mask=attention_mask, - invert_mask=invert_mask, - ) - elif isinstance(adapter_setup, Parallel): - key_states, value_states, _, attention_mask = self.adapter_parallel( - adapter_setup, - key_states, - value_states, - residual_input, - attention_mask=attention_mask, - invert_mask=invert_mask, - ) - elif isinstance(adapter_setup, BatchSplit): - key_states, value_states, _, attention_mask = self.adapter_batchsplit( - adapter_setup, - key_states, - value_states, - residual_input, - attention_mask=attention_mask, - invert_mask=invert_mask, - ) - else: - raise ValueError(f"Invalid adapter setup. Cannot use {adapter_setup} with prefix tuning.") + state = PrefixTuningState(key_states, value_states, residual_input, attention_mask, invert_mask) + state = self.compose(adapter_setup, state) + key_states, value_states, residual_input, attention_mask = state[:4] return key_states, value_states, attention_mask diff --git a/src/adapters/model_mixin.py b/src/adapters/model_mixin.py index 33b57230da..be5e30cd93 100644 --- a/src/adapters/model_mixin.py +++ b/src/adapters/model_mixin.py @@ -16,11 +16,12 @@ from .configuration import ADAPTER_CONFIG_MAP, AdapterConfigBase, AdapterFusionConfig, BnConfig from .context import AdapterSetup, ForwardContext from .hub_mixin import PushAdapterToHubMixin -from .layer import AdapterLayer, AdapterLayerBase from .loading import AdapterFusionLoader, AdapterLoader, PredictionHeadLoader, WeightsLoader -from .lora import LoRALayer -from .modeling import Adapter, GLOWCouplingBlock, NICECouplingBlock, init_shared_parameters -from .prefix_tuning import PrefixTuningPool, PrefixTuningShim +from .methods.adapter_layer_base import AdapterLayerBase +from .methods.bottleneck import BottleneckLayer +from .methods.lora import LoRALayer +from .methods.modeling import Adapter, GLOWCouplingBlock, NICECouplingBlock, init_shared_parameters +from .methods.prefix_tuning import PrefixTuningLayer, PrefixTuningPool from .utils import EMBEDDING_FILE, TOKENIZER_PATH, get_adapter_config_hash, inherit_doc from .wrappers.configuration import SUBMODEL_NAMES, init_adapters_config @@ -367,7 +368,7 @@ def __init__(self, config, *args, **kwargs): super().__init__(config, *args, **kwargs) def _link_prefix_to_pool(self, layer): - if isinstance(layer, PrefixTuningShim): + if isinstance(layer, PrefixTuningLayer): layer.set_pool(self.base_model.prefix_tuning) @property @@ -933,7 +934,7 @@ def get_fusion_regularization_loss(self): target = torch.zeros((self.config.hidden_size, self.config.hidden_size)).fill_diagonal_(1.0).to(self.device) for i, layer in self.iter_layers(): for module in layer.modules(): - if isinstance(module, AdapterLayer): + if isinstance(module, BottleneckLayer): for _, layer_fusion in module.adapter_fusion_layer.items(): if hasattr(layer_fusion, "value") and layer_fusion.value.weight.requires_grad: layer_reg_loss = 0.01 * (target - layer_fusion.value.weight).pow(2).sum() diff --git a/src/adapters/models/albert/mixin_albert.py b/src/adapters/models/albert/mixin_albert.py index fa2340687e..21534980af 100644 --- a/src/adapters/models/albert/mixin_albert.py +++ b/src/adapters/models/albert/mixin_albert.py @@ -3,10 +3,10 @@ import torch.nn as nn from ...composition import adjust_tensors_for_parallel_ -from ...layer import AdapterLayer -from ...lora import Linear as LoRALinear +from ...methods.bottleneck import BottleneckLayer +from ...methods.lora import Linear as LoRALinear +from ...methods.prefix_tuning import PrefixTuningLayer from ...model_mixin import EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin -from ...prefix_tuning import PrefixTuningShim class AlbertAttentionAdaptersMixin: @@ -18,9 +18,9 @@ def init_adapters(self, model_config, adapters_config): self.key = LoRALinear.wrap(self.key, "selfattn", model_config, adapters_config, attn_key="k") self.value = LoRALinear.wrap(self.value, "selfattn", model_config, adapters_config, attn_key="v") - self.attention_adapters = AdapterLayer("mh_adapter") + self.attention_adapters = BottleneckLayer("mh_adapter") - self.prefix_tuning = PrefixTuningShim( + self.prefix_tuning = PrefixTuningLayer( self.location_key + "_prefix" if self.location_key else None, model_config, adapters_config ) @@ -36,7 +36,7 @@ def init_adapters(self, model_config, adapters_config): # Set location keys for prefix tuning self.location_key = "output_adapter" - self.output_adapters = AdapterLayer("output_adapter") + self.output_adapters = BottleneckLayer("output_adapter") self.attention.location_key = "self" diff --git a/src/adapters/models/bart/mixin_bart.py b/src/adapters/models/bart/mixin_bart.py index e050d66940..5ef20aaa86 100644 --- a/src/adapters/models/bart/mixin_bart.py +++ b/src/adapters/models/bart/mixin_bart.py @@ -4,8 +4,9 @@ import torch.nn as nn from ...composition import adjust_tensors_for_parallel -from ...layer import AdapterLayer -from ...lora import Linear as LoRALinear +from ...methods.bottleneck import BottleneckLayer +from ...methods.lora import Linear as LoRALinear +from ...methods.prefix_tuning import PrefixTuningLayer from ...model_mixin import ( EmbeddingAdaptersMixin, EmbeddingAdaptersWrapperMixin, @@ -13,7 +14,6 @@ InvertibleAdaptersWrapperMixin, ModelBaseAdaptersMixin, ) -from ...prefix_tuning import PrefixTuningShim class BartAttentionAdaptersMixin: @@ -25,7 +25,7 @@ def init_adapters(self, model_config, adapters_config): self.v_proj = LoRALinear.wrap(self.v_proj, "selfattn", model_config, adapters_config, attn_key="v") self.q_proj = LoRALinear.wrap(self.q_proj, "selfattn", model_config, adapters_config, attn_key="q") - self.prefix_tuning = PrefixTuningShim( + self.prefix_tuning = PrefixTuningLayer( self.location_key + "_prefix" if self.location_key else None, model_config, adapters_config ) @@ -40,8 +40,8 @@ def init_adapters(self, model_config, adapters_config): # Set attention layer location key for prefix tuning self.self_attn.location_key = "encoder" - self.attention_adapters = AdapterLayer("mh_adapter") - self.output_adapters = AdapterLayer("output_adapter") + self.attention_adapters = BottleneckLayer("mh_adapter") + self.output_adapters = BottleneckLayer("output_adapter") class BartDecoderLayerAdaptersMixin(BartEncoderLayerAdaptersMixin): @@ -52,7 +52,7 @@ def init_adapters(self, model_config, adapters_config): # Set attention layer location key for prefix tuning self.self_attn.location_key = "self" self.encoder_attn.location_key = "cross" - self.cross_attention_adapters = AdapterLayer("cross_adapter") + self.cross_attention_adapters = BottleneckLayer("cross_adapter") class BartEncoderAdaptersMixin(InvertibleAdaptersMixin): diff --git a/src/adapters/models/beit/mixin_beit.py b/src/adapters/models/beit/mixin_beit.py index a54611507d..2c129f085c 100644 --- a/src/adapters/models/beit/mixin_beit.py +++ b/src/adapters/models/beit/mixin_beit.py @@ -2,10 +2,10 @@ import torch.nn as nn -from ...layer import AdapterLayer -from ...lora import Linear as LoRALinear +from ...methods.bottleneck import BottleneckLayer +from ...methods.lora import Linear as LoRALinear +from ...methods.prefix_tuning import PrefixTuningLayer from ...model_mixin import ModelBaseAdaptersMixin -from ...prefix_tuning import PrefixTuningShim class BeitSelfAttentionAdaptersMixin: @@ -17,7 +17,7 @@ def init_adapters(self, model_config, adapters_config): self.key = LoRALinear.wrap(self.key, "selfattn", model_config, adapters_config, attn_key="k") self.value = LoRALinear.wrap(self.value, "selfattn", model_config, adapters_config, attn_key="v") - self.prefix_tuning = PrefixTuningShim( + self.prefix_tuning = PrefixTuningLayer( self.location_key + "_prefix" if self.location_key else None, model_config, adapters_config ) @@ -38,8 +38,8 @@ class BeitLayerAdaptersMixin: """Adds adapters to the BeitLayer module.""" def init_adapters(self, model_config, adapters_config): - self.attention_adapters = AdapterLayer("mh_adapter") - self.output_adapters = AdapterLayer("output_adapter") + self.attention_adapters = BottleneckLayer("mh_adapter") + self.output_adapters = BottleneckLayer("output_adapter") class BeitModelAdaptersMixin(ModelBaseAdaptersMixin): diff --git a/src/adapters/models/beit/modeling_beit.py b/src/adapters/models/beit/modeling_beit.py index e951390b5e..1ed5082beb 100644 --- a/src/adapters/models/beit/modeling_beit.py +++ b/src/adapters/models/beit/modeling_beit.py @@ -102,7 +102,7 @@ def forward( attention_output = self.lambda_1 * attention_output # first residual connection - hidden_states = self.attention_adapters.adapter_layer_forward( + hidden_states = self.attention_adapters.bottleneck_layer_forward( self.drop_path(attention_output), hidden_states, None ) @@ -116,7 +116,7 @@ def forward( layer_output = self.lambda_2 * layer_output # second residual connection - layer_output = self.output_adapters.adapter_layer_forward(self.drop_path(layer_output), hidden_states, None) + layer_output = self.output_adapters.bottleneck_layer_forward(self.drop_path(layer_output), hidden_states, None) outputs = (layer_output,) + outputs diff --git a/src/adapters/models/bert/mixin_bert.py b/src/adapters/models/bert/mixin_bert.py index 2d715fb993..e97c9dd988 100644 --- a/src/adapters/models/bert/mixin_bert.py +++ b/src/adapters/models/bert/mixin_bert.py @@ -4,10 +4,10 @@ import torch.nn as nn from ...composition import adjust_tensors_for_parallel_ -from ...layer import AdapterLayer -from ...lora import Linear as LoRALinear +from ...methods.bottleneck import BottleneckLayer +from ...methods.lora import Linear as LoRALinear +from ...methods.prefix_tuning import PrefixTuningLayer from ...model_mixin import EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin -from ...prefix_tuning import PrefixTuningShim logger = logging.getLogger(__name__) @@ -22,13 +22,13 @@ def init_adapters(self, model_config, adapters_config): self.key = LoRALinear.wrap(self.key, "selfattn", model_config, adapters_config, attn_key="k") self.value = LoRALinear.wrap(self.value, "selfattn", model_config, adapters_config, attn_key="v") - self.prefix_tuning = PrefixTuningShim( + self.prefix_tuning = PrefixTuningLayer( self.location_key + "_prefix" if self.location_key else None, model_config, adapters_config ) -# For backwards compatibility, BertSelfOutput inherits directly from AdapterLayer -class BertSelfOutputAdaptersMixin(AdapterLayer): +# For backwards compatibility, BertSelfOutput inherits directly from BottleneckLayer +class BertSelfOutputAdaptersMixin(BottleneckLayer): """Adds adapters to the BertSelfOutput module.""" def __init__(self): @@ -39,8 +39,8 @@ def init_adapters(self, model_config, adapters_config): super().init_adapters(model_config, adapters_config) -# For backwards compatibility, BertOutput inherits directly from AdapterLayer -class BertOutputAdaptersMixin(AdapterLayer): +# For backwards compatibility, BertOutput inherits directly from BottleneckLayer +class BertOutputAdaptersMixin(BottleneckLayer): """Adds adapters to the BertOutput module.""" def __init__(self): diff --git a/src/adapters/models/bert/modeling_bert.py b/src/adapters/models/bert/modeling_bert.py index 012f4ba317..539dc74ebf 100644 --- a/src/adapters/models/bert/modeling_bert.py +++ b/src/adapters/models/bert/modeling_bert.py @@ -141,7 +141,7 @@ class BertSelfOutputWithAdapters(BertSelfOutputAdaptersMixin, BertSelfOutput): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) - hidden_states = self.adapter_layer_forward(hidden_states, input_tensor, self.LayerNorm) + hidden_states = self.bottleneck_layer_forward(hidden_states, input_tensor, self.LayerNorm) return hidden_states @@ -149,5 +149,5 @@ class BertOutputWithAdapters(BertOutputAdaptersMixin, BertOutput): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) - hidden_states = self.adapter_layer_forward(hidden_states, input_tensor, self.LayerNorm) + hidden_states = self.bottleneck_layer_forward(hidden_states, input_tensor, self.LayerNorm) return hidden_states diff --git a/src/adapters/models/bert_generation/modeling_bert_generation.py b/src/adapters/models/bert_generation/modeling_bert_generation.py index c21d2a3f4d..8f083fe295 100644 --- a/src/adapters/models/bert_generation/modeling_bert_generation.py +++ b/src/adapters/models/bert_generation/modeling_bert_generation.py @@ -36,7 +36,7 @@ class BertGenerationSelfOutputWithAdapters(BertSelfOutputAdaptersMixin, BertGene def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) - hidden_states = self.adapter_layer_forward(hidden_states, input_tensor, self.LayerNorm) + hidden_states = self.bottleneck_layer_forward(hidden_states, input_tensor, self.LayerNorm) return hidden_states @@ -153,5 +153,5 @@ class BertGenerationOutputWithAdapters(BertOutputAdaptersMixin, BertGenerationOu def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) - hidden_states = self.adapter_layer_forward(hidden_states, input_tensor, self.LayerNorm) + hidden_states = self.bottleneck_layer_forward(hidden_states, input_tensor, self.LayerNorm) return hidden_states diff --git a/src/adapters/models/clip/mixin_clip.py b/src/adapters/models/clip/mixin_clip.py index 3eb8c8bbb0..36eae84b0f 100644 --- a/src/adapters/models/clip/mixin_clip.py +++ b/src/adapters/models/clip/mixin_clip.py @@ -3,8 +3,9 @@ import torch.nn as nn from ...composition import adjust_tensors_for_parallel_ -from ...layer import AdapterLayer -from ...lora import Linear as LoRALinear +from ...methods.bottleneck import BottleneckLayer +from ...methods.lora import Linear as LoRALinear +from ...methods.prefix_tuning import PrefixTuningLayer from ...model_mixin import ( EmbeddingAdaptersMixin, EmbeddingAdaptersWrapperMixin, @@ -12,7 +13,6 @@ InvertibleAdaptersWrapperMixin, ModelBaseAdaptersMixin, ) -from ...prefix_tuning import PrefixTuningShim class CLIPAttentionAdaptersMixin: @@ -24,7 +24,9 @@ def init_adapters(self, model_config, adapters_config): self.k_proj = LoRALinear.wrap(self.k_proj, "selfattn", model_config, adapters_config, attn_key="k") self.v_proj = LoRALinear.wrap(self.v_proj, "selfattn", model_config, adapters_config, attn_key="v") - self.prefix_tuning = PrefixTuningShim("self_prefix", model_config, adapters_config, add_model_type_to_key=True) + self.prefix_tuning = PrefixTuningLayer( + "self_prefix", model_config, adapters_config, add_model_type_to_key=True + ) class CLIPEncoderLayerAdaptersMixin: @@ -35,8 +37,8 @@ def init_adapters(self, model_config, adapters_config): self.mlp.fc1 = LoRALinear.wrap(self.mlp.fc1, "intermediate", model_config, adapters_config) self.mlp.fc2 = LoRALinear.wrap(self.mlp.fc2, "output", model_config, adapters_config) - self.attention_adapters = AdapterLayer("mh_adapter") - self.output_adapters = AdapterLayer("output_adapter") + self.attention_adapters = BottleneckLayer("mh_adapter") + self.output_adapters = BottleneckLayer("output_adapter") class CLIPEncoderAdaptersMixin: diff --git a/src/adapters/models/deberta/mixin_deberta.py b/src/adapters/models/deberta/mixin_deberta.py index 0407e59a82..cee8530f02 100644 --- a/src/adapters/models/deberta/mixin_deberta.py +++ b/src/adapters/models/deberta/mixin_deberta.py @@ -1,5 +1,5 @@ -from ...lora import MergedLinear as LoRAMergedLinear -from ...prefix_tuning import PrefixTuningShim +from ...methods.lora import MergedLinear as LoRAMergedLinear +from ...methods.prefix_tuning import PrefixTuningLayer class DebertaSelfAttentionAdaptersMixin: @@ -9,6 +9,6 @@ def init_adapters(self, model_config, adapters_config): # Wrap layers for LoRA self.in_proj = LoRAMergedLinear.wrap(self.in_proj, "selfattn", model_config, adapters_config) - self.prefix_tuning = PrefixTuningShim( + self.prefix_tuning = PrefixTuningLayer( self.location_key + "_prefix" if self.location_key else None, model_config, adapters_config ) diff --git a/src/adapters/models/deberta/modeling_deberta.py b/src/adapters/models/deberta/modeling_deberta.py index 462685a85d..8197c19fb6 100644 --- a/src/adapters/models/deberta/modeling_deberta.py +++ b/src/adapters/models/deberta/modeling_deberta.py @@ -33,7 +33,7 @@ class DebertaSelfOutputWithAdapters(BertSelfOutputAdaptersMixin, DebertaSelfOutp def forward(self, hidden_states, input_tensor): hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) - hidden_states = self.adapter_layer_forward(hidden_states, input_tensor, self.LayerNorm) + hidden_states = self.bottleneck_layer_forward(hidden_states, input_tensor, self.LayerNorm) return hidden_states @@ -41,7 +41,7 @@ class DebertaOutputWithAdapters(BertOutputAdaptersMixin, DebertaOutput): def forward(self, hidden_states, input_tensor): hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) - hidden_states = self.adapter_layer_forward(hidden_states, input_tensor, self.LayerNorm) + hidden_states = self.bottleneck_layer_forward(hidden_states, input_tensor, self.LayerNorm) return hidden_states diff --git a/src/adapters/models/deberta_v2/mixin_deberta_v2.py b/src/adapters/models/deberta_v2/mixin_deberta_v2.py index 3b4e01aa2f..f60e8788fb 100644 --- a/src/adapters/models/deberta_v2/mixin_deberta_v2.py +++ b/src/adapters/models/deberta_v2/mixin_deberta_v2.py @@ -1,5 +1,5 @@ -from ...lora import Linear as LoRALinear -from ...prefix_tuning import PrefixTuningShim +from ...methods.lora import Linear as LoRALinear +from ...methods.prefix_tuning import PrefixTuningLayer class DebertaV2SelfAttentionAdaptersMixin: @@ -11,6 +11,6 @@ def init_adapters(self, model_config, adapters_config): self.key_proj = LoRALinear.wrap(self.key_proj, "selfattn", model_config, adapters_config, attn_key="k") self.value_proj = LoRALinear.wrap(self.value_proj, "selfattn", model_config, adapters_config, attn_key="v") - self.prefix_tuning = PrefixTuningShim( + self.prefix_tuning = PrefixTuningLayer( self.location_key + "_prefix" if self.location_key else None, model_config, adapters_config ) diff --git a/src/adapters/models/deberta_v2/modeling_deberta_v2.py b/src/adapters/models/deberta_v2/modeling_deberta_v2.py index 79cd4e6a34..082e77a721 100644 --- a/src/adapters/models/deberta_v2/modeling_deberta_v2.py +++ b/src/adapters/models/deberta_v2/modeling_deberta_v2.py @@ -34,7 +34,7 @@ class DebertaV2SelfOutputWithAdapters(BertSelfOutputAdaptersMixin, DebertaV2Self def forward(self, hidden_states, input_tensor): hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) - hidden_states = self.adapter_layer_forward(hidden_states, input_tensor, self.LayerNorm) + hidden_states = self.bottleneck_layer_forward(hidden_states, input_tensor, self.LayerNorm) return hidden_states @@ -43,7 +43,7 @@ class DebertaV2OutputWithAdapters(BertOutputAdaptersMixin, DebertaV2Output): def forward(self, hidden_states, input_tensor): hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) - hidden_states = self.adapter_layer_forward(hidden_states, input_tensor, self.LayerNorm) + hidden_states = self.bottleneck_layer_forward(hidden_states, input_tensor, self.LayerNorm) return hidden_states diff --git a/src/adapters/models/distilbert/mixin_distilbert.py b/src/adapters/models/distilbert/mixin_distilbert.py index 582543b765..44bcbb0b16 100644 --- a/src/adapters/models/distilbert/mixin_distilbert.py +++ b/src/adapters/models/distilbert/mixin_distilbert.py @@ -2,10 +2,10 @@ import torch.nn as nn -from ...layer import AdapterLayer -from ...lora import Linear as LoRALinear +from ...methods.bottleneck import BottleneckLayer +from ...methods.lora import Linear as LoRALinear +from ...methods.prefix_tuning import PrefixTuningLayer from ...model_mixin import EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin -from ...prefix_tuning import PrefixTuningShim class DistilBertMultiHeadSelfAttentionMixin: @@ -17,7 +17,7 @@ def init_adapters(self, model_config, adapters_config): self.k_lin = LoRALinear.wrap(self.k_lin, "selfattn", model_config, adapters_config, attn_key="k") self.v_lin = LoRALinear.wrap(self.v_lin, "selfattn", model_config, adapters_config, attn_key="v") - self.prefix_tuning = PrefixTuningShim("self", model_config, adapters_config) + self.prefix_tuning = PrefixTuningLayer("self", model_config, adapters_config) class DistilBertTransfomerBlockAdaptersMixin: @@ -28,8 +28,8 @@ def init_adapters(self, model_config, adapters_config): self.ffn.lin1 = LoRALinear.wrap(self.ffn.lin1, "intermediate", model_config, adapters_config) self.ffn.lin2 = LoRALinear.wrap(self.ffn.lin2, "output", model_config, adapters_config) - self.attention_adapters = AdapterLayer("mh_adapter") - self.output_adapters = AdapterLayer("output_adapter") + self.attention_adapters = BottleneckLayer("mh_adapter") + self.output_adapters = BottleneckLayer("output_adapter") class DistilBertTransformerAdaptersMixin: diff --git a/src/adapters/models/electra/modeling_electra.py b/src/adapters/models/electra/modeling_electra.py index 0412b4dc10..35552782ce 100644 --- a/src/adapters/models/electra/modeling_electra.py +++ b/src/adapters/models/electra/modeling_electra.py @@ -122,7 +122,7 @@ class ElectraSelfOutputWithAdapters(BertSelfOutputAdaptersMixin, ElectraSelfOutp def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) - hidden_states = self.adapter_layer_forward(hidden_states, input_tensor, self.LayerNorm) + hidden_states = self.bottleneck_layer_forward(hidden_states, input_tensor, self.LayerNorm) return hidden_states @@ -130,5 +130,5 @@ class ElectraOutputWithAdapters(BertOutputAdaptersMixin, ElectraOutput): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) - hidden_states = self.adapter_layer_forward(hidden_states, input_tensor, self.LayerNorm) + hidden_states = self.bottleneck_layer_forward(hidden_states, input_tensor, self.LayerNorm) return hidden_states diff --git a/src/adapters/models/gpt2/mixin_gpt2.py b/src/adapters/models/gpt2/mixin_gpt2.py index b3cbf12219..e86c2967a9 100644 --- a/src/adapters/models/gpt2/mixin_gpt2.py +++ b/src/adapters/models/gpt2/mixin_gpt2.py @@ -2,11 +2,11 @@ import torch.nn as nn -from ...layer import AdapterLayer -from ...lora import Linear as LoRALinear -from ...lora import MergedLinear as LoRAMergedLinear +from ...methods.bottleneck import BottleneckLayer +from ...methods.lora import Linear as LoRALinear +from ...methods.lora import MergedLinear as LoRAMergedLinear +from ...methods.prefix_tuning import PrefixTuningLayer from ...model_mixin import EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin -from ...prefix_tuning import PrefixTuningShim class GPT2AttentionAdaptersMixin: @@ -25,7 +25,7 @@ def init_adapters(self, model_config, adapters_config): ) location_key = "cross_prefix" if self.is_cross_attention else "self_prefix" - self.prefix_tuning = PrefixTuningShim(location_key, model_config, adapters_config) + self.prefix_tuning = PrefixTuningLayer(location_key, model_config, adapters_config) class GPT2DecoderBlockAdaptersMixin: @@ -50,8 +50,8 @@ def init_adapters(self, model_config, adapters_config): no_init_bias=True, ) - self.attention_adapters = AdapterLayer("mh_adapter") - self.output_adapters = AdapterLayer("output_adapter") + self.attention_adapters = BottleneckLayer("mh_adapter") + self.output_adapters = BottleneckLayer("output_adapter") class GPT2ModelAdapterMixin(EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin): diff --git a/src/adapters/models/gptj/mixin_gptj.py b/src/adapters/models/gptj/mixin_gptj.py index d05880fbbe..333c1b9358 100644 --- a/src/adapters/models/gptj/mixin_gptj.py +++ b/src/adapters/models/gptj/mixin_gptj.py @@ -2,10 +2,10 @@ import torch.nn as nn -from ...layer import AdapterLayer -from ...lora import Linear as LoRALinear +from ...methods.bottleneck import BottleneckLayer +from ...methods.lora import Linear as LoRALinear +from ...methods.prefix_tuning import PrefixTuningLayer from ...model_mixin import EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin -from ...prefix_tuning import PrefixTuningShim class GPTJAttentionAdaptersMixin: @@ -17,7 +17,7 @@ def init_adapters(self, model_config, adapters_config): self.k_proj = LoRALinear.wrap(self.k_proj, "selfattn", model_config, adapters_config, attn_key="k") self.v_proj = LoRALinear.wrap(self.v_proj, "selfattn", model_config, adapters_config, attn_key="v") - self.prefix_tuning = PrefixTuningShim( + self.prefix_tuning = PrefixTuningLayer( self.location_key + "_prefix" if self.location_key else None, model_config, adapters_config ) @@ -33,8 +33,8 @@ class GPTJDecoderBlockAdaptersMixin: """Adds adapters to the TransformerBlock module of GPTJ.""" def init_adapters(self, model_config, adapters_config): - self.attention_adapters = AdapterLayer("mh_adapter") - self.output_adapters = AdapterLayer("output_adapter") + self.attention_adapters = BottleneckLayer("mh_adapter") + self.output_adapters = BottleneckLayer("output_adapter") class GPTJModelAdapterMixin(EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin): diff --git a/src/adapters/models/llama/mixin_llama.py b/src/adapters/models/llama/mixin_llama.py index 4627e02593..22223edaf4 100644 --- a/src/adapters/models/llama/mixin_llama.py +++ b/src/adapters/models/llama/mixin_llama.py @@ -2,10 +2,10 @@ import torch.nn as nn -from ...layer import AdapterLayer -from ...lora import Linear as LoRALinear +from ...methods.bottleneck import BottleneckLayer +from ...methods.lora import Linear as LoRALinear +from ...methods.prefix_tuning import PrefixTuningLayer from ...model_mixin import EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin -from ...prefix_tuning import PrefixTuningShim class LlamaAttentionMixin: @@ -14,7 +14,7 @@ def init_adapters(self, model_config, adapters_config): self.k_proj = LoRALinear.wrap(self.k_proj, "selfattn", model_config, adapters_config, attn_key="k") self.v_proj = LoRALinear.wrap(self.v_proj, "selfattn", model_config, adapters_config, attn_key="v") - self.prefix_tuning = PrefixTuningShim("self_prefix", model_config, adapters_config) + self.prefix_tuning = PrefixTuningLayer("self_prefix", model_config, adapters_config) class LlamaDecoderLayerMixin: @@ -23,8 +23,8 @@ def init_adapters(self, model_config, adapters_config): self.mlp.down_proj = LoRALinear.wrap(self.mlp.down_proj, "intermediate", model_config, adapters_config) self.mlp.up_proj = LoRALinear.wrap(self.mlp.up_proj, "output", model_config, adapters_config) - self.attention_adapters = AdapterLayer("mh_adapter") - self.output_adapters = AdapterLayer("output_adapter") + self.attention_adapters = BottleneckLayer("mh_adapter") + self.output_adapters = BottleneckLayer("output_adapter") class LlamaModelAdapterMixin(EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin): diff --git a/src/adapters/models/roberta/modeling_roberta.py b/src/adapters/models/roberta/modeling_roberta.py index cf37a337ae..47a8ed35a9 100644 --- a/src/adapters/models/roberta/modeling_roberta.py +++ b/src/adapters/models/roberta/modeling_roberta.py @@ -142,7 +142,7 @@ class RobertaSelfOutputWithAdapters(BertSelfOutputAdaptersMixin, RobertaSelfOutp def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) - hidden_states = self.adapter_layer_forward(hidden_states, input_tensor, self.LayerNorm) + hidden_states = self.bottleneck_layer_forward(hidden_states, input_tensor, self.LayerNorm) return hidden_states @@ -151,5 +151,5 @@ class RobertaOutputWithAdapters(BertOutputAdaptersMixin, RobertaOutput): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) - hidden_states = self.adapter_layer_forward(hidden_states, input_tensor, self.LayerNorm) + hidden_states = self.bottleneck_layer_forward(hidden_states, input_tensor, self.LayerNorm) return hidden_states diff --git a/src/adapters/models/t5/mixin_t5.py b/src/adapters/models/t5/mixin_t5.py index 3917c78366..a5c39acaa6 100644 --- a/src/adapters/models/t5/mixin_t5.py +++ b/src/adapters/models/t5/mixin_t5.py @@ -2,8 +2,9 @@ import torch.nn as nn -from ...layer import AdapterLayer -from ...lora import Linear as LoRALinear +from ...methods.bottleneck import BottleneckLayer +from ...methods.lora import Linear as LoRALinear +from ...methods.prefix_tuning import PrefixTuningLayer from ...model_mixin import ( EmbeddingAdaptersMixin, InvertibleAdaptersMixin, @@ -11,7 +12,6 @@ ModelBaseAdaptersMixin, ModelWithHeadsAdaptersMixin, ) -from ...prefix_tuning import PrefixTuningShim class T5AttentionAdaptersMixin: @@ -23,12 +23,12 @@ def init_adapters(self, model_config, adapters_config): self.k = LoRALinear.wrap(self.k, "selfattn", model_config, adapters_config, attn_key="k", bias=False) self.v = LoRALinear.wrap(self.v, "selfattn", model_config, adapters_config, attn_key="v", bias=False) - self.prefix_tuning = PrefixTuningShim( + self.prefix_tuning = PrefixTuningLayer( self.location_key + "_prefix" if self.location_key else None, model_config, adapters_config ) -class T5SelfAttentionLayerAdaptersMixin(AdapterLayer): +class T5SelfAttentionLayerAdaptersMixin(BottleneckLayer): def __init__(self): super().__init__("mh_adapter", None) @@ -37,7 +37,7 @@ def init_adapters(self, model_config, adapters_config): super().init_adapters(model_config, adapters_config) -class T5CrossAttentionLayerAdaptersMixin(AdapterLayer): +class T5CrossAttentionLayerAdaptersMixin(BottleneckLayer): def __init__(self): super().__init__("cross_adapter", None) @@ -47,7 +47,7 @@ def init_adapters(self, model_config, adapters_config): super().init_adapters(model_config, adapters_config) -class T5FFLayerAdaptersMixin(AdapterLayer): +class T5FFLayerAdaptersMixin(BottleneckLayer): def __init__(self): super().__init__("output_adapter", None) diff --git a/src/adapters/models/t5/modeling_t5.py b/src/adapters/models/t5/modeling_t5.py index 7d7e467f0a..3440a4bb73 100644 --- a/src/adapters/models/t5/modeling_t5.py +++ b/src/adapters/models/t5/modeling_t5.py @@ -45,7 +45,7 @@ class T5LayerFFWithAdapters(T5FFLayerAdaptersMixin, T5LayerFF): def forward(self, hidden_states): forwarded_states = self.layer_norm(hidden_states) forwarded_states = self.DenseReluDense(forwarded_states) - hidden_states = self.adapter_layer_forward( + hidden_states = self.bottleneck_layer_forward( hidden_states=self.dropout(forwarded_states), residual_input=hidden_states, layer_norm=None ) return hidden_states @@ -207,7 +207,7 @@ def forward( use_cache=use_cache, output_attentions=output_attentions, ) - hidden_states = self.adapter_layer_forward( + hidden_states = self.bottleneck_layer_forward( hidden_states=self.dropout(attention_output[0]), residual_input=hidden_states, layer_norm=None ) outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them @@ -239,7 +239,7 @@ def forward( query_length=query_length, output_attentions=output_attentions, ) - layer_output = self.adapter_layer_forward( + layer_output = self.bottleneck_layer_forward( hidden_states=self.dropout(attention_output[0]), residual_input=hidden_states, layer_norm=None ) outputs = (layer_output,) + attention_output[1:] # add attentions if we output them diff --git a/src/adapters/models/vit/mixin_vit.py b/src/adapters/models/vit/mixin_vit.py index 5b540245bd..07598ad8ae 100644 --- a/src/adapters/models/vit/mixin_vit.py +++ b/src/adapters/models/vit/mixin_vit.py @@ -2,10 +2,10 @@ import torch.nn as nn -from ...layer import AdapterLayer -from ...lora import Linear as LoRALinear +from ...methods.bottleneck import BottleneckLayer +from ...methods.lora import Linear as LoRALinear +from ...methods.prefix_tuning import PrefixTuningLayer from ...model_mixin import ModelBaseAdaptersMixin -from ...prefix_tuning import PrefixTuningShim class ViTSelfAttentionAdaptersMixin: @@ -17,7 +17,7 @@ def init_adapters(self, model_config, adapters_config): self.key = LoRALinear.wrap(self.key, "selfattn", model_config, adapters_config, attn_key="k") self.value = LoRALinear.wrap(self.value, "selfattn", model_config, adapters_config, attn_key="v") - self.prefix_tuning = PrefixTuningShim( + self.prefix_tuning = PrefixTuningLayer( self.location_key + "_prefix" if self.location_key else None, model_config, adapters_config ) @@ -32,7 +32,7 @@ class ViTOutputAdaptersMixin: """Adds adapters to the ViTOutput module.""" def init_adapters(self, model_config, adapters_config): - self.output_adapters = AdapterLayer("output_adapter") + self.output_adapters = BottleneckLayer("output_adapter") # Wrap layers for LoRA self.dense = LoRALinear.wrap(self.dense, "output", model_config, adapters_config) @@ -43,7 +43,7 @@ class ViTLayerAdaptersMixin: """Adds adapters to the ViTSelfOutput module.""" def init_adapters(self, model_config, adapters_config): - self.attention_adapters = AdapterLayer("mh_adapter") + self.attention_adapters = BottleneckLayer("mh_adapter") class ViTModelAdaptersMixin(ModelBaseAdaptersMixin): diff --git a/src/adapters/models/vit/modeling_vit.py b/src/adapters/models/vit/modeling_vit.py index 4ffb61d5f6..bb0fadd2ca 100644 --- a/src/adapters/models/vit/modeling_vit.py +++ b/src/adapters/models/vit/modeling_vit.py @@ -72,7 +72,7 @@ class ViTOutputWithAdapters(ViTOutputAdaptersMixin, ViTOutput): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) - hidden_states = self.output_adapters.adapter_layer_forward(hidden_states, input_tensor, None) + hidden_states = self.output_adapters.bottleneck_layer_forward(hidden_states, input_tensor, None) return hidden_states @@ -94,7 +94,7 @@ def forward( attention_output = self_attention_outputs[0] outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - hidden_states = self.attention_adapters.adapter_layer_forward(attention_output, hidden_states, None) + hidden_states = self.attention_adapters.bottleneck_layer_forward(attention_output, hidden_states, None) # in ViT, layernorm is also applied after self-attention layer_output = self.layernorm_after(hidden_states) diff --git a/src/adapters/models/xlm_roberta/modeling_xlm_roberta.py b/src/adapters/models/xlm_roberta/modeling_xlm_roberta.py index cd8dd9bf08..a8d22284b7 100644 --- a/src/adapters/models/xlm_roberta/modeling_xlm_roberta.py +++ b/src/adapters/models/xlm_roberta/modeling_xlm_roberta.py @@ -146,7 +146,7 @@ class XLMRobertaSelfOutputWithAdapters(BertSelfOutputAdaptersMixin, XLMRobertaSe def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) - hidden_states = self.adapter_layer_forward(hidden_states, input_tensor, self.LayerNorm) + hidden_states = self.bottleneck_layer_forward(hidden_states, input_tensor, self.LayerNorm) return hidden_states @@ -155,5 +155,5 @@ class XLMRobertaOutputWithAdapters(BertOutputAdaptersMixin, XLMRobertaOutput): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) - hidden_states = self.adapter_layer_forward(hidden_states, input_tensor, self.LayerNorm) + hidden_states = self.bottleneck_layer_forward(hidden_states, input_tensor, self.LayerNorm) return hidden_states diff --git a/src/adapters/models/xmod/modeling_xmod.py b/src/adapters/models/xmod/modeling_xmod.py index 3a3a38066f..b772321667 100644 --- a/src/adapters/models/xmod/modeling_xmod.py +++ b/src/adapters/models/xmod/modeling_xmod.py @@ -140,7 +140,7 @@ class XmodSelfOutputWithAdapters(BertSelfOutputAdaptersMixin, XmodSelfOutput): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) - hidden_states = self.adapter_layer_forward(hidden_states, input_tensor, None) + hidden_states = self.bottleneck_layer_forward(hidden_states, input_tensor, None) return hidden_states @@ -152,5 +152,5 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor, lang_ layer_norm = self.adapter_layer_norm elif self.adapter_reuse_layer_norm: layer_norm = self.LayerNorm - hidden_states = self.adapter_layer_forward(hidden_states, input_tensor, layer_norm) + hidden_states = self.bottleneck_layer_forward(hidden_states, input_tensor, layer_norm) return hidden_states diff --git a/tests_adapters/composition/test_adapter_composition.py b/tests_adapters/composition/test_adapter_composition.py index ff30bd8a33..2670488cb9 100644 --- a/tests_adapters/composition/test_adapter_composition.py +++ b/tests_adapters/composition/test_adapter_composition.py @@ -21,10 +21,10 @@ def test_to_deep(self): def test_invalid_nesting_fusion(self): self.assertRaises(ValueError, lambda: parse_composition(Fuse(Fuse("a", "b"), "c"))) - self.assertRaises(ValueError, lambda: parse_composition(Fuse(Split("a", "b", 128), "c"))) + self.assertRaises(ValueError, lambda: parse_composition(Fuse(Split("a", "b", splits=128), "c"))) def test_invalid_nesting_split(self): - self.assertRaises(ValueError, lambda: parse_composition(Split("a", Fuse("b", "c"), 128))) + self.assertRaises(ValueError, lambda: parse_composition(Split("a", Fuse("b", "c"), splits=128))) @require_torch @@ -83,7 +83,7 @@ def test_simple_split(self): model = self.build_model() # pass over split setup - model.set_active_adapters(Split("a", "b", 64)) + model.set_active_adapters(Split("a", "b", splits=64)) self.training_pass(model) @@ -93,7 +93,7 @@ def test_stacked_split(self): model = self.build_model() # split into two stacks - model.set_active_adapters(Split(Stack("a", "b"), Stack("c", "d"), split_index=64)) + model.set_active_adapters(Split(Stack("a", "b"), Stack("c", "d"), splits=64)) self.training_pass(model) @@ -118,7 +118,7 @@ def test_mixed_stack(self): model.add_adapter_fusion(Fuse("a", "b")) model.to(torch_device) - model.set_active_adapters(Stack("a", Split("c", "d", split_index=64), Fuse("a", "b"))) + model.set_active_adapters(Stack("a", Split("c", "d", splits=64), Fuse("a", "b"))) self.training_pass(model) @@ -128,7 +128,7 @@ def test_nested_split(self): model = self.build_model() # split into two stacks - model.set_active_adapters(Split(Split("a", "b", split_index=32), "c", split_index=64)) + model.set_active_adapters(Split(Split("a", "b", splits=32), "c", splits=64)) self.training_pass(model)