Skip to content

Commit

Permalink
🔴 🔴 fix query_pre_attn_scalar different of num_heads in default g…
Browse files Browse the repository at this point in the history
…emma2 config (#34540)

* fix query_pre_attn_scalar different of num_heads in default config

* propagate modular changes

* fix copies

* fix modular copies

* fix copies?

* correct copies fix
  • Loading branch information
molbap authored Nov 1, 2024
1 parent 4cc0813 commit 86701f2
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 24 deletions.
24 changes: 12 additions & 12 deletions src/transformers/models/gemma2/configuration_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,15 @@ class Gemma2Config(PretrainedConfig):
vocab_size (`int`, *optional*, defaults to 256000):
Vocabulary size of the Gemma2 model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`Gemma2Model`]
hidden_size (`int`, *optional*, defaults to 3072):
hidden_size (`int`, *optional*, defaults to 2304):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 24576):
intermediate_size (`int`, *optional*, defaults to 9216):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 28):
num_hidden_layers (`int`, *optional*, defaults to 26):
Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 16):
num_attention_heads (`int`, *optional*, defaults to 8):
Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*, defaults to 16):
num_key_value_heads (`int`, *optional*, defaults to 4):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
Expand Down Expand Up @@ -80,7 +80,7 @@ class Gemma2Config(PretrainedConfig):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
query_pre_attn_scalar (`float`, *optional*, defaults to 224): scaling factor used on the attention scores
query_pre_attn_scalar (`float`, *optional*, defaults to 256): scaling factor used on the attention scores
sliding_window (`int`, *optional*, defaults to 4096): in Gemma2, every other layer uses sliding window attention. This is the
size of the sliding window.
final_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the logits.
Expand All @@ -103,11 +103,11 @@ class Gemma2Config(PretrainedConfig):
def __init__(
self,
vocab_size=256000,
hidden_size=3072,
intermediate_size=24576,
num_hidden_layers=28,
num_attention_heads=16,
num_key_value_heads=16,
hidden_size=2304,
intermediate_size=9216,
num_hidden_layers=26,
num_attention_heads=8,
num_key_value_heads=4,
head_dim=256,
hidden_activation="gelu_pytorch_tanh",
max_position_embeddings=8192,
Expand All @@ -121,7 +121,7 @@ def __init__(
rope_theta=10000.0,
attention_bias=False,
attention_dropout=0.0,
query_pre_attn_scalar=224,
query_pre_attn_scalar=256,
sliding_window=4096,
final_logit_softcapping=30.0,
attn_logit_softcapping=50.0,
Expand Down
24 changes: 12 additions & 12 deletions src/transformers/models/gemma2/modular_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,15 @@ class Gemma2Config(PretrainedConfig):
vocab_size (`int`, *optional*, defaults to 256000):
Vocabulary size of the Gemma2 model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`Gemma2Model`]
hidden_size (`int`, *optional*, defaults to 3072):
hidden_size (`int`, *optional*, defaults to 2304):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 24576):
intermediate_size (`int`, *optional*, defaults to 9216):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 28):
num_hidden_layers (`int`, *optional*, defaults to 26):
Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 16):
num_attention_heads (`int`, *optional*, defaults to 8):
Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*, defaults to 16):
num_key_value_heads (`int`, *optional*, defaults to 4):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
Expand Down Expand Up @@ -111,7 +111,7 @@ class Gemma2Config(PretrainedConfig):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
query_pre_attn_scalar (`float`, *optional*, defaults to 224): scaling factor used on the attention scores
query_pre_attn_scalar (`float`, *optional*, defaults to 256): scaling factor used on the attention scores
sliding_window (`int`, *optional*, defaults to 4096): in Gemma2, every other layer uses sliding window attention. This is the
size of the sliding window.
final_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the logits.
Expand All @@ -134,11 +134,11 @@ class Gemma2Config(PretrainedConfig):
def __init__(
self,
vocab_size=256000,
hidden_size=3072,
intermediate_size=24576,
num_hidden_layers=28,
num_attention_heads=16,
num_key_value_heads=16,
hidden_size=2304,
intermediate_size=9216,
num_hidden_layers=26,
num_attention_heads=8,
num_key_value_heads=4,
head_dim=256,
hidden_activation="gelu_pytorch_tanh",
max_position_embeddings=8192,
Expand All @@ -152,7 +152,7 @@ def __init__(
rope_theta=10000.0,
attention_bias=False,
attention_dropout=0.0,
query_pre_attn_scalar=224,
query_pre_attn_scalar=256,
sliding_window=4096,
final_logit_softcapping=30.0,
attn_logit_softcapping=50.0,
Expand Down

0 comments on commit 86701f2

Please sign in to comment.