From 29a5dbc41f5a0b56d43b7b45c6557817e7a58c47 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 23 Oct 2024 17:20:59 +0200 Subject: [PATCH] fixed t5_forward for real, because it's also used by blip-2 as well --- optimum/bettertransformer/models/attention.py | 326 ++++++++++++------ .../models/decoder_models.py | 4 +- setup.py | 2 +- tests/bettertransformer/test_decoder.py | 4 +- .../bettertransformer/test_encoder_decoder.py | 1 + 5 files changed, 223 insertions(+), 114 deletions(-) diff --git a/optimum/bettertransformer/models/attention.py b/optimum/bettertransformer/models/attention.py index 22b8faf1c21..c8c91a04e4e 100644 --- a/optimum/bettertransformer/models/attention.py +++ b/optimum/bettertransformer/models/attention.py @@ -387,137 +387,243 @@ def opt_forward( # Adapted from transformers.models.t5.modeling_t5.T5Attention.forward -def t5_forward( - self, - hidden_states, - mask=None, - key_value_states=None, - position_bias=None, - past_key_value=None, - layer_head_mask=None, - query_length=None, - use_cache=False, - output_attentions=False, - **kwargs, -): - raise_on_head_mask(layer_head_mask) +if check_if_transformers_greater("4.45.99"): - if output_attentions is True: - raise ValueError("output_attentions=True can not be supported with BetterTransformer.") - if len(self.pruned_heads) > 0: - raise ValueError(f"Setting `pruned_heads` is unsupported with BetterTransformer, found {self.pruned_heads}.") - batch_size, seq_length = hidden_states.shape[:2] - - real_seq_length = seq_length - - if past_key_value is not None: - assert ( - len(past_key_value) == 2 - ), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" - real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length - - key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] - - def shape(states): - """projection""" - return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) - - def unshape(states): - """reshape""" - return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) - - def project(hidden_states, proj_layer, key_value_states, past_key_value): - """projects hidden states correctly to key/query states""" - if key_value_states is None: - # self-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(hidden_states)) - elif past_key_value is None: - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) + def t5_forward( + self, + hidden_states, + mask=None, + key_value_states=None, + position_bias=None, + past_key_value=None, + layer_head_mask=None, + query_length=None, + use_cache=False, + output_attentions=False, + cache_position=None, + ): + """ + Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). + """ + # Input is (batch_size, seq_length, dim) + # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder) + batch_size, seq_length = hidden_states.shape[:2] + + # if key_value_states are provided this layer is used as a cross-attention layer for the decoder + is_cross_attention = key_value_states is not None + + query_states = self.q(hidden_states) + query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) if past_key_value is not None: + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: + # reuse k,v, cross_attentions + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] + else: + key_states = self.k(current_states) + value_states = self.v(current_states) + key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True + + if position_bias is None: + key_length = key_states.shape[-2] + # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past) + real_seq_length = query_length if query_length is not None else cache_position[-1] + 1 + if not self.has_relative_attention_bias: + position_bias = torch.zeros( + (1, self.n_heads, seq_length, key_length), device=query_states.device, dtype=query_states.dtype + ) + if self.gradient_checkpointing and self.training: + position_bias.requires_grad = True + else: + position_bias = self.compute_bias( + real_seq_length, key_length, device=query_states.device, cache_position=cache_position + ) + position_bias = position_bias[:, :, -seq_length:, :] + + if mask is not None: + causal_mask = mask[:, :, :, : key_states.shape[-2]] + position_bias = position_bias + causal_mask + + if self.pruned_heads: + mask = torch.ones(position_bias.shape[1]) + mask[list(self.pruned_heads)] = 0 + position_bias_masked = position_bias[:, mask.bool()] + else: + position_bias_masked = position_bias + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=position_bias_masked, + dropout_p=self.dropout if self.training else 0.0, + is_causal=False, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(batch_size, -1, self.inner_dim) + attn_output = self.o(attn_output) + + outputs = (attn_output, past_key_value, position_bias) + + return outputs + +else: + + def t5_forward( + self, + hidden_states, + mask=None, + key_value_states=None, + position_bias=None, + past_key_value=None, + layer_head_mask=None, + query_length=None, + use_cache=False, + output_attentions=False, + **kwargs, + ): + raise_on_head_mask(layer_head_mask) + + if output_attentions is True: + raise ValueError("output_attentions=True can not be supported with BetterTransformer.") + if len(self.pruned_heads) > 0: + raise ValueError( + f"Setting `pruned_heads` is unsupported with BetterTransformer, found {self.pruned_heads}." + ) + + batch_size, seq_length = hidden_states.shape[:2] + + real_seq_length = seq_length + + if past_key_value is not None: + assert ( + len(past_key_value) == 2 + ), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" + real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length + + key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] + + def shape(states): + """projection""" + return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + + def unshape(states): + """reshape""" + return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) + + def project(hidden_states, proj_layer, key_value_states, past_key_value): + """projects hidden states correctly to key/query states""" if key_value_states is None: # self-attn - # (batch_size, n_heads, key_length, dim_per_head) - hidden_states = torch.cat([past_key_value, hidden_states], dim=2) - elif past_key_value.shape[2] != key_value_states.shape[1]: - # checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(hidden_states)) + elif past_key_value is None: # cross-attn # (batch_size, n_heads, seq_length, dim_per_head) hidden_states = shape(proj_layer(key_value_states)) - else: - # cross-attn - hidden_states = past_key_value - return hidden_states - - # get query states - query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) - # get key/value states - key_states = project( - hidden_states, - self.k, - key_value_states, - past_key_value[0] if past_key_value is not None else None, - ) - value_states = project( - hidden_states, - self.v, - key_value_states, - past_key_value[1] if past_key_value is not None else None, - ) + if past_key_value is not None: + if key_value_states is None: + # self-attn + # (batch_size, n_heads, key_length, dim_per_head) + hidden_states = torch.cat([past_key_value, hidden_states], dim=2) + elif past_key_value.shape[2] != key_value_states.shape[1]: + # checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + # cross-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(key_value_states)) + else: + # cross-attn + hidden_states = past_key_value + return hidden_states + + # get query states + query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) + + # get key/value states + key_states = project( + hidden_states, + self.k, + key_value_states, + past_key_value[0] if past_key_value is not None else None, + ) + value_states = project( + hidden_states, + self.v, + key_value_states, + past_key_value[1] if past_key_value is not None else None, + ) - dropout_p = self.dropout if self.training else 0.0 - query_states = self.scale * query_states - if position_bias is None and not self.has_relative_attention_bias: - if mask is None: - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, key_states, value_states, attn_mask=None, dropout_p=dropout_p, is_causal=False - ) - elif mask is not None: + dropout_p = self.dropout if self.training else 0.0 + query_states = self.scale * query_states + if position_bias is None and not self.has_relative_attention_bias: attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=mask, dropout_p=dropout_p, is_causal=False ) - if position_bias is None: - if not self.has_relative_attention_bias: - position_bias = torch.zeros( - (1, self.n_heads, real_seq_length, key_length), - device=value_states.device, - dtype=value_states.dtype, - ) - if self.gradient_checkpointing and self.training: - position_bias.requires_grad = True + if position_bias is None: + if not self.has_relative_attention_bias: + position_bias = torch.zeros( + (1, self.n_heads, real_seq_length, key_length), + device=value_states.device, + dtype=value_states.dtype, + ) + if self.gradient_checkpointing and self.training: + position_bias.requires_grad = True + else: + position_bias = self.compute_bias(real_seq_length, key_length, device=value_states.device) + + # if key and values are already calculated + # we want only the last query position bias + if past_key_value is not None: + position_bias = position_bias[:, :, -hidden_states.size(1) :, :] + + if mask is not None: + position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) + + if self.has_relative_attention_bias: + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=position_bias, + dropout_p=dropout_p, + is_causal=False, + ) else: - position_bias = self.compute_bias(real_seq_length, key_length, device=value_states.device) - - # if key and values are already calculated - # we want only the last query position bias - if past_key_value is not None: - position_bias = position_bias[:, :, -hidden_states.size(1) :, :] - - if mask is not None: - position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) - - if self.has_relative_attention_bias: attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=position_bias, dropout_p=dropout_p, is_causal=False ) - else: - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, key_states, value_states, attn_mask=position_bias, dropout_p=dropout_p, is_causal=False - ) - attn_output = unshape(attn_output) # (batch_size, seq_length, dim) - attn_output = self.o(attn_output) + attn_output = unshape(attn_output) # (batch_size, seq_length, dim) + attn_output = self.o(attn_output) - present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None - outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) + present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None + outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) - return outputs + return outputs # Adapted from transformers.models.bart.modeling_bart.BartAttention.forward diff --git a/optimum/bettertransformer/models/decoder_models.py b/optimum/bettertransformer/models/decoder_models.py index 52d28d076d3..e8045e695c1 100644 --- a/optimum/bettertransformer/models/decoder_models.py +++ b/optimum/bettertransformer/models/decoder_models.py @@ -327,9 +327,9 @@ def __init__(self, layer: "nn.Module", config: "PretrainedConfig"): setattr(self, "relative_attention_bias", layer.relative_attention_bias) self.original_layers_mapping["relative_attention_bias"] = "relative_attention_bias" - self.module_mapping = None - + self.layer_idx = getattr(layer, "layer_idx", None) self.is_decoder = layer.is_decoder + self.module_mapping = None def forward(self, *args, **kwargs): return t5_forward(self, *args, **kwargs) diff --git a/setup.py b/setup.py index 243bb46699e..62dd6ee8fa5 100644 --- a/setup.py +++ b/setup.py @@ -77,7 +77,7 @@ "h5py", "numpy<1.24.0", "datasets<=2.16", - # "transformers[sentencepiece]>=4.26,<4.38", + "transformers>=4.26,<4.38", ], "diffusers": ["diffusers"], "intel": "optimum-intel>=1.18.0", diff --git a/tests/bettertransformer/test_decoder.py b/tests/bettertransformer/test_decoder.py index f5958ceb1d2..e2bc6ddc2fb 100644 --- a/tests/bettertransformer/test_decoder.py +++ b/tests/bettertransformer/test_decoder.py @@ -224,7 +224,9 @@ def test_invert_model_logits(self, test_name: str, model_type: str, keep_origina @require_torch_gpu @require_accelerate def test_accelerate_compatibility_cpu_gpu(self, keep_original_model=True, max_memory=None): - hf_model = AutoModelForCausalLM.from_pretrained("gpt2", device_map="auto", max_memory=max_memory, attn_implementation="eager").eval() + hf_model = AutoModelForCausalLM.from_pretrained( + "gpt2", device_map="auto", max_memory=max_memory, attn_implementation="eager" + ).eval() bt_model = BetterTransformer.transform( hf_model, keep_original_model=keep_original_model, max_memory=max_memory ) diff --git a/tests/bettertransformer/test_encoder_decoder.py b/tests/bettertransformer/test_encoder_decoder.py index b64f66fa1a3..5ce4d62b12c 100644 --- a/tests/bettertransformer/test_encoder_decoder.py +++ b/tests/bettertransformer/test_encoder_decoder.py @@ -45,6 +45,7 @@ class BetterTransformersEncoderDecoderTest(BetterTransformersTestMixin, unittest "mbart", "pegasus", "prophetnet", + "t5", ] FULL_GRID = {