Skip to content

Commit

Permalink
Update test_spmd_debugging.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ManfeiBai authored Jan 18, 2024
1 parent 01488e3 commit 7eb80d8
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions test/spmd/test_spmd_debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def setUpClass(cls):
super().setUpClass()

@unittest.skipIf(xr.device_type() == 'CPU',
f"Requires PJRT_DEVICE set to `TPU`, `GPU`, or `CUDA`.")
f"Requires PJRT_DEVICE set to `TPU`, `GPU`, `CUDA`, or 'ROCM'.")
def test_debugging_spmd_single_host_tiled_tpu(self):
from torch_xla.distributed.spmd.debugging import visualize_sharding
sharding = '{devices=[2,4]0,1,2,3,4,5,6,7}'
Expand Down Expand Up @@ -107,8 +107,8 @@ def test_debugging_spmd_single_host_tiled_tpu(self):
fake_output = fake_capture.get()
assert output == fake_output

@unittest.skipIf(xr.device_type() != 'TPU',
f"Requires PJRT_DEVICE set to `TPU`.")
@unittest.skipIf(xr.device_type() == 'CPU',
f"Requires PJRT_DEVICE set to `TPU`, `GPU`, `CUDA`, or 'ROCM'.")
def test_single_host_partial_replication_tpu(self):
from torch_xla.distributed.spmd.debugging import visualize_sharding
sharding = '{devices=[4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}'
Expand Down Expand Up @@ -162,8 +162,8 @@ def test_single_host_partial_replication_tpu(self):
fake_output = fake_capture.get()
assert output == fake_output

@unittest.skipIf(xr.device_type() != 'TPU',
f"Requires PJRT_DEVICE set to `TPU`.")
@unittest.skipIf(xr.device_type() == 'CPU',
f"Requires PJRT_DEVICE set to `TPU`, `GPU`, `CUDA`, or 'ROCM'.")
def test_single_host_replicated_tpu(self):
from torch_xla.distributed.spmd.debugging import visualize_sharding
sharding = '{replicated}'
Expand Down Expand Up @@ -335,8 +335,8 @@ def test_single_host_replicated_cpu(self):
# e.g.: sharding={devices=[8,1,2]0,1,4,5,8,9,12,13,2,3,6,7,10,11,14,15 last_tile_dim_replicate}
# e.g.: sharding={replicated}

@unittest.skipIf(xr.device_type() != 'TPU',
f"Requires PJRT_DEVICE set to `TPU`.")
@unittest.skipIf(xr.device_type() == 'CPU',
f"Requires PJRT_DEVICE set to `TPU`, `GPU`, `CUDA`, or 'ROCM'.")
def test_debugging_spmd_multi_host_tiled_tpu(self):
from torch_xla.distributed.spmd.debugging import visualize_sharding
sharding = '{devices=[2,8]0,4,8,12,2,6,10,14,1,5,9,13,3,7,11,15}'
Expand Down Expand Up @@ -446,8 +446,8 @@ def test_debugging_spmd_multi_host_tiled_tpu(self):
fake_output = fake_capture.get()
assert output == fake_output

@unittest.skipIf(xr.device_type() != 'TPU',
f"Requires PJRT_DEVICE set to `TPU`.")
@unittest.skipIf(xr.device_type() == 'CPU',
f"Requires PJRT_DEVICE set to `TPU`, `GPU`, `CUDA`, or 'ROCM'.")
def test_multi_host_partial_replication_tpu(self):
from torch_xla.distributed.spmd.debugging import visualize_sharding
sharding = '{devices=[8,1,2]0,1,4,5,8,9,12,13,2,3,6,7,10,11,14,15 last_tile_dim_replicate}'
Expand Down Expand Up @@ -529,8 +529,8 @@ def test_multi_host_partial_replication_tpu(self):
fake_output = fake_capture.get()
assert output == fake_output

@unittest.skipIf(xr.device_type() != 'TPU',
f"Requires PJRT_DEVICE set to `TPU`.")
@unittest.skipIf(xr.device_type() == 'CPU',
f"Requires PJRT_DEVICE set to `TPU`, `GPU`, `CUDA`, or 'ROCM'.")
@unittest.skipIf(xr.global_runtime_device_count() != 8,
f"Limit test num_devices to 8 for function consistency")
def test_multi_host_replicated_tpu(self):
Expand Down

0 comments on commit 7eb80d8

Please sign in to comment.