diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index 3393514b0e1263..c91863060b088f 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -13,12 +13,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import math from typing import Optional, Tuple, Union import torch import torch.nn as nn import torch.utils.checkpoint -import math from ...activations import ACT2FN from ...cache_utils import Cache, HybridCache @@ -34,13 +34,13 @@ logging, ) from ..gemma.modeling_gemma import ( - GemmaRotaryEmbedding, GemmaForCausalLM, GemmaForSequenceClassification, GemmaForTokenClassification, GemmaModel, GemmaPreTrainedModel, GemmaRMSNorm, + GemmaRotaryEmbedding, apply_rotary_pos_emb, repeat_kv, ) @@ -231,7 +231,6 @@ def eager_attention_forward(config, query, key, value, mask): return attn_output - def flash_attention_forward(config, query, key, value, mask, target_dtype=torch.float16): if mask is not None: seq_len = mask.shape[1] @@ -329,9 +328,11 @@ def sdpa_attention_forward(config, query, key, value, mask, output_attentions=Fa "sdpa": sdpa_attention_forward, } + class Gemma2RotaryEmbedding(GemmaRotaryEmbedding): pass + class Gemma2Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -356,7 +357,6 @@ def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): self.attention_type = config._attn_implementation self.attention_function = GEMMA2_ATTENTION_FUNCTION[config._attn_implementation] - if self.hidden_size % self.num_heads != 0: raise ValueError( f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" @@ -450,7 +450,6 @@ def __init__(self, config: Gemma2Config, layer_idx: int): self.input_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.sliding_window = config.sliding_window