From 5f8d74f4ce075a368f59a7d7f26c0f83c9a85bb2 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 13 Sep 2023 11:45:30 -0400 Subject: [PATCH] Model parallel (#538) * model-parallel for single process * fix device/device_map * fix handling for device --- src/axolotl/utils/bench.py | 2 +- src/axolotl/utils/config.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/axolotl/utils/bench.py b/src/axolotl/utils/bench.py index d19e81ecdc..30f0985e75 100644 --- a/src/axolotl/utils/bench.py +++ b/src/axolotl/utils/bench.py @@ -28,7 +28,7 @@ def gpu_memory_usage_smi(device=0): def log_gpu_memory_usage(log, msg, device): - if not torch.cuda.is_available(): + if not torch.cuda.is_available() or device == "auto": return (0, 0, 0) usage, cache, misc = gpu_memory_usage_all(device) diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index 9b9f3cdb89..90ed409b9c 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -25,7 +25,9 @@ def get_device(): return "cpu" cfg.device = get_device() - if cfg.device_map != "auto": + if cfg.world_size == 1: + cfg.device_map = "auto" + else: if cfg.device.startswith("cuda"): cfg.device_map = {"": cfg.local_rank} else: