Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Detr bettertransformer #1424

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/source/bettertransformer/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,9 @@ The list of supported model below:
- [CLIP](https://arxiv.org/abs/2103.00020)
- [CodeGen](https://arxiv.org/abs/2203.13474)
- [Data2VecText](https://arxiv.org/abs/2202.03555)
- [DistilBert](https://arxiv.org/abs/1910.01108)
- [DeiT](https://arxiv.org/abs/2012.12877)
- [DERT] (https://arxiv.org/abs/2005.12872)
- [DistilBert](https://arxiv.org/abs/1910.01108)
- [Electra](https://arxiv.org/abs/2003.10555)
- [Ernie](https://arxiv.org/abs/1904.09223)
- [FSMT](https://arxiv.org/abs/1907.06616)
Expand Down
10 changes: 6 additions & 4 deletions optimum/bettertransformer/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
BartEncoderLayerBetterTransformer,
BertLayerBetterTransformer,
CLIPLayerBetterTransformer,
DetrEncoderLayerBetterTransformer,
DistilBertLayerBetterTransformer,
FSMTEncoderLayerBetterTransformer,
MBartEncoderLayerBetterTransformer,
Expand Down Expand Up @@ -67,21 +68,22 @@ class BetterTransformerManager:
"bert": {"BertLayer": BertLayerBetterTransformer},
"bert-generation": {"BertGenerationLayer": BertLayerBetterTransformer},
"blenderbot": {"BlenderbotAttention": BlenderbotAttentionLayerBetterTransformer},
"blip-2": {"T5Attention": T5AttentionLayerBetterTransformer},
"bloom": {"BloomAttention": BloomAttentionLayerBetterTransformer},
"camembert": {"CamembertLayer": BertLayerBetterTransformer},
"blip-2": {"T5Attention": T5AttentionLayerBetterTransformer},
"clip": {"CLIPEncoderLayer": CLIPLayerBetterTransformer},
"codegen": {"CodeGenAttention": CodegenAttentionLayerBetterTransformer},
"data2vec-text": {"Data2VecTextLayer": BertLayerBetterTransformer},
"deit": {"DeiTLayer": ViTLayerBetterTransformer},
"detr": {"DetrEncoderLayer": DetrEncoderLayerBetterTransformer},
"distilbert": {"TransformerBlock": DistilBertLayerBetterTransformer},
"electra": {"ElectraLayer": BertLayerBetterTransformer},
"ernie": {"ErnieLayer": BertLayerBetterTransformer},
"fsmt": {"EncoderLayer": FSMTEncoderLayerBetterTransformer},
"falcon": {"FalconAttention": FalconAttentionLayerBetterTransformer},
"fsmt": {"EncoderLayer": FSMTEncoderLayerBetterTransformer},
"gpt2": {"GPT2Attention": GPT2AttentionLayerBetterTransformer},
"gpt_bigcode": {"GPTBigCodeAttention": GPTBigCodeAttentionLayerBetterTransformer},
"gptj": {"GPTJAttention": GPTJAttentionLayerBetterTransformer},
"gpt_bigcode": {"GPTBigCodeAttention": GPTBigCodeAttentionLayerBetterTransformer},
"gpt_neo": {"GPTNeoSelfAttention": GPTNeoAttentionLayerBetterTransformer},
"gpt_neox": {"GPTNeoXAttention": GPTNeoXAttentionLayerBetterTransformer},
"hubert": {"HubertEncoderLayer": Wav2Vec2EncoderLayerBetterTransformer},
Expand All @@ -99,8 +101,8 @@ class BetterTransformerManager:
"mbart": {"MBartEncoderLayer": MBartEncoderLayerBetterTransformer},
"opt": {"OPTAttention": OPTAttentionLayerBetterTransformer},
"pegasus": {"PegasusAttention": PegasusAttentionLayerBetterTransformer},
"rembert": {"RemBertLayer": BertLayerBetterTransformer},
"prophetnet": {"ProphetNetEncoderLayer": ProphetNetEncoderLayerBetterTransformer},
"rembert": {"RemBertLayer": BertLayerBetterTransformer},
"roberta": {"RobertaLayer": BertLayerBetterTransformer},
"roc_bert": {"RoCBertLayer": BertLayerBetterTransformer},
"roformer": {"RoFormerLayer": BertLayerBetterTransformer},
Expand Down
123 changes: 118 additions & 5 deletions optimum/bettertransformer/models/encoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig
from transformers.activations import ACT2FN

from .base import BetterTransformerBaseLayer


if TYPE_CHECKING:
from transformers import PretrainedConfig


class AlbertLayerBetterTransformer(BetterTransformerBaseLayer, nn.Module):
def __init__(self, albert_layer, config):
r"""
Expand Down Expand Up @@ -1189,6 +1185,123 @@ def forward(self, hidden_states, output_attentions: bool, *_, **__):
return (hidden_states,)


class DetrEncoderLayerBetterTransformer(BetterTransformerBaseLayer, nn.Module):
def __init__(self, detr_layer, config):
r"""
A simple conversion of the DetrEncoderLayer to its `BetterTransformer` implementation.

Args:
detr_layer (`torch.nn.Module`):
The original `DetrEncoderLayer` where the weights needs to be retrieved.
"""
super().__init__(config)
# In_proj layer
self.in_proj_weight = nn.Parameter(
torch.cat(
[
detr_layer.self_attn.q_proj.weight,
detr_layer.self_attn.k_proj.weight,
detr_layer.self_attn.v_proj.weight,
]
)
)
self.in_proj_bias = nn.Parameter(
torch.cat(
[
detr_layer.self_attn.q_proj.bias,
detr_layer.self_attn.k_proj.bias,
detr_layer.self_attn.v_proj.bias,
]
)
)

# Out proj layer
self.out_proj_weight = detr_layer.self_attn.out_proj.weight
self.out_proj_bias = detr_layer.self_attn.out_proj.bias

# Linear layer 1
self.linear1_weight = detr_layer.fc1.weight
self.linear1_bias = detr_layer.fc1.bias

# Linear layer 2
self.linear2_weight = detr_layer.fc2.weight
self.linear2_bias = detr_layer.fc2.bias

# Layer norm 1
self.norm1_eps = detr_layer.self_attn_layer_norm.eps
self.norm1_weight = detr_layer.self_attn_layer_norm.weight
self.norm1_bias = detr_layer.self_attn_layer_norm.bias

# Layer norm 2
self.norm2_eps = detr_layer.final_layer_norm.eps
self.norm2_weight = detr_layer.final_layer_norm.weight
self.norm2_bias = detr_layer.final_layer_norm.bias

# Model hyper parameters
self.num_heads = detr_layer.self_attn.num_heads
self.embed_dim = detr_layer.self_attn.embed_dim

# Last step: set the last layer to `False` -> this will be set to `True` when converting the model
self.is_last_layer = False
self.norm_first = True

self.original_layers_mapping = {
"in_proj_weight": ["self_attn.q_proj.weight", "self_attn.k_proj.weight", "self_attn.v_proj.weight"],
"in_proj_bias": ["self_attn.q_proj.bias", "self_attn.k_proj.bias", "self_attn.v_proj.bias"],
"out_proj_weight": "self_attn.out_proj.weight",
"out_proj_bias": "self_attn.out_proj.bias",
"linear1_weight": "fc1.weight",
"linear1_bias": "fc1.bias",
"linear2_weight": "fc2.weight",
"linear2_bias": "fc2.bias",
"norm1_weight": "self_attn_layer_norm.weight",
"norm1_bias": "self_attn_layer_norm.bias",
"norm2_weight": "final_layer_norm.weight",
"norm2_bias": "final_layer_norm.bias",
}

self.validate_bettertransformer()

def forward(self, hidden_states, attention_mask, output_attentions: bool, *_, **__):
if output_attentions:
raise ValueError("output_attentions=True can not be supported with BetterTransformer.")

if not self.training and not torch.is_autocast_enabled() and not torch.is_autocast_cpu_enabled():
attention_mask = None

hidden_states = torch._transformer_encoder_layer_fwd(
hidden_states,
self.embed_dim,
self.num_heads,
self.in_proj_weight,
self.in_proj_bias,
self.out_proj_weight,
self.out_proj_bias,
self.use_gelu,
self.norm_first,
self.norm1_eps,
self.norm1_weight,
self.norm1_bias,
self.norm2_weight,
self.norm2_bias,
self.linear1_weight,
self.linear1_bias,
self.linear2_weight,
self.linear2_bias,
attention_mask,
)

if hidden_states.is_nested and self.is_last_layer:
hidden_states = hidden_states.to_padded_tensor(0.0)

else:
raise NotImplementedError(
"Training and Autocast are not implemented for BetterTransformer + Detr. Please open an issue."
)

return (hidden_states,)


class ViltLayerBetterTransformer(BetterTransformerBaseLayer, nn.Module):
def __init__(self, vilt_layer, config):
r"""
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
"torchaudio",
"einops",
"invisible-watermark",
"timm"
]

QUALITY_REQUIRE = ["black~=23.1", "ruff>=0.0.241,<=0.0.259"]
Expand Down
21 changes: 20 additions & 1 deletion tests/bettertransformer/test_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,18 @@ class BetterTransformersVisionTest(BetterTransformersTestMixin, unittest.TestCas
r"""
Testing suite for Vision Models - tests all the tests defined in `BetterTransformersTestMixin`
"""
SUPPORTED_ARCH = ["blip-2", "clip", "clip_text_model", "deit", "vilt", "vit", "vit_mae", "vit_msn", "yolos"]
SUPPORTED_ARCH = [
"blip-2",
"clip",
"clip_text_model",
"deit",
"detr",
"vilt",
"vit",
"vit_mae",
"vit_msn",
"yolos",
]

def prepare_inputs_for_class(self, model_id, model_type, batch_size=3, **preprocessor_kwargs):
if model_type == "vilt":
Expand Down Expand Up @@ -57,6 +68,14 @@ def prepare_inputs_for_class(self, model_id, model_type, batch_size=3, **preproc
if model_type == "blip-2":
inputs["decoder_input_ids"] = inputs["input_ids"]

elif model_type == "detr":
# Assuming detr just needs an image
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)

feature_extractor = AutoFeatureExtractor.from_pretrained("hf-internal-testing/tiny-random-DetrModel")
inputs = feature_extractor(images=image, return_tensors="pt")

else:
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
Expand Down
1 change: 1 addition & 0 deletions tests/bettertransformer/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
"codegen": "hf-internal-testing/tiny-random-CodeGenModel",
"data2vec-text": "hf-internal-testing/tiny-random-Data2VecTextModel",
"deit": "hf-internal-testing/tiny-random-deit",
"detr": "hf-internal-testing/tiny-random-DetrModel",
"distilbert": "hf-internal-testing/tiny-random-DistilBertModel",
"electra": "hf-internal-testing/tiny-random-ElectraModel",
"ernie": "hf-internal-testing/tiny-random-ErnieModel",
Expand Down
Loading