From 3124670306228f6d144fb9940022828c632f4ac7 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 18 Sep 2023 18:49:44 -0400 Subject: [PATCH] prepare ia3 for 8bit --- .pylintrc | 1 + src/axolotl/utils/models.py | 6 ++++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/.pylintrc b/.pylintrc index ed973d2859..9f0e453d5d 100644 --- a/.pylintrc +++ b/.pylintrc @@ -12,3 +12,4 @@ generated-members=numpy.*, torch.* disable=missing-function-docstring, line-too-long, import-error, too-many-arguments, too-many-locals, too-many-statements, too-many-branches, too-few-public-methods, too-many-instance-attributes, fixme, import-outside-toplevel, logging-fstring-interpolation, + too-many-boolean-expressions, diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 36607f2a2d..76a6fb097f 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -367,8 +367,10 @@ def load_model( module.to(torch.float32) needs_fa2_dtype = cfg.adapter or cfg.fsdp - if (cfg.adapter == "lora" and load_in_8bit) or ( - cfg.adapter == "qlora" and cfg.load_in_4bit + if ( + (cfg.adapter == "lora" and load_in_8bit) + or (cfg.adapter == "qlora" and cfg.load_in_4bit) + or (cfg.adapter == "ia3" and load_in_8bit) ): LOG.info("converting PEFT model w/ prepare_model_for_kbit_training") if cfg.gradient_checkpointing: