diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 60ec557eba7fdc..efca985f67842e 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -19,6 +19,7 @@ from typing import List, Optional, Tuple, Union import torch +import torch.nn.functional as F import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss @@ -40,12 +41,18 @@ add_end_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + is_flash_attn_2_available, logging, replace_return_docstrings, ) from .configuration_bart import BartConfig +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "facebook/bart-base" @@ -71,6 +78,19 @@ ] +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): """ Shift input ids one token to the right. @@ -119,12 +139,15 @@ def __init__( dropout: float = 0.0, is_decoder: bool = False, bias: bool = True, + is_causal: bool = False, + config: Optional[BartConfig] = None, ): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.dropout = dropout self.head_dim = embed_dim // num_heads + self.config = config if (self.head_dim * num_heads) != self.embed_dim: raise ValueError( @@ -133,6 +156,7 @@ def __init__( ) self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder + self.is_causal = is_causal self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) @@ -263,14 +287,225 @@ def forward( return attn_output, attn_weights_reshaped, past_key_value +class BartFlashAttention2(BartAttention): + """ + Bart flash attention module. This module inherits from `BartAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # BartFlashAttention2 attention does not support output_attentions + if output_attentions: + raise ValueError("BartFlashAttention2 attention does not support output_attentions") + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, q_len, _ = hidden_states.size() + + # get query proj + query_states = self._reshape(self.q_proj(hidden_states), -1, bsz) + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0].transpose(1, 2) + value_states = past_key_value[1].transpose(1, 2) + elif is_cross_attention: + # cross_attentions + key_states = self._reshape(self.k_proj(key_value_states), -1, bsz) + value_states = self._reshape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) + value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1) + value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1) + else: + # self_attention + key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) + value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2)) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + # Handle the case where the model is quantized + if hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward( + query_states, key_states, value_states, attention_mask, q_len, dropout=self.dropout + ) + + attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`int`, *optional*): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=self.is_causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal + ) + + return attn_output + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +BART_ATTENTION_CLASSES = { + "default": BartAttention, + "flash_attention_2": BartFlashAttention2, +} + + class BartEncoderLayer(nn.Module): def __init__(self, config: BartConfig): super().__init__() self.embed_dim = config.d_model - self.self_attn = BartAttention( + attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default" + + self.self_attn = BART_ATTENTION_CLASSES[attn_type]( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, + config=config, ) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.dropout = config.dropout @@ -336,22 +571,26 @@ def __init__(self, config: BartConfig): super().__init__() self.embed_dim = config.d_model - self.self_attn = BartAttention( + attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default" + self.self_attn = BART_ATTENTION_CLASSES[attn_type]( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, is_decoder=True, + is_causal=True, + config=config, ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = BartAttention( + self.encoder_attn = BART_ATTENTION_CLASSES[attn_type]( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, is_decoder=True, + config=config, ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) @@ -479,6 +718,7 @@ class BartPreTrainedModel(PreTrainedModel): _keys_to_ignore_on_load_unexpected = ["encoder.version", "decoder.version"] _no_split_modules = [r"BartEncoderLayer", r"BartDecoderLayer"] _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True def _init_weights(self, module): std = self.config.init_std @@ -792,8 +1032,11 @@ def forward( # expand attention_mask if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + if getattr(self.config, "_flash_attn_2_enabled", False): + attention_mask = attention_mask if 0 in attention_mask else None + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None @@ -995,16 +1238,24 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input) * self.embed_scale - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length - ) + if getattr(self.config, "_flash_attn_2_enabled", False): + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask( - encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] - ) + if getattr(self.config, "_flash_attn_2_enabled", False): + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) # embed positions positions = self.embed_positions(input, past_key_values_length) diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index b9a84a869dac4f..222873ac852b51 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -1174,7 +1174,7 @@ def forward( return outputs -# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->BigBirdPegasusDecoder +# Copied from transformers.models.bart.modeling_bart.BartAttention with BartConfig->BigBirdPegasusConfig, Bart->BigBirdPegasusDecoder class BigBirdPegasusDecoderAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -1185,12 +1185,15 @@ def __init__( dropout: float = 0.0, is_decoder: bool = False, bias: bool = True, + is_causal: bool = False, + config: Optional[BigBirdPegasusConfig] = None, ): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.dropout = dropout self.head_dim = embed_dim // num_heads + self.config = config if (self.head_dim * num_heads) != self.embed_dim: raise ValueError( @@ -1199,6 +1202,7 @@ def __init__( ) self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder + self.is_causal = is_causal self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index 2cbfa4ef0f82a6..221c97c86885c5 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -90,12 +90,15 @@ def __init__( dropout: float = 0.0, is_decoder: bool = False, bias: bool = True, + is_causal: bool = False, + config: Optional[BioGptConfig] = None, ): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.dropout = dropout self.head_dim = embed_dim // num_heads + self.config = config if (self.head_dim * num_heads) != self.embed_dim: raise ValueError( @@ -104,6 +107,7 @@ def __init__( ) self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder + self.is_causal = is_causal self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index 985c660fa0b843..f49f90f794fc94 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -104,12 +104,15 @@ def __init__( dropout: float = 0.0, is_decoder: bool = False, bias: bool = True, + is_causal: bool = False, + config: Optional[BlenderbotConfig] = None, ): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.dropout = dropout self.head_dim = embed_dim // num_heads + self.config = config if (self.head_dim * num_heads) != self.embed_dim: raise ValueError( @@ -118,6 +121,7 @@ def __init__( ) self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder + self.is_causal = is_causal self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) @@ -248,15 +252,21 @@ def forward( return attn_output, attn_weights_reshaped, past_key_value -# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Blenderbot +BLENDERBOT_ATTENTION_CLASSES = {"default": BlenderbotAttention} + + +# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Blenderbot, MBART->BLENDERBOT class BlenderbotEncoderLayer(nn.Module): def __init__(self, config: BlenderbotConfig): super().__init__() self.embed_dim = config.d_model - self.self_attn = BlenderbotAttention( + attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default" + + self.self_attn = BLENDERBOT_ATTENTION_CLASSES[attn_type]( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, + config=config, ) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.dropout = config.dropout @@ -317,28 +327,32 @@ def forward( return outputs -# Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Blenderbot +# Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Blenderbot, MBART->BLENDERBOT class BlenderbotDecoderLayer(nn.Module): def __init__(self, config: BlenderbotConfig): super().__init__() self.embed_dim = config.d_model + attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default" - self.self_attn = BlenderbotAttention( + self.self_attn = BLENDERBOT_ATTENTION_CLASSES[attn_type]( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, is_decoder=True, + is_causal=True, + config=config, ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = BlenderbotAttention( + self.encoder_attn = BLENDERBOT_ATTENTION_CLASSES[attn_type]( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, is_decoder=True, + config=config, ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index 3d51ee91284e71..292b5a8c6e8bf6 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -101,12 +101,15 @@ def __init__( dropout: float = 0.0, is_decoder: bool = False, bias: bool = True, + is_causal: bool = False, + config: Optional[BlenderbotSmallConfig] = None, ): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.dropout = dropout self.head_dim = embed_dim // num_heads + self.config = config if (self.head_dim * num_heads) != self.embed_dim: raise ValueError( @@ -115,6 +118,7 @@ def __init__( ) self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder + self.is_causal = is_causal self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) @@ -245,15 +249,18 @@ def forward( return attn_output, attn_weights_reshaped, past_key_value -# Copied from transformers.models.bart.modeling_bart.BartEncoderLayer with Bart->BlenderbotSmall +# Copied from transformers.models.bart.modeling_bart.BartEncoderLayer with Bart->BlenderbotSmall, BART->BLENDERBOT_SMALL class BlenderbotSmallEncoderLayer(nn.Module): def __init__(self, config: BlenderbotSmallConfig): super().__init__() self.embed_dim = config.d_model - self.self_attn = BlenderbotSmallAttention( + attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default" + + self.self_attn = BLENDERBOT_SMALL_ATTENTION_CLASSES[attn_type]( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, + config=config, ) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.dropout = config.dropout @@ -314,28 +321,35 @@ def forward( return outputs -# Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->BlenderbotSmall +BLENDERBOT_SMALL_ATTENTION_CLASSES = {"default": BlenderbotSmallAttention} + + +# Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->BlenderbotSmall, BART->BLENDERBOT_SMALL class BlenderbotSmallDecoderLayer(nn.Module): def __init__(self, config: BlenderbotSmallConfig): super().__init__() self.embed_dim = config.d_model - self.self_attn = BlenderbotSmallAttention( + attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default" + self.self_attn = BLENDERBOT_SMALL_ATTENTION_CLASSES[attn_type]( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, is_decoder=True, + is_causal=True, + config=config, ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = BlenderbotSmallAttention( + self.encoder_attn = BLENDERBOT_SMALL_ATTENTION_CLASSES[attn_type]( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, is_decoder=True, + config=config, ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) diff --git a/src/transformers/models/data2vec/modeling_data2vec_audio.py b/src/transformers/models/data2vec/modeling_data2vec_audio.py index 47cf2d6245ef47..e393db64d045ea 100755 --- a/src/transformers/models/data2vec/modeling_data2vec_audio.py +++ b/src/transformers/models/data2vec/modeling_data2vec_audio.py @@ -330,12 +330,15 @@ def __init__( dropout: float = 0.0, is_decoder: bool = False, bias: bool = True, + is_causal: bool = False, + config: Optional[Data2VecAudioConfig] = None, ): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.dropout = dropout self.head_dim = embed_dim // num_heads + self.config = config if (self.head_dim * num_heads) != self.embed_dim: raise ValueError( @@ -344,6 +347,7 @@ def __init__( ) self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder + self.is_causal = is_causal self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) diff --git a/src/transformers/models/gptsan_japanese/modeling_gptsan_japanese.py b/src/transformers/models/gptsan_japanese/modeling_gptsan_japanese.py index 1232d24730c2be..7591ecb0b82afa 100644 --- a/src/transformers/models/gptsan_japanese/modeling_gptsan_japanese.py +++ b/src/transformers/models/gptsan_japanese/modeling_gptsan_japanese.py @@ -370,12 +370,15 @@ def __init__( dropout: float = 0.0, is_decoder: bool = False, bias: bool = True, + is_causal: bool = False, + config: Optional[GPTSanJapaneseConfig] = None, ): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.dropout = dropout self.head_dim = embed_dim // num_heads + self.config = config if (self.head_dim * num_heads) != self.embed_dim: raise ValueError( @@ -384,6 +387,7 @@ def __init__( ) self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder + self.is_causal = is_causal self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index ddb80f56723ecd..a45dcb2d11fe1f 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -396,12 +396,15 @@ def __init__( dropout: float = 0.0, is_decoder: bool = False, bias: bool = True, + is_causal: bool = False, + config: Optional[HubertConfig] = None, ): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.dropout = dropout self.head_dim = embed_dim // num_heads + self.config = config if (self.head_dim * num_heads) != self.embed_dim: raise ValueError( @@ -410,6 +413,7 @@ def __init__( ) self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder + self.is_causal = is_causal self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) diff --git a/src/transformers/models/informer/modeling_informer.py b/src/transformers/models/informer/modeling_informer.py index ebb4b2821ced83..c0a5a205950285 100644 --- a/src/transformers/models/informer/modeling_informer.py +++ b/src/transformers/models/informer/modeling_informer.py @@ -287,12 +287,15 @@ def __init__( dropout: float = 0.0, is_decoder: bool = False, bias: bool = True, + is_causal: bool = False, + config: Optional[InformerConfig] = None, ): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.dropout = dropout self.head_dim = embed_dim // num_heads + self.config = config if (self.head_dim * num_heads) != self.embed_dim: raise ValueError( @@ -301,6 +304,7 @@ def __init__( ) self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder + self.is_causal = is_causal self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 4e5004fc98ffd6..c05948540f7865 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -172,12 +172,15 @@ def __init__( dropout: float = 0.0, is_decoder: bool = False, bias: bool = True, + is_causal: bool = False, + config: Optional[M2M100Config] = None, ): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.dropout = dropout self.head_dim = embed_dim // num_heads + self.config = config if (self.head_dim * num_heads) != self.embed_dim: raise ValueError( @@ -186,6 +189,7 @@ def __init__( ) self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder + self.is_causal = is_causal self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) @@ -316,15 +320,18 @@ def forward( return attn_output, attn_weights_reshaped, past_key_value -# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->M2M100 +# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->M2M100, MBART->M2M100 class M2M100EncoderLayer(nn.Module): def __init__(self, config: M2M100Config): super().__init__() self.embed_dim = config.d_model - self.self_attn = M2M100Attention( + attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default" + + self.self_attn = M2M100_ATTENTION_CLASSES[attn_type]( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, + config=config, ) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.dropout = config.dropout @@ -385,28 +392,35 @@ def forward( return outputs -# Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->M2M100 +M2M100_ATTENTION_CLASSES = {"default": M2M100Attention} + + +# Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->M2M100, MBART->M2M100 class M2M100DecoderLayer(nn.Module): def __init__(self, config: M2M100Config): super().__init__() self.embed_dim = config.d_model + attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default" - self.self_attn = M2M100Attention( + self.self_attn = M2M100_ATTENTION_CLASSES[attn_type]( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, is_decoder=True, + is_causal=True, + config=config, ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = M2M100Attention( + self.encoder_attn = M2M100_ATTENTION_CLASSES[attn_type]( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, is_decoder=True, + config=config, ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index 10a6d5f342f1d1..cabf0c68f8b62b 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -119,12 +119,15 @@ def __init__( dropout: float = 0.0, is_decoder: bool = False, bias: bool = True, + is_causal: bool = False, + config: Optional[MarianConfig] = None, ): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.dropout = dropout self.head_dim = embed_dim // num_heads + self.config = config if (self.head_dim * num_heads) != self.embed_dim: raise ValueError( @@ -133,6 +136,7 @@ def __init__( ) self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder + self.is_causal = is_causal self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) @@ -263,15 +267,18 @@ def forward( return attn_output, attn_weights_reshaped, past_key_value -# Copied from transformers.models.bart.modeling_bart.BartEncoderLayer with Bart->Marian +# Copied from transformers.models.bart.modeling_bart.BartEncoderLayer with Bart->Marian, BART->MARIAN class MarianEncoderLayer(nn.Module): def __init__(self, config: MarianConfig): super().__init__() self.embed_dim = config.d_model - self.self_attn = MarianAttention( + attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default" + + self.self_attn = MARIAN_ATTENTION_CLASSES[attn_type]( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, + config=config, ) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.dropout = config.dropout @@ -332,28 +339,35 @@ def forward( return outputs -# Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->Marian +MARIAN_ATTENTION_CLASSES = {"default": MarianAttention} + + +# Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->Marian, BART->MARIAN class MarianDecoderLayer(nn.Module): def __init__(self, config: MarianConfig): super().__init__() self.embed_dim = config.d_model - self.self_attn = MarianAttention( + attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default" + self.self_attn = MARIAN_ATTENTION_CLASSES[attn_type]( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, is_decoder=True, + is_causal=True, + config=config, ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = MarianAttention( + self.encoder_attn = MARIAN_ATTENTION_CLASSES[attn_type]( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, is_decoder=True, + config=config, ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 96044ac1c2769b..97fdf9ed87998b 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -18,6 +18,7 @@ from typing import List, Optional, Tuple, Union import torch +import torch.nn.functional as F import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss @@ -39,12 +40,18 @@ add_end_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + is_flash_attn_2_available, logging, replace_return_docstrings, ) from .configuration_mbart import MBartConfig +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "facebook/mbart-large-cc25" @@ -59,6 +66,19 @@ ] +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int): """ Shift input ids one token to the right, and wrap the last non pad token (the token) Note that MBart does not @@ -113,12 +133,15 @@ def __init__( dropout: float = 0.0, is_decoder: bool = False, bias: bool = True, + is_causal: bool = False, + config: Optional[MBartConfig] = None, ): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.dropout = dropout self.head_dim = embed_dim // num_heads + self.config = config if (self.head_dim * num_heads) != self.embed_dim: raise ValueError( @@ -127,6 +150,7 @@ def __init__( ) self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder + self.is_causal = is_causal self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) @@ -257,14 +281,226 @@ def forward( return attn_output, attn_weights_reshaped, past_key_value +# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->MBart +class MBartFlashAttention2(MBartAttention): + """ + MBart flash attention module. This module inherits from `MBartAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # MBartFlashAttention2 attention does not support output_attentions + if output_attentions: + raise ValueError("MBartFlashAttention2 attention does not support output_attentions") + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, q_len, _ = hidden_states.size() + + # get query proj + query_states = self._reshape(self.q_proj(hidden_states), -1, bsz) + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0].transpose(1, 2) + value_states = past_key_value[1].transpose(1, 2) + elif is_cross_attention: + # cross_attentions + key_states = self._reshape(self.k_proj(key_value_states), -1, bsz) + value_states = self._reshape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) + value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1) + value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1) + else: + # self_attention + key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) + value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2)) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + # Handle the case where the model is quantized + if hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward( + query_states, key_states, value_states, attention_mask, q_len, dropout=self.dropout + ) + + attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`int`, *optional*): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=self.is_causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal + ) + + return attn_output + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +MBART_ATTENTION_CLASSES = { + "default": MBartAttention, + "flash_attention_2": MBartFlashAttention2, +} + + class MBartEncoderLayer(nn.Module): def __init__(self, config: MBartConfig): super().__init__() self.embed_dim = config.d_model - self.self_attn = MBartAttention( + attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default" + + self.self_attn = MBART_ATTENTION_CLASSES[attn_type]( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, + config=config, ) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.dropout = config.dropout @@ -329,23 +565,27 @@ class MBartDecoderLayer(nn.Module): def __init__(self, config: MBartConfig): super().__init__() self.embed_dim = config.d_model + attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default" - self.self_attn = MBartAttention( + self.self_attn = MBART_ATTENTION_CLASSES[attn_type]( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, is_decoder=True, + is_causal=True, + config=config, ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = MBartAttention( + self.encoder_attn = MBART_ATTENTION_CLASSES[attn_type]( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, is_decoder=True, + config=config, ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) @@ -472,6 +712,7 @@ class MBartPreTrainedModel(PreTrainedModel): base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["MBartDecoderLayer", "MBartAttention"] + _supports_flash_attn_2 = True def _init_weights(self, module): std = self.config.init_std @@ -766,7 +1007,11 @@ def forward( # expand attention_mask if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + if getattr(self.config, "_flash_attn_2_enabled", False): + attention_mask = attention_mask if 0 in attention_mask else None + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None @@ -970,16 +1215,24 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length - ) + if getattr(self.config, "_flash_attn_2_enabled", False): + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask( - encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] - ) + if getattr(self.config, "_flash_attn_2_enabled", False): + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) # embed positions positions = self.embed_positions(input, past_key_values_length) diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index 5d96359999b47a..2a015fc0321f30 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -145,7 +145,7 @@ def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0): return self.weights.index_select(0, position_ids.view(-1)).detach() -# Copied from transformers.models.bart.modeling_bart.BartAttention +# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->Musicgen class MusicgenAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -156,12 +156,15 @@ def __init__( dropout: float = 0.0, is_decoder: bool = False, bias: bool = True, + is_causal: bool = False, + config: Optional[MusicgenConfig] = None, ): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.dropout = dropout self.head_dim = embed_dim // num_heads + self.config = config if (self.head_dim * num_heads) != self.embed_dim: raise ValueError( @@ -170,6 +173,7 @@ def __init__( ) self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder + self.is_causal = is_causal self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index 22708f1125224d..3dde07da66a84b 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -467,12 +467,15 @@ def __init__( dropout: float = 0.0, is_decoder: bool = False, bias: bool = True, + is_causal: bool = False, + config: Optional[NllbMoeConfig] = None, ): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.dropout = dropout self.head_dim = embed_dim // num_heads + self.config = config if (self.head_dim * num_heads) != self.embed_dim: raise ValueError( @@ -481,6 +484,7 @@ def __init__( ) self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder + self.is_causal = is_causal self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index 35eb1ffc1b585f..18af4d518a899b 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -119,12 +119,15 @@ def __init__( dropout: float = 0.0, is_decoder: bool = False, bias: bool = True, + is_causal: bool = False, + config: Optional[PegasusConfig] = None, ): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.dropout = dropout self.head_dim = embed_dim // num_heads + self.config = config if (self.head_dim * num_heads) != self.embed_dim: raise ValueError( @@ -133,6 +136,7 @@ def __init__( ) self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder + self.is_causal = is_causal self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) @@ -263,15 +267,21 @@ def forward( return attn_output, attn_weights_reshaped, past_key_value -# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Pegasus +PEGASUS_ATTENTION_CLASSES = {"default": PegasusAttention} + + +# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Pegasus, MBART->PEGASUS class PegasusEncoderLayer(nn.Module): def __init__(self, config: PegasusConfig): super().__init__() self.embed_dim = config.d_model - self.self_attn = PegasusAttention( + attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default" + + self.self_attn = PEGASUS_ATTENTION_CLASSES[attn_type]( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, + config=config, ) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.dropout = config.dropout @@ -332,28 +342,32 @@ def forward( return outputs -# Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Pegasus +# Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Pegasus, MBART->PEGASUS class PegasusDecoderLayer(nn.Module): def __init__(self, config: PegasusConfig): super().__init__() self.embed_dim = config.d_model + attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default" - self.self_attn = PegasusAttention( + self.self_attn = PEGASUS_ATTENTION_CLASSES[attn_type]( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, is_decoder=True, + is_causal=True, + config=config, ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = PegasusAttention( + self.encoder_attn = PEGASUS_ATTENTION_CLASSES[attn_type]( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, is_decoder=True, + config=config, ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index 89c4f29cc026f6..5af397be97b305 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -128,12 +128,15 @@ def __init__( dropout: float = 0.0, is_decoder: bool = False, bias: bool = True, + is_causal: bool = False, + config: Optional[PegasusXConfig] = None, ): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.dropout = dropout self.head_dim = embed_dim // num_heads + self.config = config if (self.head_dim * num_heads) != self.embed_dim: raise ValueError( @@ -142,6 +145,7 @@ def __init__( ) self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder + self.is_causal = is_causal self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index 4d8fe161f806e5..ad298c6d389048 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -112,12 +112,15 @@ def __init__( dropout: float = 0.0, is_decoder: bool = False, bias: bool = True, + is_causal: bool = False, + config: Optional[PLBartConfig] = None, ): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.dropout = dropout self.head_dim = embed_dim // num_heads + self.config = config if (self.head_dim * num_heads) != self.embed_dim: raise ValueError( @@ -126,6 +129,7 @@ def __init__( ) self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder + self.is_causal = is_causal self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) @@ -256,15 +260,18 @@ def forward( return attn_output, attn_weights_reshaped, past_key_value -# Copied from transformers.models.bart.modeling_bart.BartEncoderLayer with Bart->PLBart +# Copied from transformers.models.bart.modeling_bart.BartEncoderLayer with Bart->PLBart, BART->PLBART class PLBartEncoderLayer(nn.Module): def __init__(self, config: PLBartConfig): super().__init__() self.embed_dim = config.d_model - self.self_attn = PLBartAttention( + attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default" + + self.self_attn = PLBART_ATTENTION_CLASSES[attn_type]( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, + config=config, ) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.dropout = config.dropout @@ -325,28 +332,35 @@ def forward( return outputs -# Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->PLBart +PLBART_ATTENTION_CLASSES = {"default": PLBartAttention} + + +# Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->PLBart, BART->PLBART class PLBartDecoderLayer(nn.Module): def __init__(self, config: PLBartConfig): super().__init__() self.embed_dim = config.d_model - self.self_attn = PLBartAttention( + attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default" + self.self_attn = PLBART_ATTENTION_CLASSES[attn_type]( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, is_decoder=True, + is_causal=True, + config=config, ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = PLBartAttention( + self.encoder_attn = PLBART_ATTENTION_CLASSES[attn_type]( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, is_decoder=True, + config=config, ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) @@ -743,8 +757,11 @@ def forward( # expand attention_mask if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + if getattr(self.config, "_flash_attn_2_enabled", False): + attention_mask = attention_mask if 0 in attention_mask else None + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None @@ -947,16 +964,24 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input) * self.embed_scale - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length - ) + if getattr(self.config, "_flash_attn_2_enabled", False): + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask( - encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] - ) + if getattr(self.config, "_flash_attn_2_enabled", False): + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) # embed positions positions = self.embed_positions(input, past_key_values_length) diff --git a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py index 62d1f3e21f9ad0..0745663bc0fde9 100755 --- a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py @@ -1101,12 +1101,15 @@ def __init__( dropout: float = 0.0, is_decoder: bool = False, bias: bool = True, + is_causal: bool = False, + config: Optional[SeamlessM4TConfig] = None, ): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.dropout = dropout self.head_dim = embed_dim // num_heads + self.config = config if (self.head_dim * num_heads) != self.embed_dim: raise ValueError( @@ -1115,6 +1118,7 @@ def __init__( ) self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder + self.is_causal = is_causal self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py index b98e093f8cc3f3..a5ebb9b2bb4245 100644 --- a/src/transformers/models/sew/modeling_sew.py +++ b/src/transformers/models/sew/modeling_sew.py @@ -392,12 +392,15 @@ def __init__( dropout: float = 0.0, is_decoder: bool = False, bias: bool = True, + is_causal: bool = False, + config: Optional[SEWConfig] = None, ): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.dropout = dropout self.head_dim = embed_dim // num_heads + self.config = config if (self.head_dim * num_heads) != self.embed_dim: raise ValueError( @@ -406,6 +409,7 @@ def __init__( ) self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder + self.is_causal = is_causal self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index d3c48e6c91ee47..57c74c8c42e2a6 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -178,12 +178,15 @@ def __init__( dropout: float = 0.0, is_decoder: bool = False, bias: bool = True, + is_causal: bool = False, + config: Optional[Speech2TextConfig] = None, ): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.dropout = dropout self.head_dim = embed_dim // num_heads + self.config = config if (self.head_dim * num_heads) != self.embed_dim: raise ValueError( @@ -192,6 +195,7 @@ def __init__( ) self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder + self.is_causal = is_causal self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) @@ -322,15 +326,21 @@ def forward( return attn_output, attn_weights_reshaped, past_key_value -# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Speech2Text +SPEECH_TO_TEXT_ATTENTION_CLASSES = {"default": Speech2TextAttention} + + +# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Speech2Text, MBART->SPEECH_TO_TEXT class Speech2TextEncoderLayer(nn.Module): def __init__(self, config: Speech2TextConfig): super().__init__() self.embed_dim = config.d_model - self.self_attn = Speech2TextAttention( + attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default" + + self.self_attn = SPEECH_TO_TEXT_ATTENTION_CLASSES[attn_type]( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, + config=config, ) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.dropout = config.dropout @@ -391,28 +401,32 @@ def forward( return outputs -# Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Speech2Text +# Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Speech2Text, MBART->SPEECH_TO_TEXT class Speech2TextDecoderLayer(nn.Module): def __init__(self, config: Speech2TextConfig): super().__init__() self.embed_dim = config.d_model + attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default" - self.self_attn = Speech2TextAttention( + self.self_attn = SPEECH_TO_TEXT_ATTENTION_CLASSES[attn_type]( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, is_decoder=True, + is_causal=True, + config=config, ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = Speech2TextAttention( + self.encoder_attn = SPEECH_TO_TEXT_ATTENTION_CLASSES[attn_type]( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, is_decoder=True, + config=config, ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) diff --git a/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py b/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py index eebc9e57cfa2d2..9a1bd94dd7f5bd 100755 --- a/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py +++ b/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py @@ -125,12 +125,15 @@ def __init__( dropout: float = 0.0, is_decoder: bool = False, bias: bool = True, + is_causal: bool = False, + config: Optional[Speech2Text2Config] = None, ): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.dropout = dropout self.head_dim = embed_dim // num_heads + self.config = config if (self.head_dim * num_heads) != self.embed_dim: raise ValueError( @@ -139,6 +142,7 @@ def __init__( ) self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder + self.is_causal = is_causal self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) diff --git a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py index 5f961bc8e1edf9..904c02b4f04308 100644 --- a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py +++ b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py @@ -281,12 +281,15 @@ def __init__( dropout: float = 0.0, is_decoder: bool = False, bias: bool = True, + is_causal: bool = False, + config: Optional[TimeSeriesTransformerConfig] = None, ): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.dropout = dropout self.head_dim = embed_dim // num_heads + self.config = config if (self.head_dim * num_heads) != self.embed_dim: raise ValueError( @@ -295,6 +298,7 @@ def __init__( ) self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder + self.is_causal = is_causal self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) @@ -425,15 +429,18 @@ def forward( return attn_output, attn_weights_reshaped, past_key_value -# Copied from transformers.models.bart.modeling_bart.BartEncoderLayer with Bart->TimeSeriesTransformer +# Copied from transformers.models.bart.modeling_bart.BartEncoderLayer with Bart->TimeSeriesTransformer, BART->TIME_SERIES_TRANSFORMER class TimeSeriesTransformerEncoderLayer(nn.Module): def __init__(self, config: TimeSeriesTransformerConfig): super().__init__() self.embed_dim = config.d_model - self.self_attn = TimeSeriesTransformerAttention( + attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default" + + self.self_attn = TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES[attn_type]( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, + config=config, ) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.dropout = config.dropout @@ -494,28 +501,35 @@ def forward( return outputs -# Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->TimeSeriesTransformer +TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES = {"default": TimeSeriesTransformerAttention} + + +# Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->TimeSeriesTransformer, with BART->TIME_SERIES_TRANSFORMER class TimeSeriesTransformerDecoderLayer(nn.Module): def __init__(self, config: TimeSeriesTransformerConfig): super().__init__() self.embed_dim = config.d_model - self.self_attn = TimeSeriesTransformerAttention( + attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default" + self.self_attn = TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES[attn_type]( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, is_decoder=True, + is_causal=True, + config=config, ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = TimeSeriesTransformerAttention( + self.encoder_attn = TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES[attn_type]( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, is_decoder=True, + config=config, ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py index 057a9579c12bf5..11965bdb50e978 100755 --- a/src/transformers/models/unispeech/modeling_unispeech.py +++ b/src/transformers/models/unispeech/modeling_unispeech.py @@ -432,12 +432,15 @@ def __init__( dropout: float = 0.0, is_decoder: bool = False, bias: bool = True, + is_causal: bool = False, + config: Optional[UniSpeechConfig] = None, ): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.dropout = dropout self.head_dim = embed_dim // num_heads + self.config = config if (self.head_dim * num_heads) != self.embed_dim: raise ValueError( @@ -446,6 +449,7 @@ def __init__( ) self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder + self.is_causal = is_causal self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py index c2889299574f4d..10a05edc72b071 100755 --- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py @@ -446,12 +446,15 @@ def __init__( dropout: float = 0.0, is_decoder: bool = False, bias: bool = True, + is_causal: bool = False, + config: Optional[UniSpeechSatConfig] = None, ): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.dropout = dropout self.head_dim = embed_dim // num_heads + self.config = config if (self.head_dim * num_heads) != self.embed_dim: raise ValueError( @@ -460,6 +463,7 @@ def __init__( ) self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder + self.is_causal = is_causal self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index 9a2235cb2fdd04..3d97e7c73d3522 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -498,12 +498,15 @@ def __init__( dropout: float = 0.0, is_decoder: bool = False, bias: bool = True, + is_causal: bool = False, + config: Optional[Wav2Vec2Config] = None, ): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.dropout = dropout self.head_dim = embed_dim // num_heads + self.config = config if (self.head_dim * num_heads) != self.embed_dim: raise ValueError( @@ -512,6 +515,7 @@ def __init__( ) self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder + self.is_causal = is_causal self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 48f47fe12df726..a107adf74e169a 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -19,6 +19,7 @@ import numpy as np import torch +import torch.nn.functional as F import torch.utils.checkpoint from torch import nn from torch.nn import CrossEntropyLoss @@ -38,6 +39,7 @@ from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, + is_flash_attn_2_available, logging, replace_return_docstrings, ) @@ -45,6 +47,11 @@ from .tokenization_whisper import TASK_IDS, TO_LANGUAGE_CODE +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "WhisperConfig" @@ -57,6 +64,19 @@ ] +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + def sinusoids(length: int, channels: int, max_timescale: float = 10000) -> torch.Tensor: """Returns sinusoids for positional embedding""" if channels % 2 != 0: @@ -299,12 +319,15 @@ def __init__( dropout: float = 0.0, is_decoder: bool = False, bias: bool = True, + is_causal: bool = False, + config: Optional[WhisperConfig] = None, ): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.dropout = dropout self.head_dim = embed_dim // num_heads + self.config = config if (self.head_dim * num_heads) != self.embed_dim: raise ValueError( @@ -313,6 +336,7 @@ def __init__( ) self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder + self.is_causal = is_causal self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) @@ -445,15 +469,227 @@ def forward( return attn_output, attn_weights_reshaped, past_key_value -# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Whisper +# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->Whisper +class WhisperFlashAttention2(WhisperAttention): + """ + Whisper flash attention module. This module inherits from `WhisperAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # WhisperFlashAttention2 attention does not support output_attentions + if output_attentions: + raise ValueError("WhisperFlashAttention2 attention does not support output_attentions") + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, q_len, _ = hidden_states.size() + + # get query proj + query_states = self._reshape(self.q_proj(hidden_states), -1, bsz) + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0].transpose(1, 2) + value_states = past_key_value[1].transpose(1, 2) + elif is_cross_attention: + # cross_attentions + key_states = self._reshape(self.k_proj(key_value_states), -1, bsz) + value_states = self._reshape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) + value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1) + value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1) + else: + # self_attention + key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) + value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2)) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + # Handle the case where the model is quantized + if hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward( + query_states, key_states, value_states, attention_mask, q_len, dropout=self.dropout + ) + + attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`int`, *optional*): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=self.is_causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal + ) + + return attn_output + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +WHISPER_ATTENTION_CLASSES = { + "default": WhisperAttention, + "flash_attention_2": WhisperFlashAttention2, +} + + +# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Whisper, MBART->WHISPER class WhisperEncoderLayer(nn.Module): def __init__(self, config: WhisperConfig): super().__init__() self.embed_dim = config.d_model - self.self_attn = WhisperAttention( + attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default" + + self.self_attn = WHISPER_ATTENTION_CLASSES[attn_type]( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, + config=config, ) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.dropout = config.dropout @@ -514,28 +750,32 @@ def forward( return outputs -# Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Whisper +# Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Whisper, MBART->WHISPER class WhisperDecoderLayer(nn.Module): def __init__(self, config: WhisperConfig): super().__init__() self.embed_dim = config.d_model + attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default" - self.self_attn = WhisperAttention( + self.self_attn = WHISPER_ATTENTION_CLASSES[attn_type]( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, is_decoder=True, + is_causal=True, + config=config, ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = WhisperAttention( + self.encoder_attn = WHISPER_ATTENTION_CLASSES[attn_type]( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, is_decoder=True, + config=config, ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) @@ -638,6 +878,7 @@ class WhisperPreTrainedModel(PreTrainedModel): main_input_name = "input_features" supports_gradient_checkpointing = True _no_split_modules = ["WhisperEncoderLayer", "WhisperDecoderLayer"] + _supports_flash_attn_2 = True def _init_weights(self, module): std = self.config.init_std @@ -1070,9 +1311,14 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length - ) + if getattr(self.config, "_flash_attn_2_enabled", False): + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) # embed positions if input_ids is not None: diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 60a2d3b93ea6d6..05d48786148e20 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -21,13 +21,16 @@ import unittest import numpy as np +from pytest import mark import transformers from transformers import WhisperConfig from transformers.testing_utils import ( is_pt_flax_cross_test, + require_flash_attn, require_torch, require_torch_fp16, + require_torch_gpu, require_torchaudio, slow, torch_device, @@ -795,6 +798,107 @@ def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_ use_cache=use_cache, ) + @require_flash_attn + @require_torch_gpu + @mark.flash_attn_test + @slow + def test_flash_attn_2_inference(self): + import torch + + for model_class in self.all_model_classes: + if not model_class._supports_flash_attn_2: + return + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_fa = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True + ) + model_fa.to(torch_device) + + model = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False + ) + model.to(torch_device) + + dummy_input = inputs_dict[model.main_input_name][:1] + if dummy_input.dtype in [torch.float32, torch.float16]: + dummy_input = dummy_input.to(torch.bfloat16) + + decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:1] + + outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) + outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) + + logits = outputs.decoder_hidden_states[-1] + logits_fa = outputs_fa.decoder_hidden_states[-1] + + # whisper FA2 needs very high tolerance + assert torch.allclose(logits_fa, logits, atol=4e-1) + + # check with inference + dropout + model.train() + _ = model_fa(dummy_input, decoder_input_ids=decoder_input_ids) + + @require_flash_attn + @require_torch_gpu + @mark.flash_attn_test + @slow + def test_flash_attn_2_inference_padding_right(self): + import torch + + for model_class in self.all_model_classes: + if not model_class._supports_flash_attn_2: + return + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_fa = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=True + ) + model_fa.to(torch_device) + + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=False) + model.to(torch_device) + + dummy_input = inputs_dict[model.main_input_name][:1] + dummy_input = dummy_input.to(torch.float16) + + decoder_input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]], device=dummy_input.device, dtype=torch.long) + decoder_attention_mask = torch.tensor( + [[0, 0, 0, 1, 1, 1]], device=dummy_input.device, dtype=torch.long + ) + + outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) + outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) + + logits = outputs.decoder_hidden_states[-1] + logits_fa = outputs_fa.decoder_hidden_states[-1] + + # whisper FA2 needs very high tolerance + assert torch.allclose(logits_fa, logits, atol=4e-1) + + other_inputs = { + "decoder_input_ids": decoder_input_ids, + "decoder_attention_mask": decoder_attention_mask, + "output_hidden_states": True, + } + + outputs = model(dummy_input, **other_inputs) + outputs_fa = model_fa(dummy_input, **other_inputs) + + logits = outputs.decoder_hidden_states[-1] + logits_fa = outputs_fa.decoder_hidden_states[-1] + + # whisper FA2 needs very high tolerance + assert torch.allclose(logits_fa[:, -2:], logits[:, -2:], atol=4e-1) + def _create_and_check_torchscript(self, config, inputs_dict): if not self.test_torchscript: return diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 595c72cda6fd2f..f96812c36da817 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -2856,7 +2856,7 @@ def test_flash_attn_2_inference(self): if not model_class._supports_flash_attn_2: return - config, _ = self.model_tester.prepare_config_and_inputs_for_common() + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() model = model_class(config) with tempfile.TemporaryDirectory() as tmpdirname: @@ -2871,25 +2871,76 @@ def test_flash_attn_2_inference(self): ) model.to(torch_device) - dummy_input = torch.LongTensor([[1, 2, 3, 4, 5]]).to(torch_device) - dummy_attention_mask = torch.LongTensor([[0, 1, 1, 1, 1]]).to(torch_device) + dummy_input = inputs_dict[model.main_input_name][:1] + if dummy_input.dtype in [torch.float32, torch.float16]: + dummy_input = dummy_input.to(torch.bfloat16) + + dummy_attention_mask = inputs_dict.get("attention_mask", None) + + if dummy_attention_mask is not None: + dummy_attention_mask = dummy_attention_mask[:1] + dummy_attention_mask[:, 1:] = 1 + dummy_attention_mask[:, :1] = 0 - logits = model(dummy_input, output_hidden_states=True).hidden_states[-1] - logits_fa = model_fa(dummy_input, output_hidden_states=True).hidden_states[-1] + if model.config.is_encoder_decoder: + decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:1] + + outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) + outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) + else: + outputs = model(dummy_input, output_hidden_states=True) + outputs_fa = model_fa(dummy_input, output_hidden_states=True) - self.assertTrue(torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)) + logits = ( + outputs.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs.decoder_hidden_states[-1] + ) + logits_fa = ( + outputs_fa.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs_fa.decoder_hidden_states[-1] + ) - output_fa = model_fa(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True) - logits_fa = output_fa.hidden_states[-1] + assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2) - output = model(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True) - logits = output.hidden_states[-1] + if model.config.is_encoder_decoder: + other_inputs = { + "decoder_input_ids": decoder_input_ids, + "decoder_attention_mask": dummy_attention_mask, + "output_hidden_states": True, + } + if dummy_attention_mask is not None: + other_inputs["attention_mask"] = dummy_attention_mask + + outputs = model(dummy_input, **other_inputs) + outputs_fa = model_fa(dummy_input, **other_inputs) + else: + other_inputs = { + "output_hidden_states": True, + } + if dummy_attention_mask is not None: + other_inputs["attention_mask"] = dummy_attention_mask + + outputs = model(dummy_input, **other_inputs) + outputs_fa = model_fa(dummy_input, **other_inputs) + + logits = ( + outputs.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs.decoder_hidden_states[-1] + ) + logits_fa = ( + outputs_fa.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs_fa.decoder_hidden_states[-1] + ) - self.assertTrue(torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2)) + assert torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2) # check with inference + dropout model.train() - _ = model_fa(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True) + _ = model_fa(dummy_input, **other_inputs) @require_flash_attn @require_torch_gpu @@ -2902,7 +2953,7 @@ def test_flash_attn_2_inference_padding_right(self): if not model_class._supports_flash_attn_2: return - config, _ = self.model_tester.prepare_config_and_inputs_for_common() + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() model = model_class(config) with tempfile.TemporaryDirectory() as tmpdirname: @@ -2917,21 +2968,72 @@ def test_flash_attn_2_inference_padding_right(self): ) model.to(torch_device) - dummy_input = torch.LongTensor([[1, 2, 3, 4, 5]]).to(torch_device) - dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1, 0]]).to(torch_device) + dummy_input = inputs_dict[model.main_input_name][:1] + if dummy_input.dtype in [torch.float32, torch.float16]: + dummy_input = dummy_input.to(torch.bfloat16) + + dummy_attention_mask = inputs_dict.get("attention_mask", None) - logits = model(dummy_input, output_hidden_states=True).hidden_states[-1] - logits_fa = model_fa(dummy_input, output_hidden_states=True).hidden_states[-1] + if dummy_attention_mask is not None: + dummy_attention_mask = dummy_attention_mask[:1] + dummy_attention_mask[:, :-1] = 1 + dummy_attention_mask[:, -1:] = 0 - self.assertTrue(torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)) + if model.config.is_encoder_decoder: + decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:1] - output_fa = model_fa(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True) - logits_fa = output_fa.hidden_states[-1] + outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) + outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) + else: + outputs = model(dummy_input, output_hidden_states=True) + outputs_fa = model_fa(dummy_input, output_hidden_states=True) - output = model(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True) - logits = output.hidden_states[-1] + logits = ( + outputs.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs.decoder_hidden_states[-1] + ) + logits_fa = ( + outputs_fa.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs_fa.decoder_hidden_states[-1] + ) - self.assertTrue(torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)) + assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2) + + if model.config.is_encoder_decoder: + other_inputs = { + "decoder_input_ids": decoder_input_ids, + "decoder_attention_mask": dummy_attention_mask, + "output_hidden_states": True, + } + if dummy_attention_mask is not None: + other_inputs["attention_mask"] = dummy_attention_mask + + outputs = model(dummy_input, **other_inputs) + outputs_fa = model_fa(dummy_input, **other_inputs) + else: + other_inputs = { + "output_hidden_states": True, + } + if dummy_attention_mask is not None: + other_inputs["attention_mask"] = dummy_attention_mask + + outputs = model(dummy_input, **other_inputs) + outputs_fa = model_fa(dummy_input, **other_inputs) + + logits = ( + outputs.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs.decoder_hidden_states[-1] + ) + logits_fa = ( + outputs_fa.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs_fa.decoder_hidden_states[-1] + ) + + assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2) @require_flash_attn @require_torch_gpu @@ -2944,7 +3046,7 @@ def test_flash_attn_2_generate_left_padding(self): if not model_class._supports_flash_attn_2: return - config, _ = self.model_tester.prepare_config_and_inputs_for_common() + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() model = model_class(config) with tempfile.TemporaryDirectory() as tmpdirname: @@ -2953,8 +3055,14 @@ def test_flash_attn_2_generate_left_padding(self): tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=False, low_cpu_mem_usage=True ).to(torch_device) - dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device) - dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [0, 1, 1, 1]]).to(torch_device) + dummy_input = inputs_dict[model.main_input_name] + if dummy_input.dtype in [torch.float32, torch.bfloat16]: + dummy_input = dummy_input.to(torch.float16) + + dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) + # make sure we do left padding + dummy_attention_mask[:, :-1] = 0 + dummy_attention_mask[:, -1:] = 1 out = model.generate( dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False @@ -2981,7 +3089,7 @@ def test_flash_attn_2_generate_padding_right(self): if not model_class._supports_flash_attn_2: return - config, _ = self.model_tester.prepare_config_and_inputs_for_common() + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() model = model_class(config) with tempfile.TemporaryDirectory() as tmpdirname: @@ -2990,8 +3098,14 @@ def test_flash_attn_2_generate_padding_right(self): tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=False, low_cpu_mem_usage=True ).to(torch_device) - dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device) - dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device) + dummy_input = inputs_dict[model.main_input_name] + if dummy_input.dtype in [torch.float32, torch.bfloat16]: + dummy_input = dummy_input.to(torch.float16) + + dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) + # make sure we do left padding + dummy_attention_mask[:, :-1] = 1 + dummy_attention_mask[:, -1:] = 0 out = model.generate( dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False @@ -3014,26 +3128,39 @@ def test_flash_attn_2_generate_padding_right(self): def test_flash_attn_2_generate_use_cache(self): import torch + max_new_tokens = 30 + for model_class in self.all_generative_model_classes: if not model_class._supports_flash_attn_2: return - config, _ = self.model_tester.prepare_config_and_inputs_for_common() + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + dummy_input = inputs_dict[model_class.main_input_name] + if dummy_input.dtype in [torch.float32, torch.bfloat16]: + dummy_input = dummy_input.to(torch.float16) + + # make sure that all models have enough positions for generation + if hasattr(config, "max_position_embeddings"): + config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1 + model = model_class(config) with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) - dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device) - dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [0, 1, 1, 1]]).to(torch_device) + dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) model = model_class.from_pretrained( - tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=True, low_cpu_mem_usage=True + tmpdirname, + torch_dtype=torch.float16, + use_flash_attention_2=True, + low_cpu_mem_usage=True, ).to(torch_device) # Just test that a large cache works as expected _ = model.generate( - dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=30, do_sample=False + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False ) @require_flash_attn @@ -3048,14 +3175,18 @@ def test_flash_attn_2_fp32_ln(self): if not model_class._supports_flash_attn_2: return - config, _ = self.model_tester.prepare_config_and_inputs_for_common() + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() model = model_class(config) with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) - dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device) - dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [0, 1, 1, 1]]).to(torch_device) + dummy_input = inputs_dict[model.main_input_name] + dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) + + if model.config.is_encoder_decoder: + dummy_decoder_input_ids = inputs_dict["decoder_input_ids"] + dummy_decoder_attention_mask = inputs_dict["decoder_attention_mask"] model = model_class.from_pretrained( tmpdirname, @@ -3070,10 +3201,19 @@ def test_flash_attn_2_fp32_ln(self): if (param.dtype == torch.float16) or (param.dtype == torch.bfloat16): param.data = param.data.to(torch.float32) - _ = model(input_ids=dummy_input) - - # with attention mask - _ = model(input_ids=dummy_input, attention_mask=dummy_attention_mask) + if model.config.is_encoder_decoder: + _ = model(dummy_input, decoder_input_ids=dummy_decoder_input_ids) + # with attention mask + _ = model( + dummy_input, + attention_mask=dummy_attention_mask, + decoder_input_ids=dummy_decoder_input_ids, + decoder_attention_mask=dummy_decoder_attention_mask, + ) + else: + _ = model(dummy_input) + # with attention mask + _ = model(dummy_input, attention_mask=dummy_attention_mask) global_rng = random.Random()