-
-
Notifications
You must be signed in to change notification settings - Fork 895
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
btlm and falcon monkey patches for flash attn (#566)
- Loading branch information
Showing
4 changed files
with
279 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
base_model: cerebras/btlm-3b-8k-base | ||
base_model_config: cerebras/btlm-3b-8k-base | ||
model_type: AutoModelForCausalLM | ||
tokenizer_type: GPT2Tokenizer | ||
trust_remote_code: true | ||
tokenizer_use_fast: true | ||
tokenizer_legacy: true | ||
|
||
load_in_8bit: false | ||
load_in_4bit: false | ||
strict: false | ||
push_dataset_to_hub: | ||
hf_use_auth_token: true | ||
datasets: | ||
- path: mhenrichsen/alpaca_2k_test | ||
type: alpaca | ||
dataset_prepared_path: last_prepared_run | ||
val_set_size: 0.01 | ||
|
||
adapter: | ||
lora_model_dir: | ||
sequence_len: 2048 | ||
max_packed_sequence_len: | ||
sample_packing: false | ||
sample_packing_eff_est: | ||
sample_packing_seq_len_multiplier: | ||
total_num_tokens: | ||
|
||
lora_r: | ||
lora_alpha: | ||
lora_dropout: | ||
lora_target_modules: | ||
lora_target_linear: | ||
lora_fan_in_fan_out: | ||
|
||
wandb_project: | ||
wandb_entity: | ||
wandb_watch: | ||
wandb_run_id: | ||
wandb_log_model: | ||
|
||
output_dir: btlm-out | ||
gradient_accumulation_steps: 1 | ||
micro_batch_size: 1 | ||
num_epochs: 1 | ||
optimizer: adamw_torch | ||
adam_beta2: 0.95 | ||
adam_eps: 0.000000001 | ||
max_grad_norm: 1.0 | ||
|
||
torchdistx_path: | ||
lr_scheduler: cosine | ||
lr_quadratic_warmup: true | ||
learning_rate: 0.000085 | ||
train_on_inputs: true | ||
group_by_length: false | ||
bf16: true | ||
fp16: false | ||
tf32: true | ||
|
||
gradient_checkpointing: false | ||
early_stopping_patience: | ||
resume_from_checkpoint: | ||
local_rank: | ||
logging_steps: 1 | ||
|
||
xformers_attention: | ||
flash_attention: true | ||
sdp_attention: | ||
flash_optimum: | ||
|
||
gptq_groupsize: | ||
gptq_model_v1: | ||
|
||
warmup_steps: 32 | ||
eval_steps: | ||
save_steps: | ||
save_total_limit: | ||
|
||
debug: | ||
deepspeed: | ||
weight_decay: 0.1 | ||
special_tokens: | ||
pad_token: "<|endoftext|>" | ||
fsdp: | ||
# - full_shard | ||
# - auto_wrap | ||
fsdp_config: | ||
# fsdp_state_dict_type: FULL_STATE_DICT | ||
# fsdp_transformer_layer_cls_to_wrap: BTLMBlock |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
""" | ||
Flash attention monkey patch for cerebras btlm model | ||
""" | ||
|
||
import importlib | ||
import logging | ||
from typing import Optional, Tuple | ||
|
||
import torch | ||
from flash_attn.flash_attn_interface import flash_attn_func | ||
from transformers import AutoConfig, AutoModelForCausalLM | ||
|
||
LOG = logging.getLogger("axolotl") | ||
|
||
|
||
def replace_btlm_attn_with_flash_attn(model_name="cerebras/btlm-3b-8k-base"): | ||
# this is a wonky hack to get the remotely loaded module | ||
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) | ||
# we need to load the model here in order for modeling_btlm to be available | ||
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True) | ||
module_name = model_config.__class__.__module__.replace( | ||
".configuration_btlm", ".modeling_btlm" | ||
) | ||
modeling_btlm = importlib.import_module(module_name) | ||
modeling_btlm.BTLMAttention._attn = ( # pylint: disable=protected-access | ||
flashattn_attn | ||
) | ||
|
||
|
||
def flashattn_attn( | ||
self, | ||
query: torch.Tensor, | ||
key: Optional[torch.Tensor] = None, | ||
value: Optional[torch.Tensor] = None, | ||
attention_mask: Optional[torch.Tensor] = None, # pylint: disable=unused-argument | ||
head_mask: Optional[torch.Tensor] = None, | ||
position_bias: Optional[torch.Tensor] = None, # pylint: disable=unused-argument | ||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: | ||
softmax_scale = ( | ||
1 / (key.size(-1) ** self.attn_scale_power) if self.scale_attn_weights else None | ||
) | ||
|
||
query = query.permute(0, 2, 1, 3) | ||
key = key.permute(0, 2, 1, 3) | ||
value = value.permute(0, 2, 1, 3) | ||
|
||
# Perform Flash attention | ||
attn_output = flash_attn_func( | ||
query, | ||
key, | ||
value, | ||
dropout_p=0.0, # Assuming you have this attribute | ||
softmax_scale=softmax_scale, # Set this if you have specific scaling in mind | ||
causal=not self.is_cross_attention, # Assuming you have this attribute | ||
return_attn_probs=False, # Set this based on your needs | ||
) | ||
|
||
# Optional: Apply head mask if it's not None | ||
if head_mask is not None: | ||
attn_output *= head_mask | ||
|
||
attn_output = attn_output.permute(0, 2, 1, 3) | ||
|
||
return attn_output, None # We don't have explicit attn_weights in Flash attention |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
""" | ||
Flash Attention monkey patch for Falcon | ||
copied from https://github.com/pacman100/DHS-LLM-Workshop/blob/main/chat_assistant/training/falcon_flash_attn_monkey_patch.py | ||
""" | ||
|
||
from typing import Optional, Tuple | ||
|
||
import torch | ||
import transformers | ||
from flash_attn import flash_attn_func | ||
|
||
|
||
def forward( | ||
self, | ||
hidden_states: torch.Tensor, | ||
alibi: Optional[torch.Tensor], | ||
attention_mask: torch.Tensor, # pylint: disable=unused-argument | ||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, | ||
head_mask: Optional[torch.Tensor] = None, # pylint: disable=unused-argument | ||
use_cache: bool = False, | ||
output_attentions: bool = False, # pylint: disable=unused-argument | ||
): | ||
fused_qkv = self.query_key_value( | ||
hidden_states | ||
) # [batch_size, seq_length, 3 x hidden_size] | ||
num_kv_heads = ( | ||
self.num_heads if self.new_decoder_architecture else self.num_kv_heads | ||
) | ||
# 3 x [batch_size, seq_length, num_heads, head_dim] | ||
( | ||
query_layer, | ||
key_layer, | ||
value_layer, | ||
) = self._split_heads( # pylint: disable=protected-access | ||
fused_qkv | ||
) | ||
|
||
batch_size, query_length, _, _ = query_layer.shape | ||
|
||
query_layer = query_layer.transpose(1, 2).reshape( | ||
batch_size * self.num_heads, query_length, self.head_dim | ||
) | ||
key_layer = key_layer.transpose(1, 2).reshape( | ||
batch_size * num_kv_heads, | ||
query_length, | ||
self.head_dim, | ||
) | ||
value_layer = value_layer.transpose(1, 2).reshape( | ||
batch_size * num_kv_heads, query_length, self.head_dim | ||
) | ||
|
||
past_kv_length = 0 if layer_past is None else layer_past[0].shape[1] | ||
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length) | ||
|
||
if layer_past is not None: | ||
past_key, past_value = layer_past | ||
# concatenate along seq_length dimension: | ||
# - key: [batch_size * self.num_heads, kv_length, head_dim] | ||
# - value: [batch_size * self.num_heads, kv_length, head_dim] | ||
key_layer = torch.cat((past_key, key_layer), dim=1) | ||
value_layer = torch.cat((past_value, value_layer), dim=1) | ||
|
||
# unused | ||
# _, kv_length, _ = key_layer.shape | ||
if use_cache: | ||
present = (key_layer, value_layer) | ||
else: | ||
present = None | ||
# unused | ||
# attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float("-1e9")).to(query_layer.dtype) | ||
query_layer_ = ( | ||
query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim) | ||
.transpose(1, 2) | ||
.to(torch.bfloat16) | ||
) | ||
key_layer_ = ( | ||
key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim) | ||
.transpose(1, 2) | ||
.to(torch.bfloat16) | ||
) | ||
value_layer_ = ( | ||
value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim) | ||
.transpose(1, 2) | ||
.to(torch.bfloat16) | ||
) | ||
|
||
if alibi is not None: | ||
raise ValueError("`alibi` is not supported when `use_flash_attn` is True") | ||
|
||
# below output will have shape (batch_size, seqlen, nheads, headdim) | ||
attn_output = flash_attn_func(query_layer_, key_layer_, value_layer_, causal=True) | ||
attn_output = attn_output.reshape( | ||
batch_size, query_length, self.num_heads * self.head_dim | ||
) | ||
output_tensor = self.dense(attn_output) | ||
return output_tensor, present | ||
|
||
|
||
def replace_falcon_attn_with_flash_attn(): | ||
transformers.models.falcon.modeling_falcon.FalconAttention.forward = forward |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters