From a581e9f8f66e14c22ec914ee792dd4fe073e62f6 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Tue, 5 Dec 2023 01:20:06 +0900 Subject: [PATCH] feat: add check for quantized model (#913) * feat: add check for quantized model * chore: refactor and add another check * Update src/axolotl/utils/models.py --------- Co-authored-by: Wing Lian --- src/axolotl/utils/models.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 3037901761..40a0a89474 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -28,6 +28,27 @@ LOG = logging.getLogger("axolotl") +def check_model_config(cfg: DictDefault, model_config: AutoConfig): + quant_config_exists = hasattr(model_config, "quantization_config") + quant_config_method_is_gptq = ( + quant_config_exists + and "quant_method" in model_config.quantization_config + and model_config.quantization_config["quant_method"] == "gptq" + ) + + if cfg.gptq and not quant_config_method_is_gptq: + raise ValueError( + "model_config.quantization_config is not set or quant_method is not set to gptq. " + "Please make sure to point to a GPTQ model." + ) + + if not cfg.gptq and quant_config_exists: + raise ValueError( + "model_config.quantization_config is set but `gptq` flag is not. " + "Please use the `gptq` flag to train quantized model or point to a non-quantized model." + ) + + def load_model_config(cfg): model_config_name = cfg.base_model_config or cfg.base_model trust_remote_code = cfg.trust_remote_code is True @@ -38,6 +59,8 @@ def load_model_config(cfg): for key, val in cfg.model_config.items(): setattr(model_config, key, val) + check_model_config(cfg, model_config) + return model_config