diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 164de60f35f..7fd350308f6 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -234,7 +234,7 @@ def tanh_softcap(score, b, h, q_idx, kv_idx): soft_cap = config.attn_logit_softcapping score = soft_cap * torch.tanh(score / soft_cap) if mask is not None: - return score + mask[b][h] + return score + mask[b][h][q_idx][kv_idx] return score attn_output = flex_attention( diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index 828f61a1be4..edd6ad99b60 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -277,7 +277,7 @@ def tanh_softcap(score, b, h, q_idx, kv_idx): soft_cap = config.attn_logit_softcapping score = soft_cap * torch.tanh(score / soft_cap) if mask is not None: - return score + mask[b][h] + return score + mask[b][h][q_idx][kv_idx] return score attn_output = flex_attention(