diff --git a/examples/llama-2/gptq-lora.yml b/examples/llama-2/gptq-lora.yml index dbce2a6b34..0d96e4610b 100644 --- a/examples/llama-2/gptq-lora.yml +++ b/examples/llama-2/gptq-lora.yml @@ -2,7 +2,7 @@ base_model: TheBloke/Llama-2-7B-GPTQ base_model_config: TheBloke/Llama-2-7B-GPTQ is_llama_derived_model: false gptq: true -gptq_bits: 4 +gptq_disable_exllama: true model_type: AutoModelForCausalLM tokenizer_type: LlamaTokenizer tokenizer_use_fast: true @@ -62,8 +62,6 @@ xformers_attention: flash_attention: sdp_attention: flash_optimum: -gptq_groupsize: -gptq_model_v1: warmup_steps: 100 eval_steps: save_steps: diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 7dc25996c2..a349776d77 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -196,6 +196,10 @@ def load_model( if not hasattr(model_config, "quantization_config"): LOG.warning("model config does not contain quantization_config information") else: + if cfg.gptq_disable_exllama is not None: + model_config.quantization_config[ + "disable_exllama" + ] = cfg.gptq_disable_exllama model_kwargs["quantization_config"] = GPTQConfig( **model_config.quantization_config )