From 447703e7f6497472249c1c7992ed13667dfb12de Mon Sep 17 00:00:00 2001 From: Susnato Dhar Date: Fri, 29 Sep 2023 02:04:28 +0530 Subject: [PATCH 01/11] added flash attention of gpt_bigcode --- docs/source/en/perf_infer_gpu_one.md | 1 + .../gpt_bigcode/modeling_gpt_bigcode.py | 252 +++++++++++++++++- 2 files changed, 250 insertions(+), 3 deletions(-) diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index f0c0bf0b107154..69593bc87f8454 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -33,6 +33,7 @@ We natively support Flash Attention 2 for the following models: - Llama - Falcon +- GPTBigCode You can request to add Flash Attention 2 support for more models by opening an issue on GitHub, and even open a Pull Request to integrate the changes. The supported models can be used for inference and training, including training with padding tokens - *which is currently not supported for `BetterTransformer` API below.* diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 1c34f28a5c887c..6858d815e9daf6 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -16,6 +16,7 @@ from typing import List, Optional, Tuple, Union import torch +import torch.nn.functional as F import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss @@ -32,11 +33,17 @@ add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + is_flash_attn_available, logging, ) from .configuration_gpt_bigcode import GPTBigCodeConfig +if is_flash_attn_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "bigcode/gpt_bigcode-santacoder" @@ -78,6 +85,19 @@ def masked_softmax(x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor return x +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(padding_mask): + seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + class GPTBigCodeAttention(nn.Module): def __init__(self, config, is_cross_attention=False, layer_idx=None): super().__init__() @@ -211,6 +231,8 @@ def forward( encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, use_cache: Optional[bool] = False, + padding_mask: Optional[torch.LongTensor] = None, + encoder_padding_mask: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = False, ) -> Union[ Tuple[torch.Tensor, Optional[torch.Tensor]], @@ -262,6 +284,206 @@ def forward( return outputs # a, present, (attentions) +class GPTBigCodeFlashAttention2(GPTBigCodeAttention): + """ + GPTBigCode flash attention module. This module inherits from `GPTBigCodeAttention` as the weights of the module + stays untouched. The only required change would be on the forward pass where it needs to correctly call the public + API of flash attention and deal with padding tokens in case the input contains any of them. + """ + + def forward( + self, + hidden_states: torch.Tensor, + layer_past: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = False, + padding_mask: Optional[torch.LongTensor] = None, + encoder_padding_mask: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = False, + ) -> Union[ + Tuple[torch.Tensor, Optional[torch.Tensor]], + Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]], + ]: + if encoder_hidden_states is not None: + if not hasattr(self, "q_attn") or not self.is_cross_attention: + raise ValueError( + "If class is used as cross attention, the weights `q_attn` have to be defined. " + "Please make sure to instantiate class with `GPTBigCodeAttention(..., is_cross_attention=True)`." + ) + + query = self.q_attn(hidden_states) + key_value = self.c_attn(encoder_hidden_states) + padding_mask = encoder_padding_mask + elif self.multi_query: + query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2) + else: + # Note: We split as (self.num_heads, 3, self.head_dim) instead of (3, self.num_heads, self.head_dim), + # i.e., the memory layout is not the same as GPT2. + # This makes the concatenation with past_key_value more efficient. + query, key_value = ( + self.c_attn(hidden_states) + .view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim) + .transpose(1, 2) + .split((self.head_dim, 2 * self.head_dim), dim=3) + ) + + if layer_past is not None: + key_value = torch.cat((layer_past, key_value), dim=-2) + present = key_value if use_cache else None + + key, value = key_value.split((self.head_dim, self.head_dim), dim=-1) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + if self.multi_query: + batch_size, query_length, _ = query.shape + query = query.reshape(batch_size, query_length, self.num_heads, self.head_dim) + key = key.unsqueeze(2) + value = value.unsqueeze(2) + else: + query_length = query.shape[2] + batch_size, _, tgt, _ = key.shape + query = query.transpose(1, 2).reshape(batch_size, query_length, self.num_heads, self.head_dim) + key = key.transpose(1, 2).reshape(batch_size, tgt, self.num_heads, self.head_dim) + value = value.transpose(1, 2).reshape(batch_size, tgt, self.num_heads, self.head_dim) + + attn_dropout = self.dropout if self.training else 0.0 + + softmax_dtype = torch.float32 if self.attention_softmax_in_fp32 else query.dtype + upcast = query.dtype != softmax_dtype + softmax_scale = self.layer_idx + 1 if self.scale_attention_softmax_in_fp32 and upcast else 1 + softmax_scale = softmax_scale**-1 + if self.scale_attn_weights: + softmax_scale /= self.head_dim**0.5 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query.dtype + if input_dtype == torch.float32: + logger.warning_once( + "The input hidden states seems to be silently casted in float32, this might be related to" + " the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + " float16." + ) + query = query.to(torch.float16) + key = key.to(torch.float16) + value = value.to(torch.float16) + + attn_output = self._flash_attention_forward( + query, key, value, padding_mask, query_length, dropout=attn_dropout, softmax_scale=softmax_scale + ) + + attn_weights_reshaped = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim) + attn_output = self.c_proj(attn_weights_reshaped) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + + if output_attentions: + if self.multi_query: + # Transpose to return weights in the usual format (batch_size, num_heads, query_length, key_length) + attn_weights_reshaped = attn_weights_reshaped.transpose(1, 2) + else: + attn_weights_reshaped = None + + outputs += (attn_weights_reshaped,) + + return outputs # a, present, (attentions) + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward + def _flash_attention_forward( + self, query_states, key_states, value_states, padding_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + padding_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`int`, *optional*): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + # Contains at least one padding token in the sequence + if padding_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, padding_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=True, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=True + ) + + return attn_output + + def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask) + batch_size, kv_seq_len, kv_num_heads, head_dim = key_layer.shape + query_num_heads = query_layer.shape[2] + + key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, kv_num_heads, head_dim), indices_k) + value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, kv_num_heads, head_dim), indices_k) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, query_num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + padding_mask = padding_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, padding_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + class GPTBigCodeMLP(nn.Module): def __init__(self, intermediate_size, config): super().__init__() @@ -287,13 +509,21 @@ def __init__(self, config, layer_idx=None): self.inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.attn = GPTBigCodeAttention(config, layer_idx=layer_idx) + self.attn = ( + GPTBigCodeAttention(config, layer_idx=layer_idx) + if not getattr(config, "_flash_attn_2_enabled", False) + else GPTBigCodeFlashAttention2(config, layer_idx=layer_idx) + ) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) if config.add_cross_attention: if config.multi_query: raise NotImplementedError("Cross-attention not implemented for MQA") - self.crossattention = GPTBigCodeAttention(config, is_cross_attention=True, layer_idx=layer_idx) + self.crossattention = ( + GPTBigCodeAttention(config, is_cross_attention=True, layer_idx=layer_idx) + if not getattr(config, "_flash_attn_2_enabled", False) + else GPTBigCodeFlashAttention2(config, is_cross_attention=True, layer_idx=layer_idx) + ) self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.mlp = GPTBigCodeMLP(self.inner_dim, config) @@ -307,6 +537,8 @@ def forward( encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, use_cache: Optional[bool] = False, + padding_mask: Optional[torch.LongTensor] = None, + encoder_padding_mask: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = False, ) -> Union[ Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor] @@ -320,6 +552,8 @@ def forward( head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, + padding_mask=padding_mask, + encoder_padding_mask=encoder_padding_mask, ) attn_output = attn_outputs[0] # output_attn: a, present, (attentions) outputs = attn_outputs[1:] @@ -342,6 +576,8 @@ def forward( encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, output_attentions=output_attentions, + padding_mask=padding_mask, + encoder_padding_mask=encoder_padding_mask, ) attn_output = cross_attn_outputs[0] # residual connection @@ -373,6 +609,7 @@ class GPTBigCodePreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["GPTBigCodeBlock"] _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) @@ -586,6 +823,13 @@ def forward( else: past_length = past_key_values[0].size(-2) + padding_mask = None + if attention_mask is not None and 0 in attention_mask: + padding_mask = attention_mask + encoder_padding_mask = None + if encoder_attention_mask is not None and 0 in encoder_attention_mask: + encoder_padding_mask = encoder_attention_mask + if attention_mask is not None and len(attention_mask.shape) == 2 and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 @@ -656,7 +900,7 @@ def forward( def create_custom_forward(module): def custom_forward(*inputs): # None for past_key_value - return module(*inputs, use_cache, output_attentions) + return module(*inputs, use_cache, output_attentions, padding_mask, encoder_padding_mask) return custom_forward @@ -679,6 +923,8 @@ def custom_forward(*inputs): encoder_attention_mask=encoder_attention_mask, use_cache=use_cache, output_attentions=output_attentions, + padding_mask=padding_mask, + encoder_padding_mask=encoder_padding_mask, ) hidden_states = outputs[0] From 7baa248afc2abed72e13bf2beb2a7ef114c9ca5d Mon Sep 17 00:00:00 2001 From: Susnato Dhar Date: Fri, 29 Sep 2023 13:11:00 +0530 Subject: [PATCH 02/11] changed docs --- docs/source/en/perf_infer_gpu_one.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index 69593bc87f8454..ec1127ff4305f6 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -33,7 +33,7 @@ We natively support Flash Attention 2 for the following models: - Llama - Falcon -- GPTBigCode +- GPTBigCode (Starcoder) You can request to add Flash Attention 2 support for more models by opening an issue on GitHub, and even open a Pull Request to integrate the changes. The supported models can be used for inference and training, including training with padding tokens - *which is currently not supported for `BetterTransformer` API below.* From bddd8e6805796438f2cb074eeb6ab45c90ffc865 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Tue, 3 Oct 2023 14:11:02 +0200 Subject: [PATCH 03/11] Update src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py --- src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 6858d815e9daf6..c49ffcd4c0660d 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -352,8 +352,6 @@ def forward( attn_dropout = self.dropout if self.training else 0.0 - softmax_dtype = torch.float32 if self.attention_softmax_in_fp32 else query.dtype - upcast = query.dtype != softmax_dtype softmax_scale = self.layer_idx + 1 if self.scale_attention_softmax_in_fp32 and upcast else 1 softmax_scale = softmax_scale**-1 if self.scale_attn_weights: From 7f38f86b050c66f48ac096224e50387699c6989d Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 3 Oct 2023 14:16:45 +0200 Subject: [PATCH 04/11] add FA-2 docs --- docs/source/en/model_doc/gpt_bigcode.md | 39 +++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/docs/source/en/model_doc/gpt_bigcode.md b/docs/source/en/model_doc/gpt_bigcode.md index 6965d5837d8e74..7d0cf487e531ba 100644 --- a/docs/source/en/model_doc/gpt_bigcode.md +++ b/docs/source/en/model_doc/gpt_bigcode.md @@ -42,6 +42,45 @@ The main differences compared to GPT2. You can read more about the optimizations in the [original pull request](https://github.com/huggingface/transformers/pull/22575) +## Combining Starcoder and Flash Attention 2 + +First, make sure to install the latest version of Flash Attention 2 to include the sliding window attention feature. + +```bash +pip install -U flash-attn --no-build-isolation +``` + +Make also sure that you have a hardware that is compatible with Flash-Attention 2. Read more about it in the official documentation of flash-attn repository. Make also sure to load your model in half-precision (e.g. `torch.float16``) + +To load and run a model using Flash Attention 2, refer to the snippet below: + +```python +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer +device = "cuda" # the device to load the model onto + +model = AutoModelForCausalLM.from_pretrained("bigcode/starcoder", torch_dtype=torch.float16, use_flash_attention_2=True) +tokenizer = AutoTokenizer.from_pretrained("bigcode/starcoder") + +prompt = "def hello_world():" + +model_inputs = tokenizer([prompt], return_tensors="pt").to(device) +model.to(device) + +generated_ids = model.generate(**model_inputs, max_new_tokens=100, do_sample=True) +tokenizer.batch_decode(generated_ids)[0] +"The expected outupt" +``` + +### Expected speedups + +Below is a expected speedup diagram that compares pure inference time between the native implementation in transformers using `bigcode/starcoder` checkpoint and the Flash Attention 2 version of the model using two different sequence lengths. + +
+ +
+ + ## GPTBigCodeConfig [[autodoc]] GPTBigCodeConfig From 28ddca3948a5b07bf915bf90c95bc9d64ad9f2af Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 3 Oct 2023 12:20:59 +0000 Subject: [PATCH 05/11] oops --- src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index c49ffcd4c0660d..6858d815e9daf6 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -352,6 +352,8 @@ def forward( attn_dropout = self.dropout if self.training else 0.0 + softmax_dtype = torch.float32 if self.attention_softmax_in_fp32 else query.dtype + upcast = query.dtype != softmax_dtype softmax_scale = self.layer_idx + 1 if self.scale_attention_softmax_in_fp32 and upcast else 1 softmax_scale = softmax_scale**-1 if self.scale_attn_weights: From 542c2759d378507e953b7bc2ea77c3de4571e0ac Mon Sep 17 00:00:00 2001 From: Susnato Dhar Date: Fri, 13 Oct 2023 11:37:37 +0530 Subject: [PATCH 06/11] Update docs/source/en/perf_infer_gpu_one.md Last Nit Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- docs/source/en/perf_infer_gpu_one.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index ec1127ff4305f6..509e0d4060cdbc 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -33,7 +33,7 @@ We natively support Flash Attention 2 for the following models: - Llama - Falcon -- GPTBigCode (Starcoder) +- [GPTBigCode (Starcoder)](model_doc/gpt_bigcode#) You can request to add Flash Attention 2 support for more models by opening an issue on GitHub, and even open a Pull Request to integrate the changes. The supported models can be used for inference and training, including training with padding tokens - *which is currently not supported for `BetterTransformer` API below.* From f43ec5a49aad58232fe5d50ae278edc028d5ef24 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 18 Oct 2023 20:13:54 +0200 Subject: [PATCH 07/11] fix --- .../models/gpt_bigcode/modeling_gpt_bigcode.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 946c66175fef6e..68373a09827682 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -364,14 +364,17 @@ def forward( # cast them back in float16 just to be sure everything works as expected. input_dtype = query.dtype if input_dtype == torch.float32: + # Handle the case where the model is quantized + target_dtype = getattr(self.config, "_pre_quantization_dtype", self.query_key_value.weight.dtype) + logger.warning_once( - "The input hidden states seems to be silently casted in float32, this might be related to" - " the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - " float16." + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." ) - query = query.to(torch.float16) - key = key.to(torch.float16) - value = value.to(torch.float16) + query = query.to(target_dtype) + key = key.to(target_dtype) + value = value.to(target_dtype) attn_output = self._flash_attention_forward( query, key, value, padding_mask, query_length, dropout=attn_dropout, softmax_scale=softmax_scale From b2aa0d9bb62be4b88bd9c487331780e061d65be2 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 18 Oct 2023 20:16:12 +0200 Subject: [PATCH 08/11] oops --- src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 68373a09827682..519e8441b2fb57 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -33,13 +33,13 @@ add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_available, + is_flash_attn_2_available, logging, ) from .configuration_gpt_bigcode import GPTBigCodeConfig -if is_flash_attn_available(): +if is_flash_attn_2_available(): from flash_attn import flash_attn_func, flash_attn_varlen_func from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa From 22a64cbb1ebb69c118bb4214d94e4dc95f0df687 Mon Sep 17 00:00:00 2001 From: Susnato Dhar Date: Mon, 30 Oct 2023 17:00:39 +0530 Subject: [PATCH 09/11] remove padding_mask --- .../gpt_bigcode/modeling_gpt_bigcode.py | 136 ++++++++++-------- 1 file changed, 75 insertions(+), 61 deletions(-) diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 0fa8c3b97b49da..7047ea96b70493 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -13,6 +13,7 @@ # limitations under the License. """PyTorch GPTBigCode model.""" import math +import warnings from typing import List, Optional, Tuple, Union import torch @@ -86,9 +87,9 @@ def masked_softmax(x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor # Copied from transformers.models.llama.modeling_llama._get_unpad_data -def _get_unpad_data(padding_mask): - seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32) - indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten() +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() max_seqlen_in_batch = seqlens_in_batch.max().item() cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) return ( @@ -101,8 +102,9 @@ def _get_unpad_data(padding_mask): class GPTBigCodeAttention(nn.Module): def __init__(self, config, is_cross_attention=False, layer_idx=None): super().__init__() - self.mask_value = None + self.config = config + self.mask_value = None self.multi_query = config.multi_query self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads @@ -110,6 +112,8 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): self.kv_heads = 1 if self.multi_query else self.num_heads self.kv_dim = self.kv_heads * self.head_dim self.split_size = self.embed_dim + self.is_causal = True + if self.head_dim * self.num_heads != self.embed_dim: raise ValueError( f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" @@ -231,13 +235,17 @@ def forward( encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, use_cache: Optional[bool] = False, - padding_mask: Optional[torch.LongTensor] = None, - encoder_padding_mask: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = False, + **kwargs, ) -> Union[ Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]], ]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + if encoder_hidden_states is not None: if not hasattr(self, "q_attn") or not self.is_cross_attention: raise ValueError( @@ -300,13 +308,20 @@ def forward( encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, use_cache: Optional[bool] = False, - padding_mask: Optional[torch.LongTensor] = None, - encoder_padding_mask: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = False, + **kwargs, ) -> Union[ Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]], ]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + # overwrite attention_mask with padding_mask + attention_mask = kwargs.pop("padding_mask") + if encoder_hidden_states is not None: if not hasattr(self, "q_attn") or not self.is_cross_attention: raise ValueError( @@ -316,7 +331,7 @@ def forward( query = self.q_attn(hidden_states) key_value = self.c_attn(encoder_hidden_states) - padding_mask = encoder_padding_mask + attention_mask = encoder_attention_mask elif self.multi_query: query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2) else: @@ -365,7 +380,7 @@ def forward( input_dtype = query.dtype if input_dtype == torch.float32: # Handle the case where the model is quantized - target_dtype = getattr(self.config, "_pre_quantization_dtype", self.query_key_value.weight.dtype) + target_dtype = getattr(self.config, "_pre_quantization_dtype", self.c_attn.weight.dtype) logger.warning_once( f"The input hidden states seems to be silently casted in float32, this might be related to" @@ -377,7 +392,7 @@ def forward( value = value.to(target_dtype) attn_output = self._flash_attention_forward( - query, key, value, padding_mask, query_length, dropout=attn_dropout, softmax_scale=softmax_scale + query, key, value, attention_mask, query_length, dropout=attn_dropout, softmax_scale=softmax_scale ) attn_weights_reshaped = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim) @@ -399,7 +414,7 @@ def forward( # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward def _flash_attention_forward( - self, query_states, key_states, value_states, padding_mask, query_length, dropout=0.0, softmax_scale=None + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None ): """ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token @@ -412,7 +427,7 @@ def _flash_attention_forward( Input key states to be passed to Flash Attention API value_states (`torch.Tensor`): Input value states to be passed to Flash Attention API - padding_mask (`torch.Tensor`): + attention_mask (`torch.Tensor`): The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the position of padding tokens and 1 for the position of non-padding tokens. dropout (`int`, *optional*): @@ -421,10 +436,10 @@ def _flash_attention_forward( The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) """ # Contains at least one padding token in the sequence - if padding_mask is not None: + if attention_mask is not None: batch_size = query_states.shape[0] query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( - query_states, key_states, value_states, padding_mask, query_length + query_states, key_states, value_states, attention_mask, query_length ) cu_seqlens_q, cu_seqlens_k = cu_seq_lens @@ -440,27 +455,31 @@ def _flash_attention_forward( max_seqlen_k=max_seqlen_in_batch_k, dropout_p=dropout, softmax_scale=softmax_scale, - causal=True, + causal=self.is_causal, ) attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) else: attn_output = flash_attn_func( - query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=True + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal ) return attn_output - def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_length): - indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask) - batch_size, kv_seq_len, kv_num_heads, head_dim = key_layer.shape - query_num_heads = query_layer.shape[2] + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape - key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, kv_num_heads, head_dim), indices_k) - value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, kv_num_heads, head_dim), indices_k) + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) if query_length == kv_seq_len: query_layer = index_first_axis( - query_layer.reshape(batch_size * kv_seq_len, query_num_heads, head_dim), indices_k + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k ) cu_seqlens_q = cu_seqlens_k max_seqlen_in_batch_q = max_seqlen_in_batch_k @@ -474,8 +493,8 @@ def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_l query_layer = query_layer.squeeze(1) else: # The -q_len: slice assumes left padding. - padding_mask = padding_mask[:, -query_length:] - query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, padding_mask) + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) return ( query_layer, @@ -540,8 +559,6 @@ def forward( encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, use_cache: Optional[bool] = False, - padding_mask: Optional[torch.LongTensor] = None, - encoder_padding_mask: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = False, ) -> Union[ Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor] @@ -555,8 +572,6 @@ def forward( head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, - padding_mask=padding_mask, - encoder_padding_mask=encoder_padding_mask, ) attn_output = attn_outputs[0] # output_attn: a, present, (attentions) outputs = attn_outputs[1:] @@ -579,8 +594,6 @@ def forward( encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, output_attentions=output_attentions, - padding_mask=padding_mask, - encoder_padding_mask=encoder_padding_mask, ) attn_output = cross_attn_outputs[0] # residual connection @@ -819,13 +832,6 @@ def forward( else: past_length = past_key_values[0].size(-2) - padding_mask = None - if attention_mask is not None and 0 in attention_mask: - padding_mask = attention_mask - encoder_padding_mask = None - if encoder_attention_mask is not None and 0 in encoder_attention_mask: - encoder_padding_mask = encoder_attention_mask - if attention_mask is not None and len(attention_mask.shape) == 2 and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 @@ -841,28 +847,38 @@ def forward( key_length = past_length + query_length self_attention_mask = self.bias[None, key_length - query_length : key_length, :key_length] - if attention_mask is not None: - self_attention_mask = self_attention_mask * attention_mask.view(batch_size, 1, -1).to( - dtype=torch.bool, device=self_attention_mask.device + if getattr(self.config, "_flash_attn_2_enabled", False): + # 2d mask is passed through the layers + attention_mask = attention_mask.bool() if (attention_mask is not None and 0 in attention_mask) else None + encoder_attention_mask = ( + encoder_attention_mask.bool() + if (encoder_attention_mask is not None and 0 in encoder_attention_mask) + else None ) - - # MQA models: (batch_size, query_length, n_heads, key_length) - # MHA models: (batch_size, n_heads, query_length, key_length) - attention_mask = self_attention_mask.unsqueeze(2 if self.multi_query else 1) - - # If a 2D or 3D attention mask is provided for the cross-attention - # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - if ( - self.config.add_cross_attention - and encoder_hidden_states is not None - and encoder_attention_mask is not None - ): - if encoder_attention_mask.dim() == 2: - encoder_attention_mask.unsqueeze(1) - assert encoder_attention_mask.dim() == 3 - encoder_attention_mask = encoder_attention_mask.bool().unsqueeze(2 if self.multi_query else 1) else: - encoder_attention_mask = None + # 4d mask is passed through the layers + if attention_mask is not None: + self_attention_mask = self_attention_mask * attention_mask.view(batch_size, 1, -1).to( + dtype=torch.bool, device=self_attention_mask.device + ) + + # MQA models: (batch_size, query_length, n_heads, key_length) + # MHA models: (batch_size, n_heads, query_length, key_length) + attention_mask = self_attention_mask.unsqueeze(2 if self.multi_query else 1) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if ( + self.config.add_cross_attention + and encoder_hidden_states is not None + and encoder_attention_mask is not None + ): + if encoder_attention_mask.dim() == 2: + encoder_attention_mask.unsqueeze(1) + assert encoder_attention_mask.dim() == 3 + encoder_attention_mask = encoder_attention_mask.bool().unsqueeze(2 if self.multi_query else 1) + else: + encoder_attention_mask = None # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head @@ -913,8 +929,6 @@ def forward( encoder_attention_mask=encoder_attention_mask, use_cache=use_cache, output_attentions=output_attentions, - padding_mask=padding_mask, - encoder_padding_mask=encoder_padding_mask, ) hidden_states = outputs[0] From ba0de1667d1ab551aa7607600c22c7a9b842ff65 Mon Sep 17 00:00:00 2001 From: Susnato Dhar Date: Mon, 30 Oct 2023 17:41:06 +0530 Subject: [PATCH 10/11] change getattr->hasattr logic --- src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 7047ea96b70493..49aa478c1c92eb 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -380,7 +380,10 @@ def forward( input_dtype = query.dtype if input_dtype == torch.float32: # Handle the case where the model is quantized - target_dtype = getattr(self.config, "_pre_quantization_dtype", self.c_attn.weight.dtype) + if hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.c_attn.weight.dtype logger.warning_once( f"The input hidden states seems to be silently casted in float32, this might be related to" From d577b4f9dead9438bee5c43f9f4e84a79d8f682a Mon Sep 17 00:00:00 2001 From: Susnato Dhar Date: Tue, 31 Oct 2023 03:03:21 +0530 Subject: [PATCH 11/11] changed .md file --- docs/source/en/model_doc/gpt_bigcode.md | 22 +++++++++---------- .../gpt_bigcode/modeling_gpt_bigcode.py | 5 ++--- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/docs/source/en/model_doc/gpt_bigcode.md b/docs/source/en/model_doc/gpt_bigcode.md index 7d0cf487e531ba..8cc77a825de75c 100644 --- a/docs/source/en/model_doc/gpt_bigcode.md +++ b/docs/source/en/model_doc/gpt_bigcode.md @@ -55,21 +55,21 @@ Make also sure that you have a hardware that is compatible with Flash-Attention To load and run a model using Flash Attention 2, refer to the snippet below: ```python -import torch -from transformers import AutoModelForCausalLM, AutoTokenizer -device = "cuda" # the device to load the model onto +>>> import torch +>>> from transformers import AutoModelForCausalLM, AutoTokenizer +>>> device = "cuda" # the device to load the model onto -model = AutoModelForCausalLM.from_pretrained("bigcode/starcoder", torch_dtype=torch.float16, use_flash_attention_2=True) -tokenizer = AutoTokenizer.from_pretrained("bigcode/starcoder") +>>> model = AutoModelForCausalLM.from_pretrained("bigcode/gpt_bigcode-santacoder", torch_dtype=torch.float16, use_flash_attention_2=True) +>>> tokenizer = AutoTokenizer.from_pretrained("bigcode/gpt_bigcode-santacoder") -prompt = "def hello_world():" +>>> prompt = "def hello_world():" -model_inputs = tokenizer([prompt], return_tensors="pt").to(device) -model.to(device) +>>> model_inputs = tokenizer([prompt], return_tensors="pt").to(device) +>>> model.to(device) -generated_ids = model.generate(**model_inputs, max_new_tokens=100, do_sample=True) -tokenizer.batch_decode(generated_ids)[0] -"The expected outupt" +>>> generated_ids = model.generate(**model_inputs, max_new_tokens=30, do_sample=False) +>>> tokenizer.batch_decode(generated_ids)[0] +'def hello_world():\n print("hello world")\n\nif __name__ == "__main__":\n print("hello world")\n<|endoftext|>' ``` ### Expected speedups diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 49aa478c1c92eb..fcbbfca5cedac7 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -13,7 +13,6 @@ # limitations under the License. """PyTorch GPTBigCode model.""" import math -import warnings from typing import List, Optional, Tuple, Union import torch @@ -242,7 +241,7 @@ def forward( Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]], ]: if "padding_mask" in kwargs: - warnings.warn( + logger.warning_once( "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" ) @@ -315,7 +314,7 @@ def forward( Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]], ]: if "padding_mask" in kwargs: - warnings.warn( + logger.warning_once( "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" )