Skip to content

Commit

Permalink
btlm and falcon monkey patches for flash attn (#566)
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian authored Sep 17, 2023
1 parent 5610072 commit 9dd1ea9
Show file tree
Hide file tree
Showing 4 changed files with 279 additions and 0 deletions.
90 changes: 90 additions & 0 deletions examples/cerebras/btlm-ft.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
base_model: cerebras/btlm-3b-8k-base
base_model_config: cerebras/btlm-3b-8k-base
model_type: AutoModelForCausalLM
tokenizer_type: GPT2Tokenizer
trust_remote_code: true
tokenizer_use_fast: true
tokenizer_legacy: true

load_in_8bit: false
load_in_4bit: false
strict: false
push_dataset_to_hub:
hf_use_auth_token: true
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_prepared_run
val_set_size: 0.01

adapter:
lora_model_dir:
sequence_len: 2048
max_packed_sequence_len:
sample_packing: false
sample_packing_eff_est:
sample_packing_seq_len_multiplier:
total_num_tokens:

lora_r:
lora_alpha:
lora_dropout:
lora_target_modules:
lora_target_linear:
lora_fan_in_fan_out:

wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:

output_dir: btlm-out
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch
adam_beta2: 0.95
adam_eps: 0.000000001
max_grad_norm: 1.0

torchdistx_path:
lr_scheduler: cosine
lr_quadratic_warmup: true
learning_rate: 0.000085
train_on_inputs: true
group_by_length: false
bf16: true
fp16: false
tf32: true

gradient_checkpointing: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1

xformers_attention:
flash_attention: true
sdp_attention:
flash_optimum:

gptq_groupsize:
gptq_model_v1:

warmup_steps: 32
eval_steps:
save_steps:
save_total_limit:

debug:
deepspeed:
weight_decay: 0.1
special_tokens:
pad_token: "<|endoftext|>"
fsdp:
# - full_shard
# - auto_wrap
fsdp_config:
# fsdp_state_dict_type: FULL_STATE_DICT
# fsdp_transformer_layer_cls_to_wrap: BTLMBlock
64 changes: 64 additions & 0 deletions src/axolotl/monkeypatch/btlm_attn_hijack_flash.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""
Flash attention monkey patch for cerebras btlm model
"""

import importlib
import logging
from typing import Optional, Tuple

import torch
from flash_attn.flash_attn_interface import flash_attn_func
from transformers import AutoConfig, AutoModelForCausalLM

LOG = logging.getLogger("axolotl")


def replace_btlm_attn_with_flash_attn(model_name="cerebras/btlm-3b-8k-base"):
# this is a wonky hack to get the remotely loaded module
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
# we need to load the model here in order for modeling_btlm to be available
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
module_name = model_config.__class__.__module__.replace(
".configuration_btlm", ".modeling_btlm"
)
modeling_btlm = importlib.import_module(module_name)
modeling_btlm.BTLMAttention._attn = ( # pylint: disable=protected-access
flashattn_attn
)


def flashattn_attn(
self,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
value: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
head_mask: Optional[torch.Tensor] = None,
position_bias: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
softmax_scale = (
1 / (key.size(-1) ** self.attn_scale_power) if self.scale_attn_weights else None
)

query = query.permute(0, 2, 1, 3)
key = key.permute(0, 2, 1, 3)
value = value.permute(0, 2, 1, 3)

# Perform Flash attention
attn_output = flash_attn_func(
query,
key,
value,
dropout_p=0.0, # Assuming you have this attribute
softmax_scale=softmax_scale, # Set this if you have specific scaling in mind
causal=not self.is_cross_attention, # Assuming you have this attribute
return_attn_probs=False, # Set this based on your needs
)

# Optional: Apply head mask if it's not None
if head_mask is not None:
attn_output *= head_mask

attn_output = attn_output.permute(0, 2, 1, 3)

return attn_output, None # We don't have explicit attn_weights in Flash attention
101 changes: 101 additions & 0 deletions src/axolotl/monkeypatch/falcon_attn_hijack_flash.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
"""
Flash Attention monkey patch for Falcon
copied from https://github.com/pacman100/DHS-LLM-Workshop/blob/main/chat_assistant/training/falcon_flash_attn_monkey_patch.py
"""

from typing import Optional, Tuple

import torch
import transformers
from flash_attn import flash_attn_func


def forward(
self,
hidden_states: torch.Tensor,
alibi: Optional[torch.Tensor],
attention_mask: torch.Tensor, # pylint: disable=unused-argument
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
use_cache: bool = False,
output_attentions: bool = False, # pylint: disable=unused-argument
):
fused_qkv = self.query_key_value(
hidden_states
) # [batch_size, seq_length, 3 x hidden_size]
num_kv_heads = (
self.num_heads if self.new_decoder_architecture else self.num_kv_heads
)
# 3 x [batch_size, seq_length, num_heads, head_dim]
(
query_layer,
key_layer,
value_layer,
) = self._split_heads( # pylint: disable=protected-access
fused_qkv
)

batch_size, query_length, _, _ = query_layer.shape

query_layer = query_layer.transpose(1, 2).reshape(
batch_size * self.num_heads, query_length, self.head_dim
)
key_layer = key_layer.transpose(1, 2).reshape(
batch_size * num_kv_heads,
query_length,
self.head_dim,
)
value_layer = value_layer.transpose(1, 2).reshape(
batch_size * num_kv_heads, query_length, self.head_dim
)

past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length)

if layer_past is not None:
past_key, past_value = layer_past
# concatenate along seq_length dimension:
# - key: [batch_size * self.num_heads, kv_length, head_dim]
# - value: [batch_size * self.num_heads, kv_length, head_dim]
key_layer = torch.cat((past_key, key_layer), dim=1)
value_layer = torch.cat((past_value, value_layer), dim=1)

# unused
# _, kv_length, _ = key_layer.shape
if use_cache:
present = (key_layer, value_layer)
else:
present = None
# unused
# attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float("-1e9")).to(query_layer.dtype)
query_layer_ = (
query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
.transpose(1, 2)
.to(torch.bfloat16)
)
key_layer_ = (
key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
.transpose(1, 2)
.to(torch.bfloat16)
)
value_layer_ = (
value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
.transpose(1, 2)
.to(torch.bfloat16)
)

if alibi is not None:
raise ValueError("`alibi` is not supported when `use_flash_attn` is True")

# below output will have shape (batch_size, seqlen, nheads, headdim)
attn_output = flash_attn_func(query_layer_, key_layer_, value_layer_, causal=True)
attn_output = attn_output.reshape(
batch_size, query_length, self.num_heads * self.head_dim
)
output_tensor = self.dense(attn_output)
return output_tensor, present


def replace_falcon_attn_with_flash_attn():
transformers.models.falcon.modeling_falcon.FalconAttention.forward = forward
24 changes: 24 additions & 0 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,31 @@ def load_model(
base_model = cfg.base_model
base_model_config = cfg.base_model_config
model_type = cfg.model_type
model_config = load_model_config(cfg)

# TODO refactor as a kwarg
load_in_8bit = cfg.load_in_8bit

if hasattr(model_config, "model_type") and model_config.model_type == "btlm":
if cfg.flash_attention:
from axolotl.monkeypatch.btlm_attn_hijack_flash import (
replace_btlm_attn_with_flash_attn,
)

replace_btlm_attn_with_flash_attn(cfg.base_model)

if hasattr(model_config, "model_type") and model_config.model_type in [
"falcon",
"RefinedWebModel",
"RefinedWeb",
]:
if cfg.flash_attention:
from axolotl.monkeypatch.falcon_attn_hijack_flash import (
replace_falcon_attn_with_flash_attn,
)

replace_falcon_attn_with_flash_attn()

if cfg.is_llama_derived_model and cfg.flash_attention:
if cfg.device not in ["mps", "cpu"] and not inference:
from axolotl.monkeypatch.llama_attn_hijack_flash import (
Expand Down Expand Up @@ -338,6 +359,9 @@ def load_model(
for name, module in model.named_modules():
if "norm" in name:
module.to(torch.float32)
if model_config.model_type == "btlm":
# don't upcast lm_head for btlm
continue
if "lm_head" in name or "embed_tokens" in name:
if hasattr(module, "weight"):
module.to(torch.float32)
Expand Down

0 comments on commit 9dd1ea9

Please sign in to comment.