diff --git a/README.md b/README.md index f5f848a446..0bac1c4d5d 100644 --- a/README.md +++ b/README.md @@ -834,7 +834,8 @@ flash_attn_fuse_mlp: # Whether to fuse part of the MLP into a single operation # Whether to use scaled-dot-product attention # https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html sdp_attention: - +# Shifted-sparse attention (only llama) - https://arxiv.org/pdf/2309.12307.pdf +s2_attention: # Resume from a specific checkpoint dir resume_from_checkpoint: # If resume_from_checkpoint isn't set and you simply want it to start where it left off. diff --git a/examples/code-llama/13b/lora.yml b/examples/code-llama/13b/lora.yml index fc43ad14e2..9c0df0afae 100644 --- a/examples/code-llama/13b/lora.yml +++ b/examples/code-llama/13b/lora.yml @@ -52,6 +52,7 @@ local_rank: logging_steps: 1 xformers_attention: flash_attention: true +s2_attention: warmup_steps: 10 evals_per_epoch: 4 diff --git a/examples/code-llama/34b/lora.yml b/examples/code-llama/34b/lora.yml index c2f1d5ce15..a137d54e70 100644 --- a/examples/code-llama/34b/lora.yml +++ b/examples/code-llama/34b/lora.yml @@ -52,6 +52,7 @@ local_rank: logging_steps: 1 xformers_attention: flash_attention: true +s2_attention: warmup_steps: 10 evals_per_epoch: 4 diff --git a/examples/code-llama/7b/lora.yml b/examples/code-llama/7b/lora.yml index 630c8da6fc..217b2a635d 100644 --- a/examples/code-llama/7b/lora.yml +++ b/examples/code-llama/7b/lora.yml @@ -52,6 +52,7 @@ local_rank: logging_steps: 1 xformers_attention: flash_attention: true +s2_attention: warmup_steps: 10 evals_per_epoch: 4 diff --git a/examples/llama-2/lora.yml b/examples/llama-2/lora.yml index afb7dcd06f..abe1c1de0a 100644 --- a/examples/llama-2/lora.yml +++ b/examples/llama-2/lora.yml @@ -52,6 +52,7 @@ local_rank: logging_steps: 1 xformers_attention: flash_attention: true +s2_attention: warmup_steps: 10 evals_per_epoch: 4 diff --git a/examples/openllama-3b/lora.yml b/examples/openllama-3b/lora.yml index 4fbb634f92..b83b2db4e4 100644 --- a/examples/openllama-3b/lora.yml +++ b/examples/openllama-3b/lora.yml @@ -52,6 +52,7 @@ logging_steps: 1 xformers_attention: flash_attention: true gptq_groupsize: +s2_attention: gptq_model_v1: warmup_steps: 20 evals_per_epoch: 4 diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index f380c3f2ae..4bded9b027 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -70,11 +70,20 @@ def replace_llama_attn_with_flash_attn( packed: Optional[bool] = False, cross_entropy: Optional[bool] = False, rms_norm: Optional[bool] = False, + use_shifted_sparse_attn: Optional[bool] = False, ): transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access _prepare_decoder_attention_mask ) - transformers.models.llama.modeling_llama.LlamaAttention.forward = flashattn_forward + if use_shifted_sparse_attn: + transformers.models.llama.modeling_llama.LlamaAttention.forward = ( + flashattn_forward_with_s2attn + ) + else: + transformers.models.llama.modeling_llama.LlamaAttention.forward = ( + flashattn_forward + ) + if packed: transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer transformers.models.llama.modeling_llama.LlamaModel.forward = ( @@ -213,6 +222,136 @@ def _prepare_decoder_attention_mask( return attention_mask +GROUP_SIZE_RATIO = 1 / 4 + + +def flashattn_forward_with_s2attn( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument + cu_seqlens: Optional[torch.Tensor] = None, # pylint: disable=unused-argument + max_seqlen: Optional[torch.Tensor] = None, # pylint: disable=unused-argument +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel + + From: https://github.com/dvlab-research/LongLoRA/blob/main/llama_attn_replace.py + + attention_mask: [bsz, q_len] + + `cu_seqlens` will be ignored if provided + `max_seqlen` will be ignored if provided + """ + if output_attentions: + warnings.warn( + "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead." + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = ( + self.q_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + key_states = ( + self.k_proj(hidden_states) + .view(bsz, q_len, self.num_key_value_heads, self.head_dim) + .transpose(1, 2) + ) + value_states = ( + self.v_proj(hidden_states) + .view(bsz, q_len, self.num_key_value_heads, self.head_dim) + .transpose(1, 2) + ) + # [bsz, q_len, nh, hd] + # [bsz, nh, q_len, hd] + # pylint: disable=duplicate-code + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + ) + + # Past Key value support + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + # Flash attention codes from + # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py + + # transform the data into the format required by flash attention + qkv = torch.stack( + [query_states, key_states, value_states], dim=2 + ) # [bsz, nh, 3, q_len, hd] + qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd] + + # We have disabled _prepare_decoder_attention_mask in LlamaModel + # the attention_mask should be the same as the key_padding_mask + + key_padding_mask = attention_mask.repeat(2, 1) + nheads = qkv.shape[-2] + # shift + + group_size = int(q_len * GROUP_SIZE_RATIO) + if q_len % group_size > 0: + raise ValueError( + f"q_len {q_len} should be divisible by group size {group_size}." + ) + + qkv = ( + qkv.reshape(bsz, q_len, 3, 2, self.num_heads // 2, self.head_dim) + .permute(0, 3, 1, 2, 4, 5) + .reshape(bsz * 2, q_len, 3, self.num_heads // 2, self.head_dim) + ) + x = rearrange( # pylint: disable=invalid-name + qkv, "b s three h d -> b s (three h d)" + ) + x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask) + cu_q_len_tmp = torch.arange( + 0, max_s, group_size, device=key_padding_mask.device, dtype=cu_q_lens.dtype + ) + cu_q_len_tmp = torch.stack([cu_q_len_tmp, cu_q_len_tmp + group_size // 2]).repeat( + bsz, 1 + ) + cu_q_lens[:-1].unsqueeze(-1) + cu_q_lens = torch.cat([cu_q_len_tmp, cu_q_lens[1:].unsqueeze(-1)], dim=-1).view(-1) + + x_unpad = rearrange( + x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads // 2 + ) + output_unpad = flash_attn_varlen_qkvpacked_func( + x_unpad, cu_q_lens, group_size, 0.0, softmax_scale=None, causal=True + ) + output = rearrange( + pad_input( + rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz * 2, q_len + ), + "b s (h d) -> b s h d", + h=nheads // 2, + ) + output = ( + output.reshape(bsz, 2, q_len, nheads // 2, self.head_dim) + .transpose(1, 2) + .reshape(bsz, q_len, nheads, self.head_dim) + ) + return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, past_key_value + + def flashattn_forward( self, hidden_states: torch.Tensor, diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 55721f8207..71cdcc69c5 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -256,31 +256,55 @@ def load_model( replace_stablelm_attn_with_flash_attn(cfg.base_model) - if cfg.is_llama_derived_model and cfg.flash_attention and cfg.sample_packing: - if cfg.device not in ["mps", "cpu"] and not inference: + if cfg.sample_packing and cfg.s2_attention: + raise ValueError( + "Received `sample_packing=true` and `s2_attention=true`; however, \ + shifted-sparse attention does not currently support sample packing." + ) + + # Modify all llama derived models in one block + if cfg.is_llama_derived_model: + if cfg.flash_attention: from axolotl.monkeypatch.llama_attn_hijack_flash import ( replace_llama_attn_with_flash_attn, ) - LOG.info("patching with flash attention for sample packing") - replace_llama_attn_with_flash_attn( - packed=cfg.sample_packing, - cross_entropy=cfg.flash_attn_cross_entropy, - rms_norm=cfg.flash_attn_rms_norm, + if cfg.sample_packing: + if cfg.device not in ["mps", "cpu"] and not inference: + LOG.info("patching with flash attention for sample packing") + replace_llama_attn_with_flash_attn( + packed=True, + cross_entropy=cfg.flash_attn_cross_entropy, + rms_norm=cfg.flash_attn_rms_norm, + ) + elif cfg.s2_attention: + LOG.info("patching w/ flash-enabled, shifted-sparse attention") + replace_llama_attn_with_flash_attn( + packed=False, + cross_entropy=cfg.flash_attn_cross_entropy, + rms_norm=cfg.flash_attn_rms_norm, + use_shifted_sparse_attn=True, + ) + elif cfg.xformers_attention: + from axolotl.monkeypatch.llama_attn_hijack_xformers import ( + hijack_llama_attention, ) - elif cfg.is_llama_derived_model and cfg.xformers_attention: - from axolotl.monkeypatch.llama_attn_hijack_xformers import ( - hijack_llama_attention, - ) - LOG.info("patching with xformers attention") - hijack_llama_attention() - elif cfg.is_llama_derived_model and cfg.sdp_attention: - from axolotl.monkeypatch.llama_attn_hijack_sdp import hijack_llama_sdp_attention + LOG.info("patching with xformers attention") + hijack_llama_attention() + elif cfg.sdp_attention: + from axolotl.monkeypatch.llama_attn_hijack_sdp import ( + hijack_llama_sdp_attention, + ) - LOG.info("patching with sdp attention") - hijack_llama_sdp_attention() + LOG.info("patching with sdp attention") + hijack_llama_sdp_attention() + elif cfg.s2_attention: + raise NotImplementedError( + "Shifted-sparse attention not currently implemented without flash attention." + ) + # Modify mistral derived models if cfg.is_mistral_derived_model and cfg.flash_attention and cfg.sample_packing: from axolotl.monkeypatch.mistral_attn_hijack_flash import ( replace_mistral_attn_with_flash_attn, @@ -387,9 +411,12 @@ def load_model( model_kwargs["quantization_config"] = BitsAndBytesConfig( **bnb_config, ) + # sample packing uses custom FA2 patch if cfg.flash_attention: if not cfg.sample_packing: + if cfg.s2_attention: + pass if ( cfg.is_llama_derived_model or cfg.is_falcon_derived_model diff --git a/tests/e2e/patched/test_llama_s2_attention.py b/tests/e2e/patched/test_llama_s2_attention.py new file mode 100644 index 0000000000..f1d37eb3ca --- /dev/null +++ b/tests/e2e/patched/test_llama_s2_attention.py @@ -0,0 +1,111 @@ +""" +E2E tests for llama w/ S2 attn +""" + +import logging +import os +import unittest +from pathlib import Path + +from axolotl.cli import load_datasets +from axolotl.common.cli import TrainerCliArgs +from axolotl.train import train +from axolotl.utils.config import normalize_config +from axolotl.utils.dict import DictDefault + +from ..utils import with_temp_dir + +LOG = logging.getLogger("axolotl.tests.e2e") +os.environ["WANDB_DISABLED"] = "true" + + +class TestLlamaShiftedSparseAttention(unittest.TestCase): + """ + Test case for Llama models using S2 Attn + """ + + @with_temp_dir + def test_lora_s2_attn(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "JackFram/llama-68m", + "tokenizer_type": "LlamaTokenizer", + "sequence_len": 16384, + "sample_packing": False, + "flash_attention": True, + "s2_attention": True, + "load_in_8bit": True, + "adapter": "lora", + "lora_r": 32, + "lora_alpha": 16, + "lora_dropout": 0.05, + "lora_target_linear": True, + "val_set_size": 0.1, + "special_tokens": {}, + "datasets": [ + { + "path": "Yukang/LongAlpaca-12k", + "type": "alpaca", + }, + ], + "num_epochs": 2, + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch", + "lr_scheduler": "cosine", + "max_steps": 10, + "save_steps": 5, + "eval_steps": 5, + "bf16": "auto", + } + ) + + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "adapter_model.bin").exists() + + @with_temp_dir + def test_fft_s2_attn(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "JackFram/llama-68m", + "tokenizer_type": "LlamaTokenizer", + "sequence_len": 16384, + "sample_packing": False, + "flash_attention": True, + "s2_attention": True, + "val_set_size": 0.1, + "special_tokens": {}, + "datasets": [ + { + "path": "Yukang/LongAlpaca-12k", + "type": "alpaca", + }, + ], + "num_epochs": 2, + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch", + "lr_scheduler": "cosine", + "max_steps": 10, + "save_steps": 5, + "eval_steps": 5, + "bf16": "auto", + } + ) + + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "pytorch_model.bin").exists() diff --git a/tests/utils/test_models.py b/tests/utils/test_models.py new file mode 100644 index 0000000000..e06bb6c250 --- /dev/null +++ b/tests/utils/test_models.py @@ -0,0 +1,37 @@ +"""Module for testing models utils file.""" + + +import unittest +from unittest.mock import patch + +import pytest + +from axolotl.utils.dict import DictDefault +from axolotl.utils.models import load_model + + +class ModelsUtilsTest(unittest.TestCase): + """Testing module for models utils.""" + + def test_cfg_throws_error_with_s2_attention_and_sample_packing(self): + cfg = DictDefault( + { + "s2_attention": True, + "sample_packing": True, + "base_model": "", + "model_type": "LlamaForCausalLM", + } + ) + + # Mock out call to HF hub + with patch( + "axolotl.utils.models.load_model_config" + ) as mocked_load_model_config: + mocked_load_model_config.return_value = {} + with pytest.raises(ValueError) as exc: + # Should error before hitting tokenizer, so we pass in an empty str + load_model(cfg, tokenizer="") + assert ( + "shifted-sparse attention does not currently support sample packing" + in str(exc.value) + )