From 8f3fdfca0a508dfdf2e644d1cd0b421ccd6917e0 Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Sun, 26 Nov 2023 15:25:40 -0800 Subject: [PATCH] test cpu --- test/spmd/test_spmd_debugging.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/test/spmd/test_spmd_debugging.py b/test/spmd/test_spmd_debugging.py index c67edc55e1b..68ab7b99882 100644 --- a/test/spmd/test_spmd_debugging.py +++ b/test/spmd/test_spmd_debugging.py @@ -30,8 +30,8 @@ def setUpClass(cls): @unittest.skipIf( not xr.using_pjrt() or - xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'CPU'), - f"Requires PJRT_DEVICE set to `TPU`.") + xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM'), + f"Requires PJRT_DEVICE set to `TPU`, `CPU`.") def test_debugging_spmd_single_host_tiled(self): from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding device = xm.xla_device() @@ -108,8 +108,8 @@ def test_debugging_spmd_single_host_tiled(self): @unittest.skipIf( not xr.using_pjrt() or - xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'CPU'), - f"Requires PJRT_DEVICE set to `TPU`.") + xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM'), + f"Requires PJRT_DEVICE set to `TPU`, 'CPU'.") def test_single_host_partial_replication(self): from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding device = xm.xla_device() @@ -158,8 +158,8 @@ def test_single_host_partial_replication(self): @unittest.skipIf( not xr.using_pjrt() or - xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'CPU'), - f"Requires PJRT_DEVICE set to `TPU`.") + xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM'), + f"Requires PJRT_DEVICE set to `TPU`, 'CPU'.") def test_single_host_replicated(self): from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding device = xm.xla_device()