diff --git a/cheetah/utils/__init__.py b/cheetah/utils/__init__.py index 4bf5ed6b..06bb63f0 100644 --- a/cheetah/utils/__init__.py +++ b/cheetah/utils/__init__.py @@ -1,2 +1,3 @@ +from .device import is_mps_available_and_functional # noqa: F401 from .kde import kde_histogram_1d, kde_histogram_2d # noqa: F401 from .unique_name_generator import UniqueNameGenerator # noqa: F401 diff --git a/cheetah/utils/device.py b/cheetah/utils/device.py new file mode 100644 index 00000000..e6ac629d --- /dev/null +++ b/cheetah/utils/device.py @@ -0,0 +1,13 @@ +import torch + + +def is_mps_available_and_functional(): + """Check if MPS is available and functional (for GitHub Actions).""" + if not torch.backends.mps.is_available(): + return False + try: + # Try to allocate a small tensor on the MPS device + torch.tensor([1.0], device="mps") + return True + except RuntimeError: + return False diff --git a/tests/__init__.py b/tests/__init__.py index e6ac629d..e69de29b 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,13 +0,0 @@ -import torch - - -def is_mps_available_and_functional(): - """Check if MPS is available and functional (for GitHub Actions).""" - if not torch.backends.mps.is_available(): - return False - try: - # Try to allocate a small tensor on the MPS device - torch.tensor([1.0], device="mps") - return True - except RuntimeError: - return False diff --git a/tests/test_bmad_conversion.py b/tests/test_bmad_conversion.py index 63db4329..9dc78849 100644 --- a/tests/test_bmad_conversion.py +++ b/tests/test_bmad_conversion.py @@ -2,8 +2,7 @@ import torch import cheetah - -from . import is_mps_available_and_functional +from cheetah.utils import is_mps_available_and_functional def test_bmad_tutorial(): diff --git a/tests/test_device_dtype.py b/tests/test_device_dtype.py index b2e0fa9f..67e61205 100644 --- a/tests/test_device_dtype.py +++ b/tests/test_device_dtype.py @@ -2,8 +2,7 @@ import torch import cheetah - -from . import is_mps_available_and_functional +from cheetah.utils import is_mps_available_and_functional @pytest.mark.parametrize(