From ad47dbf9246bfa4c12d7212324950465bdfe08e1 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 30 Mar 2024 13:56:34 -0400 Subject: [PATCH] more fixes for dbrx loras --- examples/dbrx/README.md | 3 ++- src/axolotl/utils/models.py | 21 +++++++++++++++++++-- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/examples/dbrx/README.md b/examples/dbrx/README.md index 2efeee3f19..c0f69d328c 100644 --- a/examples/dbrx/README.md +++ b/examples/dbrx/README.md @@ -7,9 +7,10 @@ where the Experts are fused as an `nn.Parameter` rather than a `nn.Linear` layer is still a bit buggy and attempting to train a LoRA adapter over those `w1`, `w2` and `v1` layers results in the trainer hanging. -We recommend using the [`LnL-AI/dbrx-base-converted-v2`](https://huggingface.co/LnL-AI/dbrx-base-converted-v2) model as your base model for the time being. ### FSDP +We've tested using the [`LnL-AI/dbrx-base-converted-v2`](https://huggingface.co/LnL-AI/dbrx-base-converted-v2) model as the base model for FSDP. + The high memory usage seen w/ FSDP is due to FSDP not supporting 8bit optimizers. - 16-bit LoRA w/ FSDP diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 444742fe8e..877b289b5c 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -750,6 +750,13 @@ def load_model( # TODO revaldate this conditional model.to(f"cuda:{cfg.local_rank}") + if ( + cfg.fsdp + and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading + and cfg.local_rank != 0 + ): + setup_quantized_peft_meta_for_training(model) + if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1: setattr(model, "is_parallelizable", True) setattr(model, "model_parallel", True) @@ -901,7 +908,12 @@ def load_lora(model, cfg, inference=False, config_only=False): rank = int(os.environ.get("LOCAL_RANK", 0)) - if cfg.fsdp and cfg.adapter == "qlora" and rank != 0: + if ( + cfg.fsdp + and cfg.adapter + and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading + and rank != 0 + ): setup_quantized_meta_for_peft(model) if cfg.lora_model_dir: @@ -921,7 +933,12 @@ def load_lora(model, cfg, inference=False, config_only=False): if rank == 0: model.print_trainable_parameters() - elif cfg.fsdp and cfg.adapter == "qlora": + elif ( + cfg.fsdp + and cfg.adapter + and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading + and rank != 0 + ): setup_quantized_peft_meta_for_training(model) return model, lora_config