diff --git a/optimum/bettertransformer/models/attention.py b/optimum/bettertransformer/models/attention.py index 702aca3257b..7089a55c627 100644 --- a/optimum/bettertransformer/models/attention.py +++ b/optimum/bettertransformer/models/attention.py @@ -695,6 +695,7 @@ def gpt_bigcode_wrapped_scaled_dot_product( # MHA models: (batch_size, num_heads, query_length, head_dim) query_shape = query.shape batch_size = query_shape[0] + kv_seq_len = key.shape[-2] if self.multi_query: query_length = query_shape[1] @@ -721,30 +722,34 @@ def gpt_bigcode_wrapped_scaled_dot_product( key = key.expand(-1, self.num_heads, -1, -1) value = value.expand(-1, self.num_heads, -1, -1) - if batch_size == 1 or self.training: - if query_length > 1: - sdpa_result = torch.nn.functional.scaled_dot_product_attention( - query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=True - ) - else: - sdpa_result = torch.nn.functional.scaled_dot_product_attention( - query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=False - ) + # We treat self.training and (batch_size == 1 and query_length == 1) cases separately to still allow the dispatch to Flash Attention. + if self.training: + is_causal = True + attn_mask = None + elif batch_size == 1 and query_length == 1: + is_causal = False + attn_mask = None + elif batch_size == 1 and kv_seq_len == query_length: + is_causal = True + attn_mask = None + elif attention_mask is not None: + mask_value = self._get_mask_value(query.device, query.dtype) + + # gpt_bigcode has the bad taste to use a causal mask a + # [batch_size, target_length, 1, source_length] which is different from + # **all** other architectures and not compatible with SDPA. + # We could avoid this transpose by overriding the forward from GPTBigCodeModel, + # but it is probably not worth it. + attention_mask = attention_mask.transpose(1, 2) + attn_mask = torch.where(attention_mask, 0.0, mask_value) + is_causal = False else: - if attention_mask is not None: - mask_value = self._get_mask_value(query.device, query.dtype) + attn_mask = None + is_causal = True - # gpt_bigcode has the bad taste to use a causal mask a - # [batch_size, target_length, 1, source_length] which is different from - # **all** other architectures and not compatible with SDPA. - # We could avoid this transpose by overriding the forward from GPTBigCodeModel, - # but it is probably not worth it. - attention_mask = attention_mask.transpose(1, 2) - attention_mask = torch.where(attention_mask, 0.0, mask_value) - - sdpa_result = torch.nn.functional.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=dropout_p, is_causal=False - ) + sdpa_result = torch.nn.functional.scaled_dot_product_attention( + query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal + ) if self.multi_query: # (batch_size, num_heads, seq_len, head_dim) --> (batch_size, seq_len, num_heads, head_dim)