diff --git a/src/axolotl/cli/inference.py b/src/axolotl/cli/inference.py index adc991456d..b738e5c222 100644 --- a/src/axolotl/cli/inference.py +++ b/src/axolotl/cli/inference.py @@ -19,7 +19,7 @@ def do_cli(config: Path = Path("examples/"), gradio=False, **kwargs): # pylint: disable=duplicate-code print_axolotl_text_art() - parsed_cfg = load_cfg(config, **kwargs) + parsed_cfg = load_cfg(config, inference=True, **kwargs) parsed_cfg.sample_packing = False parser = transformers.HfArgumentParser((TrainerCliArgs)) parsed_cli_args, _ = parser.parse_args_into_dataclasses( diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 0f01a7cadc..c9170b7a84 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -323,11 +323,13 @@ class LoraConfig(BaseModel): @model_validator(mode="before") @classmethod def validate_adapter(cls, data): - if not data.get("adapter") and ( - data.get("load_in_8bit") or data.get("load_in_4bit") + if ( + not data.get("adapter") + and not data.get("inference") + and (data.get("load_in_8bit") or data.get("load_in_4bit")) ): raise ValueError( - "load_in_8bit and load_in_4bit are not supported without setting an adapter." + "load_in_8bit and load_in_4bit are not supported without setting an adapter for training." "If you want to full finetune, please turn off load_in_8bit and load_in_4bit." ) return data