diff --git a/test/spmd/test_spmd_debugging.py b/test/spmd/test_spmd_debugging.py index b61870692de..3fff925c7a0 100644 --- a/test/spmd/test_spmd_debugging.py +++ b/test/spmd/test_spmd_debugging.py @@ -45,11 +45,12 @@ def test_debugging_spmd_single_host_tiled(self): Mesh.mark_sharding(t, mesh, partition_spec) sharding = torch_xla._XLAC._get_xla_sharding_spec(t) generated_table = visualize_tensor_sharding(t) - console = rich.console.Console(file=io.StringIO(), width=120) - # console.print(generated_table) #TODO - output = console.file.getvalue() + console = Console() + with console.capture() as capture: + console.print(generated_table) + output = capture.get() - fake_console = rich.console.Console(file=io.StringIO(), width=120) + # fake_console = rich.console.Console(file=io.StringIO(), width=120) color = None text_color = None fask_table = rich.table.Table( @@ -103,8 +104,10 @@ def test_debugging_spmd_single_host_tiled(self): (2, 1, 2, 1), style=rich.style.Style(bgcolor=color, color=text_color))) fask_table.add_row(*col) - fake_console.print(fask_table) - fake_output = fake_console.file.getvalue() + fake_console = Console() + with fake_console.capture() as fake_capture: + fake_console.print(fake_table) + fake_output = fake_capture.get() assert output == fake_output @unittest.skipIf( @@ -177,9 +180,10 @@ def test_single_host_replicated(self): xs.mark_sharding(t, mesh, partition_spec_replicated) sharding = torch_xla._XLAC._get_xla_sharding_spec(t) generated_table = visualize_tensor_sharding(t) - console = rich.console.Console(file=io.StringIO(), width=120) - # console.print(generated_table) #TODO - output = console.file.getvalue() + console = Console() + with console.capture() as capture: + console.print(generated_table) + output = capture.get() color = None text_color = None @@ -198,9 +202,10 @@ def test_single_host_replicated(self): (0, 0, 1, 0), style=rich.style.Style(bgcolor=color, color=text_color))) fask_table.add_row(*col) - fake_console = rich.console.Console(file=io.StringIO(), width=120) - fake_console.print(fask_table) - fake_output = fake_console.file.getvalue() + fake_console = Console() + with fake_console.capture() as fake_capture: + fake_console.print(fake_table) + fake_output = fake_capture.get() assert output == fake_output