diff --git a/src/axolotl/utils/bench.py b/src/axolotl/utils/bench.py index d19e81ecdc..30f0985e75 100644 --- a/src/axolotl/utils/bench.py +++ b/src/axolotl/utils/bench.py @@ -28,7 +28,7 @@ def gpu_memory_usage_smi(device=0): def log_gpu_memory_usage(log, msg, device): - if not torch.cuda.is_available(): + if not torch.cuda.is_available() or device == "auto": return (0, 0, 0) usage, cache, misc = gpu_memory_usage_all(device) diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index c18ec1be3f..479aa6b8f1 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -24,10 +24,10 @@ def get_device(): except Exception: # pylint: disable=broad-exception-caught return "cpu" + cfg.device = get_device() if cfg.world_size == 1: cfg.device_map = "auto" else: - cfg.device = get_device() if cfg.device.startswith("cuda"): cfg.device_map = {"": cfg.local_rank} else: