From 19a600a8b859c40cf4c3749b0b8a3db17b82a0c0 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Tue, 26 Sep 2023 22:53:28 +0900 Subject: [PATCH] Feat: Add support for upstream FA2 (#626) * Feat: Add support for upstream FA2 * chore: add is_falcon_derived_model: true to examples * chore: add config to readme for documentation * feat: add extra model types * fix: remove old falcon flash patch * chore: pin transformers and accelerate --- README.md | 4 + examples/falcon/config-7b-lora.yml | 1 + examples/falcon/config-7b-qlora.yml | 1 + examples/falcon/config-7b.yml | 1 + requirements.txt | 4 +- .../monkeypatch/falcon_attn_hijack_flash.py | 101 ------------------ src/axolotl/utils/config.py | 16 +++ src/axolotl/utils/models.py | 20 ++-- 8 files changed, 31 insertions(+), 117 deletions(-) delete mode 100644 src/axolotl/monkeypatch/falcon_attn_hijack_flash.py diff --git a/README.md b/README.md index 25044a2361..c2d4d8ef8c 100644 --- a/README.md +++ b/README.md @@ -408,6 +408,10 @@ tokenizer_legacy: # this is reported to improve training speed on some models resize_token_embeddings_to_32x: +# used to identify if the model is falcon/llama based +is_falcon_derived_model: +is_llama_derived_model: + # whether you are training a 4-bit GPTQ quantized model gptq: true gptq_groupsize: 128 # group size diff --git a/examples/falcon/config-7b-lora.yml b/examples/falcon/config-7b-lora.yml index a5cbdc00df..738068a474 100644 --- a/examples/falcon/config-7b-lora.yml +++ b/examples/falcon/config-7b-lora.yml @@ -3,6 +3,7 @@ base_model_config: tiiuae/falcon-7b trust_remote_code: true model_type: AutoModelForCausalLM tokenizer_type: AutoTokenizer +is_falcon_derived_model: true load_in_8bit: true load_in_4bit: false gptq: false diff --git a/examples/falcon/config-7b-qlora.yml b/examples/falcon/config-7b-qlora.yml index 72b09b87d9..554081fcba 100644 --- a/examples/falcon/config-7b-qlora.yml +++ b/examples/falcon/config-7b-qlora.yml @@ -6,6 +6,7 @@ base_model_config: tiiuae/falcon-7b trust_remote_code: true model_type: AutoModelForCausalLM tokenizer_type: AutoTokenizer +is_falcon_derived_model: true load_in_8bit: false # enable 4bit for QLoRA load_in_4bit: true diff --git a/examples/falcon/config-7b.yml b/examples/falcon/config-7b.yml index 46f4caff15..25e67a53b1 100644 --- a/examples/falcon/config-7b.yml +++ b/examples/falcon/config-7b.yml @@ -3,6 +3,7 @@ base_model_config: tiiuae/falcon-7b trust_remote_code: true model_type: AutoModelForCausalLM tokenizer_type: AutoTokenizer +is_falcon_derived_model: true load_in_8bit: false load_in_4bit: false gptq: false diff --git a/requirements.txt b/requirements.txt index 5aba20b161..33a2157d96 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,9 +4,9 @@ torch==2.0.1 auto-gptq packaging peft @ git+https://github.com/huggingface/peft.git -transformers @ git+https://github.com/huggingface/transformers.git +transformers @ git+https://github.com/huggingface/transformers.git@0ac3875011d32dc85e0e83970507e3afe8f0febb bitsandbytes>=0.41.1 -accelerate @ git+https://github.com/huggingface/accelerate +accelerate @ git+https://github.com/huggingface/accelerate@80da9cfb09bb3cc9f1b385cb55d6b90d025a5fd9 deepspeed addict evaluate diff --git a/src/axolotl/monkeypatch/falcon_attn_hijack_flash.py b/src/axolotl/monkeypatch/falcon_attn_hijack_flash.py deleted file mode 100644 index ed11c55234..0000000000 --- a/src/axolotl/monkeypatch/falcon_attn_hijack_flash.py +++ /dev/null @@ -1,101 +0,0 @@ -""" -Flash Attention monkey patch for Falcon - -copied from https://github.com/pacman100/DHS-LLM-Workshop/blob/main/chat_assistant/training/falcon_flash_attn_monkey_patch.py -""" - -from typing import Optional, Tuple - -import torch -import transformers -from flash_attn import flash_attn_func - - -def forward( - self, - hidden_states: torch.Tensor, - alibi: Optional[torch.Tensor], - attention_mask: torch.Tensor, # pylint: disable=unused-argument - layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - head_mask: Optional[torch.Tensor] = None, # pylint: disable=unused-argument - use_cache: bool = False, - output_attentions: bool = False, # pylint: disable=unused-argument -): - fused_qkv = self.query_key_value( - hidden_states - ) # [batch_size, seq_length, 3 x hidden_size] - num_kv_heads = ( - self.num_heads if self.new_decoder_architecture else self.num_kv_heads - ) - # 3 x [batch_size, seq_length, num_heads, head_dim] - ( - query_layer, - key_layer, - value_layer, - ) = self._split_heads( # pylint: disable=protected-access - fused_qkv - ) - - batch_size, query_length, _, _ = query_layer.shape - - query_layer = query_layer.transpose(1, 2).reshape( - batch_size * self.num_heads, query_length, self.head_dim - ) - key_layer = key_layer.transpose(1, 2).reshape( - batch_size * num_kv_heads, - query_length, - self.head_dim, - ) - value_layer = value_layer.transpose(1, 2).reshape( - batch_size * num_kv_heads, query_length, self.head_dim - ) - - past_kv_length = 0 if layer_past is None else layer_past[0].shape[1] - query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length) - - if layer_past is not None: - past_key, past_value = layer_past - # concatenate along seq_length dimension: - # - key: [batch_size * self.num_heads, kv_length, head_dim] - # - value: [batch_size * self.num_heads, kv_length, head_dim] - key_layer = torch.cat((past_key, key_layer), dim=1) - value_layer = torch.cat((past_value, value_layer), dim=1) - - # unused - # _, kv_length, _ = key_layer.shape - if use_cache: - present = (key_layer, value_layer) - else: - present = None - # unused - # attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float("-1e9")).to(query_layer.dtype) - query_layer_ = ( - query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim) - .transpose(1, 2) - .to(torch.bfloat16) - ) - key_layer_ = ( - key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim) - .transpose(1, 2) - .to(torch.bfloat16) - ) - value_layer_ = ( - value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim) - .transpose(1, 2) - .to(torch.bfloat16) - ) - - if alibi is not None: - raise ValueError("`alibi` is not supported when `use_flash_attn` is True") - - # below output will have shape (batch_size, seqlen, nheads, headdim) - attn_output = flash_attn_func(query_layer_, key_layer_, value_layer_, causal=True) - attn_output = attn_output.reshape( - batch_size, query_length, self.num_heads * self.head_dim - ) - output_tensor = self.dense(attn_output) - return output_tensor, present - - -def replace_falcon_attn_with_flash_attn(): - transformers.models.falcon.modeling_falcon.FalconAttention.forward = forward diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index cb18380cb7..3a574cefcc 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -86,6 +86,22 @@ def normalize_config(cfg): or (cfg.model_type and "llama" in cfg.model_type.lower()) ) + # figure out if the model is falcon + cfg.is_falcon_derived_model = ( + ( + hasattr(model_config, "model_type") + and model_config.model_type + in [ + "falcon", + "RefinedWebModel", + "RefinedWeb", + ] + ) + or cfg.is_falcon_derived_model + or "falcon" in cfg.base_model + or (cfg.model_type and "rwforcausallm" in cfg.model_type.lower()) + ) + log_gpu_memory_usage(LOG, "baseline", cfg.device) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 543a0e1a13..361440931f 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -114,25 +114,13 @@ def load_model( replace_btlm_attn_with_flash_attn(cfg.base_model) - if hasattr(model_config, "model_type") and model_config.model_type in [ - "falcon", - "RefinedWebModel", - "RefinedWeb", - ]: - if cfg.flash_attention: - from axolotl.monkeypatch.falcon_attn_hijack_flash import ( - replace_falcon_attn_with_flash_attn, - ) - - replace_falcon_attn_with_flash_attn() - - if cfg.is_llama_derived_model and cfg.flash_attention: + if cfg.is_llama_derived_model and cfg.flash_attention and cfg.sample_packing: if cfg.device not in ["mps", "cpu"] and not inference: from axolotl.monkeypatch.llama_attn_hijack_flash import ( replace_llama_attn_with_flash_attn, ) - LOG.info("patching with flash attention") + LOG.info("patching with flash attention for sample packing") replace_llama_attn_with_flash_attn(packed=cfg.sample_packing) elif cfg.is_llama_derived_model and cfg.xformers_attention: from axolotl.monkeypatch.llama_attn_hijack_xformers import ( @@ -213,6 +201,10 @@ def load_model( bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", ) + # sample packing uses custom FA2 patch + if cfg.flash_attention and not cfg.sample_packing: + if cfg.is_llama_derived_model or cfg.is_falcon_derived_model: + model_kwargs["use_flash_attention_2"] = True try: if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq: from transformers import LlamaForCausalLM