Skip to content

Commit

Permalink
Mistral: Sliding Window Attention with Flash Attention and Sample Pac…
Browse files Browse the repository at this point in the history
…king (axolotl-ai-cloud#732)

* Implement Mistral FA + SWA + Sample Packing

* Handle unbroadcastable tensor

* chore: lint

* Simplify _prepare_decoder_attention_mask

* Uncomment window size

* Upgrade flash-attn to minimum of 2.3.0 to support SWA

* Add original condition to avoid error during inference

* chore: lint

* use torchscript to prevent oom

* chore: pylint

---------

Co-authored-by: Wing Lian <[email protected]>
  • Loading branch information
casper-hansen and winglian authored Oct 16, 2023
1 parent dc664ca commit fd8c5cf
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 6 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def parse_requirements():
dependency_links=dependency_links,
extras_require={
"flash-attn": [
"flash-attn>=2.2.1",
"flash-attn>=2.3.0",
],
"deepspeed": [
"deepspeed",
Expand Down
109 changes: 104 additions & 5 deletions src/axolotl/monkeypatch/mistral_attn_hijack_flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
flash_attn_varlen_qkvpacked_func,
)
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.mistral.modeling_mistral import (
MistralAttention as OriginalMistralAttention,
)
from transformers.models.mistral.modeling_mistral import (
MistralDecoderLayer as OriginalMistralDecoderLayer,
)
Expand Down Expand Up @@ -42,6 +45,44 @@ def replace_mistral_attn_with_flash_attn(
)


@torch.jit.script
def _make_sliding_window_causal_mask(
bsz: int,
tgt_len: int,
dtype: torch.dtype,
device: torch.device,
past_key_values_length: int = 0,
sliding_window: int = 4096,
):
"""
Make causal mask used for sliding window attention
"""
tensor = torch.full(
(tgt_len, tgt_len),
fill_value=1,
device=device,
)
mask = torch.tril(tensor, diagonal=0)
# make the mask banded to account for sliding window
# NOTE: HF implementation is wrong as of 14-10-2023 for torch.triu, needs +1
mask = torch.triu(mask, diagonal=-sliding_window + 1)
mask = torch.log(mask).to(dtype)

if past_key_values_length > 0:
mask = torch.cat(
[
torch.zeros(
tgt_len, past_key_values_length, dtype=dtype, device=device
),
mask,
],
dim=-1,
)
return mask[None, None, :, :].expand(
bsz, 1, tgt_len, tgt_len + past_key_values_length
)


# Disable the transformation of the attention mask in LlamaModel as the flash attention
# requires the attention mask to be the same as the key_padding_mask
def _prepare_decoder_attention_mask(
Expand All @@ -53,11 +94,29 @@ def _prepare_decoder_attention_mask(
sliding_window,
): # pylint: disable=unused-argument
# [bsz, seq_len]
if attention_mask is None:
return attention_mask

# NOTE: attention mask and sliding masks are only broadcastable in certain scenarios.
# Without attention_mask.shape[0] == 1, error will trigger after eval loss but only when wandb is enabled.
if input_shape[-1] > 1 and attention_mask.shape[0] == 1:
sliding_window_mask = _make_sliding_window_causal_mask(
bsz=input_shape[0],
tgt_len=input_shape[1],
dtype=inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
sliding_window=sliding_window,
)
attention_mask = attention_mask + sliding_window_mask
else:
LOG.info("skipping sliding window mask, not broadcastable with attention mask")

return attention_mask


def flashattn_forward(
self,
self: OriginalMistralAttention,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
Expand Down Expand Up @@ -91,10 +150,41 @@ def flashattn_forward(
query_states, key_states, cos, sin, position_ids
)

use_sliding_windows = (
hasattr(self.config, "sliding_window") is not None
and kv_seq_len > self.config.sliding_window
)

if use_sliding_windows:
window_size = (self.config.sliding_window, self.config.sliding_window)
else:
window_size = (-1, -1)

if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
# Activate slicing cache only if the config has a value `sliding_windows` attribute
if (
hasattr(self.config, "sliding_window")
and kv_seq_len > self.config.sliding_window
):
slicing_tokens = kv_seq_len - self.config.sliding_window

past_key = past_key_value[0]
past_value = past_key_value[1]

past_key = past_key[:, :, slicing_tokens:, :].contiguous()
past_value = past_value[:, :, slicing_tokens:, :].contiguous()

if past_key.shape[-2] != self.config.sliding_window - 1:
raise ValueError(
f"past key much have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
f" {past_key.shape}"
)

past_key_value = (past_key, past_value) if use_cache else None

if past_key_value is not None:
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)

past_key_value = (key_states, value_states) if use_cache else None

Expand All @@ -120,7 +210,13 @@ def flashattn_forward(
qkv = rearrange(qkv, "b s ... -> (b s) ...")

output = flash_attn_varlen_qkvpacked_func(
qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=True
qkv,
cu_seqlens,
max_seqlen,
0.0,
softmax_scale=None,
causal=True,
window_size=window_size,
)
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
elif query_states.shape == key_states.shape:
Expand All @@ -146,6 +242,7 @@ def flashattn_forward(
0.0,
softmax_scale=None,
causal=is_causal,
window_size=window_size,
)
output = output_pad_fn(output_unpad)
else:
Expand All @@ -157,6 +254,7 @@ def flashattn_forward(
query_states,
torch.stack([key_states, value_states], 2),
causal=is_causal,
window_size=window_size,
)
else:
( # pylint: disable=unbalanced-tuple-unpacking
Expand Down Expand Up @@ -191,6 +289,7 @@ def flashattn_forward(
0.0,
softmax_scale=None,
causal=is_causal,
window_size=window_size,
)
output = output_pad_fn(output_unpad)

Expand Down

0 comments on commit fd8c5cf

Please sign in to comment.