From 2c15e57fdb4e42a175c08abd49daf0e1e4e5d413 Mon Sep 17 00:00:00 2001 From: jchu Date: Mon, 5 Aug 2024 19:03:56 +0000 Subject: [PATCH] #10755: visualize DeviceMesh via `ttnn.visualize(device_mesh)` api --- .../test_multidevice_TG.py | 5 +++ ttnn/ttnn/__init__.py | 1 + ttnn/ttnn/multi_device.py | 34 +++++++++++++++++++ 3 files changed, 40 insertions(+) diff --git a/tests/ttnn/multichip_unit_tests/test_multidevice_TG.py b/tests/ttnn/multichip_unit_tests/test_multidevice_TG.py index 7680dcac764..a11fa5d7dc9 100644 --- a/tests/ttnn/multichip_unit_tests/test_multidevice_TG.py +++ b/tests/ttnn/multichip_unit_tests/test_multidevice_TG.py @@ -1331,3 +1331,8 @@ def test_device_line_all_gather_8x4_data(device_mesh, cluster_axis: int, dim: in expected = full_tensor[..., row_index * tile_size : (row_index + 1) * tile_size, :] assert torch.allclose(device_tensor_torch, expected, atol=1e-3) + + +@pytest.mark.parametrize("device_mesh", [pytest.param((8, 4), id="8x4_grid")], indirect=True) +def test_visualize_device_mesh(device_mesh): + ttnn.visualize_device_mesh(device_mesh) diff --git a/ttnn/ttnn/__init__.py b/ttnn/ttnn/__init__.py index c66efebc6bb..662c27cfd25 100644 --- a/ttnn/ttnn/__init__.py +++ b/ttnn/ttnn/__init__.py @@ -166,6 +166,7 @@ def manage_config(name, value): MeshToTensor, ConcatMeshToTensor, ListMeshToTensor, + visualize_device_mesh, ) from ttnn.core import ( diff --git a/ttnn/ttnn/multi_device.py b/ttnn/ttnn/multi_device.py index 8f52d65b60e..2fb7a1fc1a3 100644 --- a/ttnn/ttnn/multi_device.py +++ b/ttnn/ttnn/multi_device.py @@ -18,6 +18,40 @@ def get_device_mesh_core_grid(device_mesh): DeviceMesh.core_grid = property(get_device_mesh_core_grid) +def visualize_device_mesh(device_mesh): + from rich import box, padding + from rich.align import Align + from rich.console import Console + from rich.table import Table + + # Setup rich table + rows, cols = device_mesh.shape + mesh_table = Table( + title=f"DeviceMesh(rows={rows}, cols={cols}):", + show_header=False, + show_footer=False, + box=box.SQUARE, + expand=False, + show_lines=True, + padding=(0, 0), + ) + + for _ in range(cols): + mesh_table.add_column(justify="center", vertical="middle") + + # Populate table + for row_idx in range(rows): + row_cells = [] + for col_idx in range(cols): + device = device_mesh.get_device(row_idx, col_idx) + cell_content = f"Dev. ID: {device.id()}\n ({row_idx}, {col_idx})" if device else "Empty" + cell = padding.Padding(Align(cell_content, "center", vertical="middle"), (0, 0)) + row_cells.append(cell) + mesh_table.add_row(*row_cells) + + Console().print(mesh_table) + + def get_num_devices() -> List[int]: return ttnn._ttnn.deprecated.device.GetNumAvailableDevices()