Skip to content

Commit

Permalink
fixed t5_forward for real, because it's also used by blip-2 as well
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Oct 23, 2024
1 parent d25cd97 commit 29a5dbc
Show file tree
Hide file tree
Showing 5 changed files with 223 additions and 114 deletions.
326 changes: 216 additions & 110 deletions optimum/bettertransformer/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,137 +387,243 @@ def opt_forward(


# Adapted from transformers.models.t5.modeling_t5.T5Attention.forward
def t5_forward(
self,
hidden_states,
mask=None,
key_value_states=None,
position_bias=None,
past_key_value=None,
layer_head_mask=None,
query_length=None,
use_cache=False,
output_attentions=False,
**kwargs,
):
raise_on_head_mask(layer_head_mask)
if check_if_transformers_greater("4.45.99"):

if output_attentions is True:
raise ValueError("output_attentions=True can not be supported with BetterTransformer.")
if len(self.pruned_heads) > 0:
raise ValueError(f"Setting `pruned_heads` is unsupported with BetterTransformer, found {self.pruned_heads}.")
batch_size, seq_length = hidden_states.shape[:2]

real_seq_length = seq_length

if past_key_value is not None:
assert (
len(past_key_value) == 2
), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length

key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]

def shape(states):
"""projection"""
return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)

def unshape(states):
"""reshape"""
return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)

def project(hidden_states, proj_layer, key_value_states, past_key_value):
"""projects hidden states correctly to key/query states"""
if key_value_states is None:
# self-attn
# (batch_size, n_heads, seq_length, dim_per_head)
hidden_states = shape(proj_layer(hidden_states))
elif past_key_value is None:
# cross-attn
# (batch_size, n_heads, seq_length, dim_per_head)
hidden_states = shape(proj_layer(key_value_states))
def t5_forward(
self,
hidden_states,
mask=None,
key_value_states=None,
position_bias=None,
past_key_value=None,
layer_head_mask=None,
query_length=None,
use_cache=False,
output_attentions=False,
cache_position=None,
):
"""
Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
"""
# Input is (batch_size, seq_length, dim)
# Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder)
batch_size, seq_length = hidden_states.shape[:2]

# if key_value_states are provided this layer is used as a cross-attention layer for the decoder
is_cross_attention = key_value_states is not None

query_states = self.q(hidden_states)
query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)

if past_key_value is not None:
is_updated = past_key_value.is_updated.get(self.layer_idx)
if is_cross_attention:
# after the first generated id, we can subsequently re-use all key/value_states from cache
curr_past_key_value = past_key_value.cross_attention_cache
else:
curr_past_key_value = past_key_value.self_attention_cache

current_states = key_value_states if is_cross_attention else hidden_states
if is_cross_attention and past_key_value is not None and is_updated:
# reuse k,v, cross_attentions
key_states = curr_past_key_value.key_cache[self.layer_idx]
value_states = curr_past_key_value.value_cache[self.layer_idx]
else:
key_states = self.k(current_states)
value_states = self.v(current_states)
key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)

if past_key_value is not None:
# save all key/value_states to cache to be re-used for fast auto-regressive generation
cache_position = cache_position if not is_cross_attention else None
key_states, value_states = curr_past_key_value.update(
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
)
# set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
if is_cross_attention:
past_key_value.is_updated[self.layer_idx] = True

if position_bias is None:
key_length = key_states.shape[-2]
# cache position is 0-indexed so we add 1 to get the real length of queries (aka with past)
real_seq_length = query_length if query_length is not None else cache_position[-1] + 1
if not self.has_relative_attention_bias:
position_bias = torch.zeros(
(1, self.n_heads, seq_length, key_length), device=query_states.device, dtype=query_states.dtype
)
if self.gradient_checkpointing and self.training:
position_bias.requires_grad = True
else:
position_bias = self.compute_bias(
real_seq_length, key_length, device=query_states.device, cache_position=cache_position
)
position_bias = position_bias[:, :, -seq_length:, :]

if mask is not None:
causal_mask = mask[:, :, :, : key_states.shape[-2]]
position_bias = position_bias + causal_mask

if self.pruned_heads:
mask = torch.ones(position_bias.shape[1])
mask[list(self.pruned_heads)] = 0
position_bias_masked = position_bias[:, mask.bool()]
else:
position_bias_masked = position_bias

attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=position_bias_masked,
dropout_p=self.dropout if self.training else 0.0,
is_causal=False,
)

attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, -1, self.inner_dim)
attn_output = self.o(attn_output)

outputs = (attn_output, past_key_value, position_bias)

return outputs

else:

def t5_forward(
self,
hidden_states,
mask=None,
key_value_states=None,
position_bias=None,
past_key_value=None,
layer_head_mask=None,
query_length=None,
use_cache=False,
output_attentions=False,
**kwargs,
):
raise_on_head_mask(layer_head_mask)

if output_attentions is True:
raise ValueError("output_attentions=True can not be supported with BetterTransformer.")
if len(self.pruned_heads) > 0:
raise ValueError(
f"Setting `pruned_heads` is unsupported with BetterTransformer, found {self.pruned_heads}."
)

batch_size, seq_length = hidden_states.shape[:2]

real_seq_length = seq_length

if past_key_value is not None:
assert (
len(past_key_value) == 2
), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length

key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]

def shape(states):
"""projection"""
return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)

def unshape(states):
"""reshape"""
return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)

def project(hidden_states, proj_layer, key_value_states, past_key_value):
"""projects hidden states correctly to key/query states"""
if key_value_states is None:
# self-attn
# (batch_size, n_heads, key_length, dim_per_head)
hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
elif past_key_value.shape[2] != key_value_states.shape[1]:
# checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
# (batch_size, n_heads, seq_length, dim_per_head)
hidden_states = shape(proj_layer(hidden_states))
elif past_key_value is None:
# cross-attn
# (batch_size, n_heads, seq_length, dim_per_head)
hidden_states = shape(proj_layer(key_value_states))
else:
# cross-attn
hidden_states = past_key_value
return hidden_states

# get query states
query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head)

# get key/value states
key_states = project(
hidden_states,
self.k,
key_value_states,
past_key_value[0] if past_key_value is not None else None,
)
value_states = project(
hidden_states,
self.v,
key_value_states,
past_key_value[1] if past_key_value is not None else None,
)
if past_key_value is not None:
if key_value_states is None:
# self-attn
# (batch_size, n_heads, key_length, dim_per_head)
hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
elif past_key_value.shape[2] != key_value_states.shape[1]:
# checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
# cross-attn
# (batch_size, n_heads, seq_length, dim_per_head)
hidden_states = shape(proj_layer(key_value_states))
else:
# cross-attn
hidden_states = past_key_value
return hidden_states

# get query states
query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head)

# get key/value states
key_states = project(
hidden_states,
self.k,
key_value_states,
past_key_value[0] if past_key_value is not None else None,
)
value_states = project(
hidden_states,
self.v,
key_value_states,
past_key_value[1] if past_key_value is not None else None,
)

dropout_p = self.dropout if self.training else 0.0
query_states = self.scale * query_states
if position_bias is None and not self.has_relative_attention_bias:
if mask is None:
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, key_states, value_states, attn_mask=None, dropout_p=dropout_p, is_causal=False
)
elif mask is not None:
dropout_p = self.dropout if self.training else 0.0
query_states = self.scale * query_states
if position_bias is None and not self.has_relative_attention_bias:
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, key_states, value_states, attn_mask=mask, dropout_p=dropout_p, is_causal=False
)

if position_bias is None:
if not self.has_relative_attention_bias:
position_bias = torch.zeros(
(1, self.n_heads, real_seq_length, key_length),
device=value_states.device,
dtype=value_states.dtype,
)
if self.gradient_checkpointing and self.training:
position_bias.requires_grad = True
if position_bias is None:
if not self.has_relative_attention_bias:
position_bias = torch.zeros(
(1, self.n_heads, real_seq_length, key_length),
device=value_states.device,
dtype=value_states.dtype,
)
if self.gradient_checkpointing and self.training:
position_bias.requires_grad = True
else:
position_bias = self.compute_bias(real_seq_length, key_length, device=value_states.device)

# if key and values are already calculated
# we want only the last query position bias
if past_key_value is not None:
position_bias = position_bias[:, :, -hidden_states.size(1) :, :]

if mask is not None:
position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)

if self.has_relative_attention_bias:
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=position_bias,
dropout_p=dropout_p,
is_causal=False,
)
else:
position_bias = self.compute_bias(real_seq_length, key_length, device=value_states.device)

# if key and values are already calculated
# we want only the last query position bias
if past_key_value is not None:
position_bias = position_bias[:, :, -hidden_states.size(1) :, :]

if mask is not None:
position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)

if self.has_relative_attention_bias:
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, key_states, value_states, attn_mask=position_bias, dropout_p=dropout_p, is_causal=False
)
else:
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, key_states, value_states, attn_mask=position_bias, dropout_p=dropout_p, is_causal=False
)

attn_output = unshape(attn_output) # (batch_size, seq_length, dim)
attn_output = self.o(attn_output)
attn_output = unshape(attn_output) # (batch_size, seq_length, dim)
attn_output = self.o(attn_output)

present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None
outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None
outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)

return outputs
return outputs


# Adapted from transformers.models.bart.modeling_bart.BartAttention.forward
Expand Down
4 changes: 2 additions & 2 deletions optimum/bettertransformer/models/decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,9 +327,9 @@ def __init__(self, layer: "nn.Module", config: "PretrainedConfig"):
setattr(self, "relative_attention_bias", layer.relative_attention_bias)
self.original_layers_mapping["relative_attention_bias"] = "relative_attention_bias"

self.module_mapping = None

self.layer_idx = getattr(layer, "layer_idx", None)
self.is_decoder = layer.is_decoder
self.module_mapping = None

def forward(self, *args, **kwargs):
return t5_forward(self, *args, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
"h5py",
"numpy<1.24.0",
"datasets<=2.16",
# "transformers[sentencepiece]>=4.26,<4.38",
"transformers>=4.26,<4.38",
],
"diffusers": ["diffusers"],
"intel": "optimum-intel>=1.18.0",
Expand Down
4 changes: 3 additions & 1 deletion tests/bettertransformer/test_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,9 @@ def test_invert_model_logits(self, test_name: str, model_type: str, keep_origina
@require_torch_gpu
@require_accelerate
def test_accelerate_compatibility_cpu_gpu(self, keep_original_model=True, max_memory=None):
hf_model = AutoModelForCausalLM.from_pretrained("gpt2", device_map="auto", max_memory=max_memory, attn_implementation="eager").eval()
hf_model = AutoModelForCausalLM.from_pretrained(
"gpt2", device_map="auto", max_memory=max_memory, attn_implementation="eager"
).eval()
bt_model = BetterTransformer.transform(
hf_model, keep_original_model=keep_original_model, max_memory=max_memory
)
Expand Down
1 change: 1 addition & 0 deletions tests/bettertransformer/test_encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class BetterTransformersEncoderDecoderTest(BetterTransformersTestMixin, unittest
"mbart",
"pegasus",
"prophetnet",
"t5",
]

FULL_GRID = {
Expand Down

0 comments on commit 29a5dbc

Please sign in to comment.