Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
Co-authored-by: Cyril Vallez <[email protected]>
  • Loading branch information
ArthurZucker and Cyrilvallez authored Nov 6, 2024
1 parent b5d9819 commit 5a3dade
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions src/transformers/models/gemma2/modular_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def sdpa_attention_forward(config, query, key, value, mask, **_kwargs):


GEMMA2_ATTENTION_FUNCTION = {
"flash_attention": flash_attention_forward,
"flash_attention_2": flash_attention_forward,
"flex_attention": flex_attention_forward,
"eager": eager_attention_forward,
"sdpa": sdpa_attention_forward,
Expand Down Expand Up @@ -427,19 +427,19 @@ def forward(
class Gemma2FlashAttention2(Gemma2Attention):
def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None):
super().__init__(config, layer_idx)
self.config._attn_implementation = GEMMA2_ATTENTION_FUNCTION["flash_attention"]
self.config._attn_implementation = "flash_attention_2"
logger.warning_once(
"The `Gemma2FlashAttention2` class is deprecated in favor of simply modify the `attention_function`"
"The `Gemma2FlashAttention2` class is deprecated in favor of simply modifying the `config._attn_implementation`"
"attribute of the `GemmaAttention` class! It will be removed in v4.48"
)


class Gemma2SdpaAttention(Gemma2Attention):
def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None):
super().__init__(config, layer_idx)
self.config._attn_implementation = GEMMA2_ATTENTION_FUNCTION["sdpa"]
self.config._attn_implementation = "sdpa"
logger.warning_once(
"The `Gemma2FlashAttention2` class is deprecated in favor of simply modify the `attention_function`"
"The `Gemma2FlashAttention2` class is deprecated in favor of simply modifying the `config._attn_implementation`"
"attribute of the `GemmaAttention` class! It will be removed in v4.48"
)

Expand Down

0 comments on commit 5a3dade

Please sign in to comment.