Skip to content

Commit

Permalink
#10755: visualize DeviceMesh via ttnn.visualize(device_mesh) api
Browse files Browse the repository at this point in the history
  • Loading branch information
cfjchu committed Aug 5, 2024
1 parent b8dde20 commit 0dd3f28
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 1 deletion.
5 changes: 5 additions & 0 deletions tests/ttnn/multichip_unit_tests/test_multidevice_TG.py
Original file line number Diff line number Diff line change
Expand Up @@ -1331,3 +1331,8 @@ def test_device_line_all_gather_8x4_data(device_mesh, cluster_axis: int, dim: in
expected = full_tensor[..., row_index * tile_size : (row_index + 1) * tile_size, :]

assert torch.allclose(device_tensor_torch, expected, atol=1e-3)


@pytest.mark.parametrize("device_mesh", [pytest.param((8, 4), id="8x4_grid")], indirect=True)
def test_visualize_device_mesh(device_mesh):
ttnn.visualize_device_mesh(device_mesh)
1 change: 1 addition & 0 deletions ttnn/ttnn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def manage_config(name, value):
MeshToTensor,
ConcatMeshToTensor,
ListMeshToTensor,
visualize_device_mesh,
)

from ttnn.core import (
Expand Down
39 changes: 38 additions & 1 deletion ttnn/ttnn/multi_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,40 @@ def get_device_mesh_core_grid(device_mesh):
DeviceMesh.core_grid = property(get_device_mesh_core_grid)


def visualize_device_mesh(device_mesh):
from rich import box, padding
from rich.align import Align
from rich.console import Console
from rich.table import Table

# Setup rich table
rows, cols = device_mesh.shape
mesh_table = Table(
title=f"DeviceMesh(rows={rows}, cols={cols}):",
show_header=False,
show_footer=False,
box=box.SQUARE,
expand=False,
show_lines=True,
padding=(0, 0),
)

for _ in range(cols):
mesh_table.add_column(justify="center", vertical="middle")

# Populate table
for row_idx in range(rows):
row_cells = []
for col_idx in range(cols):
device = device_mesh.get_device(row_idx, col_idx)
cell_content = f"Dev. ID: {device.id()}\n ({row_idx}, {col_idx})" if device else "Empty"
cell = padding.Padding(Align(cell_content, "center", vertical="middle"), (0, 0))
row_cells.append(cell)
mesh_table.add_row(*row_cells)

Console().print(mesh_table)


def get_num_devices() -> List[int]:
return ttnn._ttnn.deprecated.device.GetNumAvailableDevices()

Expand Down Expand Up @@ -163,7 +197,7 @@ def __init__(self, device_mesh, shard_grid, shard_dimensions):
def map(self, tensor):
import torch

Y, X = self.shard_dimensions
Y, X = 0, 1
# Returns list of tensors to map to row-major ordering of chips in shard grid
if self.shard_dimensions[Y] is None:
row_tensors = [tensor.clone() for _ in range(self.shard_grid[Y])]
Expand All @@ -182,6 +216,9 @@ def config(self):
return {
"strategy": "shard",
"shard_dim": f"{self.shard_dimensions[0] if self.shard_dimensions[0] else self.shard_dimensions[1]}",
# "strategy": "shard_2d",
# "shard_grid_y": f"{self.shard_grid[0]}",
# "shard_grid_x": f"{self.shard_grid[1]}",
}


Expand Down

0 comments on commit 0dd3f28

Please sign in to comment.