diff --git a/optimum/bettertransformer/models/attention.py b/optimum/bettertransformer/models/attention.py index 0e17875aa2f..69ff9ad8592 100644 --- a/optimum/bettertransformer/models/attention.py +++ b/optimum/bettertransformer/models/attention.py @@ -747,7 +747,7 @@ def gpt_bigcode_wrapped_scaled_dot_product( is_causal = True sdpa_result = torch.nn.functional.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=dropout_p, is_causal=is_causal + query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal ) if self.multi_query: