From a045db02146751548fec57a5d3f31382ce4e5959 Mon Sep 17 00:00:00 2001 From: Casper Date: Mon, 16 Oct 2023 21:13:46 +0200 Subject: [PATCH] Mistral: Sliding Window Attention with Flash Attention and Sample Packing (#732) * Implement Mistral FA + SWA + Sample Packing * Handle unbroadcastable tensor * chore: lint * Simplify _prepare_decoder_attention_mask * Uncomment window size * Upgrade flash-attn to minimum of 2.3.0 to support SWA * Add original condition to avoid error during inference * chore: lint * use torchscript to prevent oom * chore: pylint --------- Co-authored-by: Wing Lian --- setup.py | 2 +- .../monkeypatch/mistral_attn_hijack_flash.py | 109 +++++++++++++++++- 2 files changed, 105 insertions(+), 6 deletions(-) diff --git a/setup.py b/setup.py index ada5fcb289..e3ee54350b 100644 --- a/setup.py +++ b/setup.py @@ -46,7 +46,7 @@ def parse_requirements(): dependency_links=dependency_links, extras_require={ "flash-attn": [ - "flash-attn>=2.2.1", + "flash-attn>=2.3.0", ], "deepspeed": [ "deepspeed", diff --git a/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py b/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py index 21a6ee0842..26b511d06e 100644 --- a/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py @@ -14,6 +14,9 @@ flash_attn_varlen_qkvpacked_func, ) from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.models.mistral.modeling_mistral import ( + MistralAttention as OriginalMistralAttention, +) from transformers.models.mistral.modeling_mistral import ( MistralDecoderLayer as OriginalMistralDecoderLayer, ) @@ -42,6 +45,44 @@ def replace_mistral_attn_with_flash_attn( ) +@torch.jit.script +def _make_sliding_window_causal_mask( + bsz: int, + tgt_len: int, + dtype: torch.dtype, + device: torch.device, + past_key_values_length: int = 0, + sliding_window: int = 4096, +): + """ + Make causal mask used for sliding window attention + """ + tensor = torch.full( + (tgt_len, tgt_len), + fill_value=1, + device=device, + ) + mask = torch.tril(tensor, diagonal=0) + # make the mask banded to account for sliding window + # NOTE: HF implementation is wrong as of 14-10-2023 for torch.triu, needs +1 + mask = torch.triu(mask, diagonal=-sliding_window + 1) + mask = torch.log(mask).to(dtype) + + if past_key_values_length > 0: + mask = torch.cat( + [ + torch.zeros( + tgt_len, past_key_values_length, dtype=dtype, device=device + ), + mask, + ], + dim=-1, + ) + return mask[None, None, :, :].expand( + bsz, 1, tgt_len, tgt_len + past_key_values_length + ) + + # Disable the transformation of the attention mask in LlamaModel as the flash attention # requires the attention mask to be the same as the key_padding_mask def _prepare_decoder_attention_mask( @@ -53,11 +94,29 @@ def _prepare_decoder_attention_mask( sliding_window, ): # pylint: disable=unused-argument # [bsz, seq_len] + if attention_mask is None: + return attention_mask + + # NOTE: attention mask and sliding masks are only broadcastable in certain scenarios. + # Without attention_mask.shape[0] == 1, error will trigger after eval loss but only when wandb is enabled. + if input_shape[-1] > 1 and attention_mask.shape[0] == 1: + sliding_window_mask = _make_sliding_window_causal_mask( + bsz=input_shape[0], + tgt_len=input_shape[1], + dtype=inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + sliding_window=sliding_window, + ) + attention_mask = attention_mask + sliding_window_mask + else: + LOG.info("skipping sliding window mask, not broadcastable with attention mask") + return attention_mask def flashattn_forward( - self, + self: OriginalMistralAttention, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -91,10 +150,41 @@ def flashattn_forward( query_states, key_states, cos, sin, position_ids ) + use_sliding_windows = ( + hasattr(self.config, "sliding_window") is not None + and kv_seq_len > self.config.sliding_window + ) + + if use_sliding_windows: + window_size = (self.config.sliding_window, self.config.sliding_window) + else: + window_size = (-1, -1) + if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + # Activate slicing cache only if the config has a value `sliding_windows` attribute + if ( + hasattr(self.config, "sliding_window") + and kv_seq_len > self.config.sliding_window + ): + slicing_tokens = kv_seq_len - self.config.sliding_window + + past_key = past_key_value[0] + past_value = past_key_value[1] + + past_key = past_key[:, :, slicing_tokens:, :].contiguous() + past_value = past_value[:, :, slicing_tokens:, :].contiguous() + + if past_key.shape[-2] != self.config.sliding_window - 1: + raise ValueError( + f"past key much have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" + f" {past_key.shape}" + ) + + past_key_value = (past_key, past_value) if use_cache else None + + if past_key_value is not None: + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) past_key_value = (key_states, value_states) if use_cache else None @@ -120,7 +210,13 @@ def flashattn_forward( qkv = rearrange(qkv, "b s ... -> (b s) ...") output = flash_attn_varlen_qkvpacked_func( - qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=True + qkv, + cu_seqlens, + max_seqlen, + 0.0, + softmax_scale=None, + causal=True, + window_size=window_size, ) output = rearrange(output, "(b s) ... -> b s ...", b=bsz) elif query_states.shape == key_states.shape: @@ -146,6 +242,7 @@ def flashattn_forward( 0.0, softmax_scale=None, causal=is_causal, + window_size=window_size, ) output = output_pad_fn(output_unpad) else: @@ -157,6 +254,7 @@ def flashattn_forward( query_states, torch.stack([key_states, value_states], 2), causal=is_causal, + window_size=window_size, ) else: ( # pylint: disable=unbalanced-tuple-unpacking @@ -191,6 +289,7 @@ def flashattn_forward( 0.0, softmax_scale=None, causal=is_causal, + window_size=window_size, ) output = output_pad_fn(output_unpad)