Skip to content

Commit

Permalink
Move MPS device check to cheetah.utils
Browse files Browse the repository at this point in the history
  • Loading branch information
cr-xu committed Jul 24, 2024
1 parent 17e2fb9 commit 03aed92
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 17 deletions.
1 change: 1 addition & 0 deletions cheetah/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .device import is_mps_available_and_functional # noqa: F401
from .kde import kde_histogram_1d, kde_histogram_2d # noqa: F401
from .unique_name_generator import UniqueNameGenerator # noqa: F401
13 changes: 13 additions & 0 deletions cheetah/utils/device.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import torch


def is_mps_available_and_functional():
"""Check if MPS is available and functional (for GitHub Actions)."""
if not torch.backends.mps.is_available():
return False
try:
# Try to allocate a small tensor on the MPS device
torch.tensor([1.0], device="mps")
return True
except RuntimeError:
return False
13 changes: 0 additions & 13 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +0,0 @@
import torch


def is_mps_available_and_functional():
"""Check if MPS is available and functional (for GitHub Actions)."""
if not torch.backends.mps.is_available():
return False
try:
# Try to allocate a small tensor on the MPS device
torch.tensor([1.0], device="mps")
return True
except RuntimeError:
return False
3 changes: 1 addition & 2 deletions tests/test_bmad_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
import torch

import cheetah

from . import is_mps_available_and_functional
from cheetah.utils import is_mps_available_and_functional


def test_bmad_tutorial():
Expand Down
3 changes: 1 addition & 2 deletions tests/test_device_dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
import torch

import cheetah

from . import is_mps_available_and_functional
from cheetah.utils import is_mps_available_and_functional


@pytest.mark.parametrize(
Expand Down

0 comments on commit 03aed92

Please sign in to comment.