diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index 90ed409b9c..a31f34b73e 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -29,7 +29,7 @@ def get_device(): cfg.device_map = "auto" else: if cfg.device.startswith("cuda"): - cfg.device_map = {"": cfg.local_rank} + cfg.device_map = {"": torch.cuda.current_device()} else: cfg.device_map = {"": cfg.device}