Skip to content

Commit

Permalink
[Whisper, Bart, MBart] Add Flash Attention 2 (#27203)
Browse files Browse the repository at this point in the history
* add whisper fa2

* correct

* change all

* correct

* correct

* fix more

* fix more

* fix more

* fix more

* fix more

* fix more

* Apply suggestions from code review

Co-authored-by: amyeroberts <[email protected]>

* fix more

* fix more

* fix more

* fix more

* fix more

---------

Co-authored-by: amyeroberts <[email protected]>
  • Loading branch information
patrickvonplaten and amyeroberts authored Nov 1, 2023
1 parent 3520e37 commit af3de8d
Show file tree
Hide file tree
Showing 28 changed files with 1,300 additions and 123 deletions.
275 changes: 263 additions & 12 deletions src/transformers/models/bart/modeling_bart.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -1174,7 +1174,7 @@ def forward(
return outputs


# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->BigBirdPegasusDecoder
# Copied from transformers.models.bart.modeling_bart.BartAttention with BartConfig->BigBirdPegasusConfig, Bart->BigBirdPegasusDecoder
class BigBirdPegasusDecoderAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

Expand All @@ -1185,12 +1185,15 @@ def __init__(
dropout: float = 0.0,
is_decoder: bool = False,
bias: bool = True,
is_causal: bool = False,
config: Optional[BigBirdPegasusConfig] = None,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
self.config = config

if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(
Expand All @@ -1199,6 +1202,7 @@ def __init__(
)
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
self.is_causal = is_causal

self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/biogpt/modeling_biogpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,15 @@ def __init__(
dropout: float = 0.0,
is_decoder: bool = False,
bias: bool = True,
is_causal: bool = False,
config: Optional[BioGptConfig] = None,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
self.config = config

if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(
Expand All @@ -104,6 +107,7 @@ def __init__(
)
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
self.is_causal = is_causal

self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
Expand Down
24 changes: 19 additions & 5 deletions src/transformers/models/blenderbot/modeling_blenderbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,15 @@ def __init__(
dropout: float = 0.0,
is_decoder: bool = False,
bias: bool = True,
is_causal: bool = False,
config: Optional[BlenderbotConfig] = None,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
self.config = config

if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(
Expand All @@ -118,6 +121,7 @@ def __init__(
)
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
self.is_causal = is_causal

self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
Expand Down Expand Up @@ -248,15 +252,21 @@ def forward(
return attn_output, attn_weights_reshaped, past_key_value


# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Blenderbot
BLENDERBOT_ATTENTION_CLASSES = {"default": BlenderbotAttention}


# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Blenderbot, MBART->BLENDERBOT
class BlenderbotEncoderLayer(nn.Module):
def __init__(self, config: BlenderbotConfig):
super().__init__()
self.embed_dim = config.d_model
self.self_attn = BlenderbotAttention(
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"

self.self_attn = BLENDERBOT_ATTENTION_CLASSES[attn_type](
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout,
config=config,
)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.dropout = config.dropout
Expand Down Expand Up @@ -317,28 +327,32 @@ def forward(
return outputs


# Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Blenderbot
# Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Blenderbot, MBART->BLENDERBOT
class BlenderbotDecoderLayer(nn.Module):
def __init__(self, config: BlenderbotConfig):
super().__init__()
self.embed_dim = config.d_model
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"

self.self_attn = BlenderbotAttention(
self.self_attn = BLENDERBOT_ATTENTION_CLASSES[attn_type](
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
is_decoder=True,
is_causal=True,
config=config,
)
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout

self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.encoder_attn = BlenderbotAttention(
self.encoder_attn = BLENDERBOT_ATTENTION_CLASSES[attn_type](
self.embed_dim,
config.decoder_attention_heads,
dropout=config.attention_dropout,
is_decoder=True,
config=config,
)
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,15 @@ def __init__(
dropout: float = 0.0,
is_decoder: bool = False,
bias: bool = True,
is_causal: bool = False,
config: Optional[BlenderbotSmallConfig] = None,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
self.config = config

if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(
Expand All @@ -115,6 +118,7 @@ def __init__(
)
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
self.is_causal = is_causal

self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
Expand Down Expand Up @@ -245,15 +249,18 @@ def forward(
return attn_output, attn_weights_reshaped, past_key_value


# Copied from transformers.models.bart.modeling_bart.BartEncoderLayer with Bart->BlenderbotSmall
# Copied from transformers.models.bart.modeling_bart.BartEncoderLayer with Bart->BlenderbotSmall, BART->BLENDERBOT_SMALL
class BlenderbotSmallEncoderLayer(nn.Module):
def __init__(self, config: BlenderbotSmallConfig):
super().__init__()
self.embed_dim = config.d_model
self.self_attn = BlenderbotSmallAttention(
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"

self.self_attn = BLENDERBOT_SMALL_ATTENTION_CLASSES[attn_type](
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout,
config=config,
)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.dropout = config.dropout
Expand Down Expand Up @@ -314,28 +321,35 @@ def forward(
return outputs


# Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->BlenderbotSmall
BLENDERBOT_SMALL_ATTENTION_CLASSES = {"default": BlenderbotSmallAttention}


# Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->BlenderbotSmall, BART->BLENDERBOT_SMALL
class BlenderbotSmallDecoderLayer(nn.Module):
def __init__(self, config: BlenderbotSmallConfig):
super().__init__()
self.embed_dim = config.d_model

self.self_attn = BlenderbotSmallAttention(
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.self_attn = BLENDERBOT_SMALL_ATTENTION_CLASSES[attn_type](
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
is_decoder=True,
is_causal=True,
config=config,
)
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout

self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.encoder_attn = BlenderbotSmallAttention(
self.encoder_attn = BLENDERBOT_SMALL_ATTENTION_CLASSES[attn_type](
self.embed_dim,
config.decoder_attention_heads,
dropout=config.attention_dropout,
is_decoder=True,
config=config,
)
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/data2vec/modeling_data2vec_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,12 +330,15 @@ def __init__(
dropout: float = 0.0,
is_decoder: bool = False,
bias: bool = True,
is_causal: bool = False,
config: Optional[Data2VecAudioConfig] = None,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
self.config = config

if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(
Expand All @@ -344,6 +347,7 @@ def __init__(
)
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
self.is_causal = is_causal

self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -370,12 +370,15 @@ def __init__(
dropout: float = 0.0,
is_decoder: bool = False,
bias: bool = True,
is_causal: bool = False,
config: Optional[GPTSanJapaneseConfig] = None,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
self.config = config

if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(
Expand All @@ -384,6 +387,7 @@ def __init__(
)
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
self.is_causal = is_causal

self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/hubert/modeling_hubert.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,12 +396,15 @@ def __init__(
dropout: float = 0.0,
is_decoder: bool = False,
bias: bool = True,
is_causal: bool = False,
config: Optional[HubertConfig] = None,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
self.config = config

if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(
Expand All @@ -410,6 +413,7 @@ def __init__(
)
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
self.is_causal = is_causal

self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/informer/modeling_informer.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,12 +287,15 @@ def __init__(
dropout: float = 0.0,
is_decoder: bool = False,
bias: bool = True,
is_causal: bool = False,
config: Optional[InformerConfig] = None,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
self.config = config

if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(
Expand All @@ -301,6 +304,7 @@ def __init__(
)
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
self.is_causal = is_causal

self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
Expand Down
Loading

0 comments on commit af3de8d

Please sign in to comment.