Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
ArthurZucker committed Nov 1, 2024
1 parent 4c6d299 commit 607c45d
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions src/transformers/models/gemma2/modular_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional, Tuple, Union

import torch
import torch.nn as nn
import torch.utils.checkpoint
import math

from ...activations import ACT2FN
from ...cache_utils import Cache, HybridCache
Expand All @@ -34,13 +34,13 @@
logging,
)
from ..gemma.modeling_gemma import (
GemmaRotaryEmbedding,
GemmaForCausalLM,
GemmaForSequenceClassification,
GemmaForTokenClassification,
GemmaModel,
GemmaPreTrainedModel,
GemmaRMSNorm,
GemmaRotaryEmbedding,
apply_rotary_pos_emb,
repeat_kv,
)
Expand Down Expand Up @@ -231,7 +231,6 @@ def eager_attention_forward(config, query, key, value, mask):
return attn_output



def flash_attention_forward(config, query, key, value, mask, target_dtype=torch.float16):
if mask is not None:
seq_len = mask.shape[1]
Expand Down Expand Up @@ -329,9 +328,11 @@ def sdpa_attention_forward(config, query, key, value, mask, output_attentions=Fa
"sdpa": sdpa_attention_forward,
}


class Gemma2RotaryEmbedding(GemmaRotaryEmbedding):
pass


class Gemma2Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

Expand All @@ -356,7 +357,6 @@ def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None):
self.attention_type = config._attn_implementation
self.attention_function = GEMMA2_ATTENTION_FUNCTION[config._attn_implementation]


if self.hidden_size % self.num_heads != 0:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
Expand Down Expand Up @@ -450,7 +450,6 @@ def __init__(self, config: Gemma2Config, layer_idx: int):
self.input_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)


self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.sliding_window = config.sliding_window
Expand Down

0 comments on commit 607c45d

Please sign in to comment.