diff --git a/test/spmd/test_spmd_debugging.py b/test/spmd/test_spmd_debugging.py index 6a960699628..61af77932af 100644 --- a/test/spmd/test_spmd_debugging.py +++ b/test/spmd/test_spmd_debugging.py @@ -29,7 +29,7 @@ def setUpClass(cls): super().setUpClass() @unittest.skipIf(xr.device_type() == 'CPU', - f"Requires PJRT_DEVICE set to `TPU`, `GPU`, or `CUDA`.") + f"Requires PJRT_DEVICE set to `TPU`, `GPU`, `CUDA`, or 'ROCM'.") def test_debugging_spmd_single_host_tiled_tpu(self): from torch_xla.distributed.spmd.debugging import visualize_sharding sharding = '{devices=[2,4]0,1,2,3,4,5,6,7}' @@ -107,8 +107,8 @@ def test_debugging_spmd_single_host_tiled_tpu(self): fake_output = fake_capture.get() assert output == fake_output - @unittest.skipIf(xr.device_type() != 'TPU', - f"Requires PJRT_DEVICE set to `TPU`.") + @unittest.skipIf(xr.device_type() == 'CPU', + f"Requires PJRT_DEVICE set to `TPU`, `GPU`, `CUDA`, or 'ROCM'.") def test_single_host_partial_replication_tpu(self): from torch_xla.distributed.spmd.debugging import visualize_sharding sharding = '{devices=[4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}' @@ -162,8 +162,8 @@ def test_single_host_partial_replication_tpu(self): fake_output = fake_capture.get() assert output == fake_output - @unittest.skipIf(xr.device_type() != 'TPU', - f"Requires PJRT_DEVICE set to `TPU`.") + @unittest.skipIf(xr.device_type() == 'CPU', + f"Requires PJRT_DEVICE set to `TPU`, `GPU`, `CUDA`, or 'ROCM'.") def test_single_host_replicated_tpu(self): from torch_xla.distributed.spmd.debugging import visualize_sharding sharding = '{replicated}' @@ -335,8 +335,8 @@ def test_single_host_replicated_cpu(self): # e.g.: sharding={devices=[8,1,2]0,1,4,5,8,9,12,13,2,3,6,7,10,11,14,15 last_tile_dim_replicate} # e.g.: sharding={replicated} - @unittest.skipIf(xr.device_type() != 'TPU', - f"Requires PJRT_DEVICE set to `TPU`.") + @unittest.skipIf(xr.device_type() == 'CPU', + f"Requires PJRT_DEVICE set to `TPU`, `GPU`, `CUDA`, or 'ROCM'.") def test_debugging_spmd_multi_host_tiled_tpu(self): from torch_xla.distributed.spmd.debugging import visualize_sharding sharding = '{devices=[2,8]0,4,8,12,2,6,10,14,1,5,9,13,3,7,11,15}' @@ -446,8 +446,8 @@ def test_debugging_spmd_multi_host_tiled_tpu(self): fake_output = fake_capture.get() assert output == fake_output - @unittest.skipIf(xr.device_type() != 'TPU', - f"Requires PJRT_DEVICE set to `TPU`.") + @unittest.skipIf(xr.device_type() == 'CPU', + f"Requires PJRT_DEVICE set to `TPU`, `GPU`, `CUDA`, or 'ROCM'.") def test_multi_host_partial_replication_tpu(self): from torch_xla.distributed.spmd.debugging import visualize_sharding sharding = '{devices=[8,1,2]0,1,4,5,8,9,12,13,2,3,6,7,10,11,14,15 last_tile_dim_replicate}' @@ -529,8 +529,8 @@ def test_multi_host_partial_replication_tpu(self): fake_output = fake_capture.get() assert output == fake_output - @unittest.skipIf(xr.device_type() != 'TPU', - f"Requires PJRT_DEVICE set to `TPU`.") + @unittest.skipIf(xr.device_type() == 'CPU', + f"Requires PJRT_DEVICE set to `TPU`, `GPU`, `CUDA`, or 'ROCM'.") @unittest.skipIf(xr.global_runtime_device_count() != 8, f"Limit test num_devices to 8 for function consistency") def test_multi_host_replicated_tpu(self):