Skip to content

Commit

Permalink
test cpu
Browse files Browse the repository at this point in the history
  • Loading branch information
ManfeiBai authored Nov 26, 2023
1 parent 9666e5f commit 8f3fdfc
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions test/spmd/test_spmd_debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def setUpClass(cls):

@unittest.skipIf(
not xr.using_pjrt() or
xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'CPU'),
f"Requires PJRT_DEVICE set to `TPU`.")
xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM'),
f"Requires PJRT_DEVICE set to `TPU`, `CPU`.")
def test_debugging_spmd_single_host_tiled(self):
from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding
device = xm.xla_device()
Expand Down Expand Up @@ -108,8 +108,8 @@ def test_debugging_spmd_single_host_tiled(self):

@unittest.skipIf(
not xr.using_pjrt() or
xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'CPU'),
f"Requires PJRT_DEVICE set to `TPU`.")
xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM'),
f"Requires PJRT_DEVICE set to `TPU`, 'CPU'.")
def test_single_host_partial_replication(self):
from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding
device = xm.xla_device()
Expand Down Expand Up @@ -158,8 +158,8 @@ def test_single_host_partial_replication(self):

@unittest.skipIf(
not xr.using_pjrt() or
xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'CPU'),
f"Requires PJRT_DEVICE set to `TPU`.")
xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM'),
f"Requires PJRT_DEVICE set to `TPU`, 'CPU'.")
def test_single_host_replicated(self):
from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding
device = xm.xla_device()
Expand Down

0 comments on commit 8f3fdfc

Please sign in to comment.