Skip to content

Commit

Permalink
revert sdpa check
Browse files Browse the repository at this point in the history
  • Loading branch information
ArthurZucker committed Nov 4, 2024
1 parent 21edaed commit b5d9819
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 0 deletions.
14 changes: 14 additions & 0 deletions src/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,20 @@ def _init_weights(self, module):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()

@classmethod
def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False):
"""
Overloads `PreTrainedModel._check_and_enable_sdpa` so as to DISABLE torch SDPA by default on Gemma2 models.
SDPA reduces the model performance on Gemma2 because of the logits softcapping.
"""
config = super()._check_and_enable_sdpa(config, hard_check_only=hard_check_only)

# if using the default path -> swap sdpa by eager
if not hard_check_only and config._attn_implementation == "sdpa":
config._attn_implementation = "eager"

return config


GEMMA2_INPUTS_DOCSTRING = r"""
Args:
Expand Down
14 changes: 14 additions & 0 deletions src/transformers/models/gemma2/modular_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,20 @@ def forward(
class Gemma2PreTrainedModel(GemmaPreTrainedModel):
_supports_quantized_cache = False

@classmethod
def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False):
"""
Overloads `PreTrainedModel._check_and_enable_sdpa` so as to DISABLE torch SDPA by default on Gemma2 models.
SDPA reduces the model performance on Gemma2 because of the logits softcapping.
"""
config = super()._check_and_enable_sdpa(config, hard_check_only=hard_check_only)

# if using the default path -> swap sdpa by eager
if not hard_check_only and config._attn_implementation == "sdpa":
config._attn_implementation = "eager"

return config


class Gemma2Model(GemmaModel, Gemma2PreTrainedModel):
def __init__(self, config: Gemma2Config):
Expand Down

0 comments on commit b5d9819

Please sign in to comment.