From 8df288dbb13e78beedb6b81920d8918f3fa21da6 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 7 Sep 2023 16:10:03 -0400 Subject: [PATCH 1/3] model-parallel for single process --- src/axolotl/utils/config.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index 6de807eab9..92eb75e053 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -24,7 +24,11 @@ def get_device(): except Exception: # pylint: disable=broad-exception-caught return "cpu" - cfg.device = get_device() + if cfg.world_size == 1: + cfg.device = "auto" + else: + cfg.device = get_device() + if cfg.device_map != "auto": if cfg.device.startswith("cuda"): cfg.device_map = {"": cfg.local_rank} From 04625e075da16a7416f3e631691e5f8e61af45ea Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 7 Sep 2023 17:26:08 -0400 Subject: [PATCH 2/3] fix device/device_map --- src/axolotl/utils/config.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index 92eb75e053..4561a2e499 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -25,11 +25,9 @@ def get_device(): return "cpu" if cfg.world_size == 1: - cfg.device = "auto" + cfg.device_map = "auto" else: cfg.device = get_device() - - if cfg.device_map != "auto": if cfg.device.startswith("cuda"): cfg.device_map = {"": cfg.local_rank} else: From 8ff01097ef51278671797d48ca7e0794c2f17dc8 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 7 Sep 2023 17:31:01 -0400 Subject: [PATCH 3/3] fix handling for device --- src/axolotl/utils/bench.py | 2 +- src/axolotl/utils/config.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/axolotl/utils/bench.py b/src/axolotl/utils/bench.py index d19e81ecdc..30f0985e75 100644 --- a/src/axolotl/utils/bench.py +++ b/src/axolotl/utils/bench.py @@ -28,7 +28,7 @@ def gpu_memory_usage_smi(device=0): def log_gpu_memory_usage(log, msg, device): - if not torch.cuda.is_available(): + if not torch.cuda.is_available() or device == "auto": return (0, 0, 0) usage, cache, misc = gpu_memory_usage_all(device) diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index 4561a2e499..a1fe85d597 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -24,10 +24,10 @@ def get_device(): except Exception: # pylint: disable=broad-exception-caught return "cpu" + cfg.device = get_device() if cfg.world_size == 1: cfg.device_map = "auto" else: - cfg.device = get_device() if cfg.device.startswith("cuda"): cfg.device_map = {"": cfg.local_rank} else: