From 9c66fdbe846924c151218577289d82fcca6633cc Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Wed, 29 Nov 2023 13:18:03 -0800 Subject: [PATCH] Update test_spmd_debugging.py --- test/spmd/test_spmd_debugging.py | 56 ++++++++++++++++++++++++++++++-- 1 file changed, 54 insertions(+), 2 deletions(-) diff --git a/test/spmd/test_spmd_debugging.py b/test/spmd/test_spmd_debugging.py index 9d4edb7153d..4365ddaea6b 100644 --- a/test/spmd/test_spmd_debugging.py +++ b/test/spmd/test_spmd_debugging.py @@ -28,14 +28,16 @@ def setUpClass(cls): xr.use_spmd() super().setUpClass() + @unittest.skipIf( not xr.using_pjrt() or - xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM'), - f"Requires PJRT_DEVICE set to `TPU`, `CPU`.") + xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'CPU'), + f"Requires PJRT_DEVICE set to `TPU`.") def test_debugging_spmd_single_host_tiled(self): from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding device = xm.xla_device() num_devices = xr.global_runtime_device_count() + # TPU has more than 1 core here, so setup tesst mesh_shape with 2*x mesh_shape = (2, num_devices / 2) print("num_devices: ") print(num_devices) @@ -209,6 +211,56 @@ def test_single_host_replicated(self): assert output == fake_output + @unittest.skipIf( + not xr.using_pjrt() or + xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'TPU'), + f"Requires PJRT_DEVICE set to `CPU`.") + def test_debugging_spmd_single_host_tiled_cpu(self): + from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding + device = xm.xla_device() + num_devices = xr.global_runtime_device_count() + # 1 CPU testing env + mesh_shape = (1, num_devices) + print("num_devices: ") + print(num_devices) + print("device: ") + print(device) + device_ids = np.array(range(num_devices)) + print("device_ids: ") + print(device_ids) + # mesh = Mesh(device_ids, mesh_shape, ('x', 'y')) + mesh = self._get_mesh(mesh_shape) + t = torch.randn(8, 4, device=device) + partition_spec = (0, 1) + xs.mark_sharding(t, mesh, partition_spec) + sharding = torch_xla._XLAC._get_xla_sharding_spec(t) + generatedtable = visualize_tensor_sharding(t) + console = rich.console.Console(file=io.StringIO(), width=120) + console.print(generatedtable) + output = console.file.getvalue() + + fake_console = rich.console.Console(file=io.StringIO(), width=120) + color = None + text_color = None + fask_table = rich.table.Table( + show_header=False, + show_lines=True, + padding=0, + highlight=True, + pad_edge=False, + box=rich.box.SQUARE) + col = [] + col.append( + rich.padding.Padding( + rich.align.Align('TPU 0', "center", vertical="middle"), + (2, 1, 2, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + fask_table.add_row(*col) + fake_console.print(fask_table) + fake_output = fake_console.file.getvalue() + assert output == fake_output + + if __name__ == '__main__': test = unittest.main() sys.exit(0 if test.result.wasSuccessful() else 1)