From 2b2fd52b014ca90874922fbaaeeffa19272ad04c Mon Sep 17 00:00:00 2001 From: joecummings Date: Sun, 17 Dec 2023 14:11:23 -0800 Subject: [PATCH 01/17] Add s2_attn to hijack flash code --- .../monkeypatch/llama_attn_hijack_flash.py | 122 +++++++++++++++++- 1 file changed, 121 insertions(+), 1 deletion(-) diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index f380c3f2ae..0638cd9b61 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -70,11 +70,16 @@ def replace_llama_attn_with_flash_attn( packed: Optional[bool] = False, cross_entropy: Optional[bool] = False, rms_norm: Optional[bool] = False, + use_shifted_sparse_attn: Optional[bool] = False, ): transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access _prepare_decoder_attention_mask ) - transformers.models.llama.modeling_llama.LlamaAttention.forward = flashattn_forward + if use_shifted_sparse_attn: + transformers.models.llama.modeling_llama.LlamaAttention.forward = flashattn_forward_with_s2attn + else: + transformers.models.llama.modeling_llama.LlamaAttention.forward = flashattn_forward + if packed: transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer transformers.models.llama.modeling_llama.LlamaModel.forward = ( @@ -213,6 +218,121 @@ def _prepare_decoder_attention_mask( return attention_mask +group_size_ratio = 1/4 +def flashattn_forward_with_s2attn( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + padding_mask: Optional[torch.LongTensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel + + From: https://github.com/dvlab-research/LongLoRA/blob/main/llama_attn_replace.py + + attention_mask: [bsz, q_len] + + `cu_seqlens` will be ignored if provided + `max_seqlen` will be ignored if provided + """ + if output_attentions: + warnings.warn( + "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead." + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = ( + self.q_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + key_states = ( + self.k_proj(hidden_states) + .view(bsz, q_len, self.num_key_value_heads, self.head_dim) + .transpose(1, 2) + ) + value_states = ( + self.v_proj(hidden_states) + .view(bsz, q_len, self.num_key_value_heads, self.head_dim) + .transpose(1, 2) + ) + # [bsz, q_len, nh, hd] + # [bsz, nh, q_len, hd] + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + ) + + # Past Key value support + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + # Flash attention codes from + # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py + + # transform the data into the format required by flash attention + qkv = torch.stack( + [query_states, key_states, value_states], dim=2 + ) # [bsz, nh, 3, q_len, hd] + qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd] + + # We have disabled _prepare_decoder_attention_mask in LlamaModel + # the attention_mask should be the same as the key_padding_mask + + key_padding_mask = attention_mask.repeat(2, 1) + nheads = qkv.shape[-2] + # shift + + group_size = int(q_len * group_size_ratio) + if q_len % group_size > 0: + raise ValueError("q_len %d should be divisible by group size %d." % (q_len, group_size)) + + qkv = qkv.reshape(bsz, q_len, 3, 2, self.num_heads // 2, self.head_dim).permute(0, 3, 1, 2, 4, 5).reshape(bsz * 2, + q_len, 3, + self.num_heads // 2, + self.head_dim) + x = rearrange(qkv, "b s three h d -> b s (three h d)") + x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask) + cu_q_len_tmp = torch.arange(0, max_s, group_size, device=key_padding_mask.device, dtype=cu_q_lens.dtype) + cu_q_len_tmp = torch.stack([cu_q_len_tmp, cu_q_len_tmp + group_size // 2]).repeat(bsz, 1) + cu_q_lens[:-1].unsqueeze(-1) + cu_q_lens = torch.cat([cu_q_len_tmp, cu_q_lens[1:].unsqueeze(-1)], dim=-1).view(-1) + + x_unpad = rearrange( + x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads // 2 + ) + output_unpad = flash_attn_varlen_qkvpacked_func( + x_unpad, cu_q_lens, group_size, 0.0, softmax_scale=None, causal=True + ) + output = rearrange( + pad_input( + rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz * 2, q_len + ), + "b s (h d) -> b s h d", + h=nheads // 2, + ) + output = output.reshape(bsz, 2, q_len, nheads // 2, self.head_dim).transpose(1, 2).reshape(bsz, q_len, nheads, + self.head_dim) + return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, past_key_value + + def flashattn_forward( self, hidden_states: torch.Tensor, From 1450af904614a7b6b7aa8150e9b9abfe8a7c755d Mon Sep 17 00:00:00 2001 From: joecummings Date: Sun, 17 Dec 2023 14:11:46 -0800 Subject: [PATCH 02/17] Refactor code to account for s2_attn --- src/axolotl/utils/models.py | 72 ++++++++++++++++++++++++++++--------- 1 file changed, 56 insertions(+), 16 deletions(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 66f6e16acf..7e30145961 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -256,31 +256,68 @@ def load_model( replace_stablelm_attn_with_flash_attn(cfg.base_model) - if cfg.is_llama_derived_model and cfg.flash_attention and cfg.sample_packing: - if cfg.device not in ["mps", "cpu"] and not inference: + if cfg.sample_packing and cfg.s2_attention: + raise ValueError("Received `sample_packing=true` and `s2_attention=true`; however, \ + shifted-sparse attention does not current support sample packing.") + + # Modify all llama derived models in one block + if cfg.is_llama_derived_model: + if cfg.flash_attention: from axolotl.monkeypatch.llama_attn_hijack_flash import ( - replace_llama_attn_with_flash_attn, + replace_llama_attn_with_flash_attn, + ) + if cfg.sample_packing: + if cfg.device not in ["mps", "cpu"] and not inference: + LOG.info("patching with flash attention for sample packing") + replace_llama_attn_with_flash_attn( + packed=True, + cross_entropy=cfg.flash_attn_cross_entropy, + rms_norm=cfg.flash_attn_rms_norm, + ) + elif cfg.s2_attention: + LOG.info("patching w/ flash-enabled, shifted-sparse attention") + replace_llama_attn_with_flash_attn( + packed=False, + cross_entropy=cfg.flash_attn_cross_entropy, + rms_norm=cfg.flash_attn_rms_norm, + use_shifted_sparse_attn=True + ) + elif cfg.xformers_attention: + from axolotl.monkeypatch.llama_attn_hijack_xformers import ( + hijack_llama_attention, ) + LOG.info("patching with xformers attention") + hijack_llama_attention() + elif cfg.sdp_attention: + from axolotl.monkeypatch.llama_attn_hijack_sdp import hijack_llama_sdp_attention + LOG.info("patching with sdp attention") + hijack_llama_sdp_attention() + elif cfg.landmark_attention: + from axolotl.monkeypatch.llama_landmark_attn import ( + MEM_TOKEN, + patch_llama_with_landmark_attn, + ) + + LOG.info("patching with landmark attention") + patch_llama_with_landmark_attn() + + # Note: This might overwrite previous additional_special_tokens + tokenizer.add_special_tokens({"additional_special_tokens": [MEM_TOKEN]}) + elif cfg.s2_attention: + raise NotImplementedError("Shifted-sparse attention not currently implemented without flash attention.") - LOG.info("patching with flash attention for sample packing") - replace_llama_attn_with_flash_attn( - packed=cfg.sample_packing, - cross_entropy=cfg.flash_attn_cross_entropy, - rms_norm=cfg.flash_attn_rms_norm, + if cfg.xpos_rope: + from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import ( + replace_llama_rope_with_xpos_rope, ) - elif cfg.is_llama_derived_model and cfg.xformers_attention: - from axolotl.monkeypatch.llama_attn_hijack_xformers import ( - hijack_llama_attention, - ) - LOG.info("patching with xformers attention") - hijack_llama_attention() - elif cfg.is_llama_derived_model and cfg.sdp_attention: - from axolotl.monkeypatch.llama_attn_hijack_sdp import hijack_llama_sdp_attention + LOG.info("patching with xpos rope") + replace_llama_rope_with_xpos_rope() LOG.info("patching with sdp attention") hijack_llama_sdp_attention() + # Modify mistral derived models if cfg.is_mistral_derived_model and cfg.flash_attention and cfg.sample_packing: from axolotl.monkeypatch.mistral_attn_hijack_flash import ( replace_mistral_attn_with_flash_attn, @@ -387,9 +424,12 @@ def load_model( model_kwargs["quantization_config"] = BitsAndBytesConfig( **bnb_config, ) + # sample packing uses custom FA2 patch if cfg.flash_attention: if not cfg.sample_packing: + if cfg.s2_attention: + pass if ( cfg.is_llama_derived_model or cfg.is_falcon_derived_model From 60126bf93f32ebc7c8b8a06f2f1d8b17291dde01 Mon Sep 17 00:00:00 2001 From: joecummings Date: Mon, 18 Dec 2023 19:50:21 -0800 Subject: [PATCH 03/17] Add test for models utils --- tests/utils/test_models.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 tests/utils/test_models.py diff --git a/tests/utils/test_models.py b/tests/utils/test_models.py new file mode 100644 index 0000000000..0515a2e57a --- /dev/null +++ b/tests/utils/test_models.py @@ -0,0 +1,33 @@ +"""Module for testing models utils file.""" + + +import unittest +from unittest.mock import patch + +import pytest + +from axolotl.utils.dict import DictDefault +from axolotl.utils.models import load_model + + +class ModelsUtilsTest(unittest.TestCase): + """Testing module for models utils.""" + + def test_cfg_throws_error_with_s2_attention_and_sample_packing(self): + cfg = DictDefault( + { + "s2_attention": True, + "sample_packing": True, + "base_model": "", + "model_type": "LlamaForCausalLM", + } + ) + + # Mock out call to HF hub + with patch( + "axolotl.utils.models.load_model_config" + ) as mocked_load_model_config: + mocked_load_model_config.return_value = {} + with pytest.raises(ValueError): + # Should error before hitting tokenizer, so we pass in an empty str + load_model(cfg, tokenizer="") From 5e66cb44cbeb3857a64e438192f18696387d3bfc Mon Sep 17 00:00:00 2001 From: joecummings Date: Sun, 17 Dec 2023 14:14:49 -0800 Subject: [PATCH 04/17] Add ``s2_attention`` option to llama configs --- examples/code-llama/13b/lora.yml | 1 + examples/code-llama/34b/lora.yml | 1 + examples/code-llama/7b/lora.yml | 1 + examples/llama-2/lora.yml | 1 + examples/openllama-3b/lora.yml | 1 + 5 files changed, 5 insertions(+) diff --git a/examples/code-llama/13b/lora.yml b/examples/code-llama/13b/lora.yml index fc43ad14e2..9c0df0afae 100644 --- a/examples/code-llama/13b/lora.yml +++ b/examples/code-llama/13b/lora.yml @@ -52,6 +52,7 @@ local_rank: logging_steps: 1 xformers_attention: flash_attention: true +s2_attention: warmup_steps: 10 evals_per_epoch: 4 diff --git a/examples/code-llama/34b/lora.yml b/examples/code-llama/34b/lora.yml index c2f1d5ce15..a137d54e70 100644 --- a/examples/code-llama/34b/lora.yml +++ b/examples/code-llama/34b/lora.yml @@ -52,6 +52,7 @@ local_rank: logging_steps: 1 xformers_attention: flash_attention: true +s2_attention: warmup_steps: 10 evals_per_epoch: 4 diff --git a/examples/code-llama/7b/lora.yml b/examples/code-llama/7b/lora.yml index 630c8da6fc..217b2a635d 100644 --- a/examples/code-llama/7b/lora.yml +++ b/examples/code-llama/7b/lora.yml @@ -52,6 +52,7 @@ local_rank: logging_steps: 1 xformers_attention: flash_attention: true +s2_attention: warmup_steps: 10 evals_per_epoch: 4 diff --git a/examples/llama-2/lora.yml b/examples/llama-2/lora.yml index afb7dcd06f..abe1c1de0a 100644 --- a/examples/llama-2/lora.yml +++ b/examples/llama-2/lora.yml @@ -52,6 +52,7 @@ local_rank: logging_steps: 1 xformers_attention: flash_attention: true +s2_attention: warmup_steps: 10 evals_per_epoch: 4 diff --git a/examples/openllama-3b/lora.yml b/examples/openllama-3b/lora.yml index 4fbb634f92..b83b2db4e4 100644 --- a/examples/openllama-3b/lora.yml +++ b/examples/openllama-3b/lora.yml @@ -52,6 +52,7 @@ logging_steps: 1 xformers_attention: flash_attention: true gptq_groupsize: +s2_attention: gptq_model_v1: warmup_steps: 20 evals_per_epoch: 4 From 0f57f30ff915a29953b0c8fe9cc0fc823ef154e0 Mon Sep 17 00:00:00 2001 From: joecummings Date: Sun, 17 Dec 2023 14:15:31 -0800 Subject: [PATCH 05/17] Add ``s2_attention`` option to README config --- README.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/README.md b/README.md index 2bd3d52814..8812f24e69 100644 --- a/README.md +++ b/README.md @@ -828,6 +828,14 @@ flash_attn_fuse_mlp: # Whether to fuse part of the MLP into a single operation # Whether to use scaled-dot-product attention # https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html sdp_attention: +# Landmark attention (only llama) +landmark_attention: +# Shifted-sparse attention (only llama) +s2_attention: + +# xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py +# LLaMA only +xpos_rope: # Resume from a specific checkpoint dir resume_from_checkpoint: From cb335d8f3aa4782443f3a72f9ac5ce54cad0db27 Mon Sep 17 00:00:00 2001 From: joecummings Date: Mon, 18 Dec 2023 19:47:35 -0800 Subject: [PATCH 06/17] Format code to appease linter --- .../monkeypatch/llama_attn_hijack_flash.py | 42 +++++++++++++------ src/axolotl/utils/models.py | 23 ++++++---- 2 files changed, 45 insertions(+), 20 deletions(-) diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index 0638cd9b61..41a241463a 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -76,9 +76,13 @@ def replace_llama_attn_with_flash_attn( _prepare_decoder_attention_mask ) if use_shifted_sparse_attn: - transformers.models.llama.modeling_llama.LlamaAttention.forward = flashattn_forward_with_s2attn + transformers.models.llama.modeling_llama.LlamaAttention.forward = ( + flashattn_forward_with_s2attn + ) else: - transformers.models.llama.modeling_llama.LlamaAttention.forward = flashattn_forward + transformers.models.llama.modeling_llama.LlamaAttention.forward = ( + flashattn_forward + ) if packed: transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer @@ -218,7 +222,9 @@ def _prepare_decoder_attention_mask( return attention_mask -group_size_ratio = 1/4 +GROUP_SIZE_RATIO = 1 / 4 + + def flashattn_forward_with_s2attn( self, hidden_states: torch.Tensor, @@ -301,18 +307,25 @@ def flashattn_forward_with_s2attn( nheads = qkv.shape[-2] # shift - group_size = int(q_len * group_size_ratio) + group_size = int(q_len * GROUP_SIZE_RATIO) if q_len % group_size > 0: - raise ValueError("q_len %d should be divisible by group size %d." % (q_len, group_size)) + raise ValueError( + f"q_len {q_len} should be divisible by group size {group_size}." + ) - qkv = qkv.reshape(bsz, q_len, 3, 2, self.num_heads // 2, self.head_dim).permute(0, 3, 1, 2, 4, 5).reshape(bsz * 2, - q_len, 3, - self.num_heads // 2, - self.head_dim) + qkv = ( + qkv.reshape(bsz, q_len, 3, 2, self.num_heads // 2, self.head_dim) + .permute(0, 3, 1, 2, 4, 5) + .reshape(bsz * 2, q_len, 3, self.num_heads // 2, self.head_dim) + ) x = rearrange(qkv, "b s three h d -> b s (three h d)") x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask) - cu_q_len_tmp = torch.arange(0, max_s, group_size, device=key_padding_mask.device, dtype=cu_q_lens.dtype) - cu_q_len_tmp = torch.stack([cu_q_len_tmp, cu_q_len_tmp + group_size // 2]).repeat(bsz, 1) + cu_q_lens[:-1].unsqueeze(-1) + cu_q_len_tmp = torch.arange( + 0, max_s, group_size, device=key_padding_mask.device, dtype=cu_q_lens.dtype + ) + cu_q_len_tmp = torch.stack([cu_q_len_tmp, cu_q_len_tmp + group_size // 2]).repeat( + bsz, 1 + ) + cu_q_lens[:-1].unsqueeze(-1) cu_q_lens = torch.cat([cu_q_len_tmp, cu_q_lens[1:].unsqueeze(-1)], dim=-1).view(-1) x_unpad = rearrange( @@ -328,8 +341,11 @@ def flashattn_forward_with_s2attn( "b s (h d) -> b s h d", h=nheads // 2, ) - output = output.reshape(bsz, 2, q_len, nheads // 2, self.head_dim).transpose(1, 2).reshape(bsz, q_len, nheads, - self.head_dim) + output = ( + output.reshape(bsz, 2, q_len, nheads // 2, self.head_dim) + .transpose(1, 2) + .reshape(bsz, q_len, nheads, self.head_dim) + ) return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, past_key_value diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 7e30145961..1ca8a31603 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -257,15 +257,18 @@ def load_model( replace_stablelm_attn_with_flash_attn(cfg.base_model) if cfg.sample_packing and cfg.s2_attention: - raise ValueError("Received `sample_packing=true` and `s2_attention=true`; however, \ - shifted-sparse attention does not current support sample packing.") + raise ValueError( + "Received `sample_packing=true` and `s2_attention=true`; however, \ + shifted-sparse attention does not current support sample packing." + ) # Modify all llama derived models in one block if cfg.is_llama_derived_model: if cfg.flash_attention: from axolotl.monkeypatch.llama_attn_hijack_flash import ( - replace_llama_attn_with_flash_attn, - ) + replace_llama_attn_with_flash_attn, + ) + if cfg.sample_packing: if cfg.device not in ["mps", "cpu"] and not inference: LOG.info("patching with flash attention for sample packing") @@ -280,16 +283,20 @@ def load_model( packed=False, cross_entropy=cfg.flash_attn_cross_entropy, rms_norm=cfg.flash_attn_rms_norm, - use_shifted_sparse_attn=True + use_shifted_sparse_attn=True, ) elif cfg.xformers_attention: from axolotl.monkeypatch.llama_attn_hijack_xformers import ( hijack_llama_attention, ) + LOG.info("patching with xformers attention") hijack_llama_attention() elif cfg.sdp_attention: - from axolotl.monkeypatch.llama_attn_hijack_sdp import hijack_llama_sdp_attention + from axolotl.monkeypatch.llama_attn_hijack_sdp import ( + hijack_llama_sdp_attention, + ) + LOG.info("patching with sdp attention") hijack_llama_sdp_attention() elif cfg.landmark_attention: @@ -304,7 +311,9 @@ def load_model( # Note: This might overwrite previous additional_special_tokens tokenizer.add_special_tokens({"additional_special_tokens": [MEM_TOKEN]}) elif cfg.s2_attention: - raise NotImplementedError("Shifted-sparse attention not currently implemented without flash attention.") + raise NotImplementedError( + "Shifted-sparse attention not currently implemented without flash attention." + ) if cfg.xpos_rope: from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import ( From dcb569452b269c42eeaf11a5926a7b7731e2e5d3 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 22 Dec 2023 11:18:41 -0500 Subject: [PATCH 07/17] chore: lint --- src/axolotl/monkeypatch/llama_attn_hijack_flash.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index 41a241463a..4bded9b027 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -233,9 +233,9 @@ def flashattn_forward_with_s2attn( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, - padding_mask: Optional[torch.LongTensor] = None, - cu_seqlens: Optional[torch.Tensor] = None, - max_seqlen: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument + cu_seqlens: Optional[torch.Tensor] = None, # pylint: disable=unused-argument + max_seqlen: Optional[torch.Tensor] = None, # pylint: disable=unused-argument ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel @@ -270,6 +270,7 @@ def flashattn_forward_with_s2attn( ) # [bsz, q_len, nh, hd] # [bsz, nh, q_len, hd] + # pylint: disable=duplicate-code kv_seq_len = key_states.shape[-2] if past_key_value is not None: @@ -318,7 +319,9 @@ def flashattn_forward_with_s2attn( .permute(0, 3, 1, 2, 4, 5) .reshape(bsz * 2, q_len, 3, self.num_heads // 2, self.head_dim) ) - x = rearrange(qkv, "b s three h d -> b s (three h d)") + x = rearrange( # pylint: disable=invalid-name + qkv, "b s three h d -> b s (three h d)" + ) x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask) cu_q_len_tmp = torch.arange( 0, max_s, group_size, device=key_padding_mask.device, dtype=cu_q_lens.dtype From 4135039475dea6d7c380ddfa8eca96aa4d0db064 Mon Sep 17 00:00:00 2001 From: joecummings Date: Thu, 11 Jan 2024 07:00:00 -0800 Subject: [PATCH 08/17] Remove xpos and llama-landmark [bad merge] --- README.md | 7 ------- src/axolotl/utils/models.py | 19 ------------------- 2 files changed, 26 deletions(-) diff --git a/README.md b/README.md index 8812f24e69..1f472cba3d 100644 --- a/README.md +++ b/README.md @@ -828,15 +828,8 @@ flash_attn_fuse_mlp: # Whether to fuse part of the MLP into a single operation # Whether to use scaled-dot-product attention # https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html sdp_attention: -# Landmark attention (only llama) -landmark_attention: # Shifted-sparse attention (only llama) s2_attention: - -# xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py -# LLaMA only -xpos_rope: - # Resume from a specific checkpoint dir resume_from_checkpoint: # If resume_from_checkpoint isn't set and you simply want it to start where it left off. diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 1ca8a31603..b305fc3d8c 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -299,30 +299,11 @@ def load_model( LOG.info("patching with sdp attention") hijack_llama_sdp_attention() - elif cfg.landmark_attention: - from axolotl.monkeypatch.llama_landmark_attn import ( - MEM_TOKEN, - patch_llama_with_landmark_attn, - ) - - LOG.info("patching with landmark attention") - patch_llama_with_landmark_attn() - - # Note: This might overwrite previous additional_special_tokens - tokenizer.add_special_tokens({"additional_special_tokens": [MEM_TOKEN]}) elif cfg.s2_attention: raise NotImplementedError( "Shifted-sparse attention not currently implemented without flash attention." ) - if cfg.xpos_rope: - from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import ( - replace_llama_rope_with_xpos_rope, - ) - - LOG.info("patching with xpos rope") - replace_llama_rope_with_xpos_rope() - LOG.info("patching with sdp attention") hijack_llama_sdp_attention() From 34c62fbb60f58e51c70878f9a81b5fa3fc10d1d4 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 14 Jan 2024 14:22:57 -0500 Subject: [PATCH 09/17] add e2e smoke tests for shifted sparse attention --- tests/e2e/patched/test_llama_s2_attention.py | 119 +++++++++++++++++++ 1 file changed, 119 insertions(+) create mode 100644 tests/e2e/patched/test_llama_s2_attention.py diff --git a/tests/e2e/patched/test_llama_s2_attention.py b/tests/e2e/patched/test_llama_s2_attention.py new file mode 100644 index 0000000000..947f8b67eb --- /dev/null +++ b/tests/e2e/patched/test_llama_s2_attention.py @@ -0,0 +1,119 @@ +""" +E2E tests for llama w/ S2 attn +""" + +import logging +import os +import unittest +from pathlib import Path + +from transformers.utils import is_torch_bf16_gpu_available + +from axolotl.cli import load_datasets +from axolotl.common.cli import TrainerCliArgs +from axolotl.train import train +from axolotl.utils.config import normalize_config +from axolotl.utils.dict import DictDefault + +from ..utils import with_temp_dir + +LOG = logging.getLogger("axolotl.tests.e2e") +os.environ["WANDB_DISABLED"] = "true" + + +class TestLlamaShiftedSparseAttention(unittest.TestCase): + """ + Test case for Llama models using S2 Attn + """ + + @with_temp_dir + def test_lora_s2_attn(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "JackFram/llama-68m", + "tokenizer_type": "LlamaTokenizer", + "sequence_len": 1024, + "sample_packing": False, + "flash_attention": True, + "s2_attention": True, + "load_in_8bit": True, + "adapter": "lora", + "lora_r": 32, + "lora_alpha": 64, + "lora_dropout": 0.05, + "lora_target_linear": True, + "val_set_size": 0.1, + "special_tokens": {}, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 2, + "micro_batch_size": 8, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch", + "lr_scheduler": "cosine", + "max_steps": 10, + "save_steps": 5, + "eval_steps": 5, + } + ) + if is_torch_bf16_gpu_available(): + cfg.bf16 = True + else: + cfg.fp16 = True + + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "adapter_model.bin").exists() + + @with_temp_dir + def test_fft_s2_attn(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "JackFram/llama-68m", + "tokenizer_type": "LlamaTokenizer", + "sequence_len": 1024, + "sample_packing": False, + "flash_attention": True, + "s2_attention": True, + "val_set_size": 0.1, + "special_tokens": {}, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 2, + "micro_batch_size": 8, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch", + "lr_scheduler": "cosine", + "max_steps": 10, + "save_steps": 5, + "eval_steps": 5, + } + ) + if is_torch_bf16_gpu_available(): + cfg.bf16 = True + else: + cfg.fp16 = True + + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "adapter_model.bin").exists() From cb899d9cde7ffb9c9d977042d97b21ef865f2410 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 14 Jan 2024 16:30:02 -0500 Subject: [PATCH 10/17] remove stray patch from merge --- src/axolotl/utils/models.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index b305fc3d8c..cd3804324a 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -304,9 +304,6 @@ def load_model( "Shifted-sparse attention not currently implemented without flash attention." ) - LOG.info("patching with sdp attention") - hijack_llama_sdp_attention() - # Modify mistral derived models if cfg.is_mistral_derived_model and cfg.flash_attention and cfg.sample_packing: from axolotl.monkeypatch.mistral_attn_hijack_flash import ( From 02d1e9078ba531caab1452f3bd51a9ff841f6bd7 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 14 Jan 2024 16:45:10 -0500 Subject: [PATCH 11/17] update yml with link to paper for s2_attention/longlora --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 1f472cba3d..606b47b918 100644 --- a/README.md +++ b/README.md @@ -828,7 +828,7 @@ flash_attn_fuse_mlp: # Whether to fuse part of the MLP into a single operation # Whether to use scaled-dot-product attention # https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html sdp_attention: -# Shifted-sparse attention (only llama) +# Shifted-sparse attention (only llama) - https://arxiv.org/pdf/2309.12307.pdf s2_attention: # Resume from a specific checkpoint dir resume_from_checkpoint: From e8ba3fe9ae69d9ccbb5e1a8c6c418eca825688ef Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 14 Jan 2024 16:49:48 -0500 Subject: [PATCH 12/17] fix assertion check for full fine tune --- tests/e2e/patched/test_llama_s2_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/e2e/patched/test_llama_s2_attention.py b/tests/e2e/patched/test_llama_s2_attention.py index 947f8b67eb..a4ecc86529 100644 --- a/tests/e2e/patched/test_llama_s2_attention.py +++ b/tests/e2e/patched/test_llama_s2_attention.py @@ -116,4 +116,4 @@ def test_fft_s2_attn(self, temp_dir): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + assert (Path(temp_dir) / "pytorch_model.bin").exists() From 9292665f325217b22dc1aaba46a4b63d17ab7819 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 17 Jan 2024 11:18:49 -0500 Subject: [PATCH 13/17] increase sequence len for tests and PR feedback updates --- src/axolotl/utils/models.py | 2 +- tests/e2e/patched/test_llama_s2_attention.py | 20 ++++++-------------- 2 files changed, 7 insertions(+), 15 deletions(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index cd3804324a..c5685b274a 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -259,7 +259,7 @@ def load_model( if cfg.sample_packing and cfg.s2_attention: raise ValueError( "Received `sample_packing=true` and `s2_attention=true`; however, \ - shifted-sparse attention does not current support sample packing." + shifted-sparse attention does not currently support sample packing." ) # Modify all llama derived models in one block diff --git a/tests/e2e/patched/test_llama_s2_attention.py b/tests/e2e/patched/test_llama_s2_attention.py index a4ecc86529..ab4952a479 100644 --- a/tests/e2e/patched/test_llama_s2_attention.py +++ b/tests/e2e/patched/test_llama_s2_attention.py @@ -7,8 +7,6 @@ import unittest from pathlib import Path -from transformers.utils import is_torch_bf16_gpu_available - from axolotl.cli import load_datasets from axolotl.common.cli import TrainerCliArgs from axolotl.train import train @@ -33,21 +31,21 @@ def test_lora_s2_attn(self, temp_dir): { "base_model": "JackFram/llama-68m", "tokenizer_type": "LlamaTokenizer", - "sequence_len": 1024, + "sequence_len": 65536, "sample_packing": False, "flash_attention": True, "s2_attention": True, "load_in_8bit": True, "adapter": "lora", "lora_r": 32, - "lora_alpha": 64, + "lora_alpha": 16, "lora_dropout": 0.05, "lora_target_linear": True, "val_set_size": 0.1, "special_tokens": {}, "datasets": [ { - "path": "mhenrichsen/alpaca_2k_test", + "path": "Yukang/LongAlpaca-12k", "type": "alpaca", }, ], @@ -61,12 +59,9 @@ def test_lora_s2_attn(self, temp_dir): "max_steps": 10, "save_steps": 5, "eval_steps": 5, + "bf16": "auto", } ) - if is_torch_bf16_gpu_available(): - cfg.bf16 = True - else: - cfg.fp16 = True normalize_config(cfg) cli_args = TrainerCliArgs() @@ -82,7 +77,7 @@ def test_fft_s2_attn(self, temp_dir): { "base_model": "JackFram/llama-68m", "tokenizer_type": "LlamaTokenizer", - "sequence_len": 1024, + "sequence_len": 65536, "sample_packing": False, "flash_attention": True, "s2_attention": True, @@ -104,12 +99,9 @@ def test_fft_s2_attn(self, temp_dir): "max_steps": 10, "save_steps": 5, "eval_steps": 5, + "bf16": "auto", } ) - if is_torch_bf16_gpu_available(): - cfg.bf16 = True - else: - cfg.fp16 = True normalize_config(cfg) cli_args = TrainerCliArgs() From 5e0890ddfb10c39dbccad8179210285c1b03e92b Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 17 Jan 2024 18:30:48 -0500 Subject: [PATCH 14/17] reduce context len to 16k for tests --- tests/e2e/patched/test_llama_s2_attention.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/e2e/patched/test_llama_s2_attention.py b/tests/e2e/patched/test_llama_s2_attention.py index ab4952a479..9f3c8af1a5 100644 --- a/tests/e2e/patched/test_llama_s2_attention.py +++ b/tests/e2e/patched/test_llama_s2_attention.py @@ -31,7 +31,7 @@ def test_lora_s2_attn(self, temp_dir): { "base_model": "JackFram/llama-68m", "tokenizer_type": "LlamaTokenizer", - "sequence_len": 65536, + "sequence_len": 32768, "sample_packing": False, "flash_attention": True, "s2_attention": True, @@ -77,7 +77,7 @@ def test_fft_s2_attn(self, temp_dir): { "base_model": "JackFram/llama-68m", "tokenizer_type": "LlamaTokenizer", - "sequence_len": 65536, + "sequence_len": 32768, "sample_packing": False, "flash_attention": True, "s2_attention": True, @@ -85,7 +85,7 @@ def test_fft_s2_attn(self, temp_dir): "special_tokens": {}, "datasets": [ { - "path": "mhenrichsen/alpaca_2k_test", + "path": "Yukang/LongAlpaca-12k", "type": "alpaca", }, ], From bee8f8c71774372c876b74b722f693e162b67bb2 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 17 Jan 2024 20:39:11 -0500 Subject: [PATCH 15/17] reduce context len to 16k for tests --- tests/e2e/patched/test_llama_s2_attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/e2e/patched/test_llama_s2_attention.py b/tests/e2e/patched/test_llama_s2_attention.py index 9f3c8af1a5..0746b79f56 100644 --- a/tests/e2e/patched/test_llama_s2_attention.py +++ b/tests/e2e/patched/test_llama_s2_attention.py @@ -31,7 +31,7 @@ def test_lora_s2_attn(self, temp_dir): { "base_model": "JackFram/llama-68m", "tokenizer_type": "LlamaTokenizer", - "sequence_len": 32768, + "sequence_len": 16384, "sample_packing": False, "flash_attention": True, "s2_attention": True, @@ -77,7 +77,7 @@ def test_fft_s2_attn(self, temp_dir): { "base_model": "JackFram/llama-68m", "tokenizer_type": "LlamaTokenizer", - "sequence_len": 32768, + "sequence_len": 16384, "sample_packing": False, "flash_attention": True, "s2_attention": True, From e6e67dd174efa50c4493e1f4b6cc1dcbdab2e2c3 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 18 Jan 2024 04:41:32 -0500 Subject: [PATCH 16/17] reduce batch size for larger context len and udpate test to check message --- tests/e2e/patched/test_llama_s2_attention.py | 4 ++-- tests/utils/test_models.py | 6 +++++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/tests/e2e/patched/test_llama_s2_attention.py b/tests/e2e/patched/test_llama_s2_attention.py index 0746b79f56..f1d37eb3ca 100644 --- a/tests/e2e/patched/test_llama_s2_attention.py +++ b/tests/e2e/patched/test_llama_s2_attention.py @@ -50,7 +50,7 @@ def test_lora_s2_attn(self, temp_dir): }, ], "num_epochs": 2, - "micro_batch_size": 8, + "micro_batch_size": 1, "gradient_accumulation_steps": 1, "output_dir": temp_dir, "learning_rate": 0.00001, @@ -90,7 +90,7 @@ def test_fft_s2_attn(self, temp_dir): }, ], "num_epochs": 2, - "micro_batch_size": 8, + "micro_batch_size": 1, "gradient_accumulation_steps": 1, "output_dir": temp_dir, "learning_rate": 0.00001, diff --git a/tests/utils/test_models.py b/tests/utils/test_models.py index 0515a2e57a..bfa82ccd1a 100644 --- a/tests/utils/test_models.py +++ b/tests/utils/test_models.py @@ -28,6 +28,10 @@ def test_cfg_throws_error_with_s2_attention_and_sample_packing(self): "axolotl.utils.models.load_model_config" ) as mocked_load_model_config: mocked_load_model_config.return_value = {} - with pytest.raises(ValueError): + with pytest.raises(ValueError) as exc: # Should error before hitting tokenizer, so we pass in an empty str load_model(cfg, tokenizer="") + assert ( + "shifted-sparse attention does not currently support sample packing" + in exc.value.message + ) From 4f09ef4bdf9b09fe537df1cd9b6e2f2eb9ff64ef Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 18 Jan 2024 05:09:49 -0500 Subject: [PATCH 17/17] fix test for message --- tests/utils/test_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utils/test_models.py b/tests/utils/test_models.py index bfa82ccd1a..e06bb6c250 100644 --- a/tests/utils/test_models.py +++ b/tests/utils/test_models.py @@ -33,5 +33,5 @@ def test_cfg_throws_error_with_s2_attention_and_sample_packing(self): load_model(cfg, tokenizer="") assert ( "shifted-sparse attention does not currently support sample packing" - in exc.value.message + in str(exc.value) )