-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add flash attention for gpt_bigcode
#26479
Changes from 1 commit
447703e
7baa248
bddd8e6
7f38f86
50506fa
28ddca3
542c275
72b353b
f43ec5a
b2aa0d9
4792b33
22a64cb
ba0de16
d577b4f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,7 @@ | |
from typing import List, Optional, Tuple, Union | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
import torch.utils.checkpoint | ||
from torch import nn | ||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss | ||
|
@@ -32,11 +33,17 @@ | |
add_code_sample_docstrings, | ||
add_start_docstrings, | ||
add_start_docstrings_to_model_forward, | ||
is_flash_attn_available, | ||
logging, | ||
) | ||
from .configuration_gpt_bigcode import GPTBigCodeConfig | ||
|
||
|
||
if is_flash_attn_available(): | ||
from flash_attn import flash_attn_func, flash_attn_varlen_func | ||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa | ||
|
||
|
||
logger = logging.get_logger(__name__) | ||
|
||
_CHECKPOINT_FOR_DOC = "bigcode/gpt_bigcode-santacoder" | ||
|
@@ -78,6 +85,19 @@ def masked_softmax(x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor | |
return x | ||
|
||
|
||
# Copied from transformers.models.llama.modeling_llama._get_unpad_data | ||
def _get_unpad_data(padding_mask): | ||
seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32) | ||
indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten() | ||
max_seqlen_in_batch = seqlens_in_batch.max().item() | ||
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) | ||
return ( | ||
indices, | ||
cu_seqlens, | ||
max_seqlen_in_batch, | ||
) | ||
|
||
|
||
class GPTBigCodeAttention(nn.Module): | ||
def __init__(self, config, is_cross_attention=False, layer_idx=None): | ||
super().__init__() | ||
|
@@ -211,6 +231,8 @@ def forward( | |
encoder_hidden_states: Optional[torch.Tensor] = None, | ||
encoder_attention_mask: Optional[torch.Tensor] = None, | ||
use_cache: Optional[bool] = False, | ||
padding_mask: Optional[torch.LongTensor] = None, | ||
encoder_padding_mask: Optional[torch.LongTensor] = None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should wait again for the padding mask refactor and not pass padding mask! #26792 |
||
output_attentions: Optional[bool] = False, | ||
) -> Union[ | ||
Tuple[torch.Tensor, Optional[torch.Tensor]], | ||
|
@@ -262,6 +284,206 @@ def forward( | |
return outputs # a, present, (attentions) | ||
|
||
|
||
class GPTBigCodeFlashAttention2(GPTBigCodeAttention): | ||
""" | ||
GPTBigCode flash attention module. This module inherits from `GPTBigCodeAttention` as the weights of the module | ||
stays untouched. The only required change would be on the forward pass where it needs to correctly call the public | ||
API of flash attention and deal with padding tokens in case the input contains any of them. | ||
""" | ||
|
||
def forward( | ||
self, | ||
hidden_states: torch.Tensor, | ||
layer_past: Optional[torch.Tensor] = None, | ||
attention_mask: Optional[torch.Tensor] = None, | ||
head_mask: Optional[torch.Tensor] = None, | ||
encoder_hidden_states: Optional[torch.Tensor] = None, | ||
encoder_attention_mask: Optional[torch.Tensor] = None, | ||
use_cache: Optional[bool] = False, | ||
padding_mask: Optional[torch.LongTensor] = None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There should be no more There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hey I am currently working to fix this. |
||
encoder_padding_mask: Optional[torch.LongTensor] = None, | ||
output_attentions: Optional[bool] = False, | ||
) -> Union[ | ||
Tuple[torch.Tensor, Optional[torch.Tensor]], | ||
Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]], | ||
]: | ||
if encoder_hidden_states is not None: | ||
if not hasattr(self, "q_attn") or not self.is_cross_attention: | ||
raise ValueError( | ||
"If class is used as cross attention, the weights `q_attn` have to be defined. " | ||
"Please make sure to instantiate class with `GPTBigCodeAttention(..., is_cross_attention=True)`." | ||
) | ||
|
||
query = self.q_attn(hidden_states) | ||
key_value = self.c_attn(encoder_hidden_states) | ||
padding_mask = encoder_padding_mask | ||
elif self.multi_query: | ||
query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2) | ||
else: | ||
# Note: We split as (self.num_heads, 3, self.head_dim) instead of (3, self.num_heads, self.head_dim), | ||
# i.e., the memory layout is not the same as GPT2. | ||
# This makes the concatenation with past_key_value more efficient. | ||
query, key_value = ( | ||
self.c_attn(hidden_states) | ||
.view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim) | ||
.transpose(1, 2) | ||
.split((self.head_dim, 2 * self.head_dim), dim=3) | ||
) | ||
|
||
if layer_past is not None: | ||
key_value = torch.cat((layer_past, key_value), dim=-2) | ||
present = key_value if use_cache else None | ||
|
||
key, value = key_value.split((self.head_dim, self.head_dim), dim=-1) | ||
|
||
# Flash attention requires the input to have the shape | ||
# batch_size x seq_length x head_dim x hidden_dim | ||
if self.multi_query: | ||
batch_size, query_length, _ = query.shape | ||
query = query.reshape(batch_size, query_length, self.num_heads, self.head_dim) | ||
key = key.unsqueeze(2) | ||
value = value.unsqueeze(2) | ||
else: | ||
query_length = query.shape[2] | ||
batch_size, _, tgt, _ = key.shape | ||
query = query.transpose(1, 2).reshape(batch_size, query_length, self.num_heads, self.head_dim) | ||
key = key.transpose(1, 2).reshape(batch_size, tgt, self.num_heads, self.head_dim) | ||
value = value.transpose(1, 2).reshape(batch_size, tgt, self.num_heads, self.head_dim) | ||
|
||
attn_dropout = self.dropout if self.training else 0.0 | ||
susnato marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
softmax_dtype = torch.float32 if self.attention_softmax_in_fp32 else query.dtype | ||
upcast = query.dtype != softmax_dtype | ||
younesbelkada marked this conversation as resolved.
Show resolved
Hide resolved
|
||
softmax_scale = self.layer_idx + 1 if self.scale_attention_softmax_in_fp32 and upcast else 1 | ||
softmax_scale = softmax_scale**-1 | ||
if self.scale_attn_weights: | ||
softmax_scale /= self.head_dim**0.5 | ||
|
||
# In PEFT, usually we cast the layer norms in float32 for training stability reasons | ||
# therefore the input hidden states gets silently casted in float32. Hence, we need | ||
# cast them back in float16 just to be sure everything works as expected. | ||
input_dtype = query.dtype | ||
if input_dtype == torch.float32: | ||
logger.warning_once( | ||
"The input hidden states seems to be silently casted in float32, this might be related to" | ||
" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" | ||
" float16." | ||
) | ||
query = query.to(torch.float16) | ||
key = key.to(torch.float16) | ||
value = value.to(torch.float16) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This can be fixed in the global fix I want to apply in #26451 as a follow up PR that I will take care There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should I then remove this block? or are we keeping this block for now. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would say we can keep it for now There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's wait until your PR is merge @younesbelkada 😉 |
||
|
||
attn_output = self._flash_attention_forward( | ||
query, key, value, padding_mask, query_length, dropout=attn_dropout, softmax_scale=softmax_scale | ||
) | ||
|
||
attn_weights_reshaped = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim) | ||
attn_output = self.c_proj(attn_weights_reshaped) | ||
attn_output = self.resid_dropout(attn_output) | ||
|
||
outputs = (attn_output, present) | ||
|
||
if output_attentions: | ||
if self.multi_query: | ||
# Transpose to return weights in the usual format (batch_size, num_heads, query_length, key_length) | ||
attn_weights_reshaped = attn_weights_reshaped.transpose(1, 2) | ||
else: | ||
attn_weights_reshaped = None | ||
|
||
outputs += (attn_weights_reshaped,) | ||
|
||
return outputs # a, present, (attentions) | ||
|
||
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward | ||
def _flash_attention_forward( | ||
self, query_states, key_states, value_states, padding_mask, query_length, dropout=0.0, softmax_scale=None | ||
): | ||
""" | ||
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token | ||
first unpad the input, then computes the attention scores and pad the final attention scores. | ||
|
||
Args: | ||
query_states (`torch.Tensor`): | ||
Input query states to be passed to Flash Attention API | ||
key_states (`torch.Tensor`): | ||
Input key states to be passed to Flash Attention API | ||
value_states (`torch.Tensor`): | ||
Input value states to be passed to Flash Attention API | ||
padding_mask (`torch.Tensor`): | ||
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the | ||
position of padding tokens and 1 for the position of non-padding tokens. | ||
dropout (`int`, *optional*): | ||
Attention dropout | ||
softmax_scale (`float`, *optional*): | ||
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) | ||
""" | ||
# Contains at least one padding token in the sequence | ||
if padding_mask is not None: | ||
batch_size = query_states.shape[0] | ||
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( | ||
query_states, key_states, value_states, padding_mask, query_length | ||
) | ||
|
||
cu_seqlens_q, cu_seqlens_k = cu_seq_lens | ||
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens | ||
|
||
attn_output_unpad = flash_attn_varlen_func( | ||
query_states, | ||
key_states, | ||
value_states, | ||
cu_seqlens_q=cu_seqlens_q, | ||
cu_seqlens_k=cu_seqlens_k, | ||
max_seqlen_q=max_seqlen_in_batch_q, | ||
max_seqlen_k=max_seqlen_in_batch_k, | ||
dropout_p=dropout, | ||
softmax_scale=softmax_scale, | ||
causal=True, | ||
) | ||
|
||
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) | ||
else: | ||
attn_output = flash_attn_func( | ||
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=True | ||
) | ||
|
||
return attn_output | ||
|
||
def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_length): | ||
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask) | ||
batch_size, kv_seq_len, kv_num_heads, head_dim = key_layer.shape | ||
query_num_heads = query_layer.shape[2] | ||
|
||
key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, kv_num_heads, head_dim), indices_k) | ||
value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, kv_num_heads, head_dim), indices_k) | ||
if query_length == kv_seq_len: | ||
query_layer = index_first_axis( | ||
query_layer.reshape(batch_size * kv_seq_len, query_num_heads, head_dim), indices_k | ||
) | ||
cu_seqlens_q = cu_seqlens_k | ||
max_seqlen_in_batch_q = max_seqlen_in_batch_k | ||
indices_q = indices_k | ||
elif query_length == 1: | ||
max_seqlen_in_batch_q = 1 | ||
cu_seqlens_q = torch.arange( | ||
batch_size + 1, dtype=torch.int32, device=query_layer.device | ||
) # There is a memcpy here, that is very bad. | ||
indices_q = cu_seqlens_q[:-1] | ||
query_layer = query_layer.squeeze(1) | ||
else: | ||
# The -q_len: slice assumes left padding. | ||
padding_mask = padding_mask[:, -query_length:] | ||
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, padding_mask) | ||
|
||
return ( | ||
query_layer, | ||
key_layer, | ||
value_layer, | ||
indices_q, | ||
(cu_seqlens_q, cu_seqlens_k), | ||
(max_seqlen_in_batch_q, max_seqlen_in_batch_k), | ||
) | ||
|
||
|
||
class GPTBigCodeMLP(nn.Module): | ||
def __init__(self, intermediate_size, config): | ||
super().__init__() | ||
|
@@ -287,13 +509,21 @@ def __init__(self, config, layer_idx=None): | |
self.inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size | ||
|
||
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) | ||
self.attn = GPTBigCodeAttention(config, layer_idx=layer_idx) | ||
self.attn = ( | ||
GPTBigCodeAttention(config, layer_idx=layer_idx) | ||
if not getattr(config, "_flash_attn_2_enabled", False) | ||
else GPTBigCodeFlashAttention2(config, layer_idx=layer_idx) | ||
) | ||
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) | ||
|
||
if config.add_cross_attention: | ||
if config.multi_query: | ||
raise NotImplementedError("Cross-attention not implemented for MQA") | ||
self.crossattention = GPTBigCodeAttention(config, is_cross_attention=True, layer_idx=layer_idx) | ||
self.crossattention = ( | ||
GPTBigCodeAttention(config, is_cross_attention=True, layer_idx=layer_idx) | ||
if not getattr(config, "_flash_attn_2_enabled", False) | ||
else GPTBigCodeFlashAttention2(config, is_cross_attention=True, layer_idx=layer_idx) | ||
) | ||
self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) | ||
|
||
self.mlp = GPTBigCodeMLP(self.inner_dim, config) | ||
|
@@ -307,6 +537,8 @@ def forward( | |
encoder_hidden_states: Optional[torch.Tensor] = None, | ||
encoder_attention_mask: Optional[torch.Tensor] = None, | ||
use_cache: Optional[bool] = False, | ||
padding_mask: Optional[torch.LongTensor] = None, | ||
encoder_padding_mask: Optional[torch.LongTensor] = None, | ||
output_attentions: Optional[bool] = False, | ||
) -> Union[ | ||
Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor] | ||
|
@@ -320,6 +552,8 @@ def forward( | |
head_mask=head_mask, | ||
use_cache=use_cache, | ||
output_attentions=output_attentions, | ||
padding_mask=padding_mask, | ||
encoder_padding_mask=encoder_padding_mask, | ||
) | ||
attn_output = attn_outputs[0] # output_attn: a, present, (attentions) | ||
outputs = attn_outputs[1:] | ||
|
@@ -342,6 +576,8 @@ def forward( | |
encoder_hidden_states=encoder_hidden_states, | ||
encoder_attention_mask=encoder_attention_mask, | ||
output_attentions=output_attentions, | ||
padding_mask=padding_mask, | ||
encoder_padding_mask=encoder_padding_mask, | ||
) | ||
attn_output = cross_attn_outputs[0] | ||
# residual connection | ||
|
@@ -373,6 +609,7 @@ class GPTBigCodePreTrainedModel(PreTrainedModel): | |
supports_gradient_checkpointing = True | ||
_no_split_modules = ["GPTBigCodeBlock"] | ||
_skip_keys_device_placement = "past_key_values" | ||
_supports_flash_attn_2 = True | ||
|
||
def __init__(self, *inputs, **kwargs): | ||
super().__init__(*inputs, **kwargs) | ||
|
@@ -586,6 +823,13 @@ def forward( | |
else: | ||
past_length = past_key_values[0].size(-2) | ||
|
||
padding_mask = None | ||
if attention_mask is not None and 0 in attention_mask: | ||
padding_mask = attention_mask | ||
encoder_padding_mask = None | ||
if encoder_attention_mask is not None and 0 in encoder_attention_mask: | ||
encoder_padding_mask = encoder_attention_mask | ||
|
||
if attention_mask is not None and len(attention_mask.shape) == 2 and position_ids is None: | ||
# create position_ids on the fly for batch generation | ||
position_ids = attention_mask.long().cumsum(-1) - 1 | ||
|
@@ -656,7 +900,7 @@ def forward( | |
def create_custom_forward(module): | ||
def custom_forward(*inputs): | ||
# None for past_key_value | ||
return module(*inputs, use_cache, output_attentions) | ||
return module(*inputs, use_cache, output_attentions, padding_mask, encoder_padding_mask) | ||
|
||
return custom_forward | ||
|
||
|
@@ -679,6 +923,8 @@ def custom_forward(*inputs): | |
encoder_attention_mask=encoder_attention_mask, | ||
use_cache=use_cache, | ||
output_attentions=output_attentions, | ||
padding_mask=padding_mask, | ||
encoder_padding_mask=encoder_padding_mask, | ||
) | ||
|
||
hidden_states = outputs[0] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done