Skip to content

Commit

Permalink
[shardformer] update whisper model (hpcaitech#5529)
Browse files Browse the repository at this point in the history
  • Loading branch information
wangbluo committed Apr 11, 2024
1 parent ad7d81d commit 9fd1a3e
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 27 deletions.
42 changes: 22 additions & 20 deletions colossalai/shardformer/modeling/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
shift_tokens_right,
)
from transformers.utils import logging

from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer import ColoAttention
from colossalai.shardformer.shard import ShardConfig
Expand Down Expand Up @@ -539,18 +539,12 @@ def whisper_encoder_forward(
layer_outputs = (None, None)
else:
if self.gradient_checkpointing and self.training:

def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)

return custom_forward

layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
None,
(head_mask[idx] if head_mask is not None else None),
output_attentions,
)
else:
layer_outputs = encoder_layer(
Expand Down Expand Up @@ -701,6 +695,20 @@ def whisper_decoder_forward(

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self._use_sdpa and head_mask is None and not output_attentions:
# output_attentions=True & head_mask can not be supported when using SDPA.
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask, input_shape, inputs_embeds, past_key_values_length
)
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
)

# embed positions
if input_ids is not None:
Expand Down Expand Up @@ -756,23 +764,17 @@ def whisper_decoder_forward(
past_key_value = past_key_values[idx] if past_key_values is not None else None

if self.gradient_checkpointing and self.training:

def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, use_cache)

return custom_forward

layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
encoder_hidden_states,
None, # encoder attention mask
head_mask[idx] if head_mask is not None else None,
(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
None, # past_key_value
output_attentions,
use_cache,
)
else:
layer_outputs = decoder_layer(
Expand Down
7 changes: 0 additions & 7 deletions colossalai/shardformer/policies/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,6 @@
class WhisperPolicy(Policy):
def __init__(self) -> None:
super().__init__()
import transformers
from packaging.version import Version

# TODO: remove this version check when transformers>=4.36.0
assert Version(transformers.__version__) <= Version(
"4.33.0"
), "The Whisper model should run on a transformers version not greater than 4.33.0."

def config_sanity_check(self):
pass
Expand Down

0 comments on commit 9fd1a3e

Please sign in to comment.