diff --git a/src/metatrain/utils/devices.py b/src/metatrain/utils/devices.py index 11ffabd14..86d624a63 100644 --- a/src/metatrain/utils/devices.py +++ b/src/metatrain/utils/devices.py @@ -4,18 +4,9 @@ import torch -def _get_available_devices() -> List[str]: - available_devices = ["cpu"] - if torch.cuda.is_available(): - available_devices.append("cuda") - if torch.cuda.device_count() > 1: - available_devices.append("multi-cuda") - # for torch<2.0 `torch.backends.mps.is_available()` is required for a reasonable - # check. - if torch.backends.mps.is_built() and torch.backends.mps.is_available(): - available_devices.append("mps") - - return available_devices +def _mps_is_available() -> bool: + # require `torch.backends.mps.is_available()` for a reasonable check in torch<2.0 + return torch.backends.mps.is_built() and torch.backends.mps.is_available() def pick_devices( @@ -31,10 +22,17 @@ def pick_devices( :param architecture_devices: Devices supported by the architecture. The list should be sorted by the preference of the architecture while the most prefferred device should be first and the least one last. - :param desired_device: desired device by the user + :param desired_device: desired device by the user. For example, ``"cpu"``, + "``cuda``", ``"multi-gpu"``, etc. """ - available_devices = _get_available_devices() + available_devices = ["cpu"] + if torch.cuda.is_available(): + available_devices.append("cuda") + if torch.cuda.device_count() > 1: + available_devices.append("multi-cuda") + if _mps_is_available(): + available_devices.append("mps") # intersect between available and architecture's devices. keep order of architecture possible_devices = [d for d in architecture_devices if d in available_devices] @@ -52,37 +50,55 @@ def pick_devices( else: desired_device = desired_device.lower() - # convert "gpu" and "multi-gpu" to "cuda" or "mps" if available - if desired_device == "gpu": - if torch.cuda.is_available(): - desired_device = "cuda" - elif torch.backends.mps.is_built() and torch.backends.mps.is_available(): - desired_device = "mps" - else: - raise ValueError( - "Requested 'gpu' device, but found no GPU (CUDA or MPS) devices." - ) - if desired_device == "multi-gpu": - desired_device = "multi-cuda" - - if desired_device not in possible_devices: + # convert "gpu" and "multi-gpu" to "cuda" or "mps" if available + if desired_device == "gpu": + if torch.cuda.is_available(): + desired_device = "cuda" + elif _mps_is_available(): + desired_device = "mps" + else: raise ValueError( - f"Unsupported desired device {desired_device!r}. " - f"Please choose from {', '.join(possible_devices)}." - ) - if desired_device == "multi-cuda" and torch.cuda.device_count() < 2: - raise ValueError( - "Requested device 'multi-gpu' or 'multi-cuda', but found only one CUDA " - "device. If you want to run on a single GPU, please use 'gpu' or " - "'cuda' instead." + "Requested 'gpu' device, but found no GPU (CUDA or MPS) devices." ) + elif desired_device == "cuda" and not torch.cuda.is_available(): + raise ValueError("Requested 'cuda' device, but cuda is not available.") + elif desired_device == "mps" and not _mps_is_available(): + raise ValueError("Requested 'mps' device, but mps is not available.") - if possible_devices.index(desired_device) > 0: - warnings.warn( - f"Device {desired_device!r} requested, but {possible_devices[0]!r} is " - "prefferred by the architecture and available on current system.", - stacklevel=2, - ) + if desired_device == "multi-gpu": + desired_device = "multi-cuda" + + if desired_device not in architecture_devices: + raise ValueError( + f"Desired device {desired_device!r} is not supported by the selected " + f"architecture. Please choose from {', '.join(possible_devices)}." + ) + + if desired_device not in available_devices: + raise ValueError( + f"Desired device {desired_device!r} is not supported on " + f"your current system. Please choose from {', '.join(possible_devices)}." + ) + + if possible_devices.index(desired_device) > 0: + warnings.warn( + f"Device {desired_device!r} requested, but {possible_devices[0]!r} is " + "prefferred by the architecture and available on current system.", + stacklevel=2, + ) + + if ( + desired_device == "cuda" + and torch.cuda.device_count() > 1 + and any(d in possible_devices for d in ["multi-cuda", "multi_gpu"]) + ): + warnings.warn( + "Requested single 'cuda' device but current system has " + f"{torch.cuda.device_count()} cuda devices and architecture supports " + "multi-gpu training. Consider using 'multi-gpu' to accelerate " + "training.", + stacklevel=2, + ) # convert the requested device to a list of torch devices if desired_device == "multi-cuda": diff --git a/tests/utils/test_device.py b/tests/utils/test_device.py index c87fee59b..8b174a569 100644 --- a/tests/utils/test_device.py +++ b/tests/utils/test_device.py @@ -7,15 +7,20 @@ file. """ -from typing import List - import pytest import torch -from metatrain.utils import devices from metatrain.utils.devices import pick_devices +def is_true() -> bool: + return True + + +def is_false() -> bool: + return False + + @pytest.mark.parametrize("desired_device", ["cpu", None]) def test_pick_devices(desired_device): picked_devices = pick_devices(["cpu"], desired_device) @@ -24,10 +29,7 @@ def test_pick_devices(desired_device): @pytest.mark.parametrize("desired_device", ["cuda", None]) def test_pick_devices_cuda(desired_device, monkeypatch): - def _get_available_devices() -> List[str]: - return ["cuda", "cpu"] - - monkeypatch.setattr(devices, "_get_available_devices", _get_available_devices) + monkeypatch.setattr(torch.cuda, "is_available", is_true) picked_devices = pick_devices(["cuda", "cpu"], desired_device) @@ -36,11 +38,9 @@ def _get_available_devices() -> List[str]: def test_pick_devices_prefer_architecture(monkeypatch): """Use architecture's preferred device if several matching devices are available.""" - - def _get_available_devices() -> List[str]: - return ["mps", "cpu", "cuda"] - - monkeypatch.setattr(devices, "_get_available_devices", _get_available_devices) + monkeypatch.setattr(torch.cuda, "is_available", is_true) + monkeypatch.setattr(torch.backends.mps, "is_built", is_true) + monkeypatch.setattr(torch.backends.mps, "is_available", is_true) picked_devices = pick_devices(["cuda", "cpu"]) @@ -49,10 +49,8 @@ def _get_available_devices() -> List[str]: @pytest.mark.parametrize("desired_device", ["mps", None]) def test_pick_devices_mps(desired_device, monkeypatch): - def _get_available_devices() -> List[str]: - return ["mps", "cpu"] - - monkeypatch.setattr(devices, "_get_available_devices", _get_available_devices) + monkeypatch.setattr(torch.backends.mps, "is_built", is_true) + monkeypatch.setattr(torch.backends.mps, "is_available", is_true) picked_devices = pick_devices(["mps", "cpu"], desired_device) @@ -60,10 +58,8 @@ def _get_available_devices() -> List[str]: def test_no_matching_device(monkeypatch): - def _get_available_devices() -> List[str]: - return ["cpu"] - - monkeypatch.setattr(devices, "_get_available_devices", _get_available_devices) + monkeypatch.setattr(torch.backends.mps, "is_built", is_false) + monkeypatch.setattr(torch.backends.mps, "is_available", is_false) match = ( "No matching device found! The architecture requires cuda, mps; but your " @@ -73,64 +69,122 @@ def _get_available_devices() -> List[str]: pick_devices(["cuda", "mps"]) -def test_pick_devices_unsoprted(): - match = "Unsupported desired device 'cuda'. Please choose from cpu." +def test_pick_devices_unsupported_by_architecture(monkeypatch): + monkeypatch.setattr(torch.cuda, "is_available", is_true) + match = ( + "Desired device 'cuda' is not supported by the selected architecture. " + "Please choose from cpu." + ) with pytest.raises(ValueError, match=match): pick_devices(["cpu"], "cuda") -def test_pick_devices_preferred_warning(monkeypatch): - def _get_available_devices() -> List[str]: - return ["mps", "cpu"] +@pytest.mark.parametrize("desired_device", ["multi-cuda", "multi-gpu"]) +def test_pick_devices_multi_error(desired_device, monkeypatch): + def device_count() -> int: + return 1 + + monkeypatch.setattr(torch.cuda, "is_available", is_true) + monkeypatch.setattr(torch.cuda, "device_count", device_count) + + match = ( + "Desired device 'multi-cuda' is not supported on your current system. " + "Please choose from cpu." + ) + with pytest.raises(ValueError, match=match): + pick_devices(["multi-cuda", "cpu"], desired_device=desired_device) + - monkeypatch.setattr(devices, "_get_available_devices", _get_available_devices) +def test_pick_devices_preferred_warning(monkeypatch): + monkeypatch.setattr(torch.backends.mps, "is_built", is_true) + monkeypatch.setattr(torch.backends.mps, "is_available", is_true) match = "Device 'cpu' requested, but 'mps' is prefferred" with pytest.warns(UserWarning, match=match): pick_devices(["mps", "cpu", "cuda"], desired_device="cpu") -@pytest.mark.parametrize("desired_device", ["multi-cuda", "multi-gpu"]) -def test_pick_devices_multi_error(desired_device, monkeypatch): - def _get_available_devices() -> List[str]: - return ["multi-cuda", "cuda", "cpu"] +def test_pick_devices_gpu_cuda_map(monkeypatch): + monkeypatch.setattr(torch.cuda, "is_available", is_true) - monkeypatch.setattr(devices, "_get_available_devices", _get_available_devices) + picked_devices = pick_devices(["cuda", "cpu"], "gpu") + assert picked_devices == [torch.device("cuda")] - with pytest.raises(ValueError, match="Requested device 'multi-gpu'"): - pick_devices(["multi-cuda", "cpu"], desired_device=desired_device) +def test_pick_devices_no_cuda(monkeypatch): + monkeypatch.setattr(torch.cuda, "is_available", is_false) -# Below tests that require specific devices to be present -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") -def test_pick_devices_gpu_cuda_map(): - picked_devices = pick_devices(["cuda", "cpu"], "gpu") - assert picked_devices == [torch.device("cuda")] + match = "Requested 'cuda' device, but cuda is not available." + with pytest.raises(ValueError, match=match): + pick_devices(["cuda", "cpu"], "cuda") -@pytest.mark.skipif( - not (torch.backends.mps.is_built() and torch.backends.mps.is_available()), - reason="MPS is not available", -) -def test_pick_devices_gpu_mps_map(): +def test_pick_devices_gpu_mps_map(monkeypatch): + monkeypatch.setattr(torch.backends.mps, "is_built", is_true) + monkeypatch.setattr(torch.backends.mps, "is_available", is_true) + picked_devices = pick_devices(["mps", "cpu"], "gpu") assert picked_devices == [torch.device("mps")] -@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="less than 2 CUDA devices") +@pytest.mark.parametrize( + "is_built, is_available", [(is_true, is_false), (is_false, is_true)] +) +def test_pick_devices_no_mps(monkeypatch, is_built, is_available): + monkeypatch.setattr(torch.backends.mps, "is_built", is_built) + monkeypatch.setattr(torch.backends.mps, "is_available", is_available) + + match = "Requested 'mps' device, but mps is not available." + with pytest.raises(ValueError, match=match): + pick_devices(["mps", "cpu"], "mps") + + @pytest.mark.parametrize("desired_device", ["multi-cuda", "multi-gpu"]) -def test_pick_devices_multi_cuda(desired_device): - picked_devices = pick_devices(["cpu", "cuda", "multi-cuda"], desired_device) +def test_pick_devices_multi_cuda(desired_device, monkeypatch): + def device_count() -> int: + return 2 + + monkeypatch.setattr(torch.cuda, "is_available", is_true) + monkeypatch.setattr(torch.cuda, "device_count", device_count) + + picked_devices = pick_devices(["multi-cuda", "cpu", "cuda"], desired_device) assert picked_devices == [ torch.device(f"cuda:{i}") for i in range(torch.cuda.device_count()) ] -@pytest.mark.skipif( - torch.cuda.is_available() - or (torch.backends.mps.is_built() and torch.backends.mps.is_available()), - reason="GPU device available", +@pytest.mark.parametrize( + "cuda_is_available, mps_is_build, mps_is_available", + [ + (is_false, is_false, is_false), + (is_false, is_true, is_false), + (is_false, is_false, is_true), + ], ) -def test_pick_devices_gpu_not_available(): +def test_pick_devices_gpu_not_available( + cuda_is_available, mps_is_build, mps_is_available, monkeypatch +): + monkeypatch.setattr(torch.cuda, "is_available", cuda_is_available) + monkeypatch.setattr(torch.backends.mps, "is_built", mps_is_build) + monkeypatch.setattr(torch.backends.mps, "is_available", mps_is_available) + with pytest.raises(ValueError, match="Requested 'gpu' device, but found no GPU"): - pick_devices(["cuda", "cpu"], "gpu") + pick_devices(["mps", "cpu"], "gpu") + + +def test_multi_gpu_warning(monkeypatch): + def device_count() -> int: + return 2 + + monkeypatch.setattr(torch.cuda, "is_available", is_true) + monkeypatch.setattr(torch.cuda, "device_count", device_count) + + match = ( + "Requested single 'cuda' device but current system has 2 cuda devices and " + "architecture supports multi-gpu training. Consider using 'multi-gpu' to " + "accelerate training." + ) + with pytest.warns(UserWarning, match=match): + picked_devices = pick_devices(["cuda", "multi-cuda", "cpu"], "cuda") + + assert picked_devices == [torch.device("cuda")]