Skip to content

Commit

Permalink
fix support for 8bit loras too for dbrx
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Mar 30, 2024
1 parent c4ad7a8 commit 4fe469a
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 14 deletions.
3 changes: 2 additions & 1 deletion examples/dbrx/16bit-lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ bf16: auto
fp16:
tf32: false

gradient_checkpointing: true
gradient_checkpointing: false # don't use with fsdp_activation_checkpointing
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
Expand All @@ -78,3 +78,4 @@ fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: DbrxBlock
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_activation_checkpointing: true
81 changes: 81 additions & 0 deletions examples/dbrx/8bit-lora.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
base_model: LnL-AI/dbrx-base-converted-v2
trust_remote_code: true

load_in_8bit: true
load_in_4bit: false
strict: false

datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./out

sequence_len: 512
sample_packing: false
pad_to_sequence_len: false

wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

adapter: lora
lora_model_dir:
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
# w1, w2, & v1 will hang the trainer
lora_target_modules:
- q_proj # attn
- k_proj # attn
- v_proj # attn
- out_proj # attn
- layer # router
# - w1
# - w2
# - v1

gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false

gradient_checkpointing: false # don't use with fsdp_activation_checkpointing
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true

warmup_steps: 10
evals_per_epoch:
saves_per_epoch: 1
debug:
weight_decay: 0.0
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_limit_all_gathers: true
fsdp_sync_module_states: true
fsdp_offload_params: false
fsdp_use_orig_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: DbrxBlock
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_activation_checkpointing: true
2 changes: 1 addition & 1 deletion examples/dbrx/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ The high memory usage seen w/ FSDP is due to FSDP not supporting 8bit optimizers
- 16-bit LoRA w/ FSDP
- ✅ w/o CPU Offload - 8x80GB uses ~80GiB/gpu
- ❌ w/ CPU Offload - `paged_adamw_8bit` optimizer errors from being on cpu
- 8-bit LoRA w/ FSDP - WIP, need to handle loading 8-bit quantized weights
- 8-bit LoRA w/ FSDP
- ❌ 4-bit QLoRA w/ FSDP - errors w/: `Error an illegal memory access was encountered at line 90 in file /src/csrc/ops.cu`
- ✅ bf16 full finetune w/ FSDP, freezing all but first 8 layers (8x80GB uses ~78GiB/gpu)

Expand Down
3 changes: 3 additions & 0 deletions src/axolotl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import torch
import transformers.modelcard
from accelerate import Accelerator
from accelerate.logging import get_logger
from datasets import Dataset
from peft import PeftModel
Expand Down Expand Up @@ -81,6 +82,8 @@ def train(
if cfg.adapter:
msg += " and peft_config..."
LOG.debug(msg)
# we wait unitl the last possible moment to setup Accelerator
Accelerator()
model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
model.generation_config.do_sample = True

Expand Down
23 changes: 11 additions & 12 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,13 @@ def load_model(
model_kwargs["quantization_config"] = BitsAndBytesConfig(
**bnb_config,
)
elif cfg.adapter == "lora" and cfg.load_in_8bit:
bnb_config = {
"load_in_8bit": True,
}
model_kwargs["quantization_config"] = BitsAndBytesConfig(
**bnb_config,
)

if cfg.load_in_8bit and cfg.adapter is not None:
model_kwargs["load_in_8bit"] = True
Expand Down Expand Up @@ -517,10 +524,8 @@ def load_model(
try:
skip_move_to_device = False
if (
cfg.fsdp
and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
and not qlora_fsdp
):
cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
) and not qlora_fsdp:
model = load_sharded_model(
base_model,
model_config,
Expand Down Expand Up @@ -704,7 +709,8 @@ def load_model(
if cfg.adapter == "lora" and loftq_bits:
skip_prepare_model_for_kbit_training = True

if qlora_fsdp:
if qlora_fsdp or (cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading):
# make sure everything is in the same dtype
skip_prepare_model_for_kbit_training = True

if cfg.adapter in ["lora", "qlora"]:
Expand Down Expand Up @@ -750,13 +756,6 @@ def load_model(
# TODO revaldate this conditional
model.to(f"cuda:{cfg.local_rank}")

if (
cfg.fsdp
and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
and cfg.local_rank != 0
):
setup_quantized_peft_meta_for_training(model)

if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1:
setattr(model, "is_parallelizable", True)
setattr(model, "model_parallel", True)
Expand Down

0 comments on commit 4fe469a

Please sign in to comment.