From b6ab8aad62f5b9f75e3456a4f60fa0e2040cf34b Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 27 Sep 2023 18:41:00 -0400 Subject: [PATCH] Mistral flash attn packing (#646) * add mistral monkeypatch * add arg for decoder attention masl * fix lint for duplicate code * make sure to update transformers too * tweak install for e2e * move mistral patch to conditional --- .github/workflows/tests.yml | 5 +- requirements.txt | 2 +- .../monkeypatch/mistral_attn_hijack_flash.py | 401 ++++++++++++++++++ src/axolotl/utils/models.py | 8 + 4 files changed, 412 insertions(+), 4 deletions(-) create mode 100644 src/axolotl/monkeypatch/mistral_attn_hijack_flash.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 4da10a6c35..b4637fd67f 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -44,7 +44,7 @@ jobs: - name: Install dependencies run: | - pip3 install -e . + pip3 install -U -e . pip3 install -r requirements-tests.txt - name: Run tests @@ -69,8 +69,7 @@ jobs: - name: Install dependencies run: | - pip3 install -e . - pip3 install flash-attn + pip3 install -U -e .[flash-attn] pip3 install -r requirements-tests.txt - name: Run e2e tests diff --git a/requirements.txt b/requirements.txt index 7616d1fb0a..18659daec5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,7 @@ torch==2.0.1 auto-gptq packaging peft @ git+https://github.com/huggingface/peft.git -transformers @ git+https://github.com/huggingface/transformers.git@0ac3875011d32dc85e0e83970507e3afe8f0febb +transformers @ git+https://github.com/huggingface/transformers.git@78dd120 bitsandbytes>=0.41.1 accelerate @ git+https://github.com/huggingface/accelerate@80da9cfb09bb3cc9f1b385cb55d6b90d025a5fd9 deepspeed diff --git a/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py b/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py new file mode 100644 index 0000000000..f53d5d0071 --- /dev/null +++ b/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py @@ -0,0 +1,401 @@ +"""Flash attention monkey patch for mistral model""" +# pylint: disable=duplicate-code + +import logging +import math +from typing import List, Optional, Tuple, Union + +import torch +import transformers +from einops import rearrange +from torch import nn +from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.models.mistral.modeling_mistral import ( + MistralDecoderLayer as OriginalMistralDecoderLayer, +) +from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb, repeat_kv + +from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids + +try: + from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports + flash_attn_varlen_qkvpacked_func, + ) +except ImportError: + from flash_attn.flash_attn_interface import ( + flash_attn_unpadded_qkvpacked_func as flash_attn_varlen_qkvpacked_func, + ) + + +LOG = logging.getLogger("axolotl.monkeypatch.mistral") + + +def replace_mistral_attn_with_flash_attn( + packed: Optional[bool] = False, +): + transformers.models.mistral.modeling_mistral.MistralModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access + _prepare_decoder_attention_mask + ) + transformers.models.mistral.modeling_mistral.MistralAttention.forward = ( + flashattn_forward + ) + if packed: + transformers.models.mistral.modeling_mistral.MistralDecoderLayer = ( + MistralDecoderLayer + ) + transformers.models.mistral.modeling_mistral.MistralModel.forward = ( + mistral_model_forward + ) + + +# Disable the transformation of the attention mask in LlamaModel as the flash attention +# requires the attention mask to be the same as the key_padding_mask +def _prepare_decoder_attention_mask( + self, + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + sliding_window, +): # pylint: disable=unused-argument + # [bsz, seq_len] + return attention_mask + + +def flashattn_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view( + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + key_states = key_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + ) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1: + # special handling using sample packing + qkv = torch.stack( + [query_states, key_states, value_states], dim=2 + ) # [bsz, nh, 3, q_len, hd] + qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd] + qkv = rearrange(qkv, "b s ... -> (b s) ...") + + output = flash_attn_varlen_qkvpacked_func( + qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=True + ) + output = rearrange(output, "(b s) ... -> b s ...", b=bsz) + attn_output = output + if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + attn_output = rearrange(attn_output, "b s h d -> b s (h d)") + attn_weights = None + else: + attn_weights = torch.matmul( + query_states, key_states.transpose(2, 3) + ) / math.sqrt(self.head_dim) + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +def mistral_model_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, +) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" + ) + if input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError( + "You have to specify either decoder_input_ids or decoder_inputs_embeds" + ) + + seq_length_with_past = seq_length + past_key_values_length = 0 + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + cu_seqlens = None + max_seqlen = None + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids) + cu_seqlens = cu_seqlens.squeeze() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + # embed positions + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), + dtype=torch.bool, + device=inputs_embeds.device, + ) + attention_mask = ( + self._prepare_decoder_attention_mask( # pylint: disable=protected-access + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + sliding_window=self.config.sliding_window, + ) + ) + + hidden_states = inputs_embeds + + if self.gradient_checkpointing and self.training: + if use_cache: + transformers.logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + position_ids, + past_key_value, + output_attentions, + None, + cu_seqlens, + max_seqlen, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] + if v is not None + ) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class MistralDecoderLayer(OriginalMistralDecoderLayer): + """ + patched version of MistralDecoderLayer to pass through the precalculated cu_seqlens + """ + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[torch.Tensor] = None, + ) -> Tuple[ + torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cu_seqlens (`torch.Tensor`, *optional*) cumulative sequence len when packing + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 07cdc4d6ed..63e34293ee 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -150,6 +150,14 @@ def load_model( # Note: This might overwrite previous additional_special_tokens tokenizer.add_special_tokens({"additional_special_tokens": [MEM_TOKEN]}) + if cfg.is_mistral_derived_model and cfg.flash_attention: + from axolotl.monkeypatch.mistral_attn_hijack_flash import ( + replace_mistral_attn_with_flash_attn, + ) + + LOG.info("patching with flash attention") + replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing) + if cfg.is_llama_derived_model and cfg.xpos_rope: from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import ( replace_llama_rope_with_xpos_rope,