From de2cf0440c18e3efd93eb95eb925cc08b3225762 Mon Sep 17 00:00:00 2001 From: Dong Hande <45357817+DongHande@users.noreply.github.com> Date: Wed, 27 Sep 2023 03:08:28 -0500 Subject: [PATCH 1/4] Update attention.py modify the code about bigcode. This modification makes the KV cache with multiple new tokens works well. --- optimum/bettertransformer/models/attention.py | 33 ++++++++----------- 1 file changed, 14 insertions(+), 19 deletions(-) diff --git a/optimum/bettertransformer/models/attention.py b/optimum/bettertransformer/models/attention.py index 702aca3257b..ba62516e466 100644 --- a/optimum/bettertransformer/models/attention.py +++ b/optimum/bettertransformer/models/attention.py @@ -721,30 +721,25 @@ def gpt_bigcode_wrapped_scaled_dot_product( key = key.expand(-1, self.num_heads, -1, -1) value = value.expand(-1, self.num_heads, -1, -1) - if batch_size == 1 or self.training: - if query_length > 1: - sdpa_result = torch.nn.functional.scaled_dot_product_attention( - query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=True - ) - else: - sdpa_result = torch.nn.functional.scaled_dot_product_attention( - query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=False - ) - else: - if attention_mask is not None: - mask_value = self._get_mask_value(query.device, query.dtype) + if attention_mask is not None: + mask_value = self._get_mask_value(query.device, query.dtype) - # gpt_bigcode has the bad taste to use a causal mask a - # [batch_size, target_length, 1, source_length] which is different from - # **all** other architectures and not compatible with SDPA. - # We could avoid this transpose by overriding the forward from GPTBigCodeModel, - # but it is probably not worth it. - attention_mask = attention_mask.transpose(1, 2) - attention_mask = torch.where(attention_mask, 0.0, mask_value) + # gpt_bigcode has the bad taste to use a causal mask a + # [batch_size, target_length, 1, source_length] which is different from + # **all** other architectures and not compatible with SDPA. + # We could avoid this transpose by overriding the forward from GPTBigCodeModel, + # but it is probably not worth it. + attention_mask = attention_mask.transpose(1, 2) + attention_mask = torch.where(attention_mask, 0.0, mask_value) sdpa_result = torch.nn.functional.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=dropout_p, is_causal=False ) + + else: + sdpa_result = torch.nn.functional.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=dropout_p, is_causal=True + ) if self.multi_query: # (batch_size, num_heads, seq_len, head_dim) --> (batch_size, seq_len, num_heads, head_dim) From 7376b42168d7143ee63c7ef5679d07612a545196 Mon Sep 17 00:00:00 2001 From: Dong Hande <45357817+DongHande@users.noreply.github.com> Date: Mon, 9 Oct 2023 04:47:21 -0500 Subject: [PATCH 2/4] consider batch size = 1 --- optimum/bettertransformer/models/attention.py | 29 ++++++++++++------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/optimum/bettertransformer/models/attention.py b/optimum/bettertransformer/models/attention.py index ba62516e466..0e17875aa2f 100644 --- a/optimum/bettertransformer/models/attention.py +++ b/optimum/bettertransformer/models/attention.py @@ -721,7 +721,17 @@ def gpt_bigcode_wrapped_scaled_dot_product( key = key.expand(-1, self.num_heads, -1, -1) value = value.expand(-1, self.num_heads, -1, -1) - if attention_mask is not None: + # We treat self.training and (batch_size == 1 and query_length == 1) cases separately to still allow the dispatch to Flash Attention. + if self.training: + is_causal = True + attn_mask = None + elif batch_size == 1 and query_length == 1: + is_causal = False + attn_mask = None + elif batch_size == 1 and kv_seq_len == query_length: + is_causal = True + attn_mask = None + elif attention_mask is not None: mask_value = self._get_mask_value(query.device, query.dtype) # gpt_bigcode has the bad taste to use a causal mask a @@ -730,16 +740,15 @@ def gpt_bigcode_wrapped_scaled_dot_product( # We could avoid this transpose by overriding the forward from GPTBigCodeModel, # but it is probably not worth it. attention_mask = attention_mask.transpose(1, 2) - attention_mask = torch.where(attention_mask, 0.0, mask_value) + attn_mask = torch.where(attention_mask, 0.0, mask_value) + is_causal = False + else: + attn_mask = None + is_causal = True - sdpa_result = torch.nn.functional.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=dropout_p, is_causal=False - ) - - else: - sdpa_result = torch.nn.functional.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=dropout_p, is_causal=True - ) + sdpa_result = torch.nn.functional.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=dropout_p, is_causal=is_causal + ) if self.multi_query: # (batch_size, num_heads, seq_len, head_dim) --> (batch_size, seq_len, num_heads, head_dim) From 34b7ab11f5238ee1a62d49ee2d7db63b2bb87488 Mon Sep 17 00:00:00 2001 From: Dong Hande <45357817+DongHande@users.noreply.github.com> Date: Mon, 9 Oct 2023 05:47:38 -0500 Subject: [PATCH 3/4] Update attention.py --- optimum/bettertransformer/models/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/bettertransformer/models/attention.py b/optimum/bettertransformer/models/attention.py index 0e17875aa2f..69ff9ad8592 100644 --- a/optimum/bettertransformer/models/attention.py +++ b/optimum/bettertransformer/models/attention.py @@ -747,7 +747,7 @@ def gpt_bigcode_wrapped_scaled_dot_product( is_causal = True sdpa_result = torch.nn.functional.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=dropout_p, is_causal=is_causal + query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal ) if self.multi_query: From ae7323227f274aa3d4f983b467799ef797e51f15 Mon Sep 17 00:00:00 2001 From: Dong Hande <45357817+DongHande@users.noreply.github.com> Date: Mon, 9 Oct 2023 07:12:03 -0500 Subject: [PATCH 4/4] def kv_seq_len --- optimum/bettertransformer/models/attention.py | 1 + 1 file changed, 1 insertion(+) diff --git a/optimum/bettertransformer/models/attention.py b/optimum/bettertransformer/models/attention.py index 69ff9ad8592..7089a55c627 100644 --- a/optimum/bettertransformer/models/attention.py +++ b/optimum/bettertransformer/models/attention.py @@ -695,6 +695,7 @@ def gpt_bigcode_wrapped_scaled_dot_product( # MHA models: (batch_size, num_heads, query_length, head_dim) query_shape = query.shape batch_size = query_shape[0] + kv_seq_len = key.shape[-2] if self.multi_query: query_length = query_shape[1]