From 972bd42d0a7a6e08e1155ae492e99a2c1a96a6da Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 19 Sep 2023 17:49:11 -0400 Subject: [PATCH 1/4] skip the gpu memory checks if the device is set to 'auto' --- src/axolotl/utils/bench.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/src/axolotl/utils/bench.py b/src/axolotl/utils/bench.py index b460b2ba7c..794c1e50df 100644 --- a/src/axolotl/utils/bench.py +++ b/src/axolotl/utils/bench.py @@ -1,14 +1,40 @@ """Benchmarking and measurement utilities""" +import functools import pynvml import torch from pynvml.nvml import NVMLError +def check_cuda_device(default_value): + """ + wraps a function and returns the default value instead of running the + wrapped function if cuda isn't available or the device is auto + :param default_value: + :return: + """ + + def actual_decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + device = kwargs.get("device", args[0] if args else None) + + if not torch.cuda.is_available() or device == "auto": + return default_value + + return func(*args, **kwargs) + + return wrapper + + return actual_decorator + + +@check_cuda_device(0.0) def gpu_memory_usage(device=0): return torch.cuda.memory_allocated(device) / 1024.0**3 +@check_cuda_device((0.0, 0.0, 0.0)) def gpu_memory_usage_all(device=0): usage = torch.cuda.memory_allocated(device) / 1024.0**3 reserved = torch.cuda.memory_reserved(device) / 1024.0**3 @@ -16,6 +42,7 @@ def gpu_memory_usage_all(device=0): return usage, reserved - usage, max(0, smi - reserved) +@check_cuda_device(0.0) def gpu_memory_usage_smi(device=0): if isinstance(device, torch.device): device = device.index From b1f1f44757ffa0025760eb1fc11b6cec3b3723df Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 21 Sep 2023 09:13:25 -0400 Subject: [PATCH 2/4] skip gpu mem logging if cpu too --- src/axolotl/utils/bench.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/utils/bench.py b/src/axolotl/utils/bench.py index 794c1e50df..82f4dfddf2 100644 --- a/src/axolotl/utils/bench.py +++ b/src/axolotl/utils/bench.py @@ -19,7 +19,7 @@ def actual_decorator(func): def wrapper(*args, **kwargs): device = kwargs.get("device", args[0] if args else None) - if not torch.cuda.is_available() or device == "auto": + if not torch.cuda.is_available() or device == "auto" or device == "cpu": return default_value return func(*args, **kwargs) From 07c943670bd6bbca01e0f0c9c9e07d993c6fc928 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 21 Sep 2023 09:15:46 -0400 Subject: [PATCH 3/4] don't worry about log_gpu_memory_usage since it calls another annotated fn --- src/axolotl/utils/bench.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/axolotl/utils/bench.py b/src/axolotl/utils/bench.py index 82f4dfddf2..e3b445023c 100644 --- a/src/axolotl/utils/bench.py +++ b/src/axolotl/utils/bench.py @@ -58,9 +58,6 @@ def gpu_memory_usage_smi(device=0): def log_gpu_memory_usage(log, msg, device): - if not torch.cuda.is_available() or device == "auto": - return (0, 0, 0) - usage, cache, misc = gpu_memory_usage_all(device) extras = [] if cache > 0: From b72827e1257f73429dcf3a1c49c4047dc1ac1054 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 21 Sep 2023 09:17:01 -0400 Subject: [PATCH 4/4] rename decorator internal --- src/axolotl/utils/bench.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/axolotl/utils/bench.py b/src/axolotl/utils/bench.py index e3b445023c..685be526f0 100644 --- a/src/axolotl/utils/bench.py +++ b/src/axolotl/utils/bench.py @@ -14,7 +14,7 @@ def check_cuda_device(default_value): :return: """ - def actual_decorator(func): + def deco(func): @functools.wraps(func) def wrapper(*args, **kwargs): device = kwargs.get("device", args[0] if args else None) @@ -26,7 +26,7 @@ def wrapper(*args, **kwargs): return wrapper - return actual_decorator + return deco @check_cuda_device(0.0)