From 4c0d735c03ebf2e15adfa7a1de688843283b0691 Mon Sep 17 00:00:00 2001 From: Susnato Dhar Date: Tue, 31 Oct 2023 16:51:02 +0530 Subject: [PATCH] Add flash attention for `gpt_bigcode` (#26479) * added flash attention of gpt_bigcode * changed docs * Update src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py * add FA-2 docs * oops * Update docs/source/en/perf_infer_gpu_one.md Last Nit Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * fix * oops * remove padding_mask * change getattr->hasattr logic * changed .md file --------- Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Co-authored-by: younesbelkada Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- docs/source/en/model_doc/gpt_bigcode.md | 39 +++ docs/source/en/perf_infer_gpu_one.md | 1 + .../gpt_bigcode/modeling_gpt_bigcode.py | 311 ++++++++++++++++-- 3 files changed, 328 insertions(+), 23 deletions(-) diff --git a/docs/source/en/model_doc/gpt_bigcode.md b/docs/source/en/model_doc/gpt_bigcode.md index 6965d5837d8e74..8cc77a825de75c 100644 --- a/docs/source/en/model_doc/gpt_bigcode.md +++ b/docs/source/en/model_doc/gpt_bigcode.md @@ -42,6 +42,45 @@ The main differences compared to GPT2. You can read more about the optimizations in the [original pull request](https://github.com/huggingface/transformers/pull/22575) +## Combining Starcoder and Flash Attention 2 + +First, make sure to install the latest version of Flash Attention 2 to include the sliding window attention feature. + +```bash +pip install -U flash-attn --no-build-isolation +``` + +Make also sure that you have a hardware that is compatible with Flash-Attention 2. Read more about it in the official documentation of flash-attn repository. Make also sure to load your model in half-precision (e.g. `torch.float16``) + +To load and run a model using Flash Attention 2, refer to the snippet below: + +```python +>>> import torch +>>> from transformers import AutoModelForCausalLM, AutoTokenizer +>>> device = "cuda" # the device to load the model onto + +>>> model = AutoModelForCausalLM.from_pretrained("bigcode/gpt_bigcode-santacoder", torch_dtype=torch.float16, use_flash_attention_2=True) +>>> tokenizer = AutoTokenizer.from_pretrained("bigcode/gpt_bigcode-santacoder") + +>>> prompt = "def hello_world():" + +>>> model_inputs = tokenizer([prompt], return_tensors="pt").to(device) +>>> model.to(device) + +>>> generated_ids = model.generate(**model_inputs, max_new_tokens=30, do_sample=False) +>>> tokenizer.batch_decode(generated_ids)[0] +'def hello_world():\n print("hello world")\n\nif __name__ == "__main__":\n print("hello world")\n<|endoftext|>' +``` + +### Expected speedups + +Below is a expected speedup diagram that compares pure inference time between the native implementation in transformers using `bigcode/starcoder` checkpoint and the Flash Attention 2 version of the model using two different sequence lengths. + +
+ +
+ + ## GPTBigCodeConfig [[autodoc]] GPTBigCodeConfig diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index d24299012e9fe1..39f2ca22b1f040 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -34,6 +34,7 @@ We natively support Flash Attention 2 for the following models: - Llama - Mistral - Falcon +- [GPTBigCode (Starcoder)](model_doc/gpt_bigcode#) You can request to add Flash Attention 2 support for more models by opening an issue on GitHub, and even open a Pull Request to integrate the changes. The supported models can be used for inference and training, including training with padding tokens - *which is currently not supported for `BetterTransformer` API below.* diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index f8e52b6510a0bd..fcbbfca5cedac7 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -16,6 +16,7 @@ from typing import List, Optional, Tuple, Union import torch +import torch.nn.functional as F import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss @@ -32,11 +33,17 @@ add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + is_flash_attn_2_available, logging, ) from .configuration_gpt_bigcode import GPTBigCodeConfig +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "bigcode/gpt_bigcode-santacoder" @@ -78,11 +85,25 @@ def masked_softmax(x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor return x +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + class GPTBigCodeAttention(nn.Module): def __init__(self, config, is_cross_attention=False, layer_idx=None): super().__init__() - self.mask_value = None + self.config = config + self.mask_value = None self.multi_query = config.multi_query self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads @@ -90,6 +111,8 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): self.kv_heads = 1 if self.multi_query else self.num_heads self.kv_dim = self.kv_heads * self.head_dim self.split_size = self.embed_dim + self.is_causal = True + if self.head_dim * self.num_heads != self.embed_dim: raise ValueError( f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" @@ -212,10 +235,16 @@ def forward( encoder_attention_mask: Optional[torch.Tensor] = None, use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, + **kwargs, ) -> Union[ Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]], ]: + if "padding_mask" in kwargs: + logger.warning_once( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + if encoder_hidden_states is not None: if not hasattr(self, "q_attn") or not self.is_cross_attention: raise ValueError( @@ -262,6 +291,223 @@ def forward( return outputs # a, present, (attentions) +class GPTBigCodeFlashAttention2(GPTBigCodeAttention): + """ + GPTBigCode flash attention module. This module inherits from `GPTBigCodeAttention` as the weights of the module + stays untouched. The only required change would be on the forward pass where it needs to correctly call the public + API of flash attention and deal with padding tokens in case the input contains any of them. + """ + + def forward( + self, + hidden_states: torch.Tensor, + layer_past: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs, + ) -> Union[ + Tuple[torch.Tensor, Optional[torch.Tensor]], + Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]], + ]: + if "padding_mask" in kwargs: + logger.warning_once( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + # overwrite attention_mask with padding_mask + attention_mask = kwargs.pop("padding_mask") + + if encoder_hidden_states is not None: + if not hasattr(self, "q_attn") or not self.is_cross_attention: + raise ValueError( + "If class is used as cross attention, the weights `q_attn` have to be defined. " + "Please make sure to instantiate class with `GPTBigCodeAttention(..., is_cross_attention=True)`." + ) + + query = self.q_attn(hidden_states) + key_value = self.c_attn(encoder_hidden_states) + attention_mask = encoder_attention_mask + elif self.multi_query: + query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2) + else: + # Note: We split as (self.num_heads, 3, self.head_dim) instead of (3, self.num_heads, self.head_dim), + # i.e., the memory layout is not the same as GPT2. + # This makes the concatenation with past_key_value more efficient. + query, key_value = ( + self.c_attn(hidden_states) + .view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim) + .transpose(1, 2) + .split((self.head_dim, 2 * self.head_dim), dim=3) + ) + + if layer_past is not None: + key_value = torch.cat((layer_past, key_value), dim=-2) + present = key_value if use_cache else None + + key, value = key_value.split((self.head_dim, self.head_dim), dim=-1) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + if self.multi_query: + batch_size, query_length, _ = query.shape + query = query.reshape(batch_size, query_length, self.num_heads, self.head_dim) + key = key.unsqueeze(2) + value = value.unsqueeze(2) + else: + query_length = query.shape[2] + batch_size, _, tgt, _ = key.shape + query = query.transpose(1, 2).reshape(batch_size, query_length, self.num_heads, self.head_dim) + key = key.transpose(1, 2).reshape(batch_size, tgt, self.num_heads, self.head_dim) + value = value.transpose(1, 2).reshape(batch_size, tgt, self.num_heads, self.head_dim) + + attn_dropout = self.dropout if self.training else 0.0 + + softmax_dtype = torch.float32 if self.attention_softmax_in_fp32 else query.dtype + upcast = query.dtype != softmax_dtype + softmax_scale = self.layer_idx + 1 if self.scale_attention_softmax_in_fp32 and upcast else 1 + softmax_scale = softmax_scale**-1 + if self.scale_attn_weights: + softmax_scale /= self.head_dim**0.5 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query.dtype + if input_dtype == torch.float32: + # Handle the case where the model is quantized + if hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.c_attn.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + query = query.to(target_dtype) + key = key.to(target_dtype) + value = value.to(target_dtype) + + attn_output = self._flash_attention_forward( + query, key, value, attention_mask, query_length, dropout=attn_dropout, softmax_scale=softmax_scale + ) + + attn_weights_reshaped = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim) + attn_output = self.c_proj(attn_weights_reshaped) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + + if output_attentions: + if self.multi_query: + # Transpose to return weights in the usual format (batch_size, num_heads, query_length, key_length) + attn_weights_reshaped = attn_weights_reshaped.transpose(1, 2) + else: + attn_weights_reshaped = None + + outputs += (attn_weights_reshaped,) + + return outputs # a, present, (attentions) + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`int`, *optional*): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=self.is_causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal + ) + + return attn_output + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + class GPTBigCodeMLP(nn.Module): def __init__(self, intermediate_size, config): super().__init__() @@ -287,13 +533,21 @@ def __init__(self, config, layer_idx=None): self.inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.attn = GPTBigCodeAttention(config, layer_idx=layer_idx) + self.attn = ( + GPTBigCodeAttention(config, layer_idx=layer_idx) + if not getattr(config, "_flash_attn_2_enabled", False) + else GPTBigCodeFlashAttention2(config, layer_idx=layer_idx) + ) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) if config.add_cross_attention: if config.multi_query: raise NotImplementedError("Cross-attention not implemented for MQA") - self.crossattention = GPTBigCodeAttention(config, is_cross_attention=True, layer_idx=layer_idx) + self.crossattention = ( + GPTBigCodeAttention(config, is_cross_attention=True, layer_idx=layer_idx) + if not getattr(config, "_flash_attn_2_enabled", False) + else GPTBigCodeFlashAttention2(config, is_cross_attention=True, layer_idx=layer_idx) + ) self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.mlp = GPTBigCodeMLP(self.inner_dim, config) @@ -373,6 +627,7 @@ class GPTBigCodePreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["GPTBigCodeBlock"] _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) @@ -594,28 +849,38 @@ def forward( key_length = past_length + query_length self_attention_mask = self.bias[None, key_length - query_length : key_length, :key_length] - if attention_mask is not None: - self_attention_mask = self_attention_mask * attention_mask.view(batch_size, 1, -1).to( - dtype=torch.bool, device=self_attention_mask.device + if getattr(self.config, "_flash_attn_2_enabled", False): + # 2d mask is passed through the layers + attention_mask = attention_mask.bool() if (attention_mask is not None and 0 in attention_mask) else None + encoder_attention_mask = ( + encoder_attention_mask.bool() + if (encoder_attention_mask is not None and 0 in encoder_attention_mask) + else None ) - - # MQA models: (batch_size, query_length, n_heads, key_length) - # MHA models: (batch_size, n_heads, query_length, key_length) - attention_mask = self_attention_mask.unsqueeze(2 if self.multi_query else 1) - - # If a 2D or 3D attention mask is provided for the cross-attention - # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - if ( - self.config.add_cross_attention - and encoder_hidden_states is not None - and encoder_attention_mask is not None - ): - if encoder_attention_mask.dim() == 2: - encoder_attention_mask.unsqueeze(1) - assert encoder_attention_mask.dim() == 3 - encoder_attention_mask = encoder_attention_mask.bool().unsqueeze(2 if self.multi_query else 1) else: - encoder_attention_mask = None + # 4d mask is passed through the layers + if attention_mask is not None: + self_attention_mask = self_attention_mask * attention_mask.view(batch_size, 1, -1).to( + dtype=torch.bool, device=self_attention_mask.device + ) + + # MQA models: (batch_size, query_length, n_heads, key_length) + # MHA models: (batch_size, n_heads, query_length, key_length) + attention_mask = self_attention_mask.unsqueeze(2 if self.multi_query else 1) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if ( + self.config.add_cross_attention + and encoder_hidden_states is not None + and encoder_attention_mask is not None + ): + if encoder_attention_mask.dim() == 2: + encoder_attention_mask.unsqueeze(1) + assert encoder_attention_mask.dim() == 3 + encoder_attention_mask = encoder_attention_mask.bool().unsqueeze(2 if self.multi_query else 1) + else: + encoder_attention_mask = None # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head