Skip to content

Commit

Permalink
Add all gather perf to TG
Browse files Browse the repository at this point in the history
  • Loading branch information
Aswinmcw committed Nov 13, 2024
1 parent 16123a1 commit f38e3ba
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ show_help() {
echo
echo "Options:"
echo " -d, --debug Enable debug mode to show real-time output."
echo " -t, --target Specify the target configuration (t3000 or n300). Default is n300."
echo " -t, --target Specify the target configuration (t3000 or n300 or tg). Default is n300."
echo " -h, --help Display this help message."
echo
echo "Example:"
Expand Down Expand Up @@ -42,8 +42,8 @@ while [ $# -gt 0 ]; do
shift 2

# Validate the target value
if [ "$TARGET" != "t3000" ] && [ "$TARGET" != "n300" ]; then
echo "Error: Invalid target configuration: $TARGET. Must be either 't3000' or 'n300'."
if [ "$TARGET" != "t3000" ] && [ "$TARGET" != "tg" ] && [ "$TARGET" != "n300" ]; then
echo "Error: Invalid target configuration: $TARGET. Must be 't3000' or 'n300' or 'tg'."
exit 1
fi
;;
Expand Down
65 changes: 65 additions & 0 deletions tests/ttnn/unit_tests/operations/ccl/perf/test_ccl_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
from tests.ttnn.unit_tests.operations.ccl.test_reduce_scatter_post_commit import (
run_reduce_scatter_test,
)
from tests.ttnn.unit_tests.operations.ccl.test_all_gather_TG_post_commit import (
run_line_all_gather_on_TG_with_mesh_tensor_along_rows,
)


@skip_for_grayskull("Requires eth connected devices to run")
Expand Down Expand Up @@ -266,3 +269,65 @@ def test_reduce_scatter_on_n300(
enable_async=enable_async,
trace_mode=True,
)


@skip_for_grayskull("Requires eth connected devices to run")
@pytest.mark.parametrize(
"num_devices, num_links, per_chip_output_shape, dim, layout",
[
(4, 3, [4, 1, 32, 1280], 0, ttnn.TILE_LAYOUT),
(4, 3, [1, 1, 32, 16384 * 4], 3, ttnn.TILE_LAYOUT),
(4, 3, [1, 4, 32, 6656], 1, ttnn.TILE_LAYOUT),
],
)
@pytest.mark.parametrize(
"input_dtype",
[
ttnn.bfloat16,
ttnn.bfloat8_b,
],
)
@pytest.mark.parametrize(
"buffer_type",
[
ttnn.BufferType.DRAM,
ttnn.BufferType.L1,
],
)
@pytest.mark.parametrize("replication_factor", [8])
@pytest.mark.parametrize("enable_async", [True])
@pytest.mark.parametrize("mesh_device", [pytest.param((8, 4), id="8x4_grid")], indirect=True)
@pytest.mark.parametrize("device_params", [{"trace_region_size": 266240}], indirect=True)
def test_all_gather_on_tg(
mesh_device,
num_devices,
per_chip_output_shape,
dim,
num_links,
input_dtype,
layout,
buffer_type,
use_program_cache,
function_level_defaults,
enable_async,
replication_factor,
num_iters=1,
):
run_line_all_gather_on_TG_with_mesh_tensor_along_rows(
mesh_device,
num_devices,
per_chip_output_shape,
ttnn.TensorMemoryLayout.INTERLEAVED,
dim,
num_links,
input_dtype,
layout,
buffer_type,
use_program_cache,
function_level_defaults,
enable_async=enable_async,
num_iters=num_iters,
num_all_gather_instances=replication_factor,
cluster_axis=1,
trace_mode=True,
)
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,59 @@ def print_tile_corners_of_tensor(t):
print(f"{str_vals}")


def run_with_trace(
mesh_device,
all_gather_topology,
input_tensor,
dim,
num_links,
cluster_axis,
output_mem_config,
n_worker=None,
n_buffer=None,
num_iter=20,
):
# Compile Run
logger.info("Compiling model")
tt_out_tensor = ttnn.all_gather(
input_tensor,
dim=dim,
cluster_axis=cluster_axis,
mesh_device=mesh_device,
num_links=num_links,
memory_config=output_mem_config,
topology=all_gather_topology,
)
for d in mesh_device.get_devices():
ttnn.synchronize_device(d)

# Capture trace
logger.info("Capturing trace")
trace_id = ttnn.begin_trace_capture(mesh_device, cq_id=0)
for i in range(num_iter):
tt_out_tensor = ttnn.all_gather(
input_tensor,
dim=dim,
cluster_axis=cluster_axis,
mesh_device=mesh_device,
num_links=num_links,
memory_config=output_mem_config,
topology=all_gather_topology,
)
ttnn.end_trace_capture(mesh_device, trace_id, cq_id=0)
for d in mesh_device.get_devices():
ttnn.synchronize_device(d)

# Run the op
logger.info("Starting Trace perf test...")
ttnn.execute_trace(mesh_device, trace_id, blocking=False)
ttnn.release_trace(mesh_device, trace_id)
for d in mesh_device.get_devices():
ttnn.synchronize_device(d)

return tt_out_tensor


def run_line_all_gather_on_TG_with_mesh_tensor_along_rows(
mesh_device,
num_devices_per_line,
Expand All @@ -63,6 +116,8 @@ def run_line_all_gather_on_TG_with_mesh_tensor_along_rows(
num_iters: int = 1,
cluster_axis: int = 0,
tile=(32, 32),
trace_mode=False,
debug=False,
):
if len(mesh_device.get_devices()) != 32:
pytest.skip("Not TG!")
Expand Down Expand Up @@ -120,16 +175,28 @@ def run_line_all_gather_on_TG_with_mesh_tensor_along_rows(
ttnn_tensor = ttnn.to_device(ttnn_tensor, mesh_device)

# ttnn.visualize_mesh_device(mesh_device, tensor=ttnn_tensor)
for _ in range(num_iters):
ttnn_tensor_out = ttnn.all_gather(
ttnn_tensor,
if trace_mode:
ttnn_tensor_out = run_with_trace(
input_tensor=ttnn_tensor,
dim=dim,
cluster_axis=cluster_axis,
mesh_device=mesh_device,
num_links=num_links,
memory_config=output_mem_config,
topology=ttnn.Topology.Linear,
output_mem_config=output_mem_config,
all_gather_topology=ttnn.Topology.Linear,
num_iter=num_iters,
)
else:
for _ in range(num_iters):
ttnn_tensor_out = ttnn.all_gather(
ttnn_tensor,
dim=dim,
cluster_axis=cluster_axis,
mesh_device=mesh_device,
num_links=num_links,
memory_config=output_mem_config,
topology=ttnn.Topology.Linear,
)

# ttnn.visualize_mesh_device(mesh_device, tensor=ttnn_tensor_out)
tt_output_tensor = ttnn.to_torch(
Expand All @@ -150,7 +217,7 @@ def run_line_all_gather_on_TG_with_mesh_tensor_along_rows(
if not eq and debug is True:
logger.error(f"found mismatches")
report_mismatches(tt_output_tensor, output_golden, 100)
print_tile_corners_of_tensor(output_tensor)
print_tile_corners_of_tensor(tt_output_tensor)
else:
eq, output = comp_pcc(tt_output_tensor, output_golden)
if not eq:
Expand Down

0 comments on commit f38e3ba

Please sign in to comment.