From df097d75b9b508ab36c610db49c353dd3ada107b Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Fri, 19 Jan 2024 15:27:26 -0800 Subject: [PATCH] Update test_spmd_debugging.py to avoid code test code self (#6263) --- test/spmd/test_spmd_debugging.py | 310 ++++++++++++++++++------------- 1 file changed, 180 insertions(+), 130 deletions(-) diff --git a/test/spmd/test_spmd_debugging.py b/test/spmd/test_spmd_debugging.py index c65613837fc..20ae3a3f71f 100644 --- a/test/spmd/test_spmd_debugging.py +++ b/test/spmd/test_spmd_debugging.py @@ -28,20 +28,13 @@ def setUpClass(cls): xr.use_spmd() super().setUpClass() - @unittest.skipIf(xr.device_type() != 'TPU', - f"Requires PJRT_DEVICE set to `TPU`.") + @unittest.skipIf( + xr.device_type() == 'CPU', + f"Requires PJRT_DEVICE set to `TPU`, `GPU`, `CUDA`, or 'ROCM'.") def test_debugging_spmd_single_host_tiled_tpu(self): - from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding - device = xm.xla_device() - num_devices = self.n_devices - mesh_shape = (2, num_devices // 2) - device_ids = np.array(range(num_devices)) - mesh = self._get_mesh(mesh_shape) - t = torch.randn(8, 4, device=device) - partition_spec = (0, 1) - xs.mark_sharding(t, mesh, partition_spec) - sharding = torch_xla._XLAC._get_xla_sharding_spec(t) - generated_table = visualize_tensor_sharding(t) + from torch_xla.distributed.spmd.debugging import visualize_sharding + sharding = '{devices=[2,4]0,1,2,3,4,5,6,7}' + generated_table = visualize_sharding(sharding) console = rich.console.Console() with console.capture() as capture: console.print(generated_table) @@ -49,54 +42,63 @@ def test_debugging_spmd_single_host_tiled_tpu(self): color = None text_color = None + use_color = True if rich.console.Console().color_system else False fake_table = rich.table.Table( show_header=False, - show_lines=True, + show_lines=not use_color, padding=0, - highlight=True, + highlight=not use_color, pad_edge=False, - box=rich.box.SQUARE) + box=rich.box.SQUARE if not use_color else None) col = [] col.append( rich.padding.Padding( - rich.align.Align('TPU 0', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' 0', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) col.append( rich.padding.Padding( - rich.align.Align('TPU 1', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' 1', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) col.append( rich.padding.Padding( - rich.align.Align('TPU 2', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' 2', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) col.append( rich.padding.Padding( - rich.align.Align('TPU 3', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' 3', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) fake_table.add_row(*col) col = [] col.append( rich.padding.Padding( - rich.align.Align('TPU 4', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' 4', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) col.append( rich.padding.Padding( - rich.align.Align('TPU 5', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' 5', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) col.append( rich.padding.Padding( - rich.align.Align('TPU 6', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' 6', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) col.append( rich.padding.Padding( - rich.align.Align('TPU 7', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' 7', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) fake_table.add_row(*col) @@ -106,21 +108,13 @@ def test_debugging_spmd_single_host_tiled_tpu(self): fake_output = fake_capture.get() assert output == fake_output - @unittest.skipIf(xr.device_type() != 'TPU', - f"Requires PJRT_DEVICE set to `TPU`.") + @unittest.skipIf( + xr.device_type() == 'CPU', + f"Requires PJRT_DEVICE set to `TPU`, `GPU`, `CUDA`, or 'ROCM'.") def test_single_host_partial_replication_tpu(self): - from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding - device = xm.xla_device() - num_devices = self.n_devices - mesh_shape = (2, num_devices // 2) - device_ids = np.array(range(num_devices)) - mesh = self._get_mesh(mesh_shape) - - partition_spec = (0, None) - t = torch.randn(8, 32, device=device) - xs.mark_sharding(t, mesh, (0, None)) - sharding = torch_xla._XLAC._get_xla_sharding_spec(t) - generated_table = visualize_tensor_sharding(t) + from torch_xla.distributed.spmd.debugging import visualize_sharding + sharding = '{devices=[4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}' + generated_table = visualize_sharding(sharding) console = rich.console.Console() with console.capture() as capture: console.print(generated_table) @@ -128,24 +122,43 @@ def test_single_host_partial_replication_tpu(self): color = None text_color = None + use_color = True if rich.console.Console().color_system else False fake_table = rich.table.Table( show_header=False, - show_lines=True, + show_lines=not use_color, padding=0, - highlight=True, + highlight=not use_color, pad_edge=False, - box=rich.box.SQUARE) + box=rich.box.SQUARE if not use_color else None) + col = [] + col.append( + rich.padding.Padding( + rich.align.Align( + xr.device_type() + ' [0, 1]', "center", vertical="middle"), + (1, 1, 1, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + fake_table.add_row(*col) + col = [] + col.append( + rich.padding.Padding( + rich.align.Align( + xr.device_type() + ' [2, 3]', "center", vertical="middle"), + (1, 1, 1, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + fake_table.add_row(*col) col = [] col.append( rich.padding.Padding( - rich.align.Align('TPU [0, 1, 2, 3]', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' [4, 5]', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) fake_table.add_row(*col) col = [] col.append( rich.padding.Padding( - rich.align.Align('TPU [4, 5, 6, 7]', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' [6, 7]', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) fake_table.add_row(*col) @@ -155,21 +168,13 @@ def test_single_host_partial_replication_tpu(self): fake_output = fake_capture.get() assert output == fake_output - @unittest.skipIf(xr.device_type() != 'TPU', - f"Requires PJRT_DEVICE set to `TPU`.") + @unittest.skipIf( + xr.device_type() == 'CPU', + f"Requires PJRT_DEVICE set to `TPU`, `GPU`, `CUDA`, or 'ROCM'.") def test_single_host_replicated_tpu(self): - from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding - device = xm.xla_device() - num_devices = self.n_devices - mesh_shape = (2, num_devices // 2) - device_ids = np.array(range(num_devices)) - mesh = self._get_mesh(mesh_shape) - - partition_spec_replicated = (None, None) - t = torch.randn(8, 32, device=device) - xs.mark_sharding(t, mesh, partition_spec_replicated) - sharding = torch_xla._XLAC._get_xla_sharding_spec(t) - generated_table = visualize_tensor_sharding(t) + from torch_xla.distributed.spmd.debugging import visualize_sharding + sharding = '{replicated}' + generated_table = visualize_sharding(sharding) console = rich.console.Console() with console.capture() as capture: console.print(generated_table) @@ -177,18 +182,22 @@ def test_single_host_replicated_tpu(self): color = None text_color = None + use_color = True if rich.console.Console().color_system else False fake_table = rich.table.Table( show_header=False, - show_lines=True, + show_lines=not use_color, padding=0, - highlight=True, + highlight=not use_color, pad_edge=False, - box=rich.box.SQUARE) + box=rich.box.SQUARE if not use_color else None) col = [] + alltpus = xr.device_type() + ' [0' + for i in range(xr.global_runtime_device_count() - 1): + alltpus = alltpus + ',' + str(i + 1) + alltpus = alltpus + ']' col.append( rich.padding.Padding( - rich.align.Align( - 'TPU [0, 1, 2, 3, 4, 5, 6, 7]', "center", vertical="middle"), + rich.align.Align(alltpus, "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) fake_table.add_row(*col) @@ -201,17 +210,18 @@ def test_single_host_replicated_tpu(self): @unittest.skipIf(xr.device_type() != 'CPU', f"Requires PJRT_DEVICE set to `CPU`.") def test_debugging_spmd_single_host_tiled_cpu(self): - from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding + from torch_xla.distributed.spmd.debugging import visualize_sharding device = xm.xla_device() num_devices = self.n_devices mesh_shape = (1, num_devices) device_ids = np.array(range(num_devices)) mesh = self._get_mesh(mesh_shape) - t = torch.randn(8, 4, device=device) - partition_spec = (0, 1) - xs.mark_sharding(t, mesh, partition_spec) + + partition_spec = (0, None) + t = torch.randn(8, 32, device=device) + xs.mark_sharding(t, mesh, (0, None)) sharding = torch_xla._XLAC._get_xla_sharding_spec(t) - generated_table = visualize_tensor_sharding(t) + generated_table = visualize_sharding(sharding) console = rich.console.Console() with console.capture() as capture: console.print(generated_table) @@ -219,13 +229,14 @@ def test_debugging_spmd_single_host_tiled_cpu(self): color = None text_color = None + use_color = True if rich.console.Console().color_system else False fake_table = rich.table.Table( show_header=False, - show_lines=True, + show_lines=not use_color, padding=0, - highlight=True, + highlight=not use_color, pad_edge=False, - box=rich.box.SQUARE) + box=rich.box.SQUARE if not use_color else None) col = [] col.append( rich.padding.Padding( @@ -242,7 +253,7 @@ def test_debugging_spmd_single_host_tiled_cpu(self): @unittest.skipIf(xr.device_type() != 'CPU', f"Requires PJRT_DEVICE set to `CPU`.") def test_single_host_partial_replication_cpu(self): - from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding + from torch_xla.distributed.spmd.debugging import visualize_sharding device = xm.xla_device() num_devices = self.n_devices mesh_shape = (1, num_devices) @@ -253,7 +264,7 @@ def test_single_host_partial_replication_cpu(self): t = torch.randn(8, 32, device=device) xs.mark_sharding(t, mesh, (0, None)) sharding = torch_xla._XLAC._get_xla_sharding_spec(t) - generated_table = visualize_tensor_sharding(t) + generated_table = visualize_sharding(sharding) console = rich.console.Console() with console.capture() as capture: console.print(generated_table) @@ -261,13 +272,14 @@ def test_single_host_partial_replication_cpu(self): color = None text_color = None + use_color = True if rich.console.Console().color_system else False fake_table = rich.table.Table( show_header=False, - show_lines=True, + show_lines=not use_color, padding=0, - highlight=True, + highlight=not use_color, pad_edge=False, - box=rich.box.SQUARE) + box=rich.box.SQUARE if not use_color else None) col = [] col.append( rich.padding.Padding( @@ -284,7 +296,7 @@ def test_single_host_partial_replication_cpu(self): @unittest.skipIf(xr.device_type() != 'CPU', f"Requires PJRT_DEVICE set to `CPU`.") def test_single_host_replicated_cpu(self): - from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding + from torch_xla.distributed.spmd.debugging import visualize_sharding device = xm.xla_device() num_devices = self.n_devices mesh_shape = (1, num_devices) @@ -295,7 +307,7 @@ def test_single_host_replicated_cpu(self): t = torch.randn(8, 32, device=device) xs.mark_sharding(t, mesh, partition_spec_replicated) sharding = torch_xla._XLAC._get_xla_sharding_spec(t) - generated_table = visualize_tensor_sharding(t) + generated_table = visualize_sharding(sharding) console = rich.console.Console() with console.capture() as capture: console.print(generated_table) @@ -303,13 +315,14 @@ def test_single_host_replicated_cpu(self): color = None text_color = None + use_color = True if rich.console.Console().color_system else False fake_table = rich.table.Table( show_header=False, - show_lines=True, + show_lines=not use_color, padding=0, - highlight=True, + highlight=not use_color, pad_edge=False, - box=rich.box.SQUARE) + box=rich.box.SQUARE if not use_color else None) col = [] col.append( rich.padding.Padding( @@ -329,8 +342,9 @@ def test_single_host_replicated_cpu(self): # e.g.: sharding={devices=[8,1,2]0,1,4,5,8,9,12,13,2,3,6,7,10,11,14,15 last_tile_dim_replicate} # e.g.: sharding={replicated} - @unittest.skipIf(xr.device_type() != 'TPU', - f"Requires PJRT_DEVICE set to `TPU`.") + @unittest.skipIf( + xr.device_type() == 'CPU', + f"Requires PJRT_DEVICE set to `TPU`, `GPU`, `CUDA`, or 'ROCM'.") def test_debugging_spmd_multi_host_tiled_tpu(self): from torch_xla.distributed.spmd.debugging import visualize_sharding sharding = '{devices=[2,8]0,4,8,12,2,6,10,14,1,5,9,13,3,7,11,15}' @@ -342,94 +356,111 @@ def test_debugging_spmd_multi_host_tiled_tpu(self): color = None text_color = None + use_color = True if rich.console.Console().color_system else False fake_table = rich.table.Table( show_header=False, - show_lines=True, + show_lines=not use_color, padding=0, - highlight=True, + highlight=not use_color, pad_edge=False, - box=rich.box.SQUARE) + box=rich.box.SQUARE if not use_color else None) col = [] col.append( rich.padding.Padding( - rich.align.Align('TPU 0', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' 0', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) col.append( rich.padding.Padding( - rich.align.Align('TPU 4', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' 4', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) col.append( rich.padding.Padding( - rich.align.Align('TPU 8', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' 8', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) col.append( rich.padding.Padding( - rich.align.Align('TPU 12', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' 12', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) col.append( rich.padding.Padding( - rich.align.Align('TPU 2', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' 2', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) col.append( rich.padding.Padding( - rich.align.Align('TPU 6', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' 6', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) col.append( rich.padding.Padding( - rich.align.Align('TPU 10', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' 10', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) col.append( rich.padding.Padding( - rich.align.Align('TPU 14', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' 14', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) fake_table.add_row(*col) col = [] col.append( rich.padding.Padding( - rich.align.Align('TPU 1', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' 1', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) col.append( rich.padding.Padding( - rich.align.Align('TPU 5', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' 5', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) col.append( rich.padding.Padding( - rich.align.Align('TPU 9', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' 9', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) col.append( rich.padding.Padding( - rich.align.Align('TPU 13', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' 13', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) col.append( rich.padding.Padding( - rich.align.Align('TPU 3', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' 3', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) col.append( rich.padding.Padding( - rich.align.Align('TPU 7', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' 7', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) col.append( rich.padding.Padding( - rich.align.Align('TPU 11', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' 11', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) col.append( rich.padding.Padding( - rich.align.Align('TPU 15', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' 15', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) fake_table.add_row(*col) @@ -439,8 +470,9 @@ def test_debugging_spmd_multi_host_tiled_tpu(self): fake_output = fake_capture.get() assert output == fake_output - @unittest.skipIf(xr.device_type() != 'TPU', - f"Requires PJRT_DEVICE set to `TPU`.") + @unittest.skipIf( + xr.device_type() == 'CPU', + f"Requires PJRT_DEVICE set to `TPU`, `GPU`, `CUDA`, or 'ROCM'.") def test_multi_host_partial_replication_tpu(self): from torch_xla.distributed.spmd.debugging import visualize_sharding sharding = '{devices=[8,1,2]0,1,4,5,8,9,12,13,2,3,6,7,10,11,14,15 last_tile_dim_replicate}' @@ -452,66 +484,75 @@ def test_multi_host_partial_replication_tpu(self): color = None text_color = None + use_color = True if rich.console.Console().color_system else False fake_table = rich.table.Table( show_header=False, - show_lines=True, + show_lines=not use_color, padding=0, - highlight=True, + highlight=not use_color, pad_edge=False, - box=rich.box.SQUARE) + box=rich.box.SQUARE if not use_color else None) col = [] col.append( rich.padding.Padding( - rich.align.Align('TPU [0, 1]', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' [0, 1]', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) fake_table.add_row(*col) col = [] col.append( rich.padding.Padding( - rich.align.Align('TPU [4, 5]', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' [4, 5]', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) fake_table.add_row(*col) col = [] col.append( rich.padding.Padding( - rich.align.Align('TPU [8, 9]', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' [8, 9]', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) fake_table.add_row(*col) col = [] col.append( rich.padding.Padding( - rich.align.Align('TPU [12, 13]', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' [12, 13]', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) fake_table.add_row(*col) col = [] col.append( rich.padding.Padding( - rich.align.Align('TPU [2, 3]', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' [2, 3]', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) fake_table.add_row(*col) col = [] col.append( rich.padding.Padding( - rich.align.Align('TPU [6, 7]', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' [6, 7]', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) fake_table.add_row(*col) col = [] col.append( rich.padding.Padding( - rich.align.Align('TPU [10, 11]', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' [10, 11]', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) fake_table.add_row(*col) col = [] col.append( rich.padding.Padding( - rich.align.Align('TPU [14, 15]', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' [14, 15]', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) fake_table.add_row(*col) @@ -521,8 +562,11 @@ def test_multi_host_partial_replication_tpu(self): fake_output = fake_capture.get() assert output == fake_output - @unittest.skipIf(xr.device_type() != 'TPU', - f"Requires PJRT_DEVICE set to `TPU`.") + @unittest.skipIf( + xr.device_type() == 'CPU', + f"Requires PJRT_DEVICE set to `TPU`, `GPU`, `CUDA`, or 'ROCM'.") + @unittest.skipIf(xr.global_runtime_device_count() != 8, + f"Limit test num_devices to 8 for function consistency") def test_multi_host_replicated_tpu(self): from torch_xla.distributed.spmd.debugging import visualize_sharding sharding = '{replicated}' @@ -534,19 +578,21 @@ def test_multi_host_replicated_tpu(self): color = None text_color = None + use_color = True if rich.console.Console().color_system else False fake_table = rich.table.Table( show_header=False, - show_lines=True, + show_lines=not use_color, padding=0, - highlight=True, + highlight=not use_color, pad_edge=False, - box=rich.box.SQUARE) + box=rich.box.SQUARE if not use_color else None) col = [] col.append( rich.padding.Padding( rich.align.Align( - 'TPU [0, 1, 2, 3, 4, 5, 6, 7]', "center", vertical="middle"), - (1, 1, 1, 1), + xr.device_type() + ' [0, 1, 2, 3, 4, 5, 6, 7]', + "center", + vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) fake_table.add_row(*col) fake_console = rich.console.Console() @@ -568,13 +614,14 @@ def test_debugging_spmd_multi_host_tiled_cpu(self): color = None text_color = None + use_color = True if rich.console.Console().color_system else False fake_table = rich.table.Table( show_header=False, - show_lines=True, + show_lines=not use_color, padding=0, - highlight=True, + highlight=not use_color, pad_edge=False, - box=rich.box.SQUARE) + box=rich.box.SQUARE if not use_color else None) col = [] col.append( rich.padding.Padding( @@ -678,13 +725,14 @@ def test_multi_host_partial_replication_cpu(self): color = None text_color = None + use_color = True if rich.console.Console().color_system else False fake_table = rich.table.Table( show_header=False, - show_lines=True, + show_lines=not use_color, padding=0, - highlight=True, + highlight=not use_color, pad_edge=False, - box=rich.box.SQUARE) + box=rich.box.SQUARE if not use_color else None) col = [] col.append( rich.padding.Padding( @@ -760,14 +808,16 @@ def test_multi_host_replicated_cpu(self): color = None text_color = None + use_color = True if rich.console.Console().color_system else False fake_table = rich.table.Table( show_header=False, - show_lines=True, + show_lines=not use_color, padding=0, - highlight=True, + highlight=not use_color, pad_edge=False, - box=rich.box.SQUARE) + box=rich.box.SQUARE if not use_color else None) col = [] + # PJRT_DEVICE=CPU will only has one CPU, please update once situation change col.append( rich.padding.Padding( rich.align.Align('CPU [0]', "center", vertical="middle"),