From 3e3229e2d99bb509784ac72e6589f8a8e406247f Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 30 Nov 2023 12:45:50 -0500 Subject: [PATCH] fix for qwen w lora (#906) --- src/axolotl/utils/models.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 0d8c812f3b..acc6f41fa6 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -412,15 +412,22 @@ def load_model( module.to(torch.float32) needs_fa2_dtype = cfg.adapter or cfg.fsdp + skip_prepare_model_for_kbit_training = False + + if cfg.model_config_type == "qwen" and cfg.adapter == "lora": + # Qwen doesn't play nicely with LoRA if this is enabled + skip_prepare_model_for_kbit_training = True + if (cfg.adapter == "lora" and load_in_8bit) or ( cfg.adapter == "qlora" and cfg.load_in_4bit ): LOG.info("converting PEFT model w/ prepare_model_for_kbit_training") if cfg.gradient_checkpointing: model.gradient_checkpointing_enable() - model = prepare_model_for_kbit_training( - model, use_gradient_checkpointing=cfg.gradient_checkpointing - ) + if not skip_prepare_model_for_kbit_training: + model = prepare_model_for_kbit_training( + model, use_gradient_checkpointing=cfg.gradient_checkpointing + ) needs_fa2_dtype = True # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to