Skip to content

Commit

Permalink
flash_attention + sample packing for stablelm 3b (#671)
Browse files Browse the repository at this point in the history
* stablelm epoch fa patch

* is causal for fa

* working stablelm fa w packing

* chore: pre-commit linting
  • Loading branch information
winglian authored Oct 5, 2023
1 parent 5b4e92b commit b838a67
Show file tree
Hide file tree
Showing 3 changed files with 429 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/axolotl/monkeypatch/btlm_attn_hijack_flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Optional, Tuple

import torch
from accelerate import init_empty_weights
from flash_attn.flash_attn_interface import flash_attn_func
from transformers import AutoConfig, AutoModelForCausalLM

Expand All @@ -17,7 +18,8 @@ def replace_btlm_attn_with_flash_attn(model_name="cerebras/btlm-3b-8k-base"):
# this is a wonky hack to get the remotely loaded module
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
# we need to load the model here in order for modeling_btlm to be available
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
with init_empty_weights():
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
module_name = model_config.__class__.__module__.replace(
".configuration_btlm", ".modeling_btlm"
)
Expand Down
Loading

0 comments on commit b838a67

Please sign in to comment.