Skip to content

Commit

Permalink
integrate attention refacto
Browse files Browse the repository at this point in the history
  • Loading branch information
eustlb committed Dec 20, 2024
1 parent 338c7c0 commit 2ec366a
Show file tree
Hide file tree
Showing 11 changed files with 228 additions and 982 deletions.
4 changes: 2 additions & 2 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -500,10 +500,10 @@
title: mLUKE
- local: model_doc/mobilebert
title: MobileBERT
- local: model_doc/moonshine
title: moonshine
- local: model_doc/modernbert
title: ModernBert
- local: model_doc/moonshine
title: moonshine
- local: model_doc/mpnet
title: MPNet
- local: model_doc/mpt
Expand Down
12 changes: 6 additions & 6 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5584,8 +5584,8 @@
from .models.mobilevitv2 import (
MobileViTV2Config,
)
from .models.moonshine import MoonshineConfig
from .models.modernbert import ModernBertConfig
from .models.moonshine import MoonshineConfig
from .models.moshi import (
MoshiConfig,
MoshiDepthConfig,
Expand Down Expand Up @@ -7578,18 +7578,18 @@
MobileViTV2Model,
MobileViTV2PreTrainedModel,
)
from .models.moonshine import (
MoonshineForConditionalGeneration,
MoonshineModel,
MoonshinePreTrainedModel,
)
from .models.modernbert import (
ModernBertForMaskedLM,
ModernBertForSequenceClassification,
ModernBertForTokenClassification,
ModernBertModel,
ModernBertPreTrainedModel,
)
from .models.moonshine import (
MoonshineForConditionalGeneration,
MoonshineModel,
MoonshinePreTrainedModel,
)
from .models.moshi import (
MoshiForCausalLM,
MoshiForConditionalGeneration,
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,8 @@
mobilenet_v2,
mobilevit,
mobilevitv2,
moonshine,
modernbert,
moonshine,
moshi,
mpnet,
mpt,
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,8 @@
("mobilenet_v2", "MobileNetV2Config"),
("mobilevit", "MobileViTConfig"),
("mobilevitv2", "MobileViTV2Config"),
("moonshine", "MoonshineConfig"),
("modernbert", "ModernBertConfig"),
("moonshine", "MoonshineConfig"),
("moshi", "MoshiConfig"),
("mpnet", "MPNetConfig"),
("mpt", "MptConfig"),
Expand Down Expand Up @@ -512,8 +512,8 @@
("mobilenet_v2", "MobileNetV2"),
("mobilevit", "MobileViT"),
("mobilevitv2", "MobileViTV2"),
("moonshine", "Moonshine"),
("modernbert", "ModernBERT"),
("moonshine", "Moonshine"),
("moshi", "Moshi"),
("mpnet", "MPNet"),
("mpt", "MPT"),
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,8 @@
("mobilenet_v2", "MobileNetV2Model"),
("mobilevit", "MobileViTModel"),
("mobilevitv2", "MobileViTV2Model"),
("moonshine", "MoonshineModel"),
("modernbert", "ModernBertModel"),
("moonshine", "MoonshineModel"),
("moshi", "MoshiModel"),
("mpnet", "MPNetModel"),
("mpt", "MptModel"),
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/auto/tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,8 +313,8 @@
("mllama", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
("mluke", ("MLukeTokenizer" if is_sentencepiece_available() else None, None)),
("mobilebert", ("MobileBertTokenizer", "MobileBertTokenizerFast" if is_tokenizers_available() else None)),
("moonshine", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
("modernbert", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
("moonshine", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
("moshi", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
("mpnet", ("MPNetTokenizer", "MPNetTokenizerFast" if is_tokenizers_available() else None)),
("mpt", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
Expand Down
8 changes: 4 additions & 4 deletions src/transformers/models/moonshine/configuration_moonshine.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ class MoonshineConfig(PretrainedConfig):
The non-linear activation function (function or string) in the encoder.
decoder_hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 2048):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (`float`, *optional*, defaults to 1e-05):
Expand All @@ -61,8 +63,6 @@ class MoonshineConfig(PretrainedConfig):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
qk_layernorm (`bool`, *optional*, defaults to `False`):
Whether or not to normalize the Queries and Keys after projecting the hidden states.
ff_mult (`int`, *optional*, defaults to 4):
Factor by which to scale the intermediate size.
bos_token_id (`int`, *optional*, defaults to 1):
Expand Down Expand Up @@ -133,6 +133,7 @@ def __init__(
num_key_value_heads=None,
encoder_hidden_act="gelu",
decoder_hidden_act="silu",
max_position_embeddings=2048,
initializer_range=0.02,
layer_norm_eps=1e-5,
decoder_start_token_id=1,
Expand All @@ -142,7 +143,6 @@ def __init__(
min_rotary_ndims=32,
attention_bias=False,
attention_dropout=0.0,
qk_layernorm=False,
ff_mult=4,
bos_token_id=1,
eos_token_id=2,
Expand All @@ -167,6 +167,7 @@ def __init__(
self.num_key_value_heads = num_key_value_heads
self.encoder_hidden_act = encoder_hidden_act
self.decoder_hidden_act = decoder_hidden_act
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.decoder_start_token_id = decoder_start_token_id
Expand All @@ -176,7 +177,6 @@ def __init__(
self.min_rotary_ndims = min_rotary_ndims
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.qk_layernorm = qk_layernorm
self.ff_mult = ff_mult

# fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,15 @@ def _convert_layer_names(name, gated_mlp=False):
name = re.sub(r"mha_with_rope\.key_dense", "self_attn.k_proj", name)
name = re.sub(r"mha_with_rope\.query_dense", "self_attn.q_proj", name)
name = re.sub(r"mha_with_rope\.value_dense", "self_attn.v_proj", name)
name = re.sub(r"mha_with_rope\.output_dense", "self_attn.dense", name)
name = re.sub(r"mha_with_rope\.output_dense", "self_attn.o_proj", name)
name = re.sub(r"mha_precomputed_kv\.key_dense", "encoder_attn.k_proj", name)
name = re.sub(r"mha_precomputed_kv\.query_dense", "encoder_attn.q_proj", name)
name = re.sub(r"mha_precomputed_kv\.value_dense", "encoder_attn.v_proj", name)
name = re.sub(r"mha_precomputed_kv\.output_dense", "encoder_attn.dense", name)
name = re.sub(r"mha_precomputed_kv\.output_dense", "encoder_attn.o_proj", name)
name = re.sub(r"mha_causal_with_rope\.key_dense", "self_attn.k_proj", name)
name = re.sub(r"mha_causal_with_rope\.query_dense", "self_attn.q_proj", name)
name = re.sub(r"mha_causal_with_rope\.value_dense", "self_attn.v_proj", name)
name = re.sub(r"mha_causal_with_rope\.output_dense", "self_attn.dense", name)
name = re.sub(r"mha_causal_with_rope\.output_dense", "self_attn.o_proj", name)
name = re.sub(r"layer_normalization\.", "input_layernorm.", name)
name = re.sub(r"layer_normalization_1\.", "post_attention_layernorm.", name)
name = re.sub(r"layer_normalization_2\.", "final_layernorm.", name)
Expand Down
Loading

0 comments on commit 2ec366a

Please sign in to comment.