diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index 346759aa2b2517..7d880cd6558e3d 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -79,6 +79,7 @@ FlashAttention-2 is currently supported for the following architectures: * [OPT](https://huggingface.co/docs/transformers/model_doc/opt#transformers.OPTModel) * [Phi](https://huggingface.co/docs/transformers/model_doc/phi#transformers.PhiModel) * [Phi3](https://huggingface.co/docs/transformers/model_doc/phi3#transformers.Phi3Model) +* [Speech2Text](https://huggingface.co/docs/transformers/en/model_doc/speech_to_text) * [StableLm](https://huggingface.co/docs/transformers/model_doc/stablelm#transformers.StableLmModel) * [Starcoder2](https://huggingface.co/docs/transformers/model_doc/starcoder2#transformers.Starcoder2Model) * [Qwen2](https://huggingface.co/docs/transformers/model_doc/qwen2#transformers.Qwen2Model) @@ -251,6 +252,7 @@ For now, Transformers supports SDPA inference and training for the following arc * [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel) * [mBart](https://huggingface.co/docs/transformers/model_doc/mbart#transformers.MBartModel) * [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel) +* [Speech2Text](https://huggingface.co/docs/transformers/en/model_doc/speech_to_text) * [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel) * [StableLm](https://huggingface.co/docs/transformers/model_doc/stablelm#transformers.StableLmModel) * [Starcoder2](https://huggingface.co/docs/transformers/model_doc/starcoder2#transformers.Starcoder2Model) diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index bdd532fa25e82a..6fc59034617077 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -23,7 +23,12 @@ from ...activations import ACT2FN from ...generation import GenerationMixin -from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask +from ...modeling_attn_mask_utils import ( + _prepare_4d_attention_mask, + _prepare_4d_attention_mask_for_sdpa, + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -34,12 +39,17 @@ from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) from .configuration_speech_to_text import Speech2TextConfig +if is_flash_attn_2_available(): + from ...modeling_flash_attention_utils import _flash_attention_forward + logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "Speech2TextConfig" @@ -144,7 +154,7 @@ def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0): ) # expand embeddings if needed - max_pos = self.padding_idx + 1 + seq_len + max_pos = self.padding_idx + 1 + seq_len + past_key_values_length if max_pos > self.weights.size(0): self.make_weights(max_pos + self.offset, self.embedding_dim, self.padding_idx) @@ -326,7 +336,246 @@ def forward( return attn_output, attn_weights_reshaped, past_key_value -SPEECH_TO_TEXT_ATTENTION_CLASSES = {"eager": Speech2TextAttention} +# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->Speech2Text +class Speech2TextFlashAttention2(Speech2TextAttention): + """ + Speech2Text flash attention module. This module inherits from `Speech2TextAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # Speech2TextFlashAttention2 attention does not support output_attentions + if output_attentions: + raise ValueError("Speech2TextFlashAttention2 attention does not support output_attentions") + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, q_len, _ = hidden_states.size() + + # get query proj + query_states = self._reshape(self.q_proj(hidden_states), -1, bsz) + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0].transpose(1, 2) + value_states = past_key_value[1].transpose(1, 2) + elif is_cross_attention: + # cross_attentions + key_states = self._reshape(self.k_proj(key_value_states), -1, bsz) + value_states = self._reshape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) + value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1) + value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1) + else: + # self_attention + key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) + value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2)) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=self.dropout, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# Copied from transformers.models.bart.modeling_bart.BartSdpaAttention with Bart->Speech2Text +class Speech2TextSdpaAttention(Speech2TextAttention): + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + if output_attentions or layer_head_mask is not None: + # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "Speech2TextModel is using Speech2TextSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention" + ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states, + key_value_states=key_value_states, + past_key_value=past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + query_states = self._shape(query_states, tgt_len, bsz) + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. + is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False + + # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, + # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=self.dropout if self.training else 0.0, + is_causal=is_causal, + ) + + if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, None, past_key_value + + +SPEECH_TO_TEXT_ATTENTION_CLASSES = { + "eager": Speech2TextAttention, + "sdpa": Speech2TextSdpaAttention, + "flash_attention_2": Speech2TextFlashAttention2, +} # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Speech2Text, MBART->SPEECH_TO_TEXT @@ -526,6 +775,8 @@ class Speech2TextPreTrainedModel(PreTrainedModel): base_model_prefix = "model" main_input_name = "input_features" supports_gradient_checkpointing = True + _supports_sdpa = True + _supports_flash_attn_2 = True def _init_weights(self, module): std = self.config.init_std @@ -698,7 +949,8 @@ def __init__(self, config: Speech2TextConfig): ) self.layers = nn.ModuleList([Speech2TextEncoderLayer(config) for _ in range(config.encoder_layers)]) self.layer_norm = nn.LayerNorm(config.d_model) - + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_sdpa = config._attn_implementation == "sdpa" self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() @@ -766,8 +1018,16 @@ def forward( # expand attention_mask if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + if self._use_flash_attention_2: + attention_mask = attention_mask if 0 in attention_mask else None + elif self._use_sdpa and head_mask is None and not output_attentions: + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None @@ -847,9 +1107,9 @@ def __init__(self, config: Speech2TextConfig): config.d_model, self.padding_idx, ) - self.layers = nn.ModuleList([Speech2TextDecoderLayer(config) for _ in range(config.decoder_layers)]) - + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_sdpa = config._attn_implementation == "sdpa" self.layer_norm = nn.LayerNorm(config.d_model) self.gradient_checkpointing = False @@ -965,17 +1225,42 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length - ) + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self._use_sdpa and not output_attentions and cross_attn_head_mask is None: + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask( - encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] - ) + if self._use_flash_attention_2: + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self._use_sdpa and cross_attn_head_mask is None and not output_attentions: + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) # embed positions positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length) diff --git a/src/transformers/models/speecht5/modeling_speecht5.py b/src/transformers/models/speecht5/modeling_speecht5.py index 790e6a74a47135..8e84dcd96db963 100644 --- a/src/transformers/models/speecht5/modeling_speecht5.py +++ b/src/transformers/models/speecht5/modeling_speecht5.py @@ -328,7 +328,7 @@ def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0): ) # expand embeddings if needed - max_pos = self.padding_idx + 1 + seq_len + max_pos = self.padding_idx + 1 + seq_len + past_key_values_length if max_pos > self.weights.size(0): self.make_weights(max_pos + self.offset, self.embedding_dim, self.padding_idx) diff --git a/tests/models/speech_to_text/test_modeling_speech_to_text.py b/tests/models/speech_to_text/test_modeling_speech_to_text.py index cef2a6781775a9..2ed30ce4131318 100644 --- a/tests/models/speech_to_text/test_modeling_speech_to_text.py +++ b/tests/models/speech_to_text/test_modeling_speech_to_text.py @@ -20,13 +20,17 @@ import tempfile import unittest +import pytest + from transformers import Speech2TextConfig from transformers.testing_utils import ( is_torch_available, + require_flash_attn, require_sentencepiece, require_tokenizers, require_torch, require_torch_fp16, + require_torch_gpu, require_torchaudio, slow, torch_device, @@ -106,6 +110,7 @@ def __init__( eos_token_id=2, pad_token_id=1, bos_token_id=0, + attn_implementation="eager", ): self.parent = parent self.batch_size = batch_size @@ -131,6 +136,7 @@ def __init__( self.eos_token_id = eos_token_id self.pad_token_id = pad_token_id self.bos_token_id = bos_token_id + self.attn_implementation = attn_implementation def prepare_config_and_inputs(self): input_features = floats_tensor( @@ -171,6 +177,7 @@ def get_config(self): eos_token_id=self.eos_token_id, bos_token_id=self.bos_token_id, pad_token_id=self.pad_token_id, + attn_implementation=self.attn_implementation, ) def prepare_config_and_inputs_for_common(self): @@ -750,6 +757,59 @@ def _create_and_check_torchscript(self, config, inputs_dict): self.assertTrue(models_equal) + @require_flash_attn + @require_torch_gpu + @pytest.mark.flash_attn_test + @slow + def test_flash_attn_2_generate_reuse_cache(self): + max_new_tokens = 2 + for model_class in self.all_generative_model_classes: + if not model_class._supports_flash_attn_2: + self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + dummy_input = inputs_dict[model_class.main_input_name][..., :24] + if dummy_input.dtype in [torch.float32, torch.bfloat16]: + dummy_input = dummy_input.to(torch.float16) + + # make sure that all models have enough positions for generation + if hasattr(config, "max_position_embeddings"): + config.max_position_embeddings = dummy_input.shape[1] * 2 + max_new_tokens * 2 + 1 + + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + + model = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + attn_implementation="flash_attention_2", + low_cpu_mem_usage=True, + ).to(torch_device) + + # run generate once to get filled cache + output = model.generate( + dummy_input, + max_new_tokens=max_new_tokens, + do_sample=False, + use_cache=True, + return_dict_in_generate=True, + ) + past_key_values = output.past_key_values + + # Try to continue generation from where we left, given that we have more than 1 new token to process + # e.g. this can happen in speculative decoding when feeding candidate tokens back to target model + _ = model.generate( + dummy_input, + decoder_input_ids=output.sequences, + max_new_tokens=max_new_tokens, + do_sample=False, + use_cache=True, + past_key_values=past_key_values, + ) + def test_pt_tf_model_equivalence(self, allow_missing_keys=True): # Allow missing keys since TF doesn't cache the sinusoidal embeddings in an attribute super().test_pt_tf_model_equivalence(allow_missing_keys=allow_missing_keys) diff --git a/tests/models/speech_to_text/test_modeling_tf_speech_to_text.py b/tests/models/speech_to_text/test_modeling_tf_speech_to_text.py index 26fee7d93c3e39..ecf773bf927883 100644 --- a/tests/models/speech_to_text/test_modeling_tf_speech_to_text.py +++ b/tests/models/speech_to_text/test_modeling_tf_speech_to_text.py @@ -94,6 +94,7 @@ def __init__( pad_token_id=1, bos_token_id=0, scale_embedding=False, + attn_implementation="eager", ): self.parent = parent self.batch_size = batch_size @@ -120,6 +121,7 @@ def __init__( self.pad_token_id = pad_token_id self.bos_token_id = bos_token_id self.scale_embedding = scale_embedding + self.attn_implementation = attn_implementation def prepare_config_and_inputs(self): input_features = floats_tensor( @@ -161,6 +163,7 @@ def get_config(self): bos_token_id=self.bos_token_id, pad_token_id=self.pad_token_id, scale_embedding=self.scale_embedding, + attn_implementation=self.attn_implementation, ) def prepare_config_and_inputs_for_common(self):