Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
ArthurZucker committed Nov 4, 2024
1 parent c06b530 commit 89e6f85
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 8 deletions.
6 changes: 2 additions & 4 deletions src/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,8 +311,6 @@ def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None):

self.scaling = config.query_pre_attn_scalar**-0.5
self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None
self.attention_type = config._attn_implementation
self.attention_function = GEMMA2_ATTENTION_FUNCTION[config._attn_implementation]
self.attn_logit_softcapping = config.attn_logit_softcapping
if self.hidden_size % self.num_heads != 0:
raise ValueError(
Expand Down Expand Up @@ -363,11 +361,11 @@ def forward(
}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

if output_attentions and self.attention_type in ["sdpa", "flash_attention_2"]:
if output_attentions and self.config._attn_implementation in ["sdpa", "flash_attention_2"]:
logger.warning_once("Setting `attention_type` to `flex_attention` because `output_attentions=True`")
attention_type = "flex_attention"
else:
attention_type = self.attention_type
attention_type = self.config._attn_implementation

attn_output, attn_weights = GEMMA2_ATTENTION_FUNCTION[attention_type](
self, query_states, key_states, value_states, attention_mask
Expand Down
6 changes: 2 additions & 4 deletions src/transformers/models/gemma2/modular_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,8 +354,6 @@ def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None):

self.scaling = config.query_pre_attn_scalar**-0.5
self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None
self.attention_type = config._attn_implementation
self.attention_function = GEMMA2_ATTENTION_FUNCTION[config._attn_implementation]
self.attn_logit_softcapping = config.attn_logit_softcapping
if self.hidden_size % self.num_heads != 0:
raise ValueError(
Expand Down Expand Up @@ -406,11 +404,11 @@ def forward(
}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

if output_attentions and self.attention_type in ["sdpa", "flash_attention_2"]:
if output_attentions and self.config._attn_implementation in ["sdpa", "flash_attention_2"]:
logger.warning_once("Setting `attention_type` to `flex_attention` because `output_attentions=True`")
attention_type = "flex_attention"
else:
attention_type = self.attention_type
attention_type = self.config._attn_implementation

attn_output, attn_weights = GEMMA2_ATTENTION_FUNCTION[attention_type](
self, query_states, key_states, value_states, attention_mask
Expand Down

0 comments on commit 89e6f85

Please sign in to comment.