diff --git a/colossalai/shardformer/modeling/whisper.py b/colossalai/shardformer/modeling/whisper.py index cb8b45ae7d01..8a7e9cd0cf5d 100644 --- a/colossalai/shardformer/modeling/whisper.py +++ b/colossalai/shardformer/modeling/whisper.py @@ -19,7 +19,7 @@ WhisperModel, ) from transformers.utils import logging - +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa from colossalai.pipeline.stage_manager import PipelineStageManager @@ -369,18 +369,12 @@ def whisper_encoder_forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, None, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -528,6 +522,20 @@ def whisper_decoder_forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) + + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self._use_sdpa and head_mask is None and not output_attentions: + # output_attentions=True & head_mask can not be supported when using SDPA. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) # embed positions if input_ids is not None: @@ -535,10 +543,6 @@ def whisper_decoder_forward( else: positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length) - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length - ) - hidden_states = inputs_embeds + positions hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) @@ -575,16 +579,8 @@ def whisper_decoder_forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, @@ -592,6 +588,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, # past_key_value + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py index b5b5db79d9de..df6194a6d2d0 100644 --- a/colossalai/shardformer/policies/whisper.py +++ b/colossalai/shardformer/policies/whisper.py @@ -28,12 +28,6 @@ class WhisperPolicy(Policy): def __init__(self) -> None: super().__init__() - import transformers - from packaging.version import Version - - assert Version(transformers.__version__) <= Version( - "4.33.0" - ), "The Whisper model should run on a transformers version not greater than 4.33.0." def config_sanity_check(self): pass