Skip to content

Commit

Permalink
qwen2_moe support w multipack (#1455)
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian authored Mar 29, 2024
1 parent 4a92a3b commit 6086be8
Show file tree
Hide file tree
Showing 6 changed files with 147 additions and 4 deletions.
10 changes: 10 additions & 0 deletions examples/qwen/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Qwen

TODO

# Qwen2 MoE

✅ multipack
✅ qwen2_moe 4-bit QLoRA
✅ qwen2_moe 16-bit LoRA
❓ qwen2_moe 8-bit LoRA
64 changes: 64 additions & 0 deletions examples/qwen/qwen2-moe-lora.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
base_model: Qwen/Qwen1.5-MoE-A2.7B
trust_remote_code: true

load_in_8bit: false
load_in_4bit: false
strict: false

datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./out

sequence_len: 1024 # supports up to 32k
sample_packing: false
pad_to_sequence_len: false

adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:

wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 4
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: true

gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true

warmup_steps: 10
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
64 changes: 64 additions & 0 deletions examples/qwen/qwen2-moe-qlora.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
base_model: Qwen/Qwen1.5-MoE-A2.7B
trust_remote_code: true

load_in_8bit: false
load_in_4bit: true
strict: false

datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./out

sequence_len: 1024 # supports up to 32k
sample_packing: false
pad_to_sequence_len: false

adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:

wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 4
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: true

gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true

warmup_steps: 10
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2
peft==0.9.0
transformers @ git+https://github.com/huggingface/transformers.git@73a73b415e36f41481369f6129cb4b62bb127a78
peft==0.10.0
transformers @ git+https://github.com/huggingface/transformers.git@43d17c18360ac9c3d3491389328e2fe55fe8f9ce
tokenizers==0.15.0
bitsandbytes==0.43.0
accelerate==0.28.0
Expand Down Expand Up @@ -39,4 +39,4 @@ s3fs
gcsfs
# adlfs

trl @ git+https://github.com/huggingface/trl.git@304e208f778a5442c30cdda500348226cdc97d90
trl @ git+https://github.com/huggingface/trl.git@0ee349dcd43b0f4b3169449f16751c38ac4a609f
5 changes: 5 additions & 0 deletions src/axolotl/monkeypatch/multipack.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
SUPPORTED_MULTIPACK_MODEL_TYPES = [
"mixtral",
"qwen2",
"qwen2_moe",
"falcon",
"phi",
"gemma",
Expand All @@ -31,6 +32,10 @@ def patch_for_multipack(model_type, model_name=None):
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "qwen2_moe":
transformers.models.qwen2_moe.modeling_qwen2_moe._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "falcon":
transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
Expand Down
2 changes: 1 addition & 1 deletion src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,7 @@ def load_model(
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_quant_storage": torch.bfloat16,
}
if cfg.model_config_type == "jamba" and not cfg.deepspeed:
if not cfg.deepspeed:
# for some reason, this causes the loss to be off by an order of magnitude
# but deepspeed needs this still in bfloat16
bnb_config["bnb_4bit_quant_storage"] = torch.float32
Expand Down

0 comments on commit 6086be8

Please sign in to comment.