diff --git a/tests/ttnn/unit_tests/operations/test_all_gather_matmul.py b/tests/ttnn/unit_tests/operations/test_all_gather_matmul.py index 1a68b5f72b1..57447684254 100644 --- a/tests/ttnn/unit_tests/operations/test_all_gather_matmul.py +++ b/tests/ttnn/unit_tests/operations/test_all_gather_matmul.py @@ -9,132 +9,144 @@ import ttnn from ttnn import ShardTensorToMesh from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_equal, comp_pcc -from models.utility_functions import skip_for_grayskull, get_devices_for_t3000 -import itertools +from models.utility_functions import skip_for_grayskull, skip_for_wormhole_b0 from tests.ttnn.unit_tests.operations.test_all_gather import is_unsupported_case def run_all_gather_matmul_on_t3000_impl( t3k_mesh_device, num_devices, - input_shape, + ag_output_shape, dim, num_links, - input_dtype, + ag_input_dtype, layout, + # Matmul params matmul_output_dim, - mem_config, - function_level_defaults, + matmul_config, + matmul_weights_dtype, + max_in0_block_w, + # Memory configs + mem_config_input, + mem_config_ag, + mem_config_mm, + mem_config_weights=None, num_iters=1, ): + # Set the default config + if mem_config_weights is None: + mem_config_weights = mem_config_ag + # Skip unsupported cases (is_known_failure, message) = is_unsupported_case( - input_shape, dim, mem_config, num_devices, num_links, input_dtype, layout + ag_output_shape, dim, mem_config_ag, num_devices, num_links, ag_input_dtype, layout ) if is_known_failure: pytest.skip(f"Skipping unsupported case {message}.") devices = t3k_mesh_device.get_devices() - logger.info(f"Input shape: {input_shape}") + logger.info(f"All Gather output shape: {ag_output_shape}") logger.info(f"dim: {dim}") - # Create input tensor for the all gather - _, _, _, hidden_dim = input_shape - input_tensor = torch.rand(input_shape).bfloat16() + ##### Create input tensor for the all gather ##### + _, _, _, hidden_dim = ag_output_shape + input_tensor = torch.randn(ag_output_shape).float() input_tensors = torch.chunk(input_tensor, num_devices, dim) tt_input_tensors = [] for i, t in enumerate(input_tensors): - tt_input_tensors.append(ttl.tensor.Tensor(t, input_dtype).to(layout).to(devices[i], mem_config)) + tt_input_tensors.append(ttl.tensor.Tensor(t, ag_input_dtype).to(layout).to(devices[i], mem_config_input)) input_tensor_mesh = ttnn.aggregate_as_tensor(tt_input_tensors) - # Create the weight matrix for the matmul - weights_tensor = torch.rand([1, 1, hidden_dim, matmul_output_dim * num_devices]).bfloat16() + ##### Create the weight matrix for the matmul ##### + weights_tensor = torch.randn([1, 1, hidden_dim, matmul_output_dim * num_devices]).float() weight_tt = ttnn.as_tensor( weights_tensor, - dtype=input_dtype, - layout=ttnn.TILE_LAYOUT, + dtype=matmul_weights_dtype, + layout=layout, device=t3k_mesh_device, - memory_config=mem_config, - mesh_mapper=ShardTensorToMesh(t3k_mesh_device, dim=3), + memory_config=mem_config_weights, + mesh_mapper=ShardTensorToMesh(t3k_mesh_device, dim=dim), ) - # torch matmul output - matmul_output = torch.chunk(torch.matmul(input_tensor, weights_tensor), num_devices, 3) - - # Configs for ttnn.matmul - program_config = ttnn.MatmulMultiCoreReuseMultiCast1DProgramConfig( - compute_with_storage_grid_size=(1, 1), - in0_block_w=16, # K = 8192 / TILE_WIDTH=32 / Grid_Size is based on compute_with_storage_grid_size - out_subblock_h=1, # Must be divisible by per_core_M - out_subblock_w=4, # Must be divisible by per_core_N, out_subblock_w * out_subblock_h <= 4 - per_core_M=1, # M / TILE_HEIGHT = 32 / 32 - per_core_N=32, # N / TILE_WIDTH / Grid_Size is based on compute_with_storage_grid_size, N = 4096 for num_device=8 - fused_activation=None, - fuse_batch=True, - mcast_in0=True, - ) + ##### Configs for ttnn.matmul ##### + if matmul_config == "matmul_1d": + core_grid = (8, 4) + program_config = ttnn.MatmulMultiCoreReuseMultiCast1DProgramConfig( + compute_with_storage_grid_size=core_grid, + in0_block_w=min(max_in0_block_w, hidden_dim // 32 // core_grid[0]), # how much inner dim you take each time + out_subblock_h=1, # Must be divisible by per_core_M + out_subblock_w=1, # Must be divisible by per_core_N, out_subblock_w * out_subblock_h <= 4 + per_core_M=max(1, ag_output_shape[2] // 32 // core_grid[1]), # M / TILE_HEIGHT / Grid_Size + per_core_N=max(1, matmul_output_dim // 32 // core_grid[0]), # N / TILE_WIDTH / Grid_Size + mcast_in0=True, + fused_activation=None, # ttnn.UnaryOpType.SILU, + fuse_batch=True, + ) + elif matmul_config == "matmul_2d": + core_grid = (8, 4) + program_config = ttnn.MatmulMultiCoreReuseMultiCastProgramConfig( + compute_with_storage_grid_size=core_grid, + in0_block_w=min(max_in0_block_w, hidden_dim // 32 // core_grid[0]), # how much inner dim you take each time + out_subblock_h=1, # Must be divisible by per_core_M + out_subblock_w=1, # Must be divisible by per_core_N, out_subblock_w * out_subblock_h <= 4 + per_core_M=max(1, ag_output_shape[2] // 32 // core_grid[1]), # M / TILE_HEIGHT / Grid_Size + per_core_N=max(1, matmul_output_dim // 32 // core_grid[0]), # N / TILE_WIDTH / Grid_Size + transpose_mcast=False, + fused_activation=None, # ttnn.UnaryOpType.SILU, + fuse_batch=False, + ) + else: + raise ValueError(f"Unsupported matmul_config: {matmul_config}") + compute_kernel_config = ttnn.WormholeComputeKernelConfig( - math_fidelity=ttl.tensor.MathFidelity.HiFi4, + math_fidelity=ttl.tensor.MathFidelity.HiFi2, math_approx_mode=True, fp32_dest_acc_en=True, packer_l1_acc=True, ) - # Perform the ops + ##### Perform the torch ops ##### + matmul_output = torch.chunk(torch.matmul(input_tensor, weights_tensor), num_devices, dim) + + ##### Perform the TT ops ##### for i in range(num_iters): - # all_gather - # tt_out_tensor = ttnn.all_gather(input_tensor_mesh, dim, num_links=num_links, memory_config=mem_config) + # # all_gather + # tt_all_gather_out_tensor = ttnn.all_gather(input_tensor_mesh, dim, num_links=num_links, memory_config=mem_config_ag) - # matmul + # # matmul # tt_matmul_output = ttnn.matmul( - # tt_out_tensor, + # tt_all_gather_out_tensor, # weight_tt, - # memory_config=mem_config, + # memory_config=mem_config_mm, # program_config=program_config, # compute_kernel_config=compute_kernel_config, # ) # Test ttnn all_gather_matmul - tt_all_gather_out_tensor, _, tt_datacopy_out_tensor = ttl.all_gather_matmul( + tt_all_gather_out_tensor, tt_matmul_output, tt_datacopy_out_tensor = ttl.all_gather_matmul( input_tensor_mesh, weight_tt, dim, - (0, 1), + (0, 4), num_links=num_links, - memory_config=mem_config, + memory_config_ag=mem_config_ag, + memory_config_mm=mem_config_mm, program_config=program_config, compute_kernel_config=compute_kernel_config, ) logger.info(f"Done iteration {i}") - # print("Checking outputs for All Gather") - # for i, t in enumerate(ttnn.get_device_tensors(tt_out_tensor)): - # tt_output_tensor = t.cpu().to(ttl.tensor.Layout.ROW_MAJOR).to_torch() - # if input_dtype == ttl.tensor.DataType.BFLOAT16: - # eq, output = comp_equal(tt_output_tensor, input_tensor) - # else: - # eq, output = comp_pcc(tt_output_tensor, input_tensor) - # logger.info(f"Output {i}: {output}") - # if not eq: - # logger.error(f"output mismatch for tensor {i}") - # assert eq, f"{i} FAILED: {output}" - - # print("Checking outputs for Matmul") - # for i, t in enumerate(ttnn.get_device_tensors(tt_matmul_output)): - # tt_output_tensor = t.cpu().to(ttl.tensor.Layout.ROW_MAJOR).to_torch() - - # eq, output = comp_pcc(tt_output_tensor, matmul_output[i]) - # logger.info(f"Output {i}: {output}") - # if not eq: - # logger.error(f"output mismatch for tensor {i}") - # assert eq, f"{i} FAILED: {output}" + # Synchronize the devices + for d in devices: + ttnn.synchronize_device(d) + ##### Compare the outputs ##### print("Checking outputs for All Gather Matmul (All Gather)") for i, t in enumerate(ttnn.get_device_tensors(tt_all_gather_out_tensor)): tt_output_tensor = t.cpu().to(ttl.tensor.Layout.ROW_MAJOR).to_torch() - if input_dtype == ttl.tensor.DataType.BFLOAT16: + if ag_input_dtype == ttl.tensor.DataType.BFLOAT16: eq, output = comp_equal(tt_output_tensor, input_tensor) else: eq, output = comp_pcc(tt_output_tensor, input_tensor) @@ -143,25 +155,40 @@ def run_all_gather_matmul_on_t3000_impl( logger.error(f"output mismatch for tensor {i}") assert eq, f"{i} FAILED: {output}" - print("Checking outputs for All Gather Matmul (Datacopy)") - for i, t in enumerate(ttnn.get_device_tensors(tt_datacopy_out_tensor)): + # print("Checking outputs for All Gather Matmul (Datacopy)") + # for i, t in enumerate(ttnn.get_device_tensors(tt_datacopy_out_tensor)): + # tt_output_tensor = t.cpu().to(ttl.tensor.Layout.ROW_MAJOR).to_torch() + # if ag_input_dtype == ttl.tensor.DataType.BFLOAT16: + # eq, output = comp_equal(tt_output_tensor, input_tensor) + # else: + # eq, output = comp_pcc(tt_output_tensor, input_tensor) + # logger.info(f"Output {i}: {output}") + # if not eq: + # logger.error(f"output mismatch for tensor {i}") + # assert eq, f"{i} FAILED: {output}" + + print("Checking outputs for Matmul") + for i, t in enumerate(ttnn.get_device_tensors(tt_matmul_output)): tt_output_tensor = t.cpu().to(ttl.tensor.Layout.ROW_MAJOR).to_torch() - if input_dtype == ttl.tensor.DataType.BFLOAT16: - eq, output = comp_equal(tt_output_tensor, input_tensor) - else: - eq, output = comp_pcc(tt_output_tensor, input_tensor) + + eq, output = comp_pcc(tt_output_tensor, matmul_output[i]) logger.info(f"Output {i}: {output}") if not eq: logger.error(f"output mismatch for tensor {i}") - assert eq, f"{i} FAILED: {output}" + assert eq, f"{i} FAILED: {output}" -# Enumerate the post-commit cases explicitly +# @skip_for_wormhole_b0() # Used to disable test @skip_for_grayskull("Requires eth connected devices to run") @pytest.mark.parametrize( - "num_devices, num_links, input_shape, dim, layout, matmul_output_dim", + "matmul_config", + [ + "matmul_2d", + ], +) +@pytest.mark.parametrize( + "num_devices, num_links, ag_output_shape, dim, layout, matmul_output_dim, max_in0_block_w, matmul_weights_dtype", [ - # (8, 1, [1, 1, 32, 512], 3, ttl.tensor.Layout.TILE, 1024), # https://github.com/tenstorrent/tt-metal/issues/9686 ( 8, 1, @@ -169,6 +196,8 @@ def run_all_gather_matmul_on_t3000_impl( 3, ttl.tensor.Layout.TILE, 1024, + 2, + ttl.tensor.DataType.BFLOAT16, ), ( 8, @@ -177,42 +206,325 @@ def run_all_gather_matmul_on_t3000_impl( 3, ttl.tensor.Layout.TILE, 1024, + 16, + ttl.tensor.DataType.BFLOAT16, + ), + ( + 8, + 1, + [1, 1, 32, 1024 * 16], + 3, + ttl.tensor.Layout.TILE, + 1024, + 16, # NOTE: 64 for some reason gives lower perf + ttl.tensor.DataType.BFLOAT16, + ), + ( + 8, + 1, + [1, 1, 1024, 1024 * 32], + 3, + ttl.tensor.Layout.TILE, + 1024, + 16, + ttl.tensor.DataType.BFLOAT16, + ), + ( # AllGather + Fused QKV Matmul llama 2k prefill + 8, + 1, + [1, 1, 2048, 8192], + 3, + ttl.tensor.Layout.TILE, + 1280, + 8, + ttl.tensor.DataType.BFLOAT16, + ), + ( # AllGather + FF1 Matmul llama 1k prefill + 8, + 1, + [1, 1, 1024, 8192], + 3, + ttl.tensor.Layout.TILE, + 4096, + 4, + ttl.tensor.DataType.BFLOAT16, ), - # ( # Removed due to unknown hang on CI, see issue # https://github.com/tenstorrent/tt-metal/issues/11617 - # 8, - # 1, - # [1, 1, 1024, 1024 * 32], - # 3, - # ttl.tensor.Layout.TILE, - # 1024, - # ), ], ) @pytest.mark.parametrize( - "input_dtype", + "ag_input_dtype", [ ttl.tensor.DataType.BFLOAT16, - # ttl.tensor.DataType.BFLOAT8_B, # https://github.com/tenstorrent/tt-metal/issues/9686 ], ) @pytest.mark.parametrize( - "mem_config", + "mem_config_input, mem_config_ag, mem_config_mm", + [ + ( + ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.DRAM), + ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.DRAM), + ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.DRAM), + ) + ], +) +@pytest.mark.parametrize( + "enable_async", [ - ttl.tensor.MemoryConfig(buffer_type=ttl.tensor.BufferType.DRAM), - # ttl.tensor.MemoryConfig(buffer_type=ttl.tensor.BufferType.L1), # https://github.com/tenstorrent/tt-metal/issues/9686 + True, + # False, ], ) -@pytest.mark.parametrize("enable_async", [True, False]) def test_all_gather_matmul_on_t3000_post_commit( t3k_mesh_device, num_devices, - input_shape, + ag_output_shape, + dim, + num_links, + ag_input_dtype, + layout, + matmul_output_dim, + matmul_config, + matmul_weights_dtype, + max_in0_block_w, + mem_config_input, + mem_config_ag, + mem_config_mm, + use_program_cache, + function_level_defaults, + enable_async, +): + run_all_gather_matmul_on_t3000_impl( + t3k_mesh_device, + num_devices, + ag_output_shape, + dim, + num_links, + ag_input_dtype, + layout, + matmul_output_dim, + matmul_config, + matmul_weights_dtype, + max_in0_block_w, + mem_config_input, + mem_config_ag, + mem_config_mm, + ) + + +# @skip_for_wormhole_b0() # Used to disable test +@skip_for_grayskull("Requires eth connected devices to run") +@pytest.mark.parametrize( + "matmul_config", + [ + "matmul_1d", + ], +) +@pytest.mark.parametrize( + "num_devices, num_links, ag_output_shape, dim, layout, matmul_output_dim, max_in0_block_w, matmul_weights_dtype", + [ + ( + 8, + 1, + [1, 1, 32, 16 * 32], + 3, + ttl.tensor.Layout.TILE, + 1024, + 2, + ttl.tensor.DataType.BFLOAT16, + ), + ( # Llama decode FF1 + 8, + 1, + [1, 1, 32, 1024 * 8], + 3, + ttl.tensor.Layout.TILE, + 4096, + 2, # TODO: update + ttl.tensor.DataType.BFLOAT4_B, + ), + ( # Llama decode Fused QKV + 8, + 1, + [1, 1, 32, 1024 * 8], + 3, + ttl.tensor.Layout.TILE, + 1280, + 2, # TODO: update + ttl.tensor.DataType.BFLOAT4_B, + ), + ], +) +@pytest.mark.parametrize( + "ag_input_dtype", + [ + ttl.tensor.DataType.BFLOAT16, + ], +) +@pytest.mark.parametrize( + "mem_config_input, mem_config_ag, mem_config_mm", + [ + ( + ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.DRAM), + ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.DRAM), + ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.DRAM), + ) + ], +) +@pytest.mark.parametrize( + "enable_async", + [ + True, + False, + ], +) +def test_all_gather_matmul_1d_on_t3000_post_commit( + t3k_mesh_device, + num_devices, + ag_output_shape, + dim, + num_links, + ag_input_dtype, + layout, + matmul_output_dim, + matmul_config, + matmul_weights_dtype, + max_in0_block_w, + mem_config_input, + mem_config_ag, + mem_config_mm, + use_program_cache, + function_level_defaults, + enable_async, +): + run_all_gather_matmul_on_t3000_impl( + t3k_mesh_device, + num_devices, + ag_output_shape, + dim, + num_links, + ag_input_dtype, + layout, + matmul_output_dim, + matmul_config, + matmul_weights_dtype, + max_in0_block_w, + mem_config_input, + mem_config_ag, + mem_config_mm, + ) + + +# @skip_for_wormhole_b0() # Used to disable test +@skip_for_grayskull("Requires eth connected devices to run") +@pytest.mark.parametrize( + "matmul_config", + [ + "matmul_1d", + ], +) +@pytest.mark.parametrize( + "num_devices, num_links, ag_output_shape, dim, layout, matmul_output_dim, max_in0_block_w, matmul_weights_dtype", + [ + ( # Llama decode Selfout + 8, + 1, + [1, 1, 32, 1024 * 8], + 3, + ttl.tensor.Layout.TILE, + 1024, + 8, + ttnn.bfloat8_b, + ), + ( + 8, + 1, + [1, 1, 32, 1024 * 8], + 3, + ttl.tensor.Layout.TILE, + 1024, + 32, + ttnn.bfloat8_b, + ), + ], +) +@pytest.mark.parametrize( + "ag_input_dtype", + [ + ttl.tensor.DataType.BFLOAT16, + ], +) +@pytest.mark.parametrize( + "mem_config_input, mem_config_ag, mem_config_mm, mem_config_weights", + [ + ( + ttnn.MemoryConfig( + ttnn.TensorMemoryLayout.WIDTH_SHARDED, + ttnn.BufferType.L1, + ttnn.ShardSpec( + ttnn.CoreRangeSet( + { + ttnn.CoreRange( + ttnn.CoreCoord(0, 0), + ttnn.CoreCoord(7, 0), + ), + } + ), + [ + 32, # shard_height + 128, # shard width + ], + ttnn.ShardOrientation.ROW_MAJOR, + False, + ), + ), + ttnn.MemoryConfig( + ttnn.TensorMemoryLayout.WIDTH_SHARDED, + ttnn.BufferType.L1, + ttnn.ShardSpec( + ttnn.CoreRangeSet( + { + ttnn.CoreRange( + ttnn.CoreCoord(0, 0), + ttnn.CoreCoord(7, 0), + ), + } + ), + [ + 32, # shard_height + 8192 // 8, # shard_width_hidden_dim_across_8_cores + ], + ttnn.ShardOrientation.ROW_MAJOR, + False, + ), + ), + ttnn.MemoryConfig(ttnn.TensorMemoryLayout.WIDTH_SHARDED, ttnn.BufferType.L1), + ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.DRAM), + ) + ], +) +@pytest.mark.parametrize( + "enable_async", + [ + True, + False, + ], +) +def test_all_gather_matmul_1d_llama_selfout_on_t3000_post_commit( + t3k_mesh_device, + num_devices, + ag_output_shape, dim, num_links, - input_dtype, + ag_input_dtype, layout, matmul_output_dim, - mem_config, + matmul_config, + matmul_weights_dtype, + max_in0_block_w, + mem_config_input, + mem_config_ag, + mem_config_mm, + mem_config_weights, use_program_cache, function_level_defaults, enable_async, @@ -220,12 +532,17 @@ def test_all_gather_matmul_on_t3000_post_commit( run_all_gather_matmul_on_t3000_impl( t3k_mesh_device, num_devices, - input_shape, + ag_output_shape, dim, num_links, - input_dtype, + ag_input_dtype, layout, matmul_output_dim, - mem_config, - function_level_defaults, + matmul_config, + matmul_weights_dtype, + max_in0_block_w, + mem_config_input, + mem_config_ag, + mem_config_mm, + mem_config_weights, ) diff --git a/ttnn/CMakeLists.txt b/ttnn/CMakeLists.txt index 4b8975c6748..2e0599cc4f4 100644 --- a/ttnn/CMakeLists.txt +++ b/ttnn/CMakeLists.txt @@ -13,9 +13,11 @@ set(ALL_TTNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/all_gather/all_gather_pybind.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/all_gather/device/multi_core/all_gather_op_multi_core.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/all_gather_matmul.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/all_gather_matmul_pybind.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/device/all_gather_matmul_op.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/device/multi_core/all_gather_matmul_op_multi_core.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/ccl_op_fusion.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/ccl_common.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/ccl_host_datastructures.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/line_all_gather/device/line_all_gather_op.cpp diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.hpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.hpp index 9fafe381060..373b6971c71 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.hpp @@ -13,7 +13,7 @@ #include "tt_metal/host_api.hpp" #include "ttnn/operations/ccl/ccl_host_datastructures.hpp" #include "ttnn/operations/ccl/ccl_common.hpp" -#include "ttnn/operations/experimental/ccl/ccl_op_fusion.hpp" +#include "ttnn/operations/ccl/ccl_op_fusion.hpp" #include "ttnn/run_operation.hpp" diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_interleaved_ring_gather_receive_writer.cpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_interleaved_ring_gather_receive_writer.cpp index 8c66a6dff87..9f748d261e8 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_interleaved_ring_gather_receive_writer.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_interleaved_ring_gather_receive_writer.cpp @@ -130,12 +130,7 @@ void kernel_main() { OpSignaler op_signaler; if constexpr(fuse_op) { - op_signaler = OpSignaler( - get_compile_time_arg_val(25), - get_compile_time_arg_val(26), - get_compile_time_arg_val(27), - arg_idx - ); + op_signaler = OpSignaler(arg_idx); } diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_interleaved_ring_gather_send_writer.cpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_interleaved_ring_gather_send_writer.cpp index 114d104e852..d9ec210b87f 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_interleaved_ring_gather_send_writer.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_interleaved_ring_gather_send_writer.cpp @@ -7,6 +7,8 @@ #include "ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_ring_gather_utils.hpp" #include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/sharded_tensor_addr_gen.hpp" #include "ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_adapters.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_sync_utils.hpp" + void kernel_main() { uint32_t arg_idx = 0; @@ -45,18 +47,26 @@ void kernel_main() { constexpr uint32_t eth_sender_noc_y = get_compile_time_arg_val(19); constexpr uint32_t half_cb_n_pages = get_compile_time_arg_val(20); constexpr uint32_t num_buffers_per_channel = get_compile_time_arg_val(21); + constexpr bool fuse_op = get_compile_time_arg_val(22); + + /* Args for overlapped all gather */ + OpSignaler op_signaler; + + if constexpr(fuse_op) { + op_signaler = OpSignaler(arg_idx); + } static_assert(half_cb_n_pages > rem_num_pages, "half_cb_n_pages must be greater than or equal to rem_num_pages"); #ifdef SHARDED_MEM_LAYOUT - constexpr tt::tt_metal::TensorMemoryLayout output_tensor_memory_layout = static_cast(get_compile_time_arg_val(22)); - constexpr uint32_t output_tensor_shard_grid_height = get_compile_time_arg_val(23); - constexpr uint32_t output_tensor_shard_grid_width = get_compile_time_arg_val(24); - constexpr uint32_t output_tensor_shard_grid_start_y_logical = get_compile_time_arg_val(25); - constexpr uint32_t output_tensor_shard_grid_start_x_logical = get_compile_time_arg_val(26); - constexpr uint32_t output_tensor_shard_pages_per_shard_y = get_compile_time_arg_val(27); - constexpr uint32_t output_tensor_shard_pages_per_shard_x = get_compile_time_arg_val(28); - constexpr bool output_tensor_shard_grid_transposed = get_compile_time_arg_val(29) != 0; + constexpr tt::tt_metal::TensorMemoryLayout output_tensor_memory_layout = static_cast(get_compile_time_arg_val(23)); + constexpr uint32_t output_tensor_shard_grid_height = get_compile_time_arg_val(24); + constexpr uint32_t output_tensor_shard_grid_width = get_compile_time_arg_val(25); + constexpr uint32_t output_tensor_shard_grid_start_y_logical = get_compile_time_arg_val(26); + constexpr uint32_t output_tensor_shard_grid_start_x_logical = get_compile_time_arg_val(27); + constexpr uint32_t output_tensor_shard_pages_per_shard_y = get_compile_time_arg_val(28); + constexpr uint32_t output_tensor_shard_pages_per_shard_x = get_compile_time_arg_val(29); + constexpr bool output_tensor_shard_grid_transposed = get_compile_time_arg_val(30) != 0; #endif constexpr uint32_t cb_id_in0 = tt::CB::c_in0; @@ -138,6 +148,11 @@ void kernel_main() { pop_filler_pages_from_cb(cb_id_in0, half_cb_n_pages - rem_num_pages); } + if constexpr(fuse_op) { + // Synchronize and signal that the local tensor slice is available + op_signaler.synchronize_workers_and_signal_op(); + } + // num_transfers = num_devices - 1 for (uint32_t i = 1; i < num_transfers; ++i) { if constexpr(num_full_chunks > 0) { diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/multi_core/all_gather_op_multi_core.cpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/multi_core/all_gather_op_multi_core.cpp index 79cd6050799..c7c27bd3ef6 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/multi_core/all_gather_op_multi_core.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/multi_core/all_gather_op_multi_core.cpp @@ -20,7 +20,7 @@ #include #include -#include "ttnn/operations/experimental/ccl/ccl_op_fusion.hpp" +#include "ttnn/operations/ccl/ccl_op_fusion.hpp" using namespace tt::constants; @@ -207,8 +207,11 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers_helper( /* All gather fusion */ bool fuse_op = fused_op_signaler.has_value(); + + // Need a seperate signaler for the sender workers, to handle the first tensor slice that is locally available + std::optional fused_op_signaler_sender_workers; if (fuse_op) { - fused_op_signaler->init_fused_op(device); + fused_op_signaler_sender_workers = fused_op_signaler.value(); } auto const& all_gather_config = AllGatherConfig(input_tensor, output_tensor, dim, ring_size, num_links, topology, num_edm_buffers_per_channel, fuse_op); auto const& topology_config = ttnn::ccl::RingTopology(device, topology, sender_device_id, receiver_device_id, num_links, ring_size, ring_index); @@ -263,6 +266,7 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers_helper( constexpr uint32_t max_num_full_send_directions = 2; // number of worker cores is 2x this since there is 1 worker for the sender buffer and 1 worker for the receiver buffer uint32_t global_num_workers = num_links * all_gather_config.get_num_eth_buffers_per_edm() * num_full_send_directions; + uint32_t global_num_workers_per_direction = global_num_workers / num_full_send_directions; uint32_t total_worker_core_pairs_used = global_num_workers; uint32_t num_input_pages = input_tensor.buffer()->size() / input_page_size; @@ -478,6 +482,9 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers_helper( /* All gather fusion */ if (fuse_op) { fused_op_signaler->init_all_gather(program, device, receiver_workers, receiver_worker_cores); + if (direction == 1) { + fused_op_signaler_sender_workers->init_all_gather(program, device, sender_workers, sender_worker_cores); + } } { @@ -726,12 +733,15 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers_helper( static_cast(device->ethernet_core_from_logical_core(worker_eth_sender_core).x), static_cast(device->ethernet_core_from_logical_core(worker_eth_sender_core).y), static_cast(cb_num_pages / 2), - static_cast(num_edm_buffers_per_channel) + static_cast(num_edm_buffers_per_channel), + + static_cast(fuse_op && direction == 1) }; if (is_sharded) { emit_sharded_tensor_kernel_ct_args(device, output_tensor, worker_writer_sender_ct_args, output_pages_per_shard_y, output_pages_per_shard_x); } + log_trace(tt::LogOp, "Worker {} SW CT args", b); log_trace(tt::LogOp, "\tall_gather_config.is_output_dram(): {}", all_gather_config.is_output_dram()); log_trace(tt::LogOp, "\tsender_num_transfers: {}", sender_worker_num_transfers.at(i).at(b)); @@ -774,6 +784,19 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers_helper( emit_sharded_tensor_kernel_rt_args(device, output_tensor, worker_writer_sender_rt_args); } + if (fuse_op && direction == 1) { + fused_op_signaler_sender_workers->push_all_gather_fused_op_rt_args( + worker_writer_sender_rt_args, + global_num_workers_per_direction, + b, + is_clockwise_direction ? 0 : 1, + std::make_optional( + {fused_op_signaler->all_gather_worker_cores_noc[0], + fused_op_signaler->all_gather_worker_sync_semaphore} + ) + ); + } + log_trace(tt::LogOp, "Worker {} SW rt args", b); log_trace(tt::LogOp, "\toutput_buffer->address(): {}", output_buffer->address()); log_trace(tt::LogOp, "\tsender_eth_buffer_addrs: {}", sender_eth_buffer_addrs.at(b)); @@ -936,16 +959,6 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers_helper( emit_sharded_tensor_kernel_ct_args(device, output_tensor, worker_writer_receiver_ct_args, output_pages_per_shard_y, output_pages_per_shard_x); } - if (fuse_op) { - uint32_t global_num_workers_per_direction = global_num_workers / num_full_send_directions; - fused_op_signaler->emit_all_gather_fused_op_ct_args(worker_writer_receiver_ct_args, global_num_workers_per_direction, b); - } else { - // Push dummy args so that kernel doesn't error out at compile time from the lack of args when fuse_op=false - for (uint32_t w = 0; w < experimental::ccl::AllGatherFusedOpSignaler::get_num_ct_args(); ++w) { - worker_writer_receiver_ct_args.push_back(static_cast(0)); - } - } - log_trace(tt::LogOp, "Worker {} RW ct args", b); log_trace(tt::LogOp, "\tall_gather_config.is_output_dram(): {}", all_gather_config.is_output_dram()); log_trace(tt::LogOp, "\treceiver_num_transfers: {}", receiver_worker_num_transfers.at(i).at(b)); @@ -994,7 +1007,12 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers_helper( /* All Gather fusion */ if (fuse_op) { - fused_op_signaler->emit_all_gather_fused_op_rt_args(worker_writer_receiver_rt_args, is_clockwise_direction ? 0 : 1); + fused_op_signaler->push_all_gather_fused_op_rt_args( + worker_writer_receiver_rt_args, + global_num_workers_per_direction, + b, + is_clockwise_direction ? 0 : 1 + ); } log_trace(tt::LogOp, "Worker {} RW rt args", b); diff --git a/ttnn/cpp/ttnn/operations/ccl/ccl_op_fusion.cpp b/ttnn/cpp/ttnn/operations/ccl/ccl_op_fusion.cpp new file mode 100644 index 00000000000..e20979f242e --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/ccl_op_fusion.cpp @@ -0,0 +1,179 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "tt_metal/host_api.hpp" +#include "tt_metal/impl/program/program.hpp" +#include "ttnn/operations/ccl/ccl_op_fusion.hpp" + +namespace ttnn { +namespace experimental { +namespace ccl { + +void AllGatherFusedOpSignaler::init_fused_op( + const std::vector& fused_op_receiver_cores_noc, + const std::vector& fused_op_receiver_signal_semaphores +) { + this->fused_op_receiver_cores_noc = fused_op_receiver_cores_noc; + this->fused_op_receiver_signal_semaphores = fused_op_receiver_signal_semaphores; + this->num_fused_op_cores_to_signal = fused_op_receiver_cores_noc.size(); + + initialized_fused_op = true; +} + +void AllGatherFusedOpSignaler::init_all_gather( + Program& program, + Device const* device, + + CoreRangeSet const& all_gather_workers, + std::vector& all_gather_worker_cores +) { + // Create the sync semaphore for the all gather workers + this->all_gather_worker_sync_semaphore = CreateSemaphore(program, all_gather_workers, 0); + + // Get the noc coords for the all gather workers + this->all_gather_worker_cores_noc.clear(); + for (const auto& core : all_gather_worker_cores) { + this->all_gather_worker_cores_noc.push_back(device->worker_core_from_logical_core(core)); + } + initialized_all_gather = true; +} + +void AllGatherFusedOpSignaler::push_all_gather_fused_op_rt_args( + std::vector& out_rt_args, + + uint32_t num_workers_to_sync, + uint32_t curr_worker_index, + uint32_t all_gather_direction, + std::optional start_signal_core_sem_pair +) { + TT_ASSERT(initialized_fused_op && initialized_all_gather, "AllGatherFusedOpSignaler not initialized fully."); + + out_rt_args.push_back(static_cast(num_workers_to_sync)); + out_rt_args.push_back(static_cast(curr_worker_index)); + out_rt_args.push_back(static_cast(this->all_gather_worker_sync_semaphore)); + + // Push the worker core noc coords + for (const auto& core : this->all_gather_worker_cores_noc) { + out_rt_args.push_back(static_cast(core.x)); + out_rt_args.push_back(static_cast(core.y)); + } + + // Push the number of fused op cores to signal + out_rt_args.push_back(static_cast(this->num_fused_op_cores_to_signal)); + + // Push the fused op receiver core noc coords + for (const auto& core : this->fused_op_receiver_cores_noc) { + out_rt_args.push_back(static_cast(core.x)); + out_rt_args.push_back(static_cast(core.y)); + } + + // Push the fused op signal semaphore addrs. Direction 0: clockwise, Direction 1: counter-clockwise + out_rt_args.push_back( + static_cast(this->fused_op_receiver_signal_semaphores[all_gather_direction]) + ); + + // Push the params for the start signal. Only wait for/send start signal if all_gather direction is counter clockwise + bool wait_for_start_signal = !start_signal_core_sem_pair.has_value() && all_gather_direction == 1; + bool send_start_signal = start_signal_core_sem_pair.has_value() && all_gather_direction == 1; + + out_rt_args.push_back(static_cast(wait_for_start_signal)); + out_rt_args.push_back(static_cast(send_start_signal)); + + if (send_start_signal) { + out_rt_args.push_back(static_cast(start_signal_core_sem_pair->core.x)); + out_rt_args.push_back(static_cast(start_signal_core_sem_pair->core.y)); + out_rt_args.push_back(static_cast(start_signal_core_sem_pair->sem_id)); + } + +} + + +// Used to propagate semaphore information from matmul to all_gather in all_gather_matmul op +void MatmulFusedOpSignaler::init_all_gather( + uint32_t num_transfers, + uint32_t ring_size, + uint32_t start_ring_index, + uint32_t tensor_slice_shape_width, + uint32_t output_page_offset, + bool is_clockwise_direction, + + uint32_t weight_output_page_offset +) { + this->num_transfers = num_transfers; + this->ring_size = ring_size; + this->start_ring_index = start_ring_index; + this->tensor_slice_shape_width = tensor_slice_shape_width; + this->output_page_offset = output_page_offset; + this->is_clockwise_dir = is_clockwise_direction; + + this->weight_output_page_offset = weight_output_page_offset; + + initialized_all_gather = true; +} + +void MatmulFusedOpSignaler::init_fused_op( + Program& program, + Device const* device, + const std::variant& core_range_to_signal +) { + // Clear the existing receiver cores + this->fused_op_receiver_cores_noc.clear(); + + // Visit the variant to handle CoreRange and CoreRangeSet differently + std::visit([&](auto& arg) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + // Handle CoreRange + const auto& cores = grid_to_cores(arg.start_coord, arg.end_coord, true); + for (auto& core : cores) { + this->fused_op_receiver_cores_noc.push_back(device->worker_core_from_logical_core(core)); + } + } else if constexpr (std::is_same_v) { + // Handle CoreRangeSet + for (const auto& range : arg.ranges()) { + const auto& cores = grid_to_cores(range.start_coord, range.end_coord, true); + for (auto& core : cores) { + this->fused_op_receiver_cores_noc.push_back(device->worker_core_from_logical_core(core)); + } + } + } + }, core_range_to_signal); + + // Create the semaphores + this->fused_op_receiver_signal_semaphores.push_back(CreateSemaphore(program, core_range_to_signal, 0)); + this->fused_op_receiver_signal_semaphores.push_back(CreateSemaphore(program, core_range_to_signal, 0)); + + // Set the number of fused op cores to signal + this->num_fused_op_cores_to_signal = this->fused_op_receiver_cores_noc.size(); + + initialized_fused_op = true; +} + +void MatmulFusedOpSignaler::push_matmul_fused_op_rt_args( + std::vector& out_rt_args, + bool use_in1_offset +) { + TT_ASSERT(initialized_all_gather && initialized_fused_op, "MatmulFusedOpSignaler not initialized fully."); + + out_rt_args.push_back(static_cast(this->num_transfers)); + out_rt_args.push_back(static_cast(this->ring_size)); + out_rt_args.push_back(static_cast(this->start_ring_index)); + out_rt_args.push_back(static_cast(this->tensor_slice_shape_width)); + if (use_in1_offset) { + out_rt_args.push_back(static_cast(this->weight_output_page_offset)); + out_rt_args.push_back(static_cast((this->ring_size - 1) * this->weight_output_page_offset)); + } else { + out_rt_args.push_back(static_cast(this->output_page_offset)); + out_rt_args.push_back(static_cast((this->ring_size - 1) * this->output_page_offset)); + } + out_rt_args.push_back(static_cast(this->is_clockwise_dir)); + out_rt_args.push_back(static_cast(this->fused_op_receiver_signal_semaphores[0])); + out_rt_args.push_back(static_cast(this->fused_op_receiver_signal_semaphores[1])); +} + + + +} // namespace ccl +} // namespace experimental +} // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/ccl/ccl_op_fusion.hpp b/ttnn/cpp/ttnn/operations/ccl/ccl_op_fusion.hpp new file mode 100644 index 00000000000..4c1fab16978 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/ccl_op_fusion.hpp @@ -0,0 +1,108 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "tt_metal/host_api.hpp" +#include "tt_metal/impl/program/program.hpp" + +namespace ttnn { +namespace experimental { +namespace ccl { + +struct CoreSemPair{ + CoreCoord core = {0, 0}; + uint32_t sem_id = 0; + + CoreSemPair() {} + CoreSemPair(CoreCoord core, uint32_t sem_id) : core(core), sem_id(sem_id) {} +}; + +struct AllGatherFusedOpSignaler { + uint32_t num_fused_op_cores_to_signal; + std::vector fused_op_receiver_cores_noc; + std::vector fused_op_receiver_signal_semaphores; + + /* All Gather specific */ + std::vector all_gather_worker_cores_noc; + uint32_t all_gather_worker_sync_semaphore; + + bool initialized_fused_op = false; + bool initialized_all_gather = false; + + AllGatherFusedOpSignaler() {} + + void init_fused_op( + const std::vector& fused_op_receiver_cores_noc, + const std::vector& fused_op_receiver_signal_semaphores + ); + + void init_all_gather( + Program& program, + Device const* device, + + CoreRangeSet const& all_gather_workers, + std::vector& all_gather_worker_cores + ); + + void push_all_gather_fused_op_rt_args( + std::vector& out_rt_args, + + uint32_t num_workers_to_sync, + uint32_t curr_worker_index, + uint32_t all_gather_direction, + std::optional start_signal_core_sem_pair = {} + ); + +}; + +// Used to propagate semaphore information from matmul to all_gather in all_gather_matmul op +struct MatmulFusedOpSignaler { + uint32_t num_fused_op_cores_to_signal; + std::vector fused_op_receiver_cores_noc; + std::vector fused_op_receiver_signal_semaphores; // [dir0, dir1] + + /* All Gather specs */ + uint32_t num_transfers; + uint32_t ring_size; + uint32_t start_ring_index; + uint32_t tensor_slice_shape_width; + uint32_t output_page_offset; + uint32_t last_output_page_offset; + bool is_clockwise_dir; + + uint32_t weight_output_page_offset; + + bool initialized_all_gather = false; + bool initialized_fused_op = false; + + MatmulFusedOpSignaler() {} + + void init_all_gather( + uint32_t num_transfers, + uint32_t ring_size, + uint32_t start_ring_index, + uint32_t tensor_slice_shape_width, + uint32_t output_page_offset, + bool is_clockwise_direction, + + uint32_t weight_tensor_width + ); + + void init_fused_op( + Program& program, + Device const* device, + const std::variant& core_range_to_signal + ); + + void push_matmul_fused_op_rt_args( + std::vector& out_rt_args, + bool use_in1_offset + ); +}; + + +} // namespace ccl +} // namespace experimental +} // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_sync_utils.hpp b/ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_sync_utils.hpp index 5193b5d9a48..b82507920fa 100644 --- a/ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_sync_utils.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_sync_utils.hpp @@ -8,24 +8,33 @@ #include "debug/assert.h" #include "debug/dprint.h" #include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" +#include // Called by the master worker to synchronize with the slave workers FORCE_INLINE void master_sync_slaves( - const uint32_t num_workers_to_sync, /* Used to get slave worker's sem addrs */ + const uint32_t num_workers_to_sync, const uint32_t* worker_noc_coords, const uint32_t worker_sync_sem_addr, - const uint64_t remote_op_l1_semaphore_addr) { + /* Used to signal the remote op */ + const uint32_t num_fused_op_cores_to_signal, + const uint32_t* fused_op_cores_noc_coords, + const uint32_t fused_op_sem_addr, + + bool wait_for_start_signal) { // Wait for all the slaves to finish their work volatile tt_l1_ptr uint32_t* master_l1_semaphore_addr = reinterpret_cast(worker_sync_sem_addr); - noc_semaphore_wait(master_l1_semaphore_addr, num_workers_to_sync - 1); + noc_semaphore_wait(master_l1_semaphore_addr, num_workers_to_sync - 1 + (uint32_t)wait_for_start_signal); // DPRINT << "MASTER SYNCED WITH SLAVES" << ENDL(); // Send signal to op - noc_semaphore_inc(remote_op_l1_semaphore_addr, 1); + for (uint32_t i = 0; i < num_fused_op_cores_to_signal; i++) { + uint64_t remote_fused_op_l1_semaphore_addr = get_noc_addr(fused_op_cores_noc_coords[i * 2], fused_op_cores_noc_coords[i * 2 + 1], fused_op_sem_addr); + noc_semaphore_inc(remote_fused_op_l1_semaphore_addr, 1); + } // DPRINT << "MASTER SIGNALED REMOTE OP" << ENDL(); // Clear the master semaphore, so that it can be used again @@ -75,29 +84,41 @@ struct OpSignaler { uint32_t num_workers_to_sync; uint32_t* workers_noc_coords; // Worker NOC coordinates [x1, y1, x2, y2...], first one is for master uint32_t worker_sync_sem_addr; - uint64_t signal_op_sem_noc_addr; + + uint32_t num_fused_op_cores_to_signal; + uint32_t* signal_op_cores_noc_coords; + uint32_t signal_op_sem_addr; uint32_t curr_worker_is_master; + + // Params for start signal + bool wait_for_start_signal; + bool send_start_signal; + uint32_t* start_signal_receiver_core_noc; + uint32_t start_signal_receiver_sem_addr; + bool initialized = false; OpSignaler() {} - OpSignaler( - uint32_t num_workers_to_sync, - uint32_t curr_worker_index, - uint32_t worker_sync_sem_id, - uint32_t& rt_args_idx) : - num_workers_to_sync(num_workers_to_sync) { - - this-> worker_sync_sem_addr = get_semaphore(worker_sync_sem_id); + OpSignaler(uint32_t& rt_args_idx) { // Runtime args + this->num_workers_to_sync = get_arg_val(rt_args_idx++); + uint32_t curr_worker_index = get_arg_val(rt_args_idx++); + this-> worker_sync_sem_addr = get_semaphore(get_arg_val(rt_args_idx++)); this->workers_noc_coords = (uint32_t*)get_arg_addr(increment_arg_idx(rt_args_idx, this->num_workers_to_sync * 2)); // Skip over the number of workers - uint32_t op_worker_noc_x = get_arg_val(rt_args_idx++); - uint32_t op_worker_noc_y = get_arg_val(rt_args_idx++); - uint32_t signal_op_sem_addr = get_semaphore(get_arg_val(rt_args_idx++)); - // Get the remote sem addresses to signal the op - this->signal_op_sem_noc_addr = get_noc_addr(op_worker_noc_x, op_worker_noc_y, signal_op_sem_addr); + this->num_fused_op_cores_to_signal = get_arg_val(rt_args_idx++); + this->signal_op_cores_noc_coords = (uint32_t*)get_arg_addr(increment_arg_idx(rt_args_idx, this->num_fused_op_cores_to_signal * 2)); + this->signal_op_sem_addr = get_semaphore(get_arg_val(rt_args_idx++)); + + + this->wait_for_start_signal = get_arg_val(rt_args_idx++); + this->send_start_signal = get_arg_val(rt_args_idx++); + if (this->send_start_signal) { + this->start_signal_receiver_core_noc = (uint32_t*)get_arg_addr(increment_arg_idx(rt_args_idx, 2)); + this->start_signal_receiver_sem_addr = get_semaphore(get_arg_val(rt_args_idx++)); + } uint32_t master_worker_noc_x = this->workers_noc_coords[0]; uint32_t master_worker_noc_y = this->workers_noc_coords[1]; @@ -112,7 +133,32 @@ struct OpSignaler { ASSERT(this->initialized); if (this->curr_worker_is_master) { - master_sync_slaves(this->num_workers_to_sync, this->workers_noc_coords, this->worker_sync_sem_addr, this->signal_op_sem_noc_addr); + master_sync_slaves( + this->num_workers_to_sync, + this->workers_noc_coords, + this->worker_sync_sem_addr, + + this->num_fused_op_cores_to_signal, + this->signal_op_cores_noc_coords, + this->signal_op_sem_addr, + + this->wait_for_start_signal + ); + + // Once start signal is received, no need to wait for it again + this->wait_for_start_signal = false; + + if (this->send_start_signal) { + uint64_t remote_master_l1_semaphore_addr = get_noc_addr( + this->start_signal_receiver_core_noc[0], + this->start_signal_receiver_core_noc[1], + this->start_signal_receiver_sem_addr + ); + noc_semaphore_inc(remote_master_l1_semaphore_addr, 1); + + // Once start signal is sent, no need to send it again + this->send_start_signal = false; + } } else { slave_sync_master(this->workers_noc_coords, this->worker_sync_sem_addr); } @@ -151,3 +197,157 @@ FORCE_INLINE void advance_start_page_idx( } } + + +struct MatmulOpReceiver { + static constexpr uint32_t num_directions = 2; // ASSUMPTION: Always 2 directions + uint32_t num_tensor_slices; + + bool wait_for_op_signal; + uint32_t num_transfers; + uint32_t ring_size; + uint32_t tensor_slice_shape_width; // In tiles + uint32_t output_page_offset; + uint32_t last_output_page_offset; + + uint32_t num_blocks; + uint32_t num_blocks_per_slice; + + // Used to track internal state + std::array ring_idxs; + std::array start_page_idxs; + std::array is_clockwise_dirs; + std::array signal_op_semaphore_addr_ptrs; + uint32_t curr_dir; + uint32_t curr_transfer_idx; + + + bool initialized = false; + + MatmulOpReceiver() {} + + MatmulOpReceiver( + bool wait_for_op_signal, + uint32_t& rt_args_idx, + uint32_t num_blocks, + uint32_t tiles_per_block // Across the same dimension as tensor_slice_shape_width + ) : wait_for_op_signal(wait_for_op_signal), + num_blocks(num_blocks) + { + + // Runtime args + this->num_transfers = get_arg_val(rt_args_idx++); + this->ring_size = get_arg_val(rt_args_idx++); + uint32_t start_ring_index = get_arg_val(rt_args_idx++); + this->tensor_slice_shape_width = get_arg_val(rt_args_idx++); + this->output_page_offset = get_arg_val(rt_args_idx++); + this->last_output_page_offset = get_arg_val(rt_args_idx++); + uint32_t is_clockwise_direction = get_arg_val(rt_args_idx++); + + if (this->wait_for_op_signal) { + this->signal_op_semaphore_addr_ptrs[0] = + reinterpret_cast(get_semaphore(get_arg_val(rt_args_idx++))); + this->signal_op_semaphore_addr_ptrs[1] = + reinterpret_cast(get_semaphore(get_arg_val(rt_args_idx++))); + } + + this->num_tensor_slices = this->num_transfers * this->num_directions; + + // Setup internal states for bi-direction + this->ring_idxs[0] = start_ring_index; + this->ring_idxs[1] = start_ring_index; + + this->start_page_idxs[0] = this->ring_idxs[0] * this->output_page_offset; + this->start_page_idxs[1] = this->ring_idxs[1] * this->output_page_offset; + + this->is_clockwise_dirs[0] = is_clockwise_direction; + this->is_clockwise_dirs[1] = !is_clockwise_direction; + + this->num_blocks_per_slice = this->tensor_slice_shape_width / tiles_per_block; + ASSERT(this->num_tensor_slices * this->num_blocks_per_slice == this->num_blocks); + + this->curr_dir = is_clockwise_direction ? 1 : 0; // Anti-clockwise direction is the first since it has local slice + this->curr_transfer_idx = 0; + + this->initialized = true; + } + + + void update_current_block_start_tile_id( + const uint32_t& block_idx, + uint32_t& curr_block_start_tile_id, + const uint32_t& tensor_start_tile_id + ) { + ASSERT(this->initialized); + + if (block_idx % this->num_blocks_per_slice == 0) { // Aligned to the start of a tensor slice + + if (this->curr_transfer_idx != 0) { // Skip update for local slice + + // Update the start page idx of the tensor slice in curr_direction + advance_start_page_idx( + this->start_page_idxs[this->curr_dir], + this->ring_idxs[this->curr_dir], + this->ring_size, + this->is_clockwise_dirs[this->curr_dir], + this->output_page_offset, + this->last_output_page_offset + ); + } + + // Use the new start page idx to find the start tile id of the current tensor slice + curr_block_start_tile_id = tensor_start_tile_id + this->start_page_idxs[this->curr_dir]; + + // Index of the current tensor slice in a certain direction + uint32_t tensor_slice_cnt = (this->curr_transfer_idx) / this->num_directions; + + // Wait for a sempaphore signal to start processing the tensor slice + if (this->wait_for_op_signal) { + noc_semaphore_wait_min(this->signal_op_semaphore_addr_ptrs[this->curr_dir], tensor_slice_cnt + 1); + } + + // Update the relevant internal states + this->curr_transfer_idx++; + this->curr_dir = !this->curr_dir; // Change direction + } + } + + uint32_t align_to_slice_and_sync(uint32_t block_idx, uint32_t sender_id) { + ASSERT(this->initialized); + + // Align the id to the start of the tensor slice in order of processing from all gather + uint32_t block_id = this->ring_idxs[this->curr_dir]; + + if (block_idx % this->num_blocks_per_slice == 0) { // Aligned to the start of a tensor slice + + if (this->curr_transfer_idx != 0) { // Skip update for local slice + // Change direction + this->curr_dir = !this->curr_dir; + + // Update the start page idx of the tensor slice in curr_direction + // We only want to know the update for the ring index + advance_start_page_idx( + this->start_page_idxs[this->curr_dir], + this->ring_idxs[this->curr_dir], + this->ring_size, + this->is_clockwise_dirs[this->curr_dir], + this->output_page_offset, + this->last_output_page_offset + ); + } + + // Update the alignment + block_id = this->ring_idxs[this->curr_dir]; + + // Wait for a sempaphore signal to start processing the tensor slice + if (this->wait_for_op_signal && block_id == sender_id) { + uint32_t tensor_slice_cnt = (this->curr_transfer_idx) / this->num_directions; + noc_semaphore_wait_min(this->signal_op_semaphore_addr_ptrs[this->curr_dir], tensor_slice_cnt + 1); + } + + this->curr_transfer_idx++; + } + + return block_id; + } +}; diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/all_gather_matmul.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/all_gather_matmul.cpp new file mode 100644 index 00000000000..b023a12977b --- /dev/null +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/all_gather_matmul.cpp @@ -0,0 +1,34 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "ttnn/operations/experimental/ccl/all_gather_matmul/device/all_gather_matmul_op.hpp" +#include "ttnn/operations/experimental/ccl/all_gather_matmul/all_gather_matmul.hpp" +// #include "ttnn/cpp/ttnn/multi_device.hpp" + +namespace ttnn { +namespace operations::experimental::ccl { + + +std::vector ExecuteAllGatherMatmul::invoke( + const ttnn::Tensor& input_tensor, + const ttnn::Tensor& weight_tensor, + const uint32_t dim, + const CoreCoord all_gather_core_grid_offset, + const uint32_t num_links, + const std::optional& memory_config_ag, + const std::optional& memory_config_mm, + const bool transpose_a, + const bool transpose_b, + const std::optional dtype, + const std::optional program_config, + const std::optional& activation, + const std::optional compute_kernel_config, + const std::optional core_grid +) { + return ttnn::operations::experimental::ccl::all_gather_matmul(input_tensor, weight_tensor, dim, all_gather_core_grid_offset, num_links, memory_config_ag, memory_config_mm, transpose_a, transpose_b, dtype, program_config, activation, compute_kernel_config, core_grid); +} + + +} // namespace opereations::experimental::ccl +} // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/all_gather_matmul_op.hpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/all_gather_matmul.hpp similarity index 79% rename from ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/all_gather_matmul_op.hpp rename to ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/all_gather_matmul.hpp index 34e01113bd4..e3164b54db6 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/all_gather_matmul_op.hpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/all_gather_matmul.hpp @@ -4,6 +4,8 @@ #pragma once +#include "ttnn/decorators.hpp" +#include "common/core_coord.h" #include "ttnn/operations/experimental/ccl/all_gather_matmul/device/all_gather_matmul_op.hpp" #include "ttnn/cpp/ttnn/multi_device.hpp" @@ -17,16 +19,15 @@ struct ExecuteAllGatherMatmul { const uint32_t dim, const CoreCoord all_gather_core_grid_offset, const uint32_t num_links = 1, - const std::optional& memory_config = std::nullopt, + const std::optional& memory_config_ag = std::nullopt, + const std::optional& memory_config_mm = std::nullopt, const bool transpose_a = false, const bool transpose_b = false, const std::optional dtype = std::nullopt, const std::optional program_config = std::nullopt, const std::optional& activation = std::nullopt, const std::optional compute_kernel_config = std::nullopt, - const std::optional core_grid = std::nullopt) { - return ttnn::operations::experimental::ccl::all_gather_matmul(input_tensor, weight_tensor, dim, all_gather_core_grid_offset, num_links, memory_config, transpose_a, transpose_b, dtype, program_config, activation, compute_kernel_config, core_grid); - } + const std::optional core_grid = std::nullopt); }; } // namespace opereations::experimental::ccl diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/all_gather_matmul_pybind.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/all_gather_matmul_pybind.cpp index b06b70e2575..139be18826f 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/all_gather_matmul_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/all_gather_matmul_pybind.cpp @@ -8,7 +8,7 @@ #include #include "ttnn/cpp/pybind11/decorators.hpp" -#include "ttnn/operations/experimental/ccl/all_gather_matmul/all_gather_matmul_op.hpp" +#include "ttnn/operations/experimental/ccl/all_gather_matmul/all_gather_matmul.hpp" #include "ttnn/types.hpp" namespace ttnn::operations::experimental::ccl { @@ -32,7 +32,8 @@ void py_bind_all_gather_matmul(pybind11::module& module) { Keyword Args: * :attr:`num_links` (int): Number of links to use for the all-gather operation. - * :attr:`memory_config` (Optional[ttnn.MemoryConfig]): Memory configuration for the operation. + * :attr:`memory_config_ag` (Optional[ttnn.MemoryConfig]): Memory configuration for the All Gather operation. + * :attr:`memory_config_mm` (Optional[ttnn.MemoryConfig]): Memory configuration for the Matmul operation. * :attr:`transpose_a` (bool) * :attr:`transpose_b` (bool) * :attr:`dtype` (Optional[DataType]) @@ -55,7 +56,8 @@ void py_bind_all_gather_matmul(pybind11::module& module) { py::arg("all_gather_core_grid_offset"), py::kw_only(), py::arg("num_links") = 1, - py::arg("memory_config") = std::nullopt, + py::arg("memory_config_ag") = std::nullopt, + py::arg("memory_config_mm") = std::nullopt, py::arg("transpose_a") = false, py::arg("transpose_b") = false, py::arg("dtype") = std::nullopt, diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/device/all_gather_matmul_op.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/device/all_gather_matmul_op.cpp index cb3a046ca0b..b98d1a941dc 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/device/all_gather_matmul_op.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/device/all_gather_matmul_op.cpp @@ -5,14 +5,9 @@ #include "common/core_coord.h" #include "ttnn/operations/ccl/all_gather/device/all_gather_op.hpp" #include "ttnn/deprecated/tt_dnn/op_library/math.hpp" - #include "tt_metal/host_api.hpp" - #include "ttnn/tensor/tensor_utils.hpp" - #include "eth_l1_address_map.h" - - #include "ttnn/operations/experimental/ccl/all_gather_matmul/device/all_gather_matmul_op.hpp" /* All Gather Matmul fusion includes */ @@ -27,11 +22,36 @@ void AllGatherMatmul::validate(const std::vector &input_tensors, const s TT_ASSERT(input_tensors.size() == 4, "AllGatherMatmul requires 4 input tensors: [input, weight, all_gather_output, datacopy_output]"); + auto& input_tensor = input_tensors[0]; + auto& all_gather_output_tensor = input_tensors[1]; + auto& weight_tensor = input_tensors[2]; + // All Gather validate - this->all_gather_struct.validate({input_tensors[0]}); + this->all_gather_struct.validate({input_tensor}); // Matmul validate. - this->matmul_struct.validate({input_tensors[1], input_tensors[2]}, optional_input_tensors); + this->matmul_struct.validate({all_gather_output_tensor, weight_tensor}, optional_input_tensors); + + // All Gather Matmul validate + TT_FATAL(this->all_gather_struct.dim == 3, "AllGatherMatmul requires dim=3 for the AllGather operaitons."); + TT_FATAL(input_tensor.get_legacy_shape()[0] == 1 && input_tensor.get_legacy_shape()[1] == 1, "AllGatherMatmul requires input tensor to have batch size of 1."); + std::visit([&] (const auto& config) { + using ProgramConfigType = std::decay_t; + if (not (std::is_same_v || std::is_same_v)) { + TT_FATAL("Unsupported MatmulProgramConfig type for AllGatherMatmul."); + } + }, this->matmul_struct.program_config.value()); + + + const auto& all_gather_output_tensor_shard_spec = all_gather_output_tensor.shard_spec(); + if (all_gather_output_tensor_shard_spec.has_value()) { + + auto const& shard_grid = all_gather_output_tensor_shard_spec->grid.bounding_box(); + auto const& shard_grid_start = shard_grid.start_coord; + auto const& shard_grid_end = shard_grid.end_coord; + const uint32_t num_all_gather_output_shards = (shard_grid_end.y - shard_grid_start.y + 1) * (shard_grid_end.x - shard_grid_start.x + 1); + TT_FATAL(this->all_gather_struct.ring_size == num_all_gather_output_shards, "AllGatherMatmul requires number of tensor slices to equal the number of output shards of the all_gather."); + } } std::vector AllGatherMatmul::compute_output_shapes(const std::vector &input_tensors) const { @@ -63,7 +83,30 @@ std::vector AllGatherMatmul::create_output_tensors(const std::vector & input_tensors, const std::vector>& optional_input_tensors, std::vector &output_tensors) const { // Return the AllGatherMatmul program with callbacks - return all_gather_matmul_multi_core_with_workers(input_tensors[0], output_tensors[0], output_tensors[2], this->all_gather_struct.dim, this->all_gather_struct.num_links, this->all_gather_struct.ring_size, this->all_gather_struct.ring_index, this->all_gather_struct.receiver_device_id, this->all_gather_struct.sender_device_id, this->all_gather_struct.topology, this->all_gather_core_grid_offset); + return all_gather_matmul_multi_core_with_workers( + input_tensors[0], // input_tensor + output_tensors[0], // all_gather_output_tensor + output_tensors[2], // datacopy_output_tensor + input_tensors[2], // weight_tensor + output_tensors[1], // matmul_output_tensor + + /* All Gather Params */ + this->all_gather_struct.dim, + this->all_gather_struct.num_links, + this->all_gather_struct.ring_size, + this->all_gather_struct.ring_index, + this->all_gather_struct.receiver_device_id, + this->all_gather_struct.sender_device_id, + this->all_gather_struct.topology, + this->all_gather_core_grid_offset, + + /* Matmul Params */ + {}, // Bias + this->matmul_struct.bcast_batch.value(), + this->matmul_struct.compute_kernel_config.value(), + this->matmul_struct.program_config.value(), + this->matmul_struct.untilize_out + ); } } // namespace experimental @@ -78,7 +121,8 @@ std::vector all_gather_matmul( const uint32_t dim, const CoreCoord all_gather_core_grid_offset, const uint32_t num_links, - const std::optional& memory_config, + const std::optional& memory_config_ag, + const std::optional& memory_config_mm, const bool transpose_a, const bool transpose_b, const std::optional dtype, @@ -98,7 +142,7 @@ std::vector all_gather_matmul( operation::launch_op( - [dim, all_gather_core_grid_offset, num_links, memory_config, transpose_a, transpose_b, dtype, program_config, activation, compute_kernel_config, core_grid, devices]( + [dim, all_gather_core_grid_offset, num_links, memory_config_ag, memory_config_mm, transpose_a, transpose_b, dtype, program_config, activation, compute_kernel_config, core_grid, devices]( const std::vector& input_tensors, const std::vector>& optional_input_tensors, const std::vector>& optional_output_tensors) mutable -> std::vector { @@ -107,7 +151,7 @@ std::vector all_gather_matmul( const auto& weight_tensor = input_tensors[1]; /* AllGather setup */ - ttnn::AllGather all_gather_struct = ttnn::create_all_gather_struct(input_tensor, dim, num_links, memory_config, devices); + ttnn::AllGather all_gather_struct = ttnn::create_all_gather_struct(input_tensor, dim, num_links, memory_config_ag, devices); // Create the all gather output tensor used as input (activation) to the matmul ttnn::Tensor all_gather_out_tensor = all_gather_struct.create_output_tensors({input_tensor})[0]; @@ -128,7 +172,7 @@ std::vector all_gather_matmul( /*parameters=*/operations::matmul::Matmul{ program_config, /*bcast_batch=*/std::nullopt, - memory_config.value_or(input_tensor.memory_config()), + memory_config_mm.value_or(input_tensor.memory_config()), dtype.value_or(input_tensor.get_dtype()), compute_kernel_config, /*untilize_out=*/false, diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/device/all_gather_matmul_op.hpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/device/all_gather_matmul_op.hpp index 2dae7fe7ad3..ed469dc924b 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/device/all_gather_matmul_op.hpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/device/all_gather_matmul_op.hpp @@ -23,7 +23,7 @@ /* Fusion includes */ #include "ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.hpp" #include "ttnn/cpp/ttnn/operations/matmul/device/matmul_op.hpp" -#include "ttnn/operations/experimental/ccl/ccl_op_fusion.hpp" +#include "ttnn/operations/ccl/ccl_op_fusion.hpp" namespace ttnn { @@ -56,6 +56,8 @@ operation::ProgramWithCallbacks all_gather_matmul_multi_core_with_workers( /* General Params */ const Tensor& input_tensor, Tensor& all_gather_output_tensor, + Tensor& datacopy_output_tensor, + const Tensor& weight_tensor, Tensor& matmul_output_tensor, const uint32_t dim, const uint32_t num_links, @@ -64,24 +66,14 @@ operation::ProgramWithCallbacks all_gather_matmul_multi_core_with_workers( const std::optional receiver_device_id, const std::optional sender_device_id, all_gather_op::Topology topology, - const CoreCoord core_grid_offset = CoreCoord(0, 0) + const CoreCoord core_grid_offset, /* Matmul Params */ - // const std::optional bias, - // Tensor &mm_output_tensor, - // bool bcast_batch, - // CoreCoord compute_with_storage_grid_size, - // DeviceComputeKernelConfig compute_kernel_config, - // uint32_t in0_block_w, - // uint32_t out_subblock_h, - // uint32_t out_subblock_w, - // uint32_t per_core_M, - // uint32_t per_core_N, - // bool fuse_batch, - // bool transpose_mcast, - // std::optional fused_activation, - // bool untilize_out - + const std::optional bias, + bool bcast_batch, + DeviceComputeKernelConfig compute_kernel_config, + const operations::matmul::MatmulProgramConfig program_config, + bool untilize_out ); } // namespace experimental @@ -96,7 +88,8 @@ std::vector all_gather_matmul( const uint32_t dim, const CoreCoord all_gather_core_grid_offset, const uint32_t num_links = 1, - const std::optional& memory_config = std::nullopt, + const std::optional& memory_config_ag = std::nullopt, + const std::optional& memory_config_mm = std::nullopt, const bool transpose_a = false, const bool transpose_b = false, const std::optional dtype = std::nullopt, diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/device/kernels/datacopy.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/device/kernels/datacopy.cpp index db29d122a6a..5d5447c645a 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/device/kernels/datacopy.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/device/kernels/datacopy.cpp @@ -38,11 +38,16 @@ void kernel_main() { const uint32_t signal_op_sem_addr_dir1 = get_semaphore(get_compile_time_arg_val(14)); constexpr uint32_t max_buffer_size = get_compile_time_arg_val(15); + // Compile time args for matmul signal semaphore + constexpr uint32_t num_matmul_cores_to_signal = get_compile_time_arg_val(16); - // Runtime args - const uint32_t dram_buffer_src_addr = get_arg_val(0); - const uint32_t dram_buffer_dst_addr = get_arg_val(1); + // Runtime args + uint32_t rt_args_idx = 0; + const uint32_t dram_buffer_src_addr = get_arg_val(rt_args_idx++); + const uint32_t dram_buffer_dst_addr = get_arg_val(rt_args_idx++); + const uint32_t* matmul_signal_sems = (uint32_t*)get_arg_addr(increment_arg_idx(rt_args_idx, 2)); // Matmul signal semaphore address + const uint32_t* matmul_cores_noc_coords = (uint32_t*)get_arg_addr(increment_arg_idx(rt_args_idx, 2 * num_matmul_cores_to_signal)); // Matmul core NOC coordinates [x1, y1, x2, y2...] // Setup buffers constexpr uint32_t cb_id_in0 = tt::CB::c_in0; @@ -75,9 +80,7 @@ void kernel_main() { ttnn::ccl::coord_t tensor_slice_shape = {tensor_slice_shape_width, tensor_slice_shape_height}; uint32_t ring_index_dir0 = start_ring_index; - // Adjust to include copying over the local tensor slice, which is at start_ring_index. If clockwise, then dir1 will be anticlockwise, which means that the ring index will update in ascending order. - // Therefore, to undo that, we subtract 1. If anticlockwise, then dir1 will be clockwise, which means that the ring index will update in descending order. Therefore, to undo that, we add 1. - uint32_t ring_index_dir1 = (is_clockwise_direction ? start_ring_index - 1 : start_ring_index + 1) % ring_size; + uint32_t ring_index_dir1 = start_ring_index; uint32_t start_page_idx_dir0 = ring_index_dir0 * output_page_offset; uint32_t start_page_idx_dir1 = ring_index_dir1 * output_page_offset; @@ -89,13 +92,20 @@ void kernel_main() { // Main for loop where each iteration handles a tensor slice // The loop alternates between the two directions, hence it runs for double the number of transfers - for (uint32_t i = 0, dir = 0; i < num_transfers * 2; i++, dir = !dir) { + for (uint32_t i = 0, dir = 1; i < num_transfers * 2; i++, dir = !dir) { uint32_t tensor_slice_cnt = i / 2; // Since we are alternating between the two directions, we need to divide by 2 to get the correct tensor slice count in each direction // Update location in input and output tensor in DRAM - advance_start_page_idx(start_page_idxs[dir], ring_idxs[dir], ring_size, is_clockwise_dirs[dir], output_page_offset, last_output_page_offset); - - // DPRINT << "DIRECTION 0 RING INDEX>>>> " << ring_index_dir0 << ENDL(); + if (i > 0) { // Skip update for local tensor slice + advance_start_page_idx( + start_page_idxs[dir], + ring_idxs[dir], + ring_size, + is_clockwise_dirs[dir], + output_page_offset, + last_output_page_offset + ); + } uint32_t curr_page_in_idx = start_page_idxs[dir]; uint32_t curr_page_out_idx = start_page_idxs[dir]; @@ -109,11 +119,17 @@ void kernel_main() { ttnn::ccl::coord_t offset_worker_slice = {0, 0}; // DPRINT << "WAITING FOR OP SIGNAL IN DATACOPY" << ENDL(); - if ((!dir && tensor_slice_cnt < num_transfers) || (dir && tensor_slice_cnt < num_transfers - 1)) { // Using dir as a selector to select which logic to choose, because dir = 1 will have 1 less semaphore (because one is local already) - noc_semaphore_wait_min(signal_op_semaphore_ptrs[dir], tensor_slice_cnt + 1); - } + noc_semaphore_wait_min(signal_op_semaphore_ptrs[dir], tensor_slice_cnt + 1); // DPRINT << "RECEIVED OP SIGNAL IN DATACOPY" << ENDL(); + // Signal matmul to begin + for (uint32_t i = 0; i < num_matmul_cores_to_signal; i++) { + auto& matmul_core_noc_x = matmul_cores_noc_coords[i * 2]; + auto& matmul_core_noc_y = matmul_cores_noc_coords[i * 2 + 1]; + auto remote_matmul_signal_sem_addr = get_noc_addr(matmul_core_noc_x, matmul_core_noc_y, get_semaphore(matmul_signal_sems[dir])); + noc_semaphore_inc(remote_matmul_signal_sem_addr, 1); + } + // To account for the granularity based on restrictions on the buffer size of L1 for (uint32_t pages = 0; pages < num_pages;) { uint32_t num_pages_to_transfer = std::min(num_pages - pages, max_buffer_size); @@ -161,5 +177,4 @@ void kernel_main() { } - } diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/device/multi_core/all_gather_matmul_op_multi_core.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/device/multi_core/all_gather_matmul_op_multi_core.cpp index af5e5d7d19d..3f74f14ae07 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/device/multi_core/all_gather_matmul_op_multi_core.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/device/multi_core/all_gather_matmul_op_multi_core.cpp @@ -21,8 +21,8 @@ #include #include "ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/device/all_gather_matmul_op.hpp" -#include "ttnn/operations/experimental/ccl/ccl_op_fusion.hpp" - +#include "ttnn/operations/ccl/ccl_op_fusion.hpp" +#include "ttnn/operations/matmul/device/matmul_op.hpp" using namespace tt::constants; @@ -33,7 +33,7 @@ using Tensors = std::vector; // Used to hold the return values for setup_datacopy struct DatacopyParams { - std::vector datacopy_cores; + std::vector datacopy_cores_noc; std::vector datacopy_signal_semaphore_ids; std::optional> datacopy_override_runtime_arguments_callback; }; @@ -49,12 +49,14 @@ DatacopyParams setup_datacopy( const uint32_t ring_index, all_gather_op::Topology topology, - CoreCoord datacopy_core_coord + CoreCoord datacopy_core_coord, + const ttnn::experimental::ccl::MatmulFusedOpSignaler& matmul_fused_op_signaler ) { std::size_t num_edm_buffers_per_channel = 2; + const auto& device = input_tensor.device(); auto const& all_gather_config = ttnn::AllGatherConfig(input_tensor, all_gather_output_tensor, dim, ring_size, num_links, topology, num_edm_buffers_per_channel, true); - const uint32_t num_transfers = 4; // ring_size - 1; + const uint32_t num_transfers = 4; auto tensor_slicer = ttnn::ccl::InterleavedRingAllGatherTensorSlicer ( input_tensor, @@ -67,6 +69,10 @@ DatacopyParams setup_datacopy( // Select cores for datacopy (single core for now) CoreRangeSet datacopy_workers = CoreRangeSet({CoreRange(datacopy_core_coord)}); std::vector all_datacopy_cores = corerange_to_cores(datacopy_workers, std::nullopt, true); + std::vector all_datacopy_cores_noc; + for (auto core : all_datacopy_cores) { + all_datacopy_cores_noc.push_back(device->worker_core_from_logical_core(core)); + } // Setup semaphores used to signal datacopy. TODO: instead of datacopy, this should be matmul cores // Dir0: first half of all gather (clockwise), Dir1: second half of all gather (counter-clockwise) @@ -110,6 +116,7 @@ DatacopyParams setup_datacopy( static_cast(datacopy_signal_semaphore_id_dir0), static_cast(datacopy_signal_semaphore_id_dir1), static_cast(datacopy_buffer_size), + static_cast(matmul_fused_op_signaler.num_fused_op_cores_to_signal) }; uint32_t cb_id_in0 = tt::CB::c_in0; @@ -123,8 +130,16 @@ DatacopyParams setup_datacopy( std::vector datacopy_rt_args = { static_cast(all_gather_output_buffer->address()), static_cast(datacopy_output_buffer->address()), + static_cast(matmul_fused_op_signaler.fused_op_receiver_signal_semaphores[0]), + static_cast(matmul_fused_op_signaler.fused_op_receiver_signal_semaphores[1]), }; + // Push the matmul core NOC coordinates + for (auto coord : matmul_fused_op_signaler.fused_op_receiver_cores_noc) { + datacopy_rt_args.push_back(static_cast(coord.x)); + datacopy_rt_args.push_back(static_cast(coord.y)); + } + std::map kernel_defines = { {"TILED_LAYOUT", "1"}, {"INTERLEAVED_MEM_LAYOUT", "1"} @@ -170,7 +185,7 @@ DatacopyParams setup_datacopy( // Return the core coordinates and semaphore address return { - .datacopy_cores = all_datacopy_cores, + .datacopy_cores_noc = all_datacopy_cores_noc, .datacopy_signal_semaphore_ids = {datacopy_signal_semaphore_id_dir0, datacopy_signal_semaphore_id_dir1}, .datacopy_override_runtime_arguments_callback = override_runtime_arguments_callback }; @@ -180,34 +195,192 @@ DatacopyParams setup_datacopy( // For ring all-gather, we can send sub-sections of input tensor in opposite directions // For linear all-gather though, we must ensure we send full tensors in BOTH directions // (in other words, disable the "bidirectional" send flag) -operation::ProgramWithCallbacks experimental::all_gather_matmul_multi_core_with_workers(const Tensor& input_tensor, Tensor& all_gather_output_tensor, Tensor& datacopy_output_tensor, const uint32_t dim, const uint32_t num_links, const uint32_t ring_size, const uint32_t ring_index, const std::optional receiver_device_id, const std::optional sender_device_id, all_gather_op::Topology topology, const CoreCoord core_grid_offset) { +operation::ProgramWithCallbacks experimental::all_gather_matmul_multi_core_with_workers( + const Tensor& input_tensor, + Tensor& all_gather_output_tensor, + Tensor& datacopy_output_tensor, + const Tensor& weight_tensor, + Tensor& matmul_output_tensor, + + /* All Gather Params */ + const uint32_t dim, + const uint32_t num_links, + const uint32_t ring_size, + const uint32_t ring_index, + const std::optional receiver_device_id, + const std::optional sender_device_id, + all_gather_op::Topology topology, + const CoreCoord core_grid_offset, + + /* Matmul Params */ + const std::optional bias, + bool bcast_batch, + DeviceComputeKernelConfig compute_kernel_config, + const operations::matmul::MatmulProgramConfig program_config, + bool untilize_out + +) { tt::tt_metal::Program program{}; + bool use_datacopy = false; /* Enable for debugging purposes */ - DatacopyParams datacopy_params = setup_datacopy(program, input_tensor, all_gather_output_tensor, datacopy_output_tensor, dim, num_links, ring_size, ring_index, topology, {0, 0}); - const auto& datacopy_override_runtime_arguments_callback = datacopy_params.datacopy_override_runtime_arguments_callback; + ////////////// Params for fused op signalers ////////////// - std::optional fused_op_signaler = AllGatherFusedOpSignaler(datacopy_params.datacopy_cores, datacopy_params.datacopy_signal_semaphore_ids); + auto tensor_slicer = ttnn::ccl::InterleavedRingAllGatherTensorSlicer ( + input_tensor, + all_gather_output_tensor, + dim, + ring_index + ); + bool is_clockwise_direction = true; + const uint32_t num_transfers = 4; + const uint32_t weight_tensor_width = weight_tensor.get_legacy_shape()[3] / 32; + + //////////////////////////////////////////////////////// + + // Create a matmul signal info object that gets populated by the matmul kernel + std::optional matmul_fused_op_signaler = ttnn::experimental::ccl::MatmulFusedOpSignaler(); + matmul_fused_op_signaler->init_all_gather( + num_transfers, + ring_size, + ring_index, + tensor_slicer.num_cols, + tensor_slicer.output_page_offset, + is_clockwise_direction, + tensor_slicer.num_cols * weight_tensor_width /* weight_output_page_offset: stride across a tensor slice in the weight_tensor */ + ); - // Pass in the datacopy cores and sempahore address (Using optional arguments) - operation::ProgramWithCallbacks program_with_callbacks = ttnn::all_gather_multi_core_with_workers_helper(program, input_tensor, all_gather_output_tensor, dim, num_links, ring_size, ring_index, receiver_device_id, sender_device_id, topology, fused_op_signaler, core_grid_offset); - auto all_gather_override_runtime_arguments_callback = program_with_callbacks.override_runtime_arguments_callback; + // Matmul + std::optional matmul_program_with_callbacks; + std::optional> matmul_override_runtime_arguments_callback; + + std::visit([&] (const auto& config) { + using ProgramConfigType = std::decay_t; + if (std::is_same_v) { + matmul_program_with_callbacks = operations::matmul::matmul_multi_core_reuse_mcast_2d_optimized_helper( + program, + all_gather_output_tensor, + weight_tensor, + bias, + matmul_output_tensor, + bcast_batch, + compute_kernel_config, + config, + untilize_out, + matmul_fused_op_signaler + ); + matmul_override_runtime_arguments_callback = matmul_program_with_callbacks->override_runtime_arguments_callback; + } else if (std::is_same_v) { + matmul_program_with_callbacks = operations::matmul::matmul_multi_core_reuse_mcast_1d_optimized_helper( + program, + all_gather_output_tensor, + weight_tensor, + bias, + matmul_output_tensor, + bcast_batch, + compute_kernel_config, + config, + untilize_out, + matmul_fused_op_signaler + ); + matmul_override_runtime_arguments_callback = matmul_program_with_callbacks->override_runtime_arguments_callback; + } else { + TT_FATAL("Unsupported MatmulProgramConfig type"); + } + }, program_config); + + if (!matmul_program_with_callbacks.has_value()) { + TT_FATAL("Matmul program with callbacks not created"); + } + + + // Datacopy + const CoreCoord datacopy_core_coord = {0, 7}; // Pick a location that doesn't overlap with all_gather/matmul + DatacopyParams datacopy_params; + if (use_datacopy) { + datacopy_params = setup_datacopy( + matmul_program_with_callbacks->program, + input_tensor, + all_gather_output_tensor, + datacopy_output_tensor, + dim, + num_links, + ring_size, + ring_index, + topology, + datacopy_core_coord, + matmul_fused_op_signaler.value() + ); + } + + // Create the all gather fused op signaler + std::optional all_gather_fused_op_signaler = AllGatherFusedOpSignaler(); + if (use_datacopy) { + all_gather_fused_op_signaler->init_fused_op( + datacopy_params.datacopy_cores_noc, + datacopy_params.datacopy_signal_semaphore_ids + ); + } else { + all_gather_fused_op_signaler->init_fused_op( + matmul_fused_op_signaler->fused_op_receiver_cores_noc, + matmul_fused_op_signaler->fused_op_receiver_signal_semaphores + ); + } + + // All Gather + operation::ProgramWithCallbacks program_with_callbacks = ttnn::all_gather_multi_core_with_workers_helper( + matmul_program_with_callbacks->program, + input_tensor, + all_gather_output_tensor, + dim, + num_links, + ring_size, + ring_index, + receiver_device_id, + sender_device_id, + topology, + all_gather_fused_op_signaler, + core_grid_offset); + const auto all_gather_override_runtime_arguments_callback = program_with_callbacks.override_runtime_arguments_callback; - // Fuse the datacopy and all-gather overriden runtime arguments callbacks - auto override_runtime_arguments_callback = [all_gather_override_runtime_arguments_callback, datacopy_override_runtime_arguments_callback] ( + + + // Fuse the override runtime arguments callbacks + auto override_runtime_arguments_callback = [use_datacopy, all_gather_override_runtime_arguments_callback, matmul_override_runtime_arguments_callback, datacopy_params] ( const void* operation, Program& program, const std::vector& input_tensors, const std::vector>& optional_input_tensors, const std::vector& output_tensors ) { + if (matmul_override_runtime_arguments_callback.has_value()) { + matmul_override_runtime_arguments_callback.value()( + operation, + program, + {input_tensors[1], input_tensors[2]}, /* all gather output tensor, weight tensor */ + optional_input_tensors, + {output_tensors[1]} /* matmul output tensor */ + ); + } if (all_gather_override_runtime_arguments_callback.has_value()) { - all_gather_override_runtime_arguments_callback.value()(operation, program, input_tensors, optional_input_tensors, output_tensors); + all_gather_override_runtime_arguments_callback.value()( + operation, + program, + {input_tensors[0], output_tensors[0]}, /* input tensor, all gather output tensor */ + optional_input_tensors, + {output_tensors[0]} /* all gather output tensor */ + ); } - if (datacopy_override_runtime_arguments_callback.has_value()) { - datacopy_override_runtime_arguments_callback.value()(operation, program, input_tensors, optional_input_tensors, output_tensors); + if (use_datacopy && datacopy_params.datacopy_override_runtime_arguments_callback.has_value()) { + datacopy_params.datacopy_override_runtime_arguments_callback.value()( + operation, + program, + {input_tensors[0], output_tensors[0]}, /* input tensor, all gather output tensor */ + optional_input_tensors, + {output_tensors[2]} /* datacopy output tensor */ + ); } }; diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/ccl_op_fusion.hpp b/ttnn/cpp/ttnn/operations/experimental/ccl/ccl_op_fusion.hpp deleted file mode 100644 index b94e270a9bd..00000000000 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/ccl_op_fusion.hpp +++ /dev/null @@ -1,117 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#pragma once - -#include "tt_metal/host_api.hpp" -#include "tt_metal/impl/program/program.hpp" - -namespace ttnn { -namespace experimental { -namespace ccl { - -struct AllGatherFusedOpSignaler { - - std::vector fused_op_receiver_cores; - std::vector fused_op_receiver_cores_noc; - std::vector fused_op_receiver_signal_semaphores; - - /* All Gather specific */ - std::vector all_gather_worker_cores_noc; - uint32_t all_gather_worker_sync_semaphore; - - bool initialized_fused_op = false; - bool initialized_all_gather = false; - - - AllGatherFusedOpSignaler( - const std::vector& fused_op_receiver_cores, - const std::vector& fused_op_receiver_signal_semaphores) - : fused_op_receiver_cores(fused_op_receiver_cores), - fused_op_receiver_signal_semaphores(fused_op_receiver_signal_semaphores) { - - } - - void init_fused_op( - Device const* device - ) { - // Get the noc coords for the fused op receiver cores - this->fused_op_receiver_cores_noc.clear(); - for (const auto& core : this->fused_op_receiver_cores) { - this->fused_op_receiver_cores_noc.push_back(device->worker_core_from_logical_core(core)); - } - initialized_fused_op = true; - } - - void init_all_gather( - Program& program, - Device const* device, - - CoreRangeSet const& all_gather_workers, - std::vector& all_gather_worker_cores - ) { - // Create the sync semaphore for the all gather workers - this->all_gather_worker_sync_semaphore = CreateSemaphore(program, all_gather_workers, 0); - - // Get the noc coords for the all gather workers - this->all_gather_worker_cores_noc.clear(); - for (const auto& core : all_gather_worker_cores) { - this->all_gather_worker_cores_noc.push_back(device->worker_core_from_logical_core(core)); - } - initialized_all_gather = true; - } - - void emit_all_gather_fused_op_ct_args( - std::vector& ct_args, - - uint32_t num_workers_to_sync, - uint32_t curr_worker_index - ) { - TT_ASSERT(initialized_fused_op && initialized_all_gather, "AllGatherFusedOpSignaler not initialized fully."); - - ct_args.push_back(static_cast(num_workers_to_sync)); - ct_args.push_back(static_cast(curr_worker_index)); - ct_args.push_back(static_cast(this->all_gather_worker_sync_semaphore)); - - } - - - void emit_all_gather_fused_op_rt_args( - std::vector& rt_args, - - bool all_gather_direction - ) { - TT_ASSERT(initialized_fused_op && initialized_all_gather, "AllGatherFusedOpSignaler not initialized fully."); - - // Push the worker core noc coords - for (const auto& core : this->all_gather_worker_cores_noc) { - rt_args.push_back(static_cast(core.x)); - rt_args.push_back(static_cast(core.y)); - } - - // Push the fused op receiver core noc coords - for (const auto& core : this->fused_op_receiver_cores_noc) { - rt_args.push_back(static_cast(core.x)); - rt_args.push_back(static_cast(core.y)); - } - - // Push the fused op signal semaphore addrs - // Direction 0: clockwise - // Direction 1: counter-clockwise - rt_args.push_back( - static_cast(this->fused_op_receiver_signal_semaphores[all_gather_direction]) - ); - - } - - static uint32_t get_num_ct_args() { - return 3; - } - -}; - - -} // namespace ccl -} // namespace experimental -} // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in0_sender_padding.cpp b/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in0_sender_padding.cpp index 95e671239c0..2e76f54632e 100644 --- a/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in0_sender_padding.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in0_sender_padding.cpp @@ -6,19 +6,24 @@ #include "dataflow_api.h" #include "hostdevcommon/common_values.hpp" +#include "debug/dprint.h" + +#include "ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_sync_utils.hpp" void kernel_main() { + + uint32_t rt_args_idx = 0; // in0 tensor args - const uint32_t in0_tensor_addr = get_arg_val(0); - uint32_t in0_tensor_start_tile_id = get_arg_val(1); + const uint32_t in0_tensor_addr = get_arg_val(rt_args_idx++); + uint32_t in0_tensor_start_tile_id = get_arg_val(rt_args_idx++); // in0 mcast args - const uint32_t in0_mcast_dest_noc_start_x = get_arg_val(2); - const uint32_t in0_mcast_dest_noc_start_y = get_arg_val(3); - const uint32_t in0_mcast_dest_noc_end_x = get_arg_val(4); - const uint32_t in0_mcast_dest_noc_end_y = get_arg_val(5); + const uint32_t in0_mcast_dest_noc_start_x = get_arg_val(rt_args_idx++); + const uint32_t in0_mcast_dest_noc_start_y = get_arg_val(rt_args_idx++); + const uint32_t in0_mcast_dest_noc_end_x = get_arg_val(rt_args_idx++); + const uint32_t in0_mcast_dest_noc_end_y = get_arg_val(rt_args_idx++); // padding args - const uint32_t last_block_h = get_arg_val(6); + const uint32_t last_block_h = get_arg_val(rt_args_idx++); // COMPILE TIME ARGS // interleaved accessor args @@ -46,6 +51,18 @@ void kernel_main() { constexpr uint32_t MtKt = get_compile_time_arg_val(15); // if 0 constexpr uint32_t batch = get_compile_time_arg_val(16); + constexpr bool fuse_op = (bool)get_compile_time_arg_val(17); + + MatmulOpReceiver fused_op_receiver; + if constexpr (fuse_op) { + fused_op_receiver = MatmulOpReceiver( + true, /* wait_for_op_signal */ + rt_args_idx, + num_blocks, + in0_block_w /* tiles_per_block (in the same dimension as tensor slice) */ + ); + } + constexpr uint32_t cb_id_in0 = 0; constexpr uint32_t in0_single_tile_size_bytes = get_tile_size(cb_id_in0); constexpr uint32_t in0_block_size_bytes = in0_block_num_tiles * in0_single_tile_size_bytes; @@ -98,6 +115,13 @@ void kernel_main() { for (uint32_t b = 0; b < batch; ++b) { uint32_t in0_tensor_current_block_start_tile_id = in0_tensor_start_tile_id; for (uint32_t block = 0; block < num_blocks; ++block) { + if constexpr (fuse_op) { + fused_op_receiver.update_current_block_start_tile_id( + block, + in0_tensor_current_block_start_tile_id, + in0_tensor_start_tile_id + ); + } #ifndef IN0_SHARDED // Operand 0 cb_reserve_back(cb_id_in0, in0_block_num_tiles); diff --git a/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in0_sender_receiver_padding_block_sharded.cpp b/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in0_sender_receiver_padding_block_sharded.cpp index c3fdda3f5bd..5f81e38fdc2 100644 --- a/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in0_sender_receiver_padding_block_sharded.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in0_sender_receiver_padding_block_sharded.cpp @@ -6,6 +6,7 @@ #include "dataflow_api.h" #include "hostdevcommon/common_values.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_sync_utils.hpp" void kernel_main() { constexpr bool core_has_output_block_work = (bool)get_compile_time_arg_val(0); @@ -28,14 +29,27 @@ void kernel_main() { constexpr uint32_t in0_block_w = get_compile_time_arg_val(14); constexpr uint32_t batch = get_compile_time_arg_val(15); - - const uint32_t sender_id = get_arg_val(0); - const uint32_t in0_mcast_dest_noc_start_x = get_arg_val(1); - const uint32_t in0_mcast_dest_noc_start_y = get_arg_val(2); - const uint32_t in0_mcast_dest_noc_end_x = get_arg_val(3); - const uint32_t in0_mcast_dest_noc_end_y = get_arg_val(4); - tt_l1_ptr uint32_t* in0_mcast_noc_x = (tt_l1_ptr uint32_t*)(get_arg_addr(5)); - tt_l1_ptr uint32_t* in0_mcast_noc_y = (tt_l1_ptr uint32_t*)(get_arg_addr(5 + num_x)); + constexpr bool fuse_op = (bool)get_compile_time_arg_val(16); + + uint32_t rt_args_idx = 0; + const uint32_t sender_id = get_arg_val(rt_args_idx++); + const uint32_t in0_mcast_dest_noc_start_x = get_arg_val(rt_args_idx++); + const uint32_t in0_mcast_dest_noc_start_y = get_arg_val(rt_args_idx++); + const uint32_t in0_mcast_dest_noc_end_x = get_arg_val(rt_args_idx++); + const uint32_t in0_mcast_dest_noc_end_y = get_arg_val(rt_args_idx++); + tt_l1_ptr uint32_t* in0_mcast_noc_x = (tt_l1_ptr uint32_t*)(get_arg_addr(increment_arg_idx(rt_args_idx, num_x))); + tt_l1_ptr uint32_t* in0_mcast_noc_y = (tt_l1_ptr uint32_t*)(get_arg_addr(increment_arg_idx(rt_args_idx, num_y))); + + + MatmulOpReceiver fused_op_receiver; + if constexpr (fuse_op) { + fused_op_receiver = MatmulOpReceiver( + true, /* wait_for_op_signal */ + rt_args_idx, + num_blocks, + in0_block_w /* tiles_per_block (in the same dimension as tensor slice) */ + ); + } constexpr uint32_t cb_id_in0 = 0; constexpr uint32_t cb_id_in2 = 2; // Sharded cb @@ -112,7 +126,11 @@ void kernel_main() { for (uint32_t b = 0; b < batch; ++b) { for (uint32_t block = 0; block < num_blocks; ++block) { - const uint32_t block_id = block / num_blocks_per_shard; + uint32_t block_id = block / num_blocks_per_shard; + if constexpr (fuse_op) { // If used fused op, make block_id conform to ordering of tensor slices from all gather + block_id = fused_op_receiver.align_to_slice_and_sync(block, sender_id); + } + cb_reserve_back(cb_id_in0, in0_block_num_tiles); // All cores in receiver grid need to participate in receiving regardless if they produce output work or diff --git a/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in1_sender_writer_padding.cpp b/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in1_sender_writer_padding.cpp index c08e421d6bb..41d09cc92b4 100644 --- a/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in1_sender_writer_padding.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in1_sender_writer_padding.cpp @@ -6,33 +6,35 @@ #include "dataflow_api.h" #include "hostdevcommon/common_values.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_sync_utils.hpp" void kernel_main() { // READER + uint32_t rt_args_idx = 0; // in1 tensor args - const uint32_t in1_tensor_addr = get_arg_val(0); - uint32_t in1_tensor_start_tile_id = get_arg_val(1); + const uint32_t in1_tensor_addr = get_arg_val(rt_args_idx++); + uint32_t in1_tensor_start_tile_id = get_arg_val(rt_args_idx++); // in1 mcast args - const uint32_t in1_mcast_dest_noc_start_x = get_arg_val(2); - const uint32_t in1_mcast_dest_noc_start_y = get_arg_val(3); - const uint32_t in1_mcast_dest_noc_end_x = get_arg_val(4); - const uint32_t in1_mcast_dest_noc_end_y = get_arg_val(5); + const uint32_t in1_mcast_dest_noc_start_x = get_arg_val(rt_args_idx++); + const uint32_t in1_mcast_dest_noc_start_y = get_arg_val(rt_args_idx++); + const uint32_t in1_mcast_dest_noc_end_x = get_arg_val(rt_args_idx++); + const uint32_t in1_mcast_dest_noc_end_y = get_arg_val(rt_args_idx++); // WRITER // out tensor args - const uint32_t out_tensor_addr = get_arg_val(6); - uint32_t out_tensor_start_tile_id = get_arg_val(7); + const uint32_t out_tensor_addr = get_arg_val(rt_args_idx++); + uint32_t out_tensor_start_tile_id = get_arg_val(rt_args_idx++); // padding args (READER) - const uint32_t last_block_w = get_arg_val(8); + const uint32_t last_block_w = get_arg_val(rt_args_idx++); // padding args (WRITER) - const uint32_t out_num_nonzero_subblocks_h = get_arg_val(9); - const uint32_t out_last_subblock_h = get_arg_val(10); - const uint32_t padded_block_tiles_h_skip = get_arg_val(11); - const uint32_t out_num_nonzero_subblocks_w = get_arg_val(12); - const uint32_t out_last_subblock_w = get_arg_val(13); - const uint32_t padded_subblock_tiles_addr_skip = get_arg_val(14); - const uint32_t padded_block_tiles_w_skip = get_arg_val(15); + const uint32_t out_num_nonzero_subblocks_h = get_arg_val(rt_args_idx++); + const uint32_t out_last_subblock_h = get_arg_val(rt_args_idx++); + const uint32_t padded_block_tiles_h_skip = get_arg_val(rt_args_idx++); + const uint32_t out_num_nonzero_subblocks_w = get_arg_val(rt_args_idx++); + const uint32_t out_last_subblock_w = get_arg_val(rt_args_idx++); + const uint32_t padded_subblock_tiles_addr_skip = get_arg_val(rt_args_idx++); + const uint32_t padded_block_tiles_w_skip = get_arg_val(rt_args_idx++); // COMPILE TIME ARGS // interleaved accessor args @@ -76,8 +78,8 @@ void kernel_main() { #ifdef FUSE_BIAS // in3 mcast args - const uint32_t in3_tensor_addr = get_arg_val(16); - const uint32_t in3_tensor_start_tile_id = get_arg_val(17); + const uint32_t in3_tensor_addr = get_arg_val(rt_args_idx++); + const uint32_t in3_tensor_start_tile_id = get_arg_val(rt_args_idx++); constexpr bool in3_is_dram = get_compile_time_arg_val(24) == 1; constexpr uint32_t in3_tensor_stride_w = get_compile_time_arg_val(25); @@ -92,24 +94,40 @@ void kernel_main() { .bank_base_address = in3_tensor_addr, .page_size = bias_single_tile_size_bytes, .data_format = bias_data_format}; +#else + rt_args_idx += 2; // Skip over placeholders #endif + constexpr bool fuse_op = (bool)get_compile_time_arg_val(26); + + MatmulOpReceiver fused_op_receiver; + if constexpr(fuse_op) { + fused_op_receiver = MatmulOpReceiver( + false, /* wait_for_op_signal */ + rt_args_idx, + num_blocks, + in1_block_h /* tiles_per_block (in the same dimension */ + ); + } + + // RT and COMPILE TIME ARGS for DRAM sharded weights #ifdef IN1_DRAM_SHARDED - const uint32_t vc = get_arg_val(18); - const uint32_t num_dram_shards_to_read = get_arg_val(19); - const uint32_t dram_tensor_start_offset = get_arg_val(20); - tt_l1_ptr uint32_t* in1_block_w_dram_stride_bytes = (tt_l1_ptr uint32_t*)get_arg_addr(21); - tt_l1_ptr uint32_t* current_dram_bank_id = (tt_l1_ptr uint32_t*)get_arg_addr(22); - - constexpr uint32_t in1_dram_block_num_tiles = get_compile_time_arg_val(26); - constexpr uint32_t in1_block_w_dram_bytes = get_compile_time_arg_val(27); + const uint32_t vc = get_arg_val(rt_args_idx++); + const uint32_t num_dram_shards_to_read = get_arg_val(rt_args_idx++); + const uint32_t dram_tensor_start_offset = get_arg_val(rt_args_idx++); + tt_l1_ptr uint32_t* in1_block_w_dram_stride_bytes = (tt_l1_ptr uint32_t*)get_arg_addr(rt_args_idx++); + tt_l1_ptr uint32_t* current_dram_bank_id = (tt_l1_ptr uint32_t*)get_arg_addr(rt_args_idx++); + + constexpr uint32_t in1_dram_block_num_tiles = get_compile_time_arg_val(27); + constexpr uint32_t in1_block_w_dram_bytes = get_compile_time_arg_val(28); #endif constexpr uint32_t cb_id_in1 = 1; constexpr uint32_t in1_single_tile_size_bytes = get_tile_size(cb_id_in1); constexpr uint32_t in1_block_size_bytes = in1_block_num_tiles * in1_single_tile_size_bytes; + // READER #ifdef IN1_SHARDED cb_reserve_back(cb_id_in1, in1_block_num_tiles); @@ -166,6 +184,13 @@ void kernel_main() { #endif uint32_t in1_tensor_current_block_start_tile_id = in1_tensor_start_tile_id; for (uint32_t block = 0; block < num_blocks; ++block) { + if constexpr(fuse_op) { + fused_op_receiver.update_current_block_start_tile_id( + block, + in1_tensor_current_block_start_tile_id, + in1_tensor_start_tile_id + ); + } #ifdef IN1_DRAM_SHARDED // Operand 1 cb_reserve_back(cb_id_in1, in1_block_num_tiles); diff --git a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.hpp b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.hpp index 29758afe1c4..0247bf0d3c8 100644 --- a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.hpp +++ b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.hpp @@ -12,6 +12,8 @@ #include "ttnn/tensor/tensor_utils.hpp" #include "ttnn/types.hpp" +#include "ttnn/operations/ccl/ccl_op_fusion.hpp" + namespace ttnn { namespace operations { @@ -186,6 +188,29 @@ Matmul create_matmul_struct( const struct Matmul ¶meters ); +operation::ProgramWithCallbacks matmul_multi_core_reuse_mcast_1d_optimized_helper( + tt::tt_metal::Program& program, + const Tensor &input_tensor_a, + const Tensor &input_tensor_b, + const std::optional bias, + Tensor &output_tensor, + bool bcast_batch, + DeviceComputeKernelConfig compute_kernel_config, + const MatmulProgramConfig program_config, + bool untilize_out, + std::optional &fused_op_signaler); +operation::ProgramWithCallbacks matmul_multi_core_reuse_mcast_2d_optimized_helper( + tt::tt_metal::Program& program, + const Tensor &input_tensor_a, + const Tensor &input_tensor_b, + const std::optional bias, + Tensor &output_tensor, + bool bcast_batch, + DeviceComputeKernelConfig compute_kernel_config, + const MatmulProgramConfig program_config, + bool untilize_out, + std::optional &matmul_signal_info); + Tensor matmul( const Tensor &input_tensor_a, const Tensor &input_tensor_b, diff --git a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op_multi_core_reuse_mcast_1d_program_factory.cpp b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op_multi_core_reuse_mcast_1d_program_factory.cpp index 61fa10766a3..0c6296d3135 100644 --- a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op_multi_core_reuse_mcast_1d_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op_multi_core_reuse_mcast_1d_program_factory.cpp @@ -23,6 +23,7 @@ using ttnn::operations::unary::UnaryWithParam; namespace reuse_mcast_1d_optimized_helpers { operation::ProgramWithCallbacks create_program_mcast_in0( + tt_metal::Program& program, const Tensor& a, tt_metal::Device* device, MathFidelity math_fidelity, @@ -51,8 +52,9 @@ operation::ProgramWithCallbacks create_program_mcast_in0( tt::DataFormat output_data_format, bool in0_is_sharded, bool output_is_sharded, - bool untilize_out) { - tt_metal::Program program{}; + bool untilize_out, + std::optional &fused_op_signaler) { + bool fuse_op = fused_op_signaler.has_value(); uint32_t num_blocks = K / in0_block_w; // Only enable packer l1 accumulation when there are spills, otherwise @@ -260,6 +262,8 @@ operation::ProgramWithCallbacks create_program_mcast_in0( (std::uint32_t)B // batch }; } + in0_sender_compile_time_args.push_back((std::uint32_t)fuse_op); + std::vector in1_sender_writer_compile_time_args = { // interleaved accessor args (std::uint32_t)in1_is_dram, @@ -302,7 +306,13 @@ operation::ProgramWithCallbacks create_program_mcast_in0( if (bias_buffer != nullptr) { in1_sender_writer_compile_time_args.push_back((std::uint32_t)in3_is_dram); in1_sender_writer_compile_time_args.push_back((std::uint32_t)1); + } else { + in1_sender_writer_compile_time_args.push_back(0); // Placeholder; not used + in1_sender_writer_compile_time_args.push_back(0); // Placeholder; not used } + + in1_sender_writer_compile_time_args.push_back((std::uint32_t)fuse_op); + std::vector in0_receiver_compile_time_args = { // in0 block args (std::uint32_t)in0_block_w * per_core_M, // in0_block_num_tiles @@ -356,6 +366,11 @@ operation::ProgramWithCallbacks create_program_mcast_in0( tt_metal::NOC in0_noc = detail::GetPreferredNOCForDRAMWrite(device->arch()); tt_metal::NOC in1_noc = detail::GetPreferredNOCForDRAMRead(device->arch()); + if (fuse_op) { + // Create semaphores + fused_op_signaler->init_fused_op(program, device, in0_mcast_cores_with_work_and_in_receiver_grid); + } + auto mm_kernel_in0_mcast_cores_with_work_and_in_receiver_grid_id = tt_metal::CreateKernel( program, in0_is_sharded @@ -625,6 +640,10 @@ operation::ProgramWithCallbacks create_program_mcast_in0( mm_in0_sender_args.insert(mm_in0_sender_args.end(), in0_mcast_noc_y.begin(), in0_mcast_noc_y.end()); if (i < num_cores_with_work) { + if (fuse_op) { + fused_op_signaler->push_matmul_fused_op_rt_args(mm_in0_sender_args, false); + } + tt_metal::SetRuntimeArgs( program, mm_kernel_in0_mcast_cores_with_work_and_in_receiver_grid_id, @@ -659,6 +678,11 @@ operation::ProgramWithCallbacks create_program_mcast_in0( // padding args (std::uint32_t)per_core_M // last_block_h }; + + if (fuse_op) { + fused_op_signaler->push_matmul_fused_op_rt_args(mm_in0_sender_args, false); + } + tt_metal::SetRuntimeArgs( program, mm_kernel_in0_mcast_cores_with_work_and_in_receiver_grid_id, @@ -723,7 +747,15 @@ operation::ProgramWithCallbacks create_program_mcast_in0( mm_in1_sender_writer_args.push_back((std::uint32_t)bias_buffer->address()); mm_in1_sender_writer_args.push_back( (std::uint32_t)per_core_N * output_idx_x); // in3_tensor_start_tile_id + } else { // Placeholder args + mm_in1_sender_writer_args.push_back(0); + mm_in1_sender_writer_args.push_back(0); } + + if (fuse_op) { + fused_op_signaler->push_matmul_fused_op_rt_args(mm_in1_sender_writer_args, true); + } + tt_metal::SetRuntimeArgs( program, mm_kernel_in1_sender_writer_id, core, mm_in1_sender_writer_args); // RISCV_0_default } @@ -823,6 +855,8 @@ operation::ProgramWithCallbacks create_program_mcast_in1( bool untilize_out) { tt_metal::Program program{}; + bool fuse_op = false; + uint32_t num_blocks = K / in0_block_w; // Only enable packer l1 accumulation when there are num_blocks > 2, otherwise // unnecessary overhead for reconfigs are added. Last iteration of l1 accumulation @@ -949,6 +983,8 @@ operation::ProgramWithCallbacks create_program_mcast_in1( (std::uint32_t)M * K, // MtKt (std::uint32_t)B // batch }; + in0_sender_compile_time_args.push_back((std::uint32_t)fuse_op); + std::vector in1_sender_writer_compile_time_args = { // interleaved accessor args (std::uint32_t)in1_is_dram, @@ -991,7 +1027,13 @@ operation::ProgramWithCallbacks create_program_mcast_in1( if (bias_buffer != nullptr) { in1_sender_writer_compile_time_args.push_back((std::uint32_t)in3_is_dram); in1_sender_writer_compile_time_args.push_back((std::uint32_t)1); + } else { + in1_sender_writer_compile_time_args.push_back(0); // Placeholder; not used + in1_sender_writer_compile_time_args.push_back(0); // Placeholder; not used } + + in1_sender_writer_compile_time_args.push_back((std::uint32_t)fuse_op); + std::vector in1_receiver_writer_compile_time_args = { // interleaved accessor args (std::uint32_t)out_is_dram, @@ -1474,6 +1516,7 @@ namespace operations { namespace matmul { operation::ProgramWithCallbacks matmul_multi_core_reuse_mcast_1d_optimized_( + tt::tt_metal::Program& program, const Tensor& a, const Tensor& b, const std::optional bias, @@ -1489,7 +1532,8 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_mcast_1d_optimized_( bool fuse_batch, std::optional fused_activation, bool mcast_in0, - bool untilize_out) { + bool untilize_out, + std::optional &fused_op_signaler) { const auto &ashape = a.get_legacy_shape(), bshape = b.get_legacy_shape(); // CB dataformats @@ -1603,6 +1647,7 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_mcast_1d_optimized_( //////////////////////////////////////////////////////////////////////////// if (mcast_in0) { return reuse_mcast_1d_optimized_helpers::create_program_mcast_in0( + program, a, device, math_fidelity, @@ -1631,7 +1676,8 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_mcast_1d_optimized_( output_data_format, a.memory_config().is_sharded(), output.memory_config().is_sharded(), - untilize_out); + untilize_out, + fused_op_signaler); } else { return reuse_mcast_1d_optimized_helpers::create_program_mcast_in1( device, @@ -1682,7 +1728,12 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_mcast_1d_optimized( std::optional fused_activation, bool mcast_in0, bool untilize_out) { + + tt_metal::Program program{}; /* Create a program */ + std::optional empty_fused_op_signaler; + return matmul_multi_core_reuse_mcast_1d_optimized_( + program, a, b, bias, @@ -1698,7 +1749,43 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_mcast_1d_optimized( fuse_batch, fused_activation, mcast_in0, - untilize_out); + untilize_out, + empty_fused_op_signaler); +} + +operation::ProgramWithCallbacks matmul_multi_core_reuse_mcast_1d_optimized_helper( + tt::tt_metal::Program& program, + const Tensor& a, + const Tensor& b, + const std::optional bias, + Tensor& output_tensor, + bool broadcast_batch, + DeviceComputeKernelConfig compute_kernel_config, + const MatmulProgramConfig program_config, + bool untilize_out, + std::optional &fused_op_signaler) { + + MatmulMultiCoreReuseMultiCast1DProgramConfig config = std::get(program_config); + + return matmul_multi_core_reuse_mcast_1d_optimized_( + program, + a, + b, + bias, + output_tensor, + broadcast_batch, + config.compute_with_storage_grid_size, + compute_kernel_config, + config.in0_block_w, + config.out_subblock_h, + config.out_subblock_w, + config.per_core_M, + config.per_core_N, + config.fuse_batch, + config.fused_activation, + config.mcast_in0, + untilize_out, + fused_op_signaler); } } // namespace matmul diff --git a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op_multi_core_reuse_mcast_2d_program_factory.cpp b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op_multi_core_reuse_mcast_2d_program_factory.cpp index 3b46bebda5d..6606685711f 100644 --- a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op_multi_core_reuse_mcast_2d_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op_multi_core_reuse_mcast_2d_program_factory.cpp @@ -24,6 +24,7 @@ using namespace tt; using namespace tt_metal; operation::ProgramWithCallbacks create_program_mcast_in0_in1( + tt_metal::Program& program, tt_metal::Device* device, MathFidelity math_fidelity, bool fp32_dest_acc_en, @@ -49,9 +50,11 @@ operation::ProgramWithCallbacks create_program_mcast_in0_in1( tt::DataFormat in1_data_format, tt::DataFormat bias_data_format, tt::DataFormat output_data_format, - bool untilize_out) { + bool untilize_out, + std::optional &fused_op_signaler) { + bool fuse_op = fused_op_signaler.has_value(); + TensorMemoryLayout in0_memory_layout = in0_buffer->buffer_layout(); - tt_metal::Program program{}; uint32_t num_blocks = K / in0_block_w; @@ -343,6 +346,8 @@ operation::ProgramWithCallbacks create_program_mcast_in0_in1( (std::uint32_t)B // batch }; } + in0_sender_compile_time_args.push_back((std::uint32_t)fuse_op); + std::vector in1_sender_writer_compile_time_args = { // interleaved accessor args (std::uint32_t)in1_is_dram, @@ -389,6 +394,9 @@ operation::ProgramWithCallbacks create_program_mcast_in0_in1( in1_sender_writer_compile_time_args.push_back(0); // Placeholder; not used in1_sender_writer_compile_time_args.push_back(0); // Placeholder; not used } + + in1_sender_writer_compile_time_args.push_back((std::uint32_t)fuse_op); + if (in1_is_sharded and in1_is_dram) { in1_sender_writer_compile_time_args.push_back((std::uint32_t)per_core_N_storage * in0_block_w); in1_sender_writer_compile_time_args.push_back((std::uint32_t)per_core_N_storage * in1_single_tile_size); @@ -524,6 +532,12 @@ operation::ProgramWithCallbacks create_program_mcast_in0_in1( .defines = mm_kernel_in0_sender_sharded_defines}); } } else { + + if (fuse_op) { + // Create semaphores + fused_op_signaler->init_fused_op(program, device, in0_sender_interleaved); + } + mm_kernel_in0_sender_id = tt_metal::CreateKernel( program, "ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in0_sender_padding.cpp", @@ -900,6 +914,10 @@ operation::ProgramWithCallbacks create_program_mcast_in0_in1( mm_in0_sender_args.push_back(per_core_M); } + if (fuse_op) { + fused_op_signaler->push_matmul_fused_op_rt_args(mm_in0_sender_args, false); + } + tt_metal::SetRuntimeArgs(program, mm_kernel_in0_sender_id, core, mm_in0_sender_args); // RISCV_0_default // in0 receiver @@ -1036,6 +1054,9 @@ operation::ProgramWithCallbacks create_program_mcast_in0_in1( } mm_in1_sender_writer_args.insert(mm_in1_sender_writer_args.begin() + 19, num_iter); } + if (fuse_op) { + fused_op_signaler->push_matmul_fused_op_rt_args(mm_in1_sender_writer_args, true); + } tt_metal::SetRuntimeArgs( program, mm_kernel_in1_sender_writer_id, core, mm_in1_sender_writer_args); // RISCV_1_default @@ -1198,6 +1219,7 @@ namespace operations { namespace matmul { operation::ProgramWithCallbacks matmul_multi_core_reuse_mcast_2d_optimized_( + tt::tt_metal::Program& program, const Tensor& a, const Tensor& b, const std::optional bias, @@ -1213,7 +1235,8 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_mcast_2d_optimized_( bool fuse_batch, bool transpose_mcast, std::optional fused_activation, - bool untilize_out) { + bool untilize_out, + std::optional &fused_op_signaler) { const auto &ashape = a.get_legacy_shape(), bshape = b.get_legacy_shape(); // CB dataformats @@ -1333,6 +1356,7 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_mcast_2d_optimized_( // Application Setup //////////////////////////////////////////////////////////////////////////// return reuse_mcast_optimized_helpers::create_program_mcast_in0_in1( + program, device, math_fidelity, fp32_dest_acc_en, @@ -1358,7 +1382,8 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_mcast_2d_optimized_( in1_data_format, bias_data_format, output_data_format, - untilize_out); + untilize_out, + fused_op_signaler); } operation::ProgramWithCallbacks matmul_multi_core_reuse_mcast_2d_optimized( @@ -1378,7 +1403,12 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_mcast_2d_optimized( bool transpose_mcast, std::optional fused_activation, bool untilize_out) { + + tt_metal::Program program{}; /* Create a program */ + std::optional empty_fused_op_signaler; + return matmul_multi_core_reuse_mcast_2d_optimized_( + program, a, b, bias, @@ -1394,7 +1424,43 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_mcast_2d_optimized( fuse_batch, transpose_mcast, fused_activation, - untilize_out); + untilize_out, + empty_fused_op_signaler); +} + +operation::ProgramWithCallbacks matmul_multi_core_reuse_mcast_2d_optimized_helper( + tt_metal::Program& program, /* Take programa as input by reference */ + const Tensor& a, + const Tensor& b, + const std::optional bias, + Tensor& output_tensor, + bool broadcast_batch, + DeviceComputeKernelConfig compute_kernel_config, + const MatmulProgramConfig program_config, + bool untilize_out, + std::optional &fused_op_signaler) { + + MatmulMultiCoreReuseMultiCastProgramConfig config = std::get(program_config); + + return matmul_multi_core_reuse_mcast_2d_optimized_( + program, + a, + b, + bias, + output_tensor, + broadcast_batch, + config.compute_with_storage_grid_size, + compute_kernel_config, + config.in0_block_w, + config.out_subblock_h, + config.out_subblock_w, + config.per_core_M, + config.per_core_N, + config.fuse_batch, + config.transpose_mcast, + config.fused_activation, + untilize_out, + fused_op_signaler); } } // namespace matmul