diff --git a/tests/ttnn/unit_tests/operations/ccl/perf/run_reduce_scatter_profile.sh b/tests/ttnn/unit_tests/operations/ccl/perf/run_reduce_scatter_profile.sh index ed7f5828585a..5c18e0de4b2a 100755 --- a/tests/ttnn/unit_tests/operations/ccl/perf/run_reduce_scatter_profile.sh +++ b/tests/ttnn/unit_tests/operations/ccl/perf/run_reduce_scatter_profile.sh @@ -11,7 +11,7 @@ show_help() { echo echo "Options:" echo " -d, --debug Enable debug mode to show real-time output." - echo " -t, --target Specify the target configuration (t3000 or n300). Default is n300." + echo " -t, --target Specify the target configuration (t3000 or n300 or tg). Default is n300." echo " -h, --help Display this help message." echo echo "Example:" @@ -42,8 +42,8 @@ while [ $# -gt 0 ]; do shift 2 # Validate the target value - if [ "$TARGET" != "t3000" ] && [ "$TARGET" != "n300" ]; then - echo "Error: Invalid target configuration: $TARGET. Must be either 't3000' or 'n300'." + if [ "$TARGET" != "t3000" ] && [ "$TARGET" != "tg" ] && [ "$TARGET" != "n300" ]; then + echo "Error: Invalid target configuration: $TARGET. Must be either 't3000', 'n300', 'tg." exit 1 fi ;; diff --git a/tests/ttnn/unit_tests/operations/ccl/perf/test_ccl_perf.py b/tests/ttnn/unit_tests/operations/ccl/perf/test_ccl_perf.py index 0a729b88f846..01df9cbf6f62 100644 --- a/tests/ttnn/unit_tests/operations/ccl/perf/test_ccl_perf.py +++ b/tests/ttnn/unit_tests/operations/ccl/perf/test_ccl_perf.py @@ -15,6 +15,9 @@ from tests.ttnn.unit_tests.operations.ccl.test_all_gather_TG_post_commit import ( run_line_all_gather_on_TG_with_mesh_tensor_along_rows, ) +from tests.ttnn.unit_tests.operations.ccl.test_reduce_scatter_TG_nightly import ( + run_line_reduce_scatter_on_TG_with_mesh_tensor_along_rows, +) @skip_for_grayskull("Requires eth connected devices to run") @@ -332,3 +335,68 @@ def test_all_gather_on_tg( cluster_axis=1, trace_mode=True, ) + + +@skip_for_grayskull("Requires eth connected devices to run") +@pytest.mark.parametrize( + "num_devices, num_links, per_chip_output_shape, dim, layout", + [ + (4, 2, [1, 4, 32, 2304], 1, ttnn.TILE_LAYOUT), + (4, 2, [1, 4, 64, 2304], 1, ttnn.TILE_LAYOUT), + (4, 2, [1, 4, 64, 6656], 1, ttnn.TILE_LAYOUT), + ], +) +@pytest.mark.parametrize( + "input_dtype", + [ + ttnn.bfloat16, + ], +) +@pytest.mark.parametrize( + "buffer_type", + [ + ttnn.BufferType.DRAM, + ttnn.BufferType.L1, + ], +) +@pytest.mark.parametrize("replication_factor", [8]) +@pytest.mark.parametrize("enable_async", [True]) +@pytest.mark.parametrize("num_iters", [20]) +@pytest.mark.parametrize("mesh_device", [pytest.param((8, 4), id="8x4_grid")], indirect=True) +@pytest.mark.parametrize("math_op", [ttnn.ReduceType.Sum]) +@pytest.mark.parametrize("device_params", [{"trace_region_size": 10281600}], indirect=True) +def test_line_reduce_scatter_on_TG_rows_post_commit( + mesh_device, + num_devices, + per_chip_output_shape, + dim, + num_links, + math_op, + input_dtype, + layout, + buffer_type, + use_program_cache, + function_level_defaults, + enable_async, + replication_factor, + num_iters, +): + run_line_reduce_scatter_on_TG_with_mesh_tensor_along_rows( + mesh_device, + num_devices, + per_chip_output_shape, + ttnn.TensorMemoryLayout.INTERLEAVED, + dim, + num_links, + math_op, + input_dtype, + layout, + buffer_type, + use_program_cache, + function_level_defaults, + enable_async=enable_async, + num_iters=num_iters, + num_reduce_scatter_instances=replication_factor, + cluster_axis=1, + trace_mode=True, + ) diff --git a/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_TG_nightly.py b/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_TG_nightly.py index 1b5bfe8f6724..5dc6a377e9f4 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_TG_nightly.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_TG_nightly.py @@ -45,6 +45,61 @@ def print_tile_corners_of_tensor(t): print(f"{str_vals}") +def run_with_trace( + mesh_device, + all_gather_topology, + input_tensor, + scatter_dim, + num_links, + cluster_axis, + output_mem_config, + n_worker=None, + n_buffer=None, + num_iter=20, +): + # Compile Run + logger.info("Compiling model") + tt_out_tensor = ttnn.reduce_scatter( + ttnn_tensor, + scatter_dim=scatter_dim, + cluster_axis=cluster_axis, + mesh_device=mesh_device, + math_op=math_op, + num_links=num_links, + memory_config=output_mem_config, + topology=ttnn.Topology.Linear, + ) + for d in mesh_device.get_devices(): + ttnn.synchronize_device(d) + + # Capture trace + logger.info("Capturing trace") + trace_id = ttnn.begin_trace_capture(mesh_device, cq_id=0) + for i in range(num_iter): + tt_out_tensor = ttnn.reduce_scatter( + ttnn_tensor, + scatter_dim=scatter_dim, + cluster_axis=cluster_axis, + mesh_device=mesh_device, + math_op=math_op, + num_links=num_links, + memory_config=output_mem_config, + topology=ttnn.Topology.Linear, + ) + ttnn.end_trace_capture(mesh_device, trace_id, cq_id=0) + for d in mesh_device.get_devices(): + ttnn.synchronize_device(d) + + # Run the op + logger.info("Starting Trace perf test...") + ttnn.execute_trace(mesh_device, trace_id, blocking=False) + ttnn.release_trace(mesh_device, trace_id) + for d in mesh_device.get_devices(): + ttnn.synchronize_device(d) + + return tt_out_tensor + + def run_line_reduce_scatter_on_TG_with_mesh_tensor_along_rows( mesh_device, num_devices_per_line, @@ -63,6 +118,7 @@ def run_line_reduce_scatter_on_TG_with_mesh_tensor_along_rows( num_reduce_scatter_instances: int = 1, num_iters: int = 1, cluster_axis: int = 0, + trace_mode=False, ): if len(mesh_device.get_devices()) != 32: pytest.skip("Not TG!") @@ -163,18 +219,24 @@ def run_line_reduce_scatter_on_TG_with_mesh_tensor_along_rows( mesh_device=mesh_device, math_op=math_op, num_links=num_links, - memory_config=output_mem_config, - topology=ttnn.Topology.Linear, + output_mem_config=output_mem_config, + all_gather_topology=ttnn.Topology.Linear, + num_iter=num_iters, ) - ttnn.end_trace_capture(mesh_device, trace_id, cq_id=0) - for d in mesh_device.get_devices(): - ttnn.synchronize_device(d) - - logger.info("Starting Trace perf test...") - ttnn.execute_trace(mesh_device, trace_id, blocking=False) - ttnn.release_trace(mesh_device, trace_id) - for d in mesh_device.get_devices(): - ttnn.synchronize_device(d) + else: + for _ in range(num_iters): + ttnn_tensor_out = ttnn.reduce_scatter( + ttnn_tensor, + scatter_dim=dim, + cluster_axis=cluster_axis, + mesh_device=mesh_device, + math_op=math_op, + num_links=num_links, + memory_config=output_mem_config, + topology=ttnn.Topology.Linear, + ) + for d in mesh_device.get_devices(): + ttnn.synchronize_device(d) # ttnn.visualize_mesh_device(mesh_device, tensor=ttnn_tensor_out) tt_output_tensor = ttnn.to_torch( @@ -290,7 +352,6 @@ def test_line_reduce_scatter_on_TG_rows_post_commit( @pytest.mark.parametrize("replication_factor", [4]) @pytest.mark.parametrize("mesh_device", [pytest.param((8, 4), id="8x4_grid")], indirect=True) @pytest.mark.parametrize("math_op", [ttnn.ReduceType.Sum]) -@pytest.mark.parametrize("device_params", [{"trace_region_size": 10281600}], indirect=True) def test_line_reduce_scatter_on_TG_cols_post_commit( mesh_device, num_devices,