diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index fe03ad368e5a61..4b49314c6e1bca 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -311,8 +311,6 @@ def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): self.scaling = config.query_pre_attn_scalar**-0.5 self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None - self.attention_type = config._attn_implementation - self.attention_function = GEMMA2_ATTENTION_FUNCTION[config._attn_implementation] self.attn_logit_softcapping = config.attn_logit_softcapping if self.hidden_size % self.num_heads != 0: raise ValueError( @@ -363,11 +361,11 @@ def forward( } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - if output_attentions and self.attention_type in ["sdpa", "flash_attention_2"]: + if output_attentions and self.config._attn_implementation in ["sdpa", "flash_attention_2"]: logger.warning_once("Setting `attention_type` to `flex_attention` because `output_attentions=True`") attention_type = "flex_attention" else: - attention_type = self.attention_type + attention_type = self.config._attn_implementation attn_output, attn_weights = GEMMA2_ATTENTION_FUNCTION[attention_type]( self, query_states, key_states, value_states, attention_mask diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index 64fd540d9250ab..a28b71564ff39c 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -354,8 +354,6 @@ def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): self.scaling = config.query_pre_attn_scalar**-0.5 self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None - self.attention_type = config._attn_implementation - self.attention_function = GEMMA2_ATTENTION_FUNCTION[config._attn_implementation] self.attn_logit_softcapping = config.attn_logit_softcapping if self.hidden_size % self.num_heads != 0: raise ValueError( @@ -406,11 +404,11 @@ def forward( } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - if output_attentions and self.attention_type in ["sdpa", "flash_attention_2"]: + if output_attentions and self.config._attn_implementation in ["sdpa", "flash_attention_2"]: logger.warning_once("Setting `attention_type` to `flex_attention` because `output_attentions=True`") attention_type = "flex_attention" else: - attention_type = self.attention_type + attention_type = self.config._attn_implementation attn_output, attn_weights = GEMMA2_ATTENTION_FUNCTION[attention_type]( self, query_states, key_states, value_states, attention_mask