From 0dd3f288a85038f164949a5bb1bcb11e2671d91e Mon Sep 17 00:00:00 2001 From: jchu Date: Mon, 5 Aug 2024 19:03:56 +0000 Subject: [PATCH] #10755: visualize DeviceMesh via `ttnn.visualize(device_mesh)` api --- .../test_multidevice_TG.py | 5 +++ ttnn/ttnn/__init__.py | 1 + ttnn/ttnn/multi_device.py | 39 ++++++++++++++++++- 3 files changed, 44 insertions(+), 1 deletion(-) diff --git a/tests/ttnn/multichip_unit_tests/test_multidevice_TG.py b/tests/ttnn/multichip_unit_tests/test_multidevice_TG.py index 7680dcac764c..a11fa5d7dc93 100644 --- a/tests/ttnn/multichip_unit_tests/test_multidevice_TG.py +++ b/tests/ttnn/multichip_unit_tests/test_multidevice_TG.py @@ -1331,3 +1331,8 @@ def test_device_line_all_gather_8x4_data(device_mesh, cluster_axis: int, dim: in expected = full_tensor[..., row_index * tile_size : (row_index + 1) * tile_size, :] assert torch.allclose(device_tensor_torch, expected, atol=1e-3) + + +@pytest.mark.parametrize("device_mesh", [pytest.param((8, 4), id="8x4_grid")], indirect=True) +def test_visualize_device_mesh(device_mesh): + ttnn.visualize_device_mesh(device_mesh) diff --git a/ttnn/ttnn/__init__.py b/ttnn/ttnn/__init__.py index c66efebc6bbc..662c27cfd25e 100644 --- a/ttnn/ttnn/__init__.py +++ b/ttnn/ttnn/__init__.py @@ -166,6 +166,7 @@ def manage_config(name, value): MeshToTensor, ConcatMeshToTensor, ListMeshToTensor, + visualize_device_mesh, ) from ttnn.core import ( diff --git a/ttnn/ttnn/multi_device.py b/ttnn/ttnn/multi_device.py index 8f52d65b60e6..64ae097edac5 100644 --- a/ttnn/ttnn/multi_device.py +++ b/ttnn/ttnn/multi_device.py @@ -18,6 +18,40 @@ def get_device_mesh_core_grid(device_mesh): DeviceMesh.core_grid = property(get_device_mesh_core_grid) +def visualize_device_mesh(device_mesh): + from rich import box, padding + from rich.align import Align + from rich.console import Console + from rich.table import Table + + # Setup rich table + rows, cols = device_mesh.shape + mesh_table = Table( + title=f"DeviceMesh(rows={rows}, cols={cols}):", + show_header=False, + show_footer=False, + box=box.SQUARE, + expand=False, + show_lines=True, + padding=(0, 0), + ) + + for _ in range(cols): + mesh_table.add_column(justify="center", vertical="middle") + + # Populate table + for row_idx in range(rows): + row_cells = [] + for col_idx in range(cols): + device = device_mesh.get_device(row_idx, col_idx) + cell_content = f"Dev. ID: {device.id()}\n ({row_idx}, {col_idx})" if device else "Empty" + cell = padding.Padding(Align(cell_content, "center", vertical="middle"), (0, 0)) + row_cells.append(cell) + mesh_table.add_row(*row_cells) + + Console().print(mesh_table) + + def get_num_devices() -> List[int]: return ttnn._ttnn.deprecated.device.GetNumAvailableDevices() @@ -163,7 +197,7 @@ def __init__(self, device_mesh, shard_grid, shard_dimensions): def map(self, tensor): import torch - Y, X = self.shard_dimensions + Y, X = 0, 1 # Returns list of tensors to map to row-major ordering of chips in shard grid if self.shard_dimensions[Y] is None: row_tensors = [tensor.clone() for _ in range(self.shard_grid[Y])] @@ -182,6 +216,9 @@ def config(self): return { "strategy": "shard", "shard_dim": f"{self.shard_dimensions[0] if self.shard_dimensions[0] else self.shard_dimensions[1]}", + # "strategy": "shard_2d", + # "shard_grid_y": f"{self.shard_grid[0]}", + # "shard_grid_x": f"{self.shard_grid[1]}", }