diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index 1c0487ff8e..7f808d5907 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -40,6 +40,11 @@ def get_device(): if accelerate_vars: cfg.device_map = None + # Override the device_map if the model we are trying to train is gptq-based" + # Otherwise training will not work + if cfg.gptq: + cfg.device_map = "auto" + def normalize_config(cfg): # setup some derived config / hyperparams