Skip to content

Commit

Permalink
prepare ia3 for 8bit
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Sep 18, 2023
1 parent 31b9e0c commit 3124670
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
1 change: 1 addition & 0 deletions .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ generated-members=numpy.*, torch.*
disable=missing-function-docstring, line-too-long, import-error,
too-many-arguments, too-many-locals, too-many-statements, too-many-branches, too-few-public-methods,
too-many-instance-attributes, fixme, import-outside-toplevel, logging-fstring-interpolation,
too-many-boolean-expressions,
6 changes: 4 additions & 2 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,8 +367,10 @@ def load_model(
module.to(torch.float32)

needs_fa2_dtype = cfg.adapter or cfg.fsdp
if (cfg.adapter == "lora" and load_in_8bit) or (
cfg.adapter == "qlora" and cfg.load_in_4bit
if (
(cfg.adapter == "lora" and load_in_8bit)
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
or (cfg.adapter == "ia3" and load_in_8bit)
):
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
if cfg.gradient_checkpointing:
Expand Down

0 comments on commit 3124670

Please sign in to comment.