Skip to content

Commit

Permalink
#10755: visualize DeviceMesh via ttnn.visualize(device_mesh) api (#…
Browse files Browse the repository at this point in the history
…11094)

Co-authored-by: jchu <[email protected]>
  • Loading branch information
cfjchu and cfjchu authored Aug 5, 2024
1 parent 0a11fb8 commit 12c0051
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 0 deletions.
5 changes: 5 additions & 0 deletions tests/ttnn/multichip_unit_tests/test_multidevice_TG.py
Original file line number Diff line number Diff line change
Expand Up @@ -1331,3 +1331,8 @@ def test_device_line_all_gather_8x4_data(device_mesh, cluster_axis: int, dim: in
expected = full_tensor[..., row_index * tile_size : (row_index + 1) * tile_size, :]

assert torch.allclose(device_tensor_torch, expected, atol=1e-3)


@pytest.mark.parametrize("device_mesh", [pytest.param((8, 4), id="8x4_grid")], indirect=True)
def test_visualize_device_mesh(device_mesh):
ttnn.visualize_device_mesh(device_mesh)
1 change: 1 addition & 0 deletions ttnn/ttnn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def manage_config(name, value):
MeshToTensor,
ConcatMeshToTensor,
ListMeshToTensor,
visualize_device_mesh,
)

from ttnn.core import (
Expand Down
34 changes: 34 additions & 0 deletions ttnn/ttnn/multi_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,40 @@ def get_device_mesh_core_grid(device_mesh):
DeviceMesh.core_grid = property(get_device_mesh_core_grid)


def visualize_device_mesh(device_mesh):
from rich import box, padding
from rich.align import Align
from rich.console import Console
from rich.table import Table

# Setup rich table
rows, cols = device_mesh.shape
mesh_table = Table(
title=f"DeviceMesh(rows={rows}, cols={cols}):",
show_header=False,
show_footer=False,
box=box.SQUARE,
expand=False,
show_lines=True,
padding=(0, 0),
)

for _ in range(cols):
mesh_table.add_column(justify="center", vertical="middle")

# Populate table
for row_idx in range(rows):
row_cells = []
for col_idx in range(cols):
device = device_mesh.get_device(row_idx, col_idx)
cell_content = f"Dev. ID: {device.id()}\n ({row_idx}, {col_idx})" if device else "Empty"
cell = padding.Padding(Align(cell_content, "center", vertical="middle"), (0, 0))
row_cells.append(cell)
mesh_table.add_row(*row_cells)

Console().print(mesh_table)


def get_num_devices() -> List[int]:
return ttnn._ttnn.deprecated.device.GetNumAvailableDevices()

Expand Down

0 comments on commit 12c0051

Please sign in to comment.