From 6813b542b23ebd6cc03eaca930a2a5bcbaf3ebd9 Mon Sep 17 00:00:00 2001 From: Napuh Date: Wed, 20 Sep 2023 13:06:35 +0200 Subject: [PATCH 1/3] override device_map if the model is gptq-based --- src/axolotl/utils/config.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index 1c0487ff8e..fed8abf3f3 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -39,6 +39,11 @@ def get_device(): accelerate_vars = [var for var in os.environ if var.startswith("ACCELERATE_USE_")] if accelerate_vars: cfg.device_map = None + + # Override the device_map if the model we are trying to train is gptq-based" + # Otherwise training will not work + if cfg.gptq: + cfg.device_map="auto" def normalize_config(cfg): From 44355828fca9576682275f042e9e6a1cfc0af88b Mon Sep 17 00:00:00 2001 From: Napuh <55241721+Napuh@users.noreply.github.com> Date: Thu, 21 Sep 2023 18:14:57 +0200 Subject: [PATCH 2/3] fiex code style --- src/axolotl/utils/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index fed8abf3f3..9db6353154 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -43,7 +43,7 @@ def get_device(): # Override the device_map if the model we are trying to train is gptq-based" # Otherwise training will not work if cfg.gptq: - cfg.device_map="auto" + cfg.device_map = "auto" def normalize_config(cfg): From b1eea4d18e841408f98312dac7614359b2e8a753 Mon Sep 17 00:00:00 2001 From: napuh Date: Thu, 21 Sep 2023 22:25:30 +0200 Subject: [PATCH 3/3] removed trailing space --- src/axolotl/utils/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index 9db6353154..7f808d5907 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -39,7 +39,7 @@ def get_device(): accelerate_vars = [var for var in os.environ if var.startswith("ACCELERATE_USE_")] if accelerate_vars: cfg.device_map = None - + # Override the device_map if the model we are trying to train is gptq-based" # Otherwise training will not work if cfg.gptq: