Skip to content

Commit

Permalink
more fixes for dbrx loras
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Mar 30, 2024
1 parent 086e48b commit ad47dbf
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 3 deletions.
3 changes: 2 additions & 1 deletion examples/dbrx/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@ where the Experts are fused as an `nn.Parameter` rather than a `nn.Linear` layer
is still a bit buggy and attempting to train a LoRA adapter over those `w1`, `w2` and `v1` layers
results in the trainer hanging.

We recommend using the [`LnL-AI/dbrx-base-converted-v2`](https://huggingface.co/LnL-AI/dbrx-base-converted-v2) model as your base model for the time being.

### FSDP
We've tested using the [`LnL-AI/dbrx-base-converted-v2`](https://huggingface.co/LnL-AI/dbrx-base-converted-v2) model as the base model for FSDP.

The high memory usage seen w/ FSDP is due to FSDP not supporting 8bit optimizers.

- 16-bit LoRA w/ FSDP
Expand Down
21 changes: 19 additions & 2 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,6 +750,13 @@ def load_model(
# TODO revaldate this conditional
model.to(f"cuda:{cfg.local_rank}")

if (
cfg.fsdp
and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
and cfg.local_rank != 0
):
setup_quantized_peft_meta_for_training(model)

if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1:
setattr(model, "is_parallelizable", True)
setattr(model, "model_parallel", True)
Expand Down Expand Up @@ -901,7 +908,12 @@ def load_lora(model, cfg, inference=False, config_only=False):

rank = int(os.environ.get("LOCAL_RANK", 0))

if cfg.fsdp and cfg.adapter == "qlora" and rank != 0:
if (
cfg.fsdp
and cfg.adapter
and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
and rank != 0
):
setup_quantized_meta_for_peft(model)

if cfg.lora_model_dir:
Expand All @@ -921,7 +933,12 @@ def load_lora(model, cfg, inference=False, config_only=False):

if rank == 0:
model.print_trainable_parameters()
elif cfg.fsdp and cfg.adapter == "qlora":
elif (
cfg.fsdp
and cfg.adapter
and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
and rank != 0
):
setup_quantized_peft_meta_for_training(model)

return model, lora_config

0 comments on commit ad47dbf

Please sign in to comment.