From b5d98194bd7380168e9e819d3b0edf73c7911d31 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 4 Nov 2024 19:16:39 +0100 Subject: [PATCH] revert sdpa check --- src/transformers/models/gemma2/modeling_gemma2.py | 14 ++++++++++++++ src/transformers/models/gemma2/modular_gemma2.py | 14 ++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index e0cf7c5edfb..952f44b5c6d 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -518,6 +518,20 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + @classmethod + def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False): + """ + Overloads `PreTrainedModel._check_and_enable_sdpa` so as to DISABLE torch SDPA by default on Gemma2 models. + SDPA reduces the model performance on Gemma2 because of the logits softcapping. + """ + config = super()._check_and_enable_sdpa(config, hard_check_only=hard_check_only) + + # if using the default path -> swap sdpa by eager + if not hard_check_only and config._attn_implementation == "sdpa": + config._attn_implementation = "eager" + + return config + GEMMA2_INPUTS_DOCSTRING = r""" Args: diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index 3f22ee0e782..c3942603cb2 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -520,6 +520,20 @@ def forward( class Gemma2PreTrainedModel(GemmaPreTrainedModel): _supports_quantized_cache = False + @classmethod + def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False): + """ + Overloads `PreTrainedModel._check_and_enable_sdpa` so as to DISABLE torch SDPA by default on Gemma2 models. + SDPA reduces the model performance on Gemma2 because of the logits softcapping. + """ + config = super()._check_and_enable_sdpa(config, hard_check_only=hard_check_only) + + # if using the default path -> swap sdpa by eager + if not hard_check_only and config._attn_implementation == "sdpa": + config._attn_implementation = "eager" + + return config + class Gemma2Model(GemmaModel, Gemma2PreTrainedModel): def __init__(self, config: Gemma2Config):