Skip to content

Commit

Permalink
Update test_spmd_debugging.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ManfeiBai authored Nov 29, 2023
1 parent 6b3072b commit 9c66fdb
Showing 1 changed file with 54 additions and 2 deletions.
56 changes: 54 additions & 2 deletions test/spmd/test_spmd_debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,16 @@ def setUpClass(cls):
xr.use_spmd()
super().setUpClass()


@unittest.skipIf(
not xr.using_pjrt() or
xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM'),
f"Requires PJRT_DEVICE set to `TPU`, `CPU`.")
xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'CPU'),
f"Requires PJRT_DEVICE set to `TPU`.")
def test_debugging_spmd_single_host_tiled(self):
from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding
device = xm.xla_device()
num_devices = xr.global_runtime_device_count()
# TPU has more than 1 core here, so setup tesst mesh_shape with 2*x
mesh_shape = (2, num_devices / 2)
print("num_devices: ")
print(num_devices)
Expand Down Expand Up @@ -209,6 +211,56 @@ def test_single_host_replicated(self):
assert output == fake_output


@unittest.skipIf(
not xr.using_pjrt() or
xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'TPU'),
f"Requires PJRT_DEVICE set to `CPU`.")
def test_debugging_spmd_single_host_tiled_cpu(self):
from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding
device = xm.xla_device()
num_devices = xr.global_runtime_device_count()
# 1 CPU testing env
mesh_shape = (1, num_devices)
print("num_devices: ")
print(num_devices)
print("device: ")
print(device)
device_ids = np.array(range(num_devices))
print("device_ids: ")
print(device_ids)
# mesh = Mesh(device_ids, mesh_shape, ('x', 'y'))
mesh = self._get_mesh(mesh_shape)
t = torch.randn(8, 4, device=device)
partition_spec = (0, 1)
xs.mark_sharding(t, mesh, partition_spec)
sharding = torch_xla._XLAC._get_xla_sharding_spec(t)
generatedtable = visualize_tensor_sharding(t)
console = rich.console.Console(file=io.StringIO(), width=120)
console.print(generatedtable)
output = console.file.getvalue()

fake_console = rich.console.Console(file=io.StringIO(), width=120)
color = None
text_color = None
fask_table = rich.table.Table(
show_header=False,
show_lines=True,
padding=0,
highlight=True,
pad_edge=False,
box=rich.box.SQUARE)
col = []
col.append(
rich.padding.Padding(
rich.align.Align('TPU 0', "center", vertical="middle"),
(2, 1, 2, 1),
style=rich.style.Style(bgcolor=color, color=text_color)))
fask_table.add_row(*col)
fake_console.print(fask_table)
fake_output = fake_console.file.getvalue()
assert output == fake_output


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)

0 comments on commit 9c66fdb

Please sign in to comment.