Skip to content

Commit

Permalink
Update test_spmd_debugging.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ManfeiBai authored Nov 28, 2023
1 parent 24ae8f6 commit e58863a
Showing 1 changed file with 17 additions and 12 deletions.
29 changes: 17 additions & 12 deletions test/spmd/test_spmd_debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,12 @@ def test_debugging_spmd_single_host_tiled(self):
Mesh.mark_sharding(t, mesh, partition_spec)
sharding = torch_xla._XLAC._get_xla_sharding_spec(t)
generated_table = visualize_tensor_sharding(t)
console = rich.console.Console(file=io.StringIO(), width=120)
# console.print(generated_table) #TODO
output = console.file.getvalue()
console = Console()
with console.capture() as capture:
console.print(generated_table)
output = capture.get()

fake_console = rich.console.Console(file=io.StringIO(), width=120)
# fake_console = rich.console.Console(file=io.StringIO(), width=120)
color = None
text_color = None
fask_table = rich.table.Table(
Expand Down Expand Up @@ -103,8 +104,10 @@ def test_debugging_spmd_single_host_tiled(self):
(2, 1, 2, 1),
style=rich.style.Style(bgcolor=color, color=text_color)))
fask_table.add_row(*col)
fake_console.print(fask_table)
fake_output = fake_console.file.getvalue()
fake_console = Console()
with fake_console.capture() as fake_capture:
fake_console.print(fake_table)
fake_output = fake_capture.get()
assert output == fake_output

@unittest.skipIf(
Expand Down Expand Up @@ -177,9 +180,10 @@ def test_single_host_replicated(self):
xs.mark_sharding(t, mesh, partition_spec_replicated)
sharding = torch_xla._XLAC._get_xla_sharding_spec(t)
generated_table = visualize_tensor_sharding(t)
console = rich.console.Console(file=io.StringIO(), width=120)
# console.print(generated_table) #TODO
output = console.file.getvalue()
console = Console()
with console.capture() as capture:
console.print(generated_table)
output = capture.get()

color = None
text_color = None
Expand All @@ -198,9 +202,10 @@ def test_single_host_replicated(self):
(0, 0, 1, 0),
style=rich.style.Style(bgcolor=color, color=text_color)))
fask_table.add_row(*col)
fake_console = rich.console.Console(file=io.StringIO(), width=120)
fake_console.print(fask_table)
fake_output = fake_console.file.getvalue()
fake_console = Console()
with fake_console.capture() as fake_capture:
fake_console.print(fake_table)
fake_output = fake_capture.get()
assert output == fake_output


Expand Down

0 comments on commit e58863a

Please sign in to comment.