From 6b9b229356f14571c4d810f215d20e7f2d245db2 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 17 Sep 2023 13:49:18 -0400 Subject: [PATCH] btlm and falcon monkey patches for flash attn (#566) --- examples/cerebras/btlm-ft.yml | 90 ++++++++++++++++ .../monkeypatch/btlm_attn_hijack_flash.py | 64 +++++++++++ .../monkeypatch/falcon_attn_hijack_flash.py | 101 ++++++++++++++++++ src/axolotl/utils/models.py | 24 +++++ 4 files changed, 279 insertions(+) create mode 100644 examples/cerebras/btlm-ft.yml create mode 100644 src/axolotl/monkeypatch/btlm_attn_hijack_flash.py create mode 100644 src/axolotl/monkeypatch/falcon_attn_hijack_flash.py diff --git a/examples/cerebras/btlm-ft.yml b/examples/cerebras/btlm-ft.yml new file mode 100644 index 0000000000..4fd34aa5f1 --- /dev/null +++ b/examples/cerebras/btlm-ft.yml @@ -0,0 +1,90 @@ +base_model: cerebras/btlm-3b-8k-base +base_model_config: cerebras/btlm-3b-8k-base +model_type: AutoModelForCausalLM +tokenizer_type: GPT2Tokenizer +trust_remote_code: true +tokenizer_use_fast: true +tokenizer_legacy: true + +load_in_8bit: false +load_in_4bit: false +strict: false +push_dataset_to_hub: +hf_use_auth_token: true +datasets: + - path: mhenrichsen/alpaca_2k_test + type: alpaca +dataset_prepared_path: last_prepared_run +val_set_size: 0.01 + +adapter: +lora_model_dir: +sequence_len: 2048 +max_packed_sequence_len: +sample_packing: false +sample_packing_eff_est: +sample_packing_seq_len_multiplier: +total_num_tokens: + +lora_r: +lora_alpha: +lora_dropout: +lora_target_modules: +lora_target_linear: +lora_fan_in_fan_out: + +wandb_project: +wandb_entity: +wandb_watch: +wandb_run_id: +wandb_log_model: + +output_dir: btlm-out +gradient_accumulation_steps: 1 +micro_batch_size: 1 +num_epochs: 1 +optimizer: adamw_torch +adam_beta2: 0.95 +adam_eps: 0.000000001 +max_grad_norm: 1.0 + +torchdistx_path: +lr_scheduler: cosine +lr_quadratic_warmup: true +learning_rate: 0.000085 +train_on_inputs: true +group_by_length: false +bf16: true +fp16: false +tf32: true + +gradient_checkpointing: false +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 + +xformers_attention: +flash_attention: true +sdp_attention: +flash_optimum: + +gptq_groupsize: +gptq_model_v1: + +warmup_steps: 32 +eval_steps: +save_steps: +save_total_limit: + +debug: +deepspeed: +weight_decay: 0.1 +special_tokens: + pad_token: "<|endoftext|>" +fsdp: +# - full_shard +# - auto_wrap +fsdp_config: +# fsdp_state_dict_type: FULL_STATE_DICT +# fsdp_transformer_layer_cls_to_wrap: BTLMBlock diff --git a/src/axolotl/monkeypatch/btlm_attn_hijack_flash.py b/src/axolotl/monkeypatch/btlm_attn_hijack_flash.py new file mode 100644 index 0000000000..be5a705595 --- /dev/null +++ b/src/axolotl/monkeypatch/btlm_attn_hijack_flash.py @@ -0,0 +1,64 @@ +""" +Flash attention monkey patch for cerebras btlm model +""" + +import importlib +import logging +from typing import Optional, Tuple + +import torch +from flash_attn.flash_attn_interface import flash_attn_func +from transformers import AutoConfig, AutoModelForCausalLM + +LOG = logging.getLogger("axolotl") + + +def replace_btlm_attn_with_flash_attn(model_name="cerebras/btlm-3b-8k-base"): + # this is a wonky hack to get the remotely loaded module + model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) + # we need to load the model here in order for modeling_btlm to be available + AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True) + module_name = model_config.__class__.__module__.replace( + ".configuration_btlm", ".modeling_btlm" + ) + modeling_btlm = importlib.import_module(module_name) + modeling_btlm.BTLMAttention._attn = ( # pylint: disable=protected-access + flashattn_attn + ) + + +def flashattn_attn( + self, + query: torch.Tensor, + key: Optional[torch.Tensor] = None, + value: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, # pylint: disable=unused-argument + head_mask: Optional[torch.Tensor] = None, + position_bias: Optional[torch.Tensor] = None, # pylint: disable=unused-argument +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + softmax_scale = ( + 1 / (key.size(-1) ** self.attn_scale_power) if self.scale_attn_weights else None + ) + + query = query.permute(0, 2, 1, 3) + key = key.permute(0, 2, 1, 3) + value = value.permute(0, 2, 1, 3) + + # Perform Flash attention + attn_output = flash_attn_func( + query, + key, + value, + dropout_p=0.0, # Assuming you have this attribute + softmax_scale=softmax_scale, # Set this if you have specific scaling in mind + causal=not self.is_cross_attention, # Assuming you have this attribute + return_attn_probs=False, # Set this based on your needs + ) + + # Optional: Apply head mask if it's not None + if head_mask is not None: + attn_output *= head_mask + + attn_output = attn_output.permute(0, 2, 1, 3) + + return attn_output, None # We don't have explicit attn_weights in Flash attention diff --git a/src/axolotl/monkeypatch/falcon_attn_hijack_flash.py b/src/axolotl/monkeypatch/falcon_attn_hijack_flash.py new file mode 100644 index 0000000000..ed11c55234 --- /dev/null +++ b/src/axolotl/monkeypatch/falcon_attn_hijack_flash.py @@ -0,0 +1,101 @@ +""" +Flash Attention monkey patch for Falcon + +copied from https://github.com/pacman100/DHS-LLM-Workshop/blob/main/chat_assistant/training/falcon_flash_attn_monkey_patch.py +""" + +from typing import Optional, Tuple + +import torch +import transformers +from flash_attn import flash_attn_func + + +def forward( + self, + hidden_states: torch.Tensor, + alibi: Optional[torch.Tensor], + attention_mask: torch.Tensor, # pylint: disable=unused-argument + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + head_mask: Optional[torch.Tensor] = None, # pylint: disable=unused-argument + use_cache: bool = False, + output_attentions: bool = False, # pylint: disable=unused-argument +): + fused_qkv = self.query_key_value( + hidden_states + ) # [batch_size, seq_length, 3 x hidden_size] + num_kv_heads = ( + self.num_heads if self.new_decoder_architecture else self.num_kv_heads + ) + # 3 x [batch_size, seq_length, num_heads, head_dim] + ( + query_layer, + key_layer, + value_layer, + ) = self._split_heads( # pylint: disable=protected-access + fused_qkv + ) + + batch_size, query_length, _, _ = query_layer.shape + + query_layer = query_layer.transpose(1, 2).reshape( + batch_size * self.num_heads, query_length, self.head_dim + ) + key_layer = key_layer.transpose(1, 2).reshape( + batch_size * num_kv_heads, + query_length, + self.head_dim, + ) + value_layer = value_layer.transpose(1, 2).reshape( + batch_size * num_kv_heads, query_length, self.head_dim + ) + + past_kv_length = 0 if layer_past is None else layer_past[0].shape[1] + query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length) + + if layer_past is not None: + past_key, past_value = layer_past + # concatenate along seq_length dimension: + # - key: [batch_size * self.num_heads, kv_length, head_dim] + # - value: [batch_size * self.num_heads, kv_length, head_dim] + key_layer = torch.cat((past_key, key_layer), dim=1) + value_layer = torch.cat((past_value, value_layer), dim=1) + + # unused + # _, kv_length, _ = key_layer.shape + if use_cache: + present = (key_layer, value_layer) + else: + present = None + # unused + # attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float("-1e9")).to(query_layer.dtype) + query_layer_ = ( + query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim) + .transpose(1, 2) + .to(torch.bfloat16) + ) + key_layer_ = ( + key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim) + .transpose(1, 2) + .to(torch.bfloat16) + ) + value_layer_ = ( + value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim) + .transpose(1, 2) + .to(torch.bfloat16) + ) + + if alibi is not None: + raise ValueError("`alibi` is not supported when `use_flash_attn` is True") + + # below output will have shape (batch_size, seqlen, nheads, headdim) + attn_output = flash_attn_func(query_layer_, key_layer_, value_layer_, causal=True) + attn_output = attn_output.reshape( + batch_size, query_length, self.num_heads * self.head_dim + ) + output_tensor = self.dense(attn_output) + return output_tensor, present + + +def replace_falcon_attn_with_flash_attn(): + transformers.models.falcon.modeling_falcon.FalconAttention.forward = forward diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 9582205f9e..36607f2a2d 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -100,10 +100,31 @@ def load_model( base_model = cfg.base_model base_model_config = cfg.base_model_config model_type = cfg.model_type + model_config = load_model_config(cfg) # TODO refactor as a kwarg load_in_8bit = cfg.load_in_8bit + if hasattr(model_config, "model_type") and model_config.model_type == "btlm": + if cfg.flash_attention: + from axolotl.monkeypatch.btlm_attn_hijack_flash import ( + replace_btlm_attn_with_flash_attn, + ) + + replace_btlm_attn_with_flash_attn(cfg.base_model) + + if hasattr(model_config, "model_type") and model_config.model_type in [ + "falcon", + "RefinedWebModel", + "RefinedWeb", + ]: + if cfg.flash_attention: + from axolotl.monkeypatch.falcon_attn_hijack_flash import ( + replace_falcon_attn_with_flash_attn, + ) + + replace_falcon_attn_with_flash_attn() + if cfg.is_llama_derived_model and cfg.flash_attention: if cfg.device not in ["mps", "cpu"] and not inference: from axolotl.monkeypatch.llama_attn_hijack_flash import ( @@ -338,6 +359,9 @@ def load_model( for name, module in model.named_modules(): if "norm" in name: module.to(torch.float32) + if model_config.model_type == "btlm": + # don't upcast lm_head for btlm + continue if "lm_head" in name or "embed_tokens" in name: if hasattr(module, "weight"): module.to(torch.float32)