From 7764bf559a3d6bd303e0a9fce8509a5a122c1681 Mon Sep 17 00:00:00 2001 From: Sean Nijjar Date: Tue, 7 Jan 2025 21:22:36 -0500 Subject: [PATCH] Add noc read/write burst command support to CCL command kernel. Also add automated command lowering to these noc commands (#16461) Add two new commands to CCL command infrastructure: - noc read burst - noc write burst The program factory can specify bursts of noc reads and writes by specifying a base address and then sequences of src/dests and sizes (source and dest depend on if it is read or write). The new commands are only implemented in the existing reference kernel. However, future work will be to enable a dedicated kernel that only supports noc reads/writes that can be more easily optimized to reach peach utilization. Additionally, added an initial command lowering pass and enabled it in all-gather (conditionally), which lowers tensor slice commands to noc read/write burst command streams. This eliminates all tensor iteration and page address lookup overheads at runtime from within the kernel. Additionally, the lowering process is done by performing a single call in the program factory. When adopting this approach, runtime arg overrides change and the details about which runtime args require overriding become hidden from the user. To account for this, infrastructure was added to automatically track which runtime args require updates as the program is invoked over time (with different tensors). This also further starts to generalize command stream use and simplifies their adoption since the user doesn't need to manually carry state (implicitly or explicitly) through the program factory to manage runtime argument overrides on op reruns. Note that for the time being, the noc burst count is limited by runtime arg counts. This limitation will be lifted in the future. ### Ticket [Link to Github Issue](https://github.com/tenstorrent/tt-metal/issues/16395) --- .../operations/ccl/test_new_all_gather.py | 90 ++-- ttnn/cpp/ttnn/operations/ccl/CMakeLists.txt | 2 + .../ccl/common/host/ccl_worker_builder.cpp | 123 ++++-- .../ccl/common/host/ccl_worker_builder.hpp | 10 +- ...command_backend_runtime_args_overrider.cpp | 51 +++ ...command_backend_runtime_args_overrider.hpp | 68 +++ .../kernel_common/algorithms.hpp | 20 + .../fabric_connection_manager.hpp | 67 +++ .../kernel_common/io_descriptors.hpp | 20 + .../kernel_common/kernel_writers.hpp | 116 ++++++ .../kernel_common/noc_addr.hpp | 36 ++ .../kernels/ccl_send_reader_two_input.cpp | 388 ++++++++++-------- .../ccl/common/kernels/command_processor.hpp | 12 +- .../common/types/ccl_types_args_emitters.cpp | 2 +- .../common/types/ccl_types_args_emitters.hpp | 4 + .../ccl/common/uops/ccl_command.hpp | 111 +++-- .../ccl/common/uops/ccl_host_commands.cpp | 111 +++++ .../ccl/common/uops/ccl_host_commands.hpp | 33 ++ .../ccl/common/uops/command_lowering.cpp | 224 ++++++++++ .../ccl/common/uops/command_lowering.hpp | 22 + .../sharded_tensor_addr_gen.hpp | 13 +- .../device/all_gather_async_program.cpp | 67 ++- .../device/reduce_scatter_async_program.cpp | 13 +- 23 files changed, 1315 insertions(+), 288 deletions(-) create mode 100644 ttnn/cpp/ttnn/operations/ccl/common/host/command_backend_runtime_args_overrider.cpp create mode 100644 ttnn/cpp/ttnn/operations/ccl/common/host/command_backend_runtime_args_overrider.hpp create mode 100644 ttnn/cpp/ttnn/operations/ccl/common/interpreter_backends/kernel_common/algorithms.hpp create mode 100644 ttnn/cpp/ttnn/operations/ccl/common/interpreter_backends/kernel_common/fabric_connection_manager.hpp create mode 100644 ttnn/cpp/ttnn/operations/ccl/common/interpreter_backends/kernel_common/io_descriptors.hpp create mode 100644 ttnn/cpp/ttnn/operations/ccl/common/interpreter_backends/kernel_common/kernel_writers.hpp create mode 100644 ttnn/cpp/ttnn/operations/ccl/common/interpreter_backends/kernel_common/noc_addr.hpp create mode 100644 ttnn/cpp/ttnn/operations/ccl/common/uops/command_lowering.cpp create mode 100644 ttnn/cpp/ttnn/operations/ccl/common/uops/command_lowering.hpp diff --git a/tests/ttnn/unit_tests/operations/ccl/test_new_all_gather.py b/tests/ttnn/unit_tests/operations/ccl/test_new_all_gather.py index 83ada0b11a1..72304350cf1 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_new_all_gather.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_new_all_gather.py @@ -184,27 +184,34 @@ def run_all_gather_impl( output_mem_config = mem_config ### - if rand_tensor: - output_tensor = torch.rand(output_shape).bfloat16() - else: - output_tensor = torch.zeros(output_shape) - tile_id = 1 - for w in range(output_shape[0]): - for z in range(output_shape[1]): - for y in range(0, output_shape[2], 32): - for x in range(0, output_shape[3], 32): - output_tensor[w, z, y : y + 32, x : x + 32] = tile_id - tile_id += 1 - - input_tensors = torch.chunk(output_tensor, num_devices, dim) - tt_input_tensors = [] - for i, t in enumerate(input_tensors): - tt_input_tensors.append( - ttnn.Tensor(t, input_dtype).to(layout).to(mesh_device.get_devices()[i], input_mem_config) - ) - logger.info(f"using device {mesh_device.get_devices()[i].id()}") + input_tensor_mesh_list = [] + output_tensor_goldens_list = [] + + for i in range(num_iters): + if rand_tensor: + output_tensor = torch.rand(output_shape).bfloat16() + else: + output_tensor = torch.zeros(output_shape) + tile_id = 1 + for w in range(output_shape[0]): + for z in range(output_shape[1]): + for y in range(0, output_shape[2], 32): + for x in range(0, output_shape[3], 32): + output_tensor[w, z, y : y + 32, x : x + 32] = tile_id + tile_id += 1 + + output_tensor_goldens_list.append(output_tensor) + input_tensors = torch.chunk(output_tensor, num_devices, dim) + tt_input_tensors = [] + for i, t in enumerate(input_tensors): + tt_input_tensors.append( + ttnn.Tensor(t, input_dtype).to(layout).to(mesh_device.get_devices()[i], input_mem_config) + ) + logger.info(f"using device {mesh_device.get_devices()[i].id()}") - input_tensor_mesh = ttnn.aggregate_as_tensor(tt_input_tensors) + input_tensor_mesh = ttnn.aggregate_as_tensor(tt_input_tensors) + + input_tensor_mesh_list.append(input_tensor_mesh) compute_grid_size = mesh_device.compute_with_storage_grid_size() worker_sub_device = ttnn.SubDevice( @@ -220,22 +227,24 @@ def run_all_gather_impl( mesh_device, [worker_sub_device], 0, 0, enable_persistent_fabric ) + tt_out_tensor_list = [] if trace_mode: tt_out_tensor = run_with_trace( mesh_device, all_gather_topology, - input_tensor_mesh, + input_tensor_mesh_list[0], dim, num_links, output_mem_config, num_iter=num_iters, subdevice_id=worker_sub_device_id, ) + tt_out_tensor_list.append(tt_out_tensor) else: for i in range(num_iters): if use_cluster_axis_api: tt_out_tensor = ttnn.experimental.all_gather_async( - input_tensor_mesh, + input_tensor_mesh_list[i], dim, cluster_axis=cluster_axis, mesh_device=mesh_device, @@ -249,7 +258,7 @@ def run_all_gather_impl( else: tt_out_tensor = ttnn.experimental.all_gather_async( - input_tensor_mesh, + input_tensor_mesh_list[i], dim, num_links=num_links, memory_config=output_mem_config, @@ -257,6 +266,7 @@ def run_all_gather_impl( subdevice_id=worker_sub_device_id, enable_persistent_fabric_mode=enable_persistent_fabric, ) + tt_out_tensor_list.append(tt_out_tensor) logger.info(f"Waiting for op {i}") for d in mesh_device.get_devices(): @@ -266,17 +276,20 @@ def run_all_gather_impl( if enable_persistent_fabric and teardown_persistent_fabric: teardown_fabric_interface(mesh_device) - for i, t in enumerate(ttnn.get_device_tensors(tt_out_tensor)): - tt_output_tensor = t.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch() - logger.info(f"Checking for device {t.device().id()}") + for tensor_index in range(len(tt_out_tensor_list)): + tt_out_tensor = tt_out_tensor_list[tensor_index] + output_tensor = output_tensor_goldens_list[tensor_index] + for i, t in enumerate(ttnn.get_device_tensors(tt_out_tensor)): + tt_output_tensor = t.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch() + logger.info(f"Checking for device {t.device().id()}") - if input_dtype == ttnn.bfloat16: - eq, output = comp_equal(tt_output_tensor, output_tensor) - else: - eq, output = comp_pcc(tt_output_tensor, output_tensor) - if not eq: - logger.error(f"output mismatch for tensor {i}") - assert eq, f"{i} FAILED: {output}" + if input_dtype == ttnn.bfloat16: + eq, output = comp_equal(tt_output_tensor, output_tensor) + else: + eq, output = comp_pcc(tt_output_tensor, output_tensor) + if not eq: + logger.error(f"output mismatch for tensor {i}") + assert eq, f"{i} FAILED: {output}" # Enumerate the post-commit cases explicitly @@ -386,6 +399,15 @@ def test_all_gather( ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(0, 3))}), ttnn.TensorMemoryLayout.HEIGHT_SHARDED, ), + ( + 4, + [1, 4, 32, 1280], + 3, + ttnn.TILE_LAYOUT, + (32, 128), + ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(1, 4))}), + ttnn.TensorMemoryLayout.HEIGHT_SHARDED, + ), ], ) @pytest.mark.parametrize("num_links", [1]) @@ -431,6 +453,8 @@ def test_all_gather_sharded( input_shard_shape=input_shard_shape, shard_grid=shard_grid, tensor_mem_layout=tensor_mem_layout, + create_persistent_fabric=True, + teardown_persistent_fabric=True, ) diff --git a/ttnn/cpp/ttnn/operations/ccl/CMakeLists.txt b/ttnn/cpp/ttnn/operations/ccl/CMakeLists.txt index 37b365f248a..4b1ecea7020 100644 --- a/ttnn/cpp/ttnn/operations/ccl/CMakeLists.txt +++ b/ttnn/cpp/ttnn/operations/ccl/CMakeLists.txt @@ -5,7 +5,9 @@ set(CCL_TTNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/ccl_common.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ccl_host_datastructures.cpp ${CMAKE_CURRENT_SOURCE_DIR}/common/types/ccl_types_args_emitters.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/common/host/command_backend_runtime_args_overrider.cpp ${CMAKE_CURRENT_SOURCE_DIR}/common/uops/ccl_command.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/common/uops/command_lowering.cpp ${CMAKE_CURRENT_SOURCE_DIR}/common/uops/ccl_host_commands.cpp ${CMAKE_CURRENT_SOURCE_DIR}/common/host/ccl_worker_builder.cpp ${CMAKE_CURRENT_SOURCE_DIR}/common/host/ccl_command_stream_builders.cpp diff --git a/ttnn/cpp/ttnn/operations/ccl/common/host/ccl_worker_builder.cpp b/ttnn/cpp/ttnn/operations/ccl/common/host/ccl_worker_builder.cpp index e09ddf81d93..5356defc888 100644 --- a/ttnn/cpp/ttnn/operations/ccl/common/host/ccl_worker_builder.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/common/host/ccl_worker_builder.cpp @@ -20,6 +20,7 @@ #include #include +#include namespace ttnn::ccl::worker_detail { @@ -571,6 +572,32 @@ size_t generate_ccl_core_descriptor_info_command_args( return num_ccl_command_args_added; } +static size_t generate_ccl_noc_transfer_burst_command_args( + const ttnn::ccl::cmd::HostCclCommandNocTransferBurst& noc_burst_descriptor, + size_t tensor_index, + ttnn::ccl::tensor_address_runtime_args_overrider &rt_args_overrider_out, + std::vector& args_out) { + ttnn::ccl::cmd::CclCommandArgHeader hdr; + hdr.code = ttnn::ccl::cmd::CclCommandArgCode::SET_NOC_TRANSFER_BURST_START_INFO; + TT_FATAL(noc_burst_descriptor.num_transfers_total > 0, "Internal Error. num_transfers_total uninitialized when generating runtime args for noc read/write commands"); + hdr.inline_value0 = noc_burst_descriptor.num_transfers_total; + // Bank base address must be set in the next arg since we may need the full 32-bit value + args_out.push_back(hdr.to_uint32()); + rt_args_overrider_out.add_runtime_arg_index(tensor_index, args_out.size()); + args_out.push_back(noc_burst_descriptor.bank_base_address); + + for (auto const& transfer_group : noc_burst_descriptor.transfer_burst_groupings) { + args_out.push_back(transfer_group.num_transfers_per_packet); + for (auto const& transfer : transfer_group.transfer_infos) { + args_out.push_back(transfer.noc_addr & 0xFFFFFFFF); + args_out.push_back(transfer.noc_addr >> 32); + args_out.push_back(transfer.noc_transfer_size_bytes); + } + } + + return 1; +} + void validate_ccl_command_dest_args(ttnn::ccl::cmd::CclCommandDestArgs const& dest_args) { bool valid = std::holds_alternative(dest_args) || std::holds_alternative(dest_args) || @@ -597,8 +624,14 @@ void validate_command(ttnn::ccl::cmd::CclHostLowLevelWorkerCommand const& comman void generate_ccl_command_stream_to_kernel_args( std::vector const& ccl_command_stream, + std::optional tensor_index, + std::optional> const& tensor_indices, + ttnn::ccl::tensor_address_runtime_args_overrider *rt_args_overrider_out, std::vector& rt_args_out) { std::optional last_tensor_slice = std::nullopt; + + bool fill_args_overrider = rt_args_overrider_out != nullptr; + TT_FATAL(!fill_args_overrider || tensor_index != std::nullopt, "Internal Error: When generating CCL command stream to kernel args, a runtime args overrider was provided but no tensor command index map was provided."); std::optional> last_src_addr_type = std::nullopt; std::optional> @@ -618,9 +651,30 @@ void generate_ccl_command_stream_to_kernel_args( static_assert(sizeof(ttnn::ccl::cmd::CclCommandHeader) == sizeof(uint32_t)); const size_t old_rt_args_start_index = rt_args_out.size(); rt_args_out.push_back(0); - // populate the body (ccl command args)of the command size_t num_ccl_command_args_added = 0; + + // populate the src_addr_type + num_ccl_command_args_added += generate_ccl_address_info_command_args( + last_src_addr_type, + {command.source_addr_type, command.source_addr_args}, + ttnn::ccl::cmd::SRC_DEST_TYPE::SRC, + rt_args_out); + last_src_addr_type = {command.source_addr_type, command.source_addr_args}; + + // populate the dest_addr_type + num_ccl_command_args_added += generate_ccl_address_info_command_args( + last_dest_addr_type, + {command.dest_addr_type, command.dest_addr_args}, + ttnn::ccl::cmd::SRC_DEST_TYPE::DEST, + rt_args_out); + last_dest_addr_type = {command.dest_addr_type, command.dest_addr_args}; + + // populate the core_desc_type + num_ccl_command_args_added += generate_ccl_core_descriptor_info_command_args( + last_core_descriptor, {command.core_desc_type, command.core_desc_args}, rt_args_out); + last_core_descriptor = {command.core_desc_type, command.core_desc_args}; + switch (command.command_code) { case ttnn::ccl::cmd::CclCommandCode::STREAM_CB_TO_TENSOR: case ttnn::ccl::cmd::CclCommandCode::STREAM_TENSOR_TO_CB: { @@ -645,6 +699,26 @@ void generate_ccl_command_stream_to_kernel_args( std::get(command.command_args), rt_args_out); break; + case ttnn::ccl::cmd::CclCommandCode::NOC_READ_BURST: + TT_FATAL(fill_args_overrider, "Internal Error: When generating noc read burst command args, an rt args override must be provided so that runtime args can be overridden on re-invocations of the owning operation"); + num_ccl_command_args_added += generate_ccl_noc_transfer_burst_command_args( + std::get(command.command_args), tensor_indices->at(tensor_index.value()), *rt_args_overrider_out, rt_args_out); + break; + + case ttnn::ccl::cmd::CclCommandCode::NOC_WRITE_BURST: + TT_FATAL(fill_args_overrider, "Internal Error: When generating noc write burst command args, an rt args override must be provided so that runtime args can be overridden on re-invocations of the owning operation"); + num_ccl_command_args_added += generate_ccl_noc_transfer_burst_command_args( + std::get(command.command_args), tensor_indices->at(tensor_index.value()), *rt_args_overrider_out, rt_args_out); + break; + + case ttnn::ccl::cmd::CclCommandCode::FLOW_CONTROLLED_NOC_READ_BURST: + TT_THROW("Command encoding support for CclCommandCode::FLOW_CONTROLLED_NOC_READ_BURST is unimplemented"); + break; + + case ttnn::ccl::cmd::CclCommandCode::NOC_WRITE_AND_ATOMIC_INC: + TT_THROW("Command encoding support for CclCommandCode::NOC_WRITE_AND_ATOMIC_INC is unimplemented"); + break; + case ttnn::ccl::cmd::CclCommandCode::STREAM_EDM_TO_TENSOR: TT_THROW( "CCL command STREAM_EDM_TO_TENSOR is not useable, supported, or intended to be supported in CCL " @@ -660,26 +734,6 @@ void generate_ccl_command_stream_to_kernel_args( break; } - // populate the src_addr_type - num_ccl_command_args_added += generate_ccl_address_info_command_args( - last_src_addr_type, - {command.source_addr_type, command.source_addr_args}, - ttnn::ccl::cmd::SRC_DEST_TYPE::SRC, - rt_args_out); - last_src_addr_type = {command.source_addr_type, command.source_addr_args}; - - // populate the dest_addr_type - num_ccl_command_args_added += generate_ccl_address_info_command_args( - last_dest_addr_type, - {command.dest_addr_type, command.dest_addr_args}, - ttnn::ccl::cmd::SRC_DEST_TYPE::DEST, - rt_args_out); - last_dest_addr_type = {command.dest_addr_type, command.dest_addr_args}; - // populate the core_desc_type - num_ccl_command_args_added += generate_ccl_core_descriptor_info_command_args( - last_core_descriptor, {command.core_desc_type, command.core_desc_args}, rt_args_out); - last_core_descriptor = {command.core_desc_type, command.core_desc_args}; - // populate the fabric_transfer_type // Handled by header log_trace( @@ -916,6 +970,9 @@ static void log_command_stream(ttnn::ccl::cmd::CclHostLowLevelCommandSequence co [&ss](CclCommandWaitValue const& a) { ss << fmt::format("(wait_value: {})", a.target_value); }, [&ss](CclCommandInlineReadWrite const& a) { ss << fmt::format("(value: {})", a.value); }, [&ss](CclCommandReadWrite const& a) { ss << fmt::format("(size_bytes: {})", a.size_bytes); }, + [&ss](HostCclCommandNocTransferBurst const& a) { + ss << fmt::format("(base_addr: {}, n_transfers: {})", a.bank_base_address, a.num_transfers_total); + }, [&ss](auto const&&) { ss << "ERROR"; }}, args); }; @@ -934,6 +991,7 @@ static void log_command_stream(ttnn::ccl::cmd::CclHostLowLevelCommandSequence co a.noc0_end_x, a.noc0_end_y); }, + [&ss](CclCommandCoreDescriptorTypeNone const& a) { ss << fmt::format("(None)"); }, }, args); }; @@ -999,7 +1057,22 @@ void generate_multi_input_command_stream_kernel_rt_args( std::optional const& ccl_command_stream1, std::optional const& forward_fabric_connections, std::optional const& backward_fabric_connections, - std::optional> const& tensor_device_override) { + std::optional> const& tensor_device_override, + std::optional> const& tensor_indices, + ttnn::ccl::tensor_address_runtime_args_overrider *rt_args_overrider) { + + bool fill_args_overrider = rt_args_overrider != nullptr; + + if (fill_args_overrider) { + TT_FATAL(tensor_indices.has_value(), "Internal Error. Tensor indices must be provided when using rt_args_overrider"); + TT_FATAL(tensor_indices.value().size() == tensors.size(), "Internal Error. Tensor indices must match the number of tensors"); + for (auto tensor_index : tensor_indices.value()) { + while (rt_args_overrider->size() <= tensor_index) { + rt_args_overrider->add_tensor(); + } + } + } + // TODO: see if we can pull the kernel defines to understand if we built the kernel in single command stream mode log_trace( tt::LogOp, @@ -1031,6 +1104,9 @@ void generate_multi_input_command_stream_kernel_rt_args( rt_args.reserve(100); for (size_t i = 0; i < tensors.size(); i++) { if (tensors[i]) { + if (fill_args_overrider) { + rt_args_overrider->add_runtime_arg_index(tensor_indices.value()[i], rt_args.size()); + } rt_args.push_back(tensors[i]->buffer()->address()); } else { // take up the rt arg with filler value in case user built a kernel across a core range @@ -1095,7 +1171,7 @@ void generate_multi_input_command_stream_kernel_rt_args( // Update the command stream start arg index argument to point to here (i.e. where // this command stream's commands will start) rt_args[command_stream_start_arg_indices[i]] = rt_args.size(); - generate_ccl_command_stream_to_kernel_args((*command_streams[i]), rt_args); + generate_ccl_command_stream_to_kernel_args((*command_streams[i]), i, tensor_indices, rt_args_overrider, rt_args); } log_trace(tt::LogOp, "\tMulti-input command processor RT Args"); @@ -1104,6 +1180,7 @@ void generate_multi_input_command_stream_kernel_rt_args( log_trace(tt::LogOp, "\t\t{}: {}", i, arg); } tt::tt_metal::SetRuntimeArgs(program, kernel_id, worker_core_range, rt_args); + } void generate_multi_command_stream_kernel_rt_args( diff --git a/ttnn/cpp/ttnn/operations/ccl/common/host/ccl_worker_builder.hpp b/ttnn/cpp/ttnn/operations/ccl/common/host/ccl_worker_builder.hpp index 79699816337..30d224f021f 100644 --- a/ttnn/cpp/ttnn/operations/ccl/common/host/ccl_worker_builder.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/common/host/ccl_worker_builder.hpp @@ -8,6 +8,7 @@ #include "ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp" #include "ttnn/operations/ccl/common/uops/ccl_command.hpp" #include "ttnn/operations/ccl/common/uops/ccl_host_commands.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/common/host/command_backend_runtime_args_overrider.hpp" #include #include @@ -56,7 +57,10 @@ void generate_ccl_cb_to_tensor_slice_sequence_commands( ttnn::ccl::cmd::CclCommandDestArgs const& dest_args); void generate_ccl_command_stream_to_kernel_args( std::vector const& ccl_command_stream, - std::vector& args_out); + std::optional tensor_index, + std::optional> const& tensor_indices, + ttnn::ccl::tensor_address_runtime_args_overrider *rt_args_overrider_out, + std::vector& rt_args_out); // TODO: eventually take a fabric handle void generate_multi_input_command_stream_kernel_rt_args( @@ -71,7 +75,9 @@ void generate_multi_input_command_stream_kernel_rt_args( std::optional> const& ccl_command_stream1, std::optional const& forward_fabric_connections, std::optional const& backward_fabric_connections, - std::optional> const& tensor_device_override = std::nullopt); + std::optional> const& tensor_device_override = std::nullopt, + std::optional> const& tensor_indices = std::nullopt, + ttnn::ccl::tensor_address_runtime_args_overrider *rt_args_overrider = nullptr); // Helper functions for building command processing datamovement kernels // TODO: Bundle into command bundle per command stream to cut down // on args and improve usability diff --git a/ttnn/cpp/ttnn/operations/ccl/common/host/command_backend_runtime_args_overrider.cpp b/ttnn/cpp/ttnn/operations/ccl/common/host/command_backend_runtime_args_overrider.cpp new file mode 100644 index 00000000000..b8b5fbc2392 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/common/host/command_backend_runtime_args_overrider.cpp @@ -0,0 +1,51 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "ttnn/cpp/ttnn/operations/ccl/common/host/command_backend_runtime_args_overrider.hpp" + +#include "tt_metal/impl/kernels/runtime_args_data.hpp" +#include "common/assert.hpp" + +namespace ttnn::ccl { +size_t tensor_address_runtime_args_overrider::add_tensor() { + size_t tensor_idx = tensor_address_runtime_arg_indices.size(); + tensor_address_runtime_arg_indices.push_back(std::vector()); + return tensor_idx; +} + +std::vector tensor_address_runtime_args_overrider::get_runtime_arg_indices(size_t tensor_idx) const { + TT_FATAL( + tensor_idx < tensor_address_runtime_arg_indices.size(), + "Internal Error. Invalid tensor index when getting runtime arg indices in " + "tensor_address_runtime_args_overrider"); + return tensor_address_runtime_arg_indices[tensor_idx]; +} + +void tensor_address_runtime_args_overrider::add_runtime_arg_index(size_t tensor_idx, size_t runtime_arg_index) { + TT_FATAL( + tensor_idx < tensor_address_runtime_arg_indices.size(), + "Invalid tensor index when adding runtime arg index. tensor_idx: {}, highest_available: {}", + tensor_idx, + tensor_address_runtime_arg_indices.size()); + tensor_address_runtime_arg_indices[tensor_idx].push_back(runtime_arg_index); +} + +void tensor_address_runtime_args_overrider::override_runtime_args( + size_t tensor_idx, uint32_t new_value, tt::tt_metal::RuntimeArgsData& runtime_args_to_modify) const { + TT_FATAL( + tensor_idx < tensor_address_runtime_arg_indices.size(), "Invalid tensor index when overriding runtime args"); + + const auto& indices = tensor_address_runtime_arg_indices[tensor_idx]; + TT_FATAL(!indices.empty(), "No runtime arg indices associated with tensor"); + + log_trace(tt::LogOp, "Overriding {} runtime args for tensor {} to value {}", indices.size(), tensor_idx, new_value); + for (size_t idx : indices) { + TT_FATAL(idx < runtime_args_to_modify.size(), "Runtime arg index out of bounds when overriding args"); + log_trace(tt::LogOp, "\t- {}", idx); + runtime_args_to_modify[idx] = new_value; + } +} + +size_t tensor_address_runtime_args_overrider::size() const { return tensor_address_runtime_arg_indices.size(); } +} // namespace ttnn::ccl diff --git a/ttnn/cpp/ttnn/operations/ccl/common/host/command_backend_runtime_args_overrider.hpp b/ttnn/cpp/ttnn/operations/ccl/common/host/command_backend_runtime_args_overrider.hpp new file mode 100644 index 00000000000..187d911235d --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/common/host/command_backend_runtime_args_overrider.hpp @@ -0,0 +1,68 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +namespace tt::tt_metal { +struct RuntimeArgsData; +} + +#include +#include +#include + +namespace ttnn::ccl { +/* + * Tracks tensors (by user defined index) and the runtime arg indices that are used to + * reference their address. + * + * Future work will let a user track runtime args by various attributes, such as shape, dtype, etc, + * -> whatever the user decides to track in runtime args that may change across invocations + */ +struct tensor_address_runtime_args_overrider { +public: + using runtime_args_t = std::vector; + + /* + * Add a tensor to be tracked by this overrider. + * + * @return: the index assigned to this tensor, for use in future lookups by user + */ + size_t add_tensor(); + + /* + * Add a runtime arg index to the tensor's runtime args + * + * @param tensor_idx: the index of the tensor to add the runtime arg index to, assigned by add_tensor() + * @param runtime_arg_index: the index of the runtime arg to add + */ + void add_runtime_arg_index(size_t tensor_idx, size_t runtime_arg_index); + + /* + * Get the runtime arg indices that are associated with the specified tensor\ + * + * @param tensor_idx: the index of the tensor to get the runtime args for, assigned by add_tensor() + * @return: the runtime args that are associated with the tensor + */ + std::vector get_runtime_arg_indices(size_t tensor_idx) const; + + /* + * Get the runtime arg indices that are used to reference the tensor's address + * + * @param tensor_idx: the index of the tensor to get the runtime args for, assigned by add_tensor() + * @param new_value: the new value to set the tensor's associated runtime args to + * @param runtime_args_to_modify: the runtime args to modify. These will be modified in place. + */ + void override_runtime_args( + size_t tensor_idx, uint32_t new_value, tt::tt_metal::RuntimeArgsData& runtime_args_to_modify) const; + + /* + * Get the number of tensors in the overrider + */ + size_t size() const; + +private: + std::vector> tensor_address_runtime_arg_indices; +}; +} // namespace ttnn::ccl diff --git a/ttnn/cpp/ttnn/operations/ccl/common/interpreter_backends/kernel_common/algorithms.hpp b/ttnn/cpp/ttnn/operations/ccl/common/interpreter_backends/kernel_common/algorithms.hpp new file mode 100644 index 00000000000..9f3e664640d --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/common/interpreter_backends/kernel_common/algorithms.hpp @@ -0,0 +1,20 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "ttnn/cpp/ttnn/operations/ccl/common/types/ccl_types.hpp" +#include + +inline size_t get_flat_index_from_shape( + const ttnn::ccl::Shape4D& shape, const ttnn::ccl::Shape4D& index) { + std::size_t offset = index.x; + std::size_t inner_volume = shape.x; + offset += index.y * inner_volume; + inner_volume *= shape.y; + offset += index.z * inner_volume; + inner_volume *= shape.z; + offset += index.w * inner_volume; + return offset; +} diff --git a/ttnn/cpp/ttnn/operations/ccl/common/interpreter_backends/kernel_common/fabric_connection_manager.hpp b/ttnn/cpp/ttnn/operations/ccl/common/interpreter_backends/kernel_common/fabric_connection_manager.hpp new file mode 100644 index 00000000000..3a4c40961e2 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/common/interpreter_backends/kernel_common/fabric_connection_manager.hpp @@ -0,0 +1,67 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +class FabricConnectionManager final { +public: + // return if there is/should be a connection - doesn't return whether or not the connection + // is actually live + inline bool is_logically_connected() const { return has_forward_connection() || has_backward_connection(); } + + // make the connection live + inline void open() { + if (has_forward_connection()) { + forward_fabric_sender.open(); + } + if (has_backward_connection()) { + backward_fabric_sender.open(); + } + } + inline bool has_forward_connection() const { return connection_flags & FORWARD_CONNECTION_FLAG_MASK; } + inline bool has_backward_connection() const { return connection_flags & BACKWARD_CONNECTION_FLAG_MASK; } + inline void close() { + if (has_forward_connection()) { + forward_fabric_sender.close(); + } + if (has_backward_connection()) { + backward_fabric_sender.close(); + } + } + + static FabricConnectionManager build_from_args(std::size_t& arg_idx) { + FabricConnectionManager connection_manager; + connection_manager.connection_flags = static_cast(get_arg_val(arg_idx++) != 0) + << FORWARD_CONNECTION_FLAG_OFFSET; + if (connection_manager.has_forward_connection()) { + connection_manager.forward_fabric_sender = + tt::fabric::WorkerToFabricEdmSender::build_from_args(arg_idx); + } + connection_manager.connection_flags |= static_cast(get_arg_val(arg_idx++) != 0) + << BACKWARD_CONNECTION_FLAG_OFFSET; + if (connection_manager.has_backward_connection()) { + connection_manager.backward_fabric_sender = + tt::fabric::WorkerToFabricEdmSender::build_from_args(arg_idx); + } + return connection_manager; + } + + tt::fabric::WorkerToFabricEdmSender& get_forward_connection() { + ASSERT(has_forward_connection()); + return forward_fabric_sender; + } + tt::fabric::WorkerToFabricEdmSender& get_backward_connection() { + ASSERT(has_backward_connection()); + return backward_fabric_sender; + } + +private: + static constexpr uint8_t FORWARD_CONNECTION_FLAG_MASK = 0x01; + static constexpr uint8_t BACKWARD_CONNECTION_FLAG_MASK = 0x02; + static constexpr uint8_t FORWARD_CONNECTION_FLAG_OFFSET = 0x0; + static constexpr uint8_t BACKWARD_CONNECTION_FLAG_OFFSET = 0x1; + tt::fabric::WorkerToFabricEdmSender forward_fabric_sender; + tt::fabric::WorkerToFabricEdmSender backward_fabric_sender; + uint8_t connection_flags; +}; diff --git a/ttnn/cpp/ttnn/operations/ccl/common/interpreter_backends/kernel_common/io_descriptors.hpp b/ttnn/cpp/ttnn/operations/ccl/common/interpreter_backends/kernel_common/io_descriptors.hpp new file mode 100644 index 00000000000..9e34b165d49 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/common/interpreter_backends/kernel_common/io_descriptors.hpp @@ -0,0 +1,20 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "ttnn/cpp/ttnn/operations/ccl/common/uops/ccl_command.hpp" + +#include + +struct address_info_t { + size_t address = 0; +}; + +struct core_descriptor_info_t { + union { + ttnn::ccl::cmd::CclCommandCoreDescriptorTypeNocXY noc_unicast; + ttnn::ccl::cmd::CclCommandCoreDescriptorTypeMcast noc_multicast; + } core_desc_args; +}; diff --git a/ttnn/cpp/ttnn/operations/ccl/common/interpreter_backends/kernel_common/kernel_writers.hpp b/ttnn/cpp/ttnn/operations/ccl/common/interpreter_backends/kernel_common/kernel_writers.hpp new file mode 100644 index 00000000000..967171b5d85 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/common/interpreter_backends/kernel_common/kernel_writers.hpp @@ -0,0 +1,116 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +// CCL Kernel common includes +#include "ttnn/cpp/ttnn/operations/ccl/common/interpreter_backends/kernel_common/fabric_connection_manager.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/common/interpreter_backends/kernel_common/noc_addr.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/common/interpreter_backends/kernel_common/command_interpreter_base.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/common/interpreter_backends/kernel_common/ccl_command_base.hpp" + +// Metal includes +#include "dataflow_api.h" + +// System includes +#include +#include "debug/dprint.h" + +template +FORCE_INLINE void write_and_advance_local_read_address_for_fabric_write( + uint64_t noc0_dest_noc_addr, + size_t packet_header_buffer_addr, + const CclCommandHeader& current_cmd_header, + FabricConnectionManager& fabric_connection, + size_t& l1_read_addr, + uint32_t payload_size_bytes) { + const auto [dest_noc_xy, dest_addr] = get_noc_address_components(noc0_dest_noc_addr); + const size_t payload_l1_address = l1_read_addr; + + auto pkt_hdr = reinterpret_cast(packet_header_buffer_addr); +#ifdef DEBUG_PRINT_ENABLED + pkt_hdr->reserved2 = my_chip_id; +#endif + + size_t packet_send_size_bytes = payload_size_bytes + sizeof(tt::fabric::PacketHeader); + pkt_hdr->to_write()->to_noc_unicast(tt::fabric::NocUnicastCommandHeader{ + dest_addr, packet_send_size_bytes, static_cast(dest_noc_xy.x), static_cast(dest_noc_xy.y)}); + + switch (current_cmd_header.dest_type) { + case ttnn::ccl::cmd::CclCommandDestType::CHIP_UNICAST: { + const auto& unicast_args = current_cmd_header.get_unicast_dest_args(); + auto& fabric_conn = unicast_args.is_forward_direction ? fabric_connection.get_forward_connection() + : fabric_connection.get_backward_connection(); + + pkt_hdr->to_chip_unicast(tt::fabric::UnicastRoutingCommandHeader{unicast_args.distance_in_hops}); + fabric_conn.wait_for_empty_write_slot(); + fabric_conn.send_payload_without_header_non_blocking_from_address(l1_read_addr, payload_size_bytes); + fabric_conn.send_payload_flush_blocking_from_address((uint32_t)pkt_hdr, sizeof(tt::fabric::PacketHeader)); + } break; + case ttnn::ccl::cmd::CclCommandDestType::CHIP_MULTICAST: { + noc_async_write( + payload_l1_address, safe_get_noc_addr(dest_noc_xy.x, dest_noc_xy.y, dest_addr), payload_size_bytes); + const auto& mcast_args = current_cmd_header.get_multicast_dest_args(); + if (fabric_connection.has_forward_connection()) { + pkt_hdr->to_chip_multicast(tt::fabric::MulticastRoutingCommandHeader{ + 1, static_cast(mcast_args.num_targets_forward_direction)}); + fabric_connection.get_forward_connection().wait_for_empty_write_slot(); + fabric_connection.get_forward_connection().send_payload_without_header_non_blocking_from_address( + l1_read_addr, payload_size_bytes); + fabric_connection.get_forward_connection().send_payload_flush_blocking_from_address( + (uint32_t)pkt_hdr, sizeof(tt::fabric::PacketHeader)); + } + + if (fabric_connection.has_backward_connection()) { + pkt_hdr->to_chip_multicast(tt::fabric::MulticastRoutingCommandHeader{ + 1, static_cast(mcast_args.num_targets_backward_direction)}); + fabric_connection.get_backward_connection().wait_for_empty_write_slot(); + fabric_connection.get_backward_connection().send_payload_without_header_non_blocking_from_address( + l1_read_addr, payload_size_bytes); + fabric_connection.get_backward_connection().send_payload_flush_blocking_from_address( + (uint32_t)pkt_hdr, sizeof(tt::fabric::PacketHeader)); + } + } break; + default: { + ASSERT(false); + } break; + } + + l1_read_addr += payload_size_bytes; +} + +template +FORCE_INLINE void write_payload_then_advance_read_address( + uint64_t noc0_dest_noc_addr, + size_t packet_header_buffer_addr, + const CclCommandHeader& current_cmd_header, + FabricConnectionManager& fabric_connection, + size_t& l1_read_addr, + size_t payload_size_bytes) { + static_assert( + ((sizeof(tt::fabric::PacketHeader) - 1) & sizeof(tt::fabric::PacketHeader)) == 0, + "sizeof(sizeof(tt::fabric::PacketHeader)) is not a power of two which violates the below assertion"); + + switch (current_cmd_header.dest_type) { + case ttnn::ccl::cmd::CclCommandDestType::CHIP_UNICAST: [[fallthrough]]; + case ttnn::ccl::cmd::CclCommandDestType::CHIP_MULTICAST: + write_and_advance_local_read_address_for_fabric_write( + noc0_dest_noc_addr, + packet_header_buffer_addr, + current_cmd_header, + fabric_connection, + l1_read_addr, + payload_size_bytes); + break; + + case ttnn::ccl::cmd::CclCommandDestType::CHIP_LOCAL_ONLY: { + const auto [dest_noc_xy, dest_addr] = get_noc_address_components(noc0_dest_noc_addr); + // Convert to our local noc_index based address + noc_async_write( + l1_read_addr, safe_get_noc_addr(dest_noc_xy.x, dest_noc_xy.y, dest_addr), payload_size_bytes); + l1_read_addr += payload_size_bytes; + } break; + } +} diff --git a/ttnn/cpp/ttnn/operations/ccl/common/interpreter_backends/kernel_common/noc_addr.hpp b/ttnn/cpp/ttnn/operations/ccl/common/interpreter_backends/kernel_common/noc_addr.hpp new file mode 100644 index 00000000000..e801b8dea1d --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/common/interpreter_backends/kernel_common/noc_addr.hpp @@ -0,0 +1,36 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" + +#include "dataflow_api.h" +#include + +// NOTE: This will eventually be updated with an official API +static constexpr size_t VIRTUAL_COORDS_START_X = 16; +static constexpr size_t VIRTUAL_COORDS_START_Y = 16; +FORCE_INLINE bool is_using_noc_coords(uint16_t noc_x, uint16_t noc_y) { + return noc_x < VIRTUAL_COORDS_START_X && noc_y < VIRTUAL_COORDS_START_Y; +} + +FORCE_INLINE uint64_t safe_get_noc_addr(uint8_t dest_noc_x, uint8_t dest_noc_y, uint32_t dest_bank_addr) { + bool using_noc_coords = is_using_noc_coords(dest_noc_x, dest_noc_y); + uint8_t noc_x = dest_noc_x; + uint8_t noc_y = dest_noc_y; + if (using_noc_coords) { + noc_x = NOC_X_PHYS_COORD(dest_noc_x); + noc_y = NOC_Y_PHYS_COORD(dest_noc_y); + } + return get_noc_addr(noc_x, noc_y, dest_bank_addr); +} +// TODO: COMMONIZE WITH THE ONE IN `ccl_send_writer.cpp` +FORCE_INLINE std::pair get_noc_address_components(uint64_t noc_addr) { + const size_t bank_addr = noc_addr & 0xFFFFFFFF; + const size_t noc_x = (noc_addr >> NOC_ADDR_LOCAL_BITS) & ((1 << NOC_ADDR_NODE_ID_BITS) - 1); + const size_t noc_y = + (noc_addr >> (NOC_ADDR_LOCAL_BITS + NOC_ADDR_NODE_ID_BITS)) & ((1 << NOC_ADDR_NODE_ID_BITS) - 1); + return {ttnn::ccl::WorkerXY(noc_x, noc_y), bank_addr}; +} diff --git a/ttnn/cpp/ttnn/operations/ccl/common/kernels/ccl_send_reader_two_input.cpp b/ttnn/cpp/ttnn/operations/ccl/common/kernels/ccl_send_reader_two_input.cpp index ca6e26a33e0..ebda3916524 100644 --- a/ttnn/cpp/ttnn/operations/ccl/common/kernels/ccl_send_reader_two_input.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/common/kernels/ccl_send_reader_two_input.cpp @@ -17,6 +17,9 @@ #include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/edm_fabric_worker_adapters.hpp" #include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/common/interpreter_backends/kernel_common/fabric_connection_manager.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/common/interpreter_backends/kernel_common/io_descriptors.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/common/interpreter_backends/kernel_common/noc_addr.hpp" #include "ttnn/cpp/ttnn/tensor/enum_types.hpp" #include #include @@ -31,6 +34,7 @@ struct no_addrgen {}; static constexpr size_t num_packet_headers_storable = 8; constexpr uint16_t my_chip_id = get_compile_time_arg_val(0); constexpr uint32_t reserved_packet_header_cb_id = get_compile_time_arg_val(1); + #ifdef NO_TENSOR_MODE constexpr TensorMemoryLayout tensor0_layout = TensorMemoryLayout::INTERLEAVED; constexpr BufferType buffer0_type = BufferType::DRAM; @@ -53,15 +57,6 @@ constexpr uint32_t cb1_id = get_compile_time_arg_val(9); #endif #endif -// TODO: COMMONIZE WITH THE ONE IN `ccl_send_writer.cpp` -FORCE_INLINE std::pair get_noc_address_components(uint64_t noc_addr) { - const size_t bank_addr = noc_addr & 0xFFFFFFFF; - const size_t noc_x = (noc_addr >> NOC_ADDR_LOCAL_BITS) & ((1 << NOC_ADDR_NODE_ID_BITS) - 1); - const size_t noc_y = - (noc_addr >> (NOC_ADDR_LOCAL_BITS + NOC_ADDR_NODE_ID_BITS)) & ((1 << NOC_ADDR_NODE_ID_BITS) - 1); - return {WorkerXY(noc_x, noc_y), bank_addr}; -} - struct sharded_addrgen_fields { bool is_sharded = false; uint8_t tensor_shard_grid_height = 0; @@ -98,10 +93,22 @@ constexpr sharded_addrgen_fields in0_sharded_addrgen_fields = { get_compile_time_arg_val(16) != 0}; #endif -static_assert(in0_sharded_addrgen_fields.tensor_shard_grid_height > 0, "Misconfigured sharded addrgen fields for tensor0. Field \"tensor_shard_grid_height\" was resolved to 0 but it must not be 0."); -static_assert(in0_sharded_addrgen_fields.tensor_shard_grid_width > 0, "Misconfigured sharded addrgen fields for tensor0. Field \"tensor_shard_grid_width\" was resolved to 0 but it must not be 0."); -static_assert(in0_sharded_addrgen_fields.tensor_shard_pages_per_shard_y > 0, "Misconfigured sharded addrgen fields for tensor0. Field \"tensor_shard_pages_per_shard_y\" was resolved to 0 but it must not be 0."); -static_assert(in0_sharded_addrgen_fields.tensor_shard_pages_per_shard_x > 0, "Misconfigured sharded addrgen fields for tensor0. Field \"tensor_shard_pages_per_shard_x\" was resolved to 0 but it must not be 0."); +static_assert( + in0_sharded_addrgen_fields.tensor_shard_grid_height > 0, + "Misconfigured sharded addrgen fields for tensor0. Field \"tensor_shard_grid_height\" was resolved to 0 but it " + "must not be 0."); +static_assert( + in0_sharded_addrgen_fields.tensor_shard_grid_width > 0, + "Misconfigured sharded addrgen fields for tensor0. Field \"tensor_shard_grid_width\" was resolved to 0 but it must " + "not be 0."); +static_assert( + in0_sharded_addrgen_fields.tensor_shard_pages_per_shard_y > 0, + "Misconfigured sharded addrgen fields for tensor0. Field \"tensor_shard_pages_per_shard_y\" was resolved to 0 but " + "it must not be 0."); +static_assert( + in0_sharded_addrgen_fields.tensor_shard_pages_per_shard_x > 0, + "Misconfigured sharded addrgen fields for tensor0. Field \"tensor_shard_pages_per_shard_x\" was resolved to 0 but " + "it must not be 0."); #else constexpr sharded_addrgen_fields in0_sharded_addrgen_fields = {false, 0, 0, 0, 0, 0, 0, 0}; #endif @@ -131,35 +138,27 @@ constexpr sharded_addrgen_fields in1_sharded_addrgen_fields = { get_compile_time_arg_val(16) != 0}; #endif -static_assert(in1_sharded_addrgen_fields.tensor_shard_grid_height > 0, "Misconfigured sharded addrgen fields for tensor1. Field \"tensor_shard_grid_height\" was resolved to 0 but it must not be 0."); -static_assert(in1_sharded_addrgen_fields.tensor_shard_grid_width > 0, "Misconfigured sharded addrgen fields for tensor1. Field \"tensor_shard_grid_width\" was resolved to 0 but it must not be 0."); -static_assert(in1_sharded_addrgen_fields.tensor_shard_pages_per_shard_y > 0, "Misconfigured sharded addrgen fields for tensor1. Field \"tensor_shard_pages_per_shard_y\" was resolved to 0 but it must not be 0."); -static_assert(in1_sharded_addrgen_fields.tensor_shard_pages_per_shard_x > 0, "Misconfigured sharded addrgen fields for tensor1. Field \"tensor_shard_pages_per_shard_x\" was resolved to 0 but it must not be 0."); +static_assert( + in1_sharded_addrgen_fields.tensor_shard_grid_height > 0, + "Misconfigured sharded addrgen fields for tensor1. Field \"tensor_shard_grid_height\" was resolved to 0 but it " + "must not be 0."); +static_assert( + in1_sharded_addrgen_fields.tensor_shard_grid_width > 0, + "Misconfigured sharded addrgen fields for tensor1. Field \"tensor_shard_grid_width\" was resolved to 0 but it must " + "not be 0."); +static_assert( + in1_sharded_addrgen_fields.tensor_shard_pages_per_shard_y > 0, + "Misconfigured sharded addrgen fields for tensor1. Field \"tensor_shard_pages_per_shard_y\" was resolved to 0 but " + "it must not be 0."); +static_assert( + in1_sharded_addrgen_fields.tensor_shard_pages_per_shard_x > 0, + "Misconfigured sharded addrgen fields for tensor1. Field \"tensor_shard_pages_per_shard_x\" was resolved to 0 but " + "it must not be 0."); #else constexpr sharded_addrgen_fields in1_sharded_addrgen_fields = {0, 0, 0, 0, 0, 0, 0, 0}; #endif #endif - -// NOTE: This will eventually be updated with an official API -static constexpr size_t VIRTUAL_COORDS_START_X = 16; -static constexpr size_t VIRTUAL_COORDS_START_Y = 16; -FORCE_INLINE bool is_using_noc_coords(uint16_t noc_x, uint16_t noc_y) { - return noc_x < VIRTUAL_COORDS_START_X && noc_y < VIRTUAL_COORDS_START_Y; -} - -FORCE_INLINE uint64_t safe_get_noc_addr(uint8_t dest_noc_x, uint8_t dest_noc_y, uint32_t dest_bank_addr) { - bool using_noc_coords = is_using_noc_coords(dest_noc_x, dest_noc_y); - uint8_t noc_x = dest_noc_x; - uint8_t noc_y = dest_noc_y; - if (using_noc_coords) { - noc_x = NOC_X_PHYS_COORD(dest_noc_x); - noc_y = NOC_Y_PHYS_COORD(dest_noc_y); - } - return get_noc_addr(noc_x, noc_y, dest_bank_addr); -} - - template < tt::tt_metal::TensorMemoryLayout tensor_layout, tt::tt_metal::BufferType buffer_type, @@ -168,7 +167,7 @@ FORCE_INLINE auto build_source_address_generator( std::size_t& arg_idx, address_t tensor_address, std::size_t page_size, - sharded_addrgen_fields const& tensor_sharded_addrgen_fields, + const sharded_addrgen_fields& tensor_sharded_addrgen_fields, uint32_t cb_id_in) -> typename source_tensor_addrgen::type { constexpr bool is_sharded = is_sharded_tensor_layout(tensor_layout); constexpr bool is_interleaved = tensor_layout == tt::tt_metal::TensorMemoryLayout::INTERLEAVED; @@ -234,92 +233,26 @@ struct remote_sem_change_context { using remote_sem_wait_context = remote_sem_change_context; using remote_atomic_inc_context = remote_sem_change_context; +struct noc_transfer_burst_context { + uint32_t bank_base_address = 0; + uint16_t num_transfers_total = 0; + uint16_t current_noc_transfer = 0; +}; union cmd_specific_context { wrapped_worker_slice_read_context wrapped_worker_slice_read_ctx; + noc_transfer_burst_context noc_transfer_burst_ctx; // sem wait and atomic inc inline_value_context inline_value_ctx; cmd_specific_context() {} }; -class FabricConnectionManager final { -public: - // return if there is/should be a connection - doesn't return whether or not the connection - // is actually live - bool is_logically_connected() const { return has_forward_connection() || has_backward_connection(); } - - // make the connection live - void open() { - if (has_forward_connection()) { - forward_fabric_sender.open(); - } - if (has_backward_connection()) { - backward_fabric_sender.open(); - } - } - bool has_forward_connection() const { return connection_flags & FORWARD_CONNECTION_FLAG_MASK; } - bool has_backward_connection() const { return connection_flags & BACKWARD_CONNECTION_FLAG_MASK; } - void close() { - if (has_forward_connection()) { - forward_fabric_sender.close(); - } - if (has_backward_connection()) { - backward_fabric_sender.close(); - } - } - - static FabricConnectionManager build_from_args(std::size_t& arg_idx) { - FabricConnectionManager connection_manager; - connection_manager.connection_flags = static_cast(get_arg_val(arg_idx++) != 0) - << FORWARD_CONNECTION_FLAG_OFFSET; - if (connection_manager.has_forward_connection()) { - connection_manager.forward_fabric_sender = - tt::fabric::WorkerToFabricEdmSender::build_from_args(arg_idx); - } - connection_manager.connection_flags |= static_cast(get_arg_val(arg_idx++) != 0) - << BACKWARD_CONNECTION_FLAG_OFFSET; - if (connection_manager.has_backward_connection()) { - connection_manager.backward_fabric_sender = - tt::fabric::WorkerToFabricEdmSender::build_from_args(arg_idx); - } - return connection_manager; - } - - tt::fabric::WorkerToFabricEdmSender& get_forward_connection() { - ASSERT(has_forward_connection()); - return forward_fabric_sender; - } - tt::fabric::WorkerToFabricEdmSender& get_backward_connection() { - ASSERT(has_backward_connection()); - return backward_fabric_sender; - } - -private: - static constexpr uint8_t FORWARD_CONNECTION_FLAG_MASK = 0x01; - static constexpr uint8_t BACKWARD_CONNECTION_FLAG_MASK = 0x02; - static constexpr uint8_t FORWARD_CONNECTION_FLAG_OFFSET = 0x0; - static constexpr uint8_t BACKWARD_CONNECTION_FLAG_OFFSET = 0x1; - tt::fabric::WorkerToFabricEdmSender forward_fabric_sender; - tt::fabric::WorkerToFabricEdmSender backward_fabric_sender; - uint8_t connection_flags; -}; - -struct address_info_t { - size_t address = 0; -}; - -struct core_descriptor_info_t { - union { - ttnn::ccl::cmd::CclCommandCoreDescriptorTypeNocXY noc_unicast; - ttnn::ccl::cmd::CclCommandCoreDescriptorTypeMcast noc_multicast; - } core_desc_args; -}; template struct command_context_t; template void update_ccl_command( - arg_idx_t& arg_idx, command_context_t& cmd_ctx, ttnn::ccl::cmd::CclCommandHeader const& cmd_header); + arg_idx_t& arg_idx, command_context_t& cmd_ctx, const ttnn::ccl::cmd::CclCommandHeader& cmd_header); template struct command_context_t final { @@ -397,11 +330,11 @@ struct command_context_t final { case ttnn::ccl::cmd::CclCommandCode::STREAM_CB_TO_TENSOR: case ttnn::ccl::cmd::CclCommandCode::STREAM_TENSOR_TO_EDM: { #ifndef NO_TENSOR_MODE - shape_t const worker_start_offset_global = v2::worker_wrapped_offset_to_coord( + const shape_t worker_start_offset_global = v2::worker_wrapped_offset_to_coord( command_tensor.tensor_slice_shape, command_tensor.worker_start_offset_in_slice); - shape_t const global_offset = command_tensor.tensor_slice_offset + worker_start_offset_global; + const shape_t global_offset = command_tensor.tensor_slice_offset + worker_start_offset_global; - size_t const curr_tile_id = get_flat_index_from_shape(command_tensor.tensor_shape, global_offset); + const size_t curr_tile_id = get_flat_index_from_shape(command_tensor.tensor_shape, global_offset); cmd_specific_ctx.wrapped_worker_slice_read_ctx = wrapped_worker_slice_read_context{curr_tile_id}; #endif } break; @@ -415,7 +348,7 @@ struct command_context_t final { template void update_ccl_command( - arg_idx_t& arg_idx, command_context_t& cmd_ctx, ttnn::ccl::cmd::CclCommandHeader const& cmd_header) { + arg_idx_t& arg_idx, command_context_t& cmd_ctx, const ttnn::ccl::cmd::CclCommandHeader& cmd_header) { using namespace ttnn::ccl::cmd; arg_idx_t arg_idx_old = arg_idx; @@ -497,6 +430,7 @@ void update_ccl_command( cmd_ctx.core_desc_type = static_cast(command_arg_header.inline_value0); switch (cmd_ctx.core_desc_type) { + case ttnn::ccl::cmd::CclCommandCoreDescriptorType::NONE: case ttnn::ccl::cmd::CclCommandCoreDescriptorType::ADDRGEN: case ttnn::ccl::cmd::CclCommandCoreDescriptorType::LOCAL: break; case ttnn::ccl::cmd::CclCommandCoreDescriptorType::NOC_XY: @@ -514,6 +448,12 @@ void update_ccl_command( } break; + case CclCommandArgCode::SET_NOC_TRANSFER_BURST_START_INFO: + cmd_ctx.cmd_specific_ctx.noc_transfer_burst_ctx.num_transfers_total = command_arg_header.inline_value0; + cmd_ctx.cmd_specific_ctx.noc_transfer_burst_ctx.bank_base_address = get_arg_val(arg_idx++); + cmd_ctx.cmd_specific_ctx.noc_transfer_burst_ctx.current_noc_transfer = 0; + break; + default: { ASSERT(false); } @@ -554,9 +494,9 @@ FORCE_INLINE void try_advance_inline_write_or_atomic_inc(command_context_tto_write(); } - #ifdef DEBUG_PRINT_ENABLED +#ifdef DEBUG_PRINT_ENABLED pkt_hdr->reserved2 = my_chip_id; - #endif +#endif pkt_hdr->to_noc_unicast_atomic_inc(tt::fabric::NocUnicastAtomicIncCommandHeader{ dest_bank_addr, static_cast(value), @@ -639,8 +579,7 @@ FORCE_INLINE void try_advance_read_tensor_to_cb(command_context_t& cmd_ uint32_t l1_write_addr = l1_write_addr_base; for (uint16_t i = 0; i < max_pages_readable; i += contig_pages_advanced) { - DPRINT << "t_id: " << (uint32_t)cmd_specific_ctx.curr_tile_id << "\n"; - auto const [noc_addr, contig_pages_] = get_noc_addr_and_contiguous_pages( + const auto [noc_addr, contig_pages_] = get_noc_addr_and_contiguous_pages( cmd_specific_ctx.curr_tile_id, cmd_specific_ctx.offset_into_worker_slice, cmd_ctx.command_tensor.worker_start_offset_in_slice, @@ -656,7 +595,6 @@ FORCE_INLINE void try_advance_read_tensor_to_cb(command_context_t& cmd_ } l1_write_addr += cmd_ctx.page_size * contig_pages_advanced; - bool done_worker_slice = ttnn::ccl::v2::advance_worker_global_page( cmd_specific_ctx.curr_tile_id, // Updated internally cmd_specific_ctx.offset_into_worker_slice, @@ -674,91 +612,98 @@ FORCE_INLINE void try_advance_read_tensor_to_cb(command_context_t& cmd_ } #endif -template FORCE_INLINE void write_and_advance_local_read_address_for_fabric_write( uint64_t noc0_dest_noc_addr, - command_context_t& cmd_ctx, - wrapped_worker_slice_read_context& cmd_specific_ctx, + size_t packet_header_buffer_addr, + const ttnn::ccl::cmd::CclCommandHeader& current_cmd_header, + FabricConnectionManager& fabric_connection, size_t& l1_read_addr, - uint16_t contig_pages_advanced) { - // All fabric writes have noc0 coordinates specified in the header. Therefore, we need to regenerate the noc - // address noc_index coordinates - const size_t payload_size_bytes = static_cast(contig_pages_advanced) * cmd_ctx.page_size; + uint32_t payload_size_bytes) { const auto [dest_noc_xy, dest_addr] = get_noc_address_components(noc0_dest_noc_addr); const size_t payload_l1_address = l1_read_addr; - auto pkt_hdr = reinterpret_cast(cmd_ctx.packet_header_buffer_addr); - #ifdef DEBUG_PRINT_ENABLED + + auto pkt_hdr = reinterpret_cast(packet_header_buffer_addr); +#ifdef DEBUG_PRINT_ENABLED pkt_hdr->reserved2 = my_chip_id; - #endif +#endif + size_t packet_send_size_bytes = payload_size_bytes + sizeof(tt::fabric::PacketHeader); pkt_hdr->to_write()->to_noc_unicast(tt::fabric::NocUnicastCommandHeader{ dest_addr, packet_send_size_bytes, static_cast(dest_noc_xy.x), static_cast(dest_noc_xy.y)}); - switch (cmd_ctx.current_cmd_header.dest_type) { + + switch (current_cmd_header.dest_type) { case ttnn::ccl::cmd::CclCommandDestType::CHIP_UNICAST: { - const auto& unicast_args = cmd_ctx.current_cmd_header.get_unicast_dest_args(); - auto& fabric_connection = unicast_args.is_forward_direction - ? cmd_ctx.fabric_connection.get_forward_connection() - : cmd_ctx.fabric_connection.get_backward_connection(); + const auto& unicast_args = current_cmd_header.get_unicast_dest_args(); + auto& fabric_conn = unicast_args.is_forward_direction ? fabric_connection.get_forward_connection() + : fabric_connection.get_backward_connection(); pkt_hdr->to_chip_unicast(tt::fabric::UnicastRoutingCommandHeader{unicast_args.distance_in_hops}); - fabric_connection.wait_for_empty_write_slot(); - fabric_connection.send_payload_without_header_non_blocking_from_address(l1_read_addr, payload_size_bytes); - fabric_connection.send_payload_flush_blocking_from_address( - (uint32_t)pkt_hdr, sizeof(tt::fabric::PacketHeader)); + fabric_conn.wait_for_empty_write_slot(); + fabric_conn.send_payload_without_header_non_blocking_from_address(l1_read_addr, payload_size_bytes); + fabric_conn.send_payload_flush_blocking_from_address((uint32_t)pkt_hdr, sizeof(tt::fabric::PacketHeader)); } break; case ttnn::ccl::cmd::CclCommandDestType::CHIP_MULTICAST: { - noc_async_write(payload_l1_address, safe_get_noc_addr(dest_noc_xy.x, dest_noc_xy.y, dest_addr), payload_size_bytes); - const auto& mcast_args = cmd_ctx.current_cmd_header.get_multicast_dest_args(); - if (cmd_ctx.fabric_connection.has_forward_connection()) { + noc_async_write( + payload_l1_address, safe_get_noc_addr(dest_noc_xy.x, dest_noc_xy.y, dest_addr), payload_size_bytes); + const auto& mcast_args = current_cmd_header.get_multicast_dest_args(); + if (fabric_connection.has_forward_connection()) { pkt_hdr->to_chip_multicast(tt::fabric::MulticastRoutingCommandHeader{ 1, static_cast(mcast_args.num_targets_forward_direction)}); - cmd_ctx.fabric_connection.get_forward_connection().wait_for_empty_write_slot(); - cmd_ctx.fabric_connection.get_forward_connection().send_payload_without_header_non_blocking_from_address(l1_read_addr, payload_size_bytes); - cmd_ctx.fabric_connection.get_forward_connection().send_payload_flush_blocking_from_address( - (uint32_t)pkt_hdr, sizeof(tt::fabric::PacketHeader)); + fabric_connection.get_forward_connection().wait_for_empty_write_slot(); + fabric_connection.get_forward_connection().send_payload_without_header_non_blocking_from_address( + l1_read_addr, payload_size_bytes); + fabric_connection.get_forward_connection().send_payload_flush_blocking_from_address( + (uint32_t)pkt_hdr, sizeof(tt::fabric::PacketHeader)); } - // Write the mcast packet (backward) - if (cmd_ctx.fabric_connection.has_backward_connection()) { + if (fabric_connection.has_backward_connection()) { pkt_hdr->to_chip_multicast(tt::fabric::MulticastRoutingCommandHeader{ 1, static_cast(mcast_args.num_targets_backward_direction)}); - cmd_ctx.fabric_connection.get_backward_connection().wait_for_empty_write_slot(); - cmd_ctx.fabric_connection.get_backward_connection().send_payload_without_header_non_blocking_from_address(l1_read_addr, payload_size_bytes); - cmd_ctx.fabric_connection.get_backward_connection().send_payload_flush_blocking_from_address( - (uint32_t)pkt_hdr, sizeof(tt::fabric::PacketHeader)); + fabric_connection.get_backward_connection().wait_for_empty_write_slot(); + fabric_connection.get_backward_connection().send_payload_without_header_non_blocking_from_address( + l1_read_addr, payload_size_bytes); + fabric_connection.get_backward_connection().send_payload_flush_blocking_from_address( + (uint32_t)pkt_hdr, sizeof(tt::fabric::PacketHeader)); } } break; default: { + DPRINT << "default\n"; ASSERT(false); } break; } - // Don't advance (payload + header) because we want to make sure we keep sizeof(tt::fabric::PacketHeader) space - // that's safe to use, preceeding the next hypothetical packet in L1. l1_read_addr += payload_size_bytes; } -template FORCE_INLINE void write_payload_then_advance_read_address( uint64_t noc0_dest_noc_addr, - command_context_t& cmd_ctx, - wrapped_worker_slice_read_context& cmd_specific_ctx, + size_t packet_header_buffer_addr, + const ttnn::ccl::cmd::CclCommandHeader& current_cmd_header, + FabricConnectionManager& fabric_connection, size_t& l1_read_addr, - uint16_t contig_pages_advanced) { + size_t payload_size_bytes) { static_assert( ((sizeof(tt::fabric::PacketHeader) - 1) & sizeof(tt::fabric::PacketHeader)) == 0, "sizeof(sizeof(tt::fabric::PacketHeader)) is not a power of two which violates the below assertion"); - switch (cmd_ctx.current_cmd_header.dest_type) { + + switch (current_cmd_header.dest_type) { case ttnn::ccl::cmd::CclCommandDestType::CHIP_UNICAST: [[fallthrough]]; case ttnn::ccl::cmd::CclCommandDestType::CHIP_MULTICAST: write_and_advance_local_read_address_for_fabric_write( - noc0_dest_noc_addr, cmd_ctx, cmd_specific_ctx, l1_read_addr, contig_pages_advanced); + noc0_dest_noc_addr, + packet_header_buffer_addr, + current_cmd_header, + fabric_connection, + l1_read_addr, + payload_size_bytes); break; + case ttnn::ccl::cmd::CclCommandDestType::CHIP_LOCAL_ONLY: { - auto const [dest_noc_xy, dest_addr] = get_noc_address_components(noc0_dest_noc_addr); - // Conver to our local noc_index based address - noc_async_write(l1_read_addr, safe_get_noc_addr(dest_noc_xy.x, dest_noc_xy.y, dest_addr), cmd_ctx.page_size * contig_pages_advanced); - l1_read_addr += cmd_ctx.page_size * contig_pages_advanced; + const auto [dest_noc_xy, dest_addr] = get_noc_address_components(noc0_dest_noc_addr); + // Convert to our local noc_index based address + noc_async_write( + l1_read_addr, safe_get_noc_addr(dest_noc_xy.x, dest_noc_xy.y, dest_addr), payload_size_bytes); + l1_read_addr += payload_size_bytes; } break; } } @@ -794,8 +739,7 @@ FORCE_INLINE void try_advance_write_tensor_from_cb(command_context_t& c // However, if we're writing locally, then we need to actually write using `noc_index` based coordinates. // This can lead to a discrepancy, so to stay consistent, we always generate noc0 based addresses here // so we can reliably translate to `noc_index` based addresses writing locally, inside the write function - DPRINT << "t_id: " << (uint32_t)cmd_specific_ctx.curr_tile_id << "\n"; - auto const [noc0_dest_noc_addr, contig_pages_] = + const auto [noc0_dest_noc_addr, contig_pages_] = get_noc_addr_and_contiguous_pages_for_fabric_write( cmd_specific_ctx.curr_tile_id, cmd_specific_ctx.offset_into_worker_slice, @@ -806,7 +750,12 @@ FORCE_INLINE void try_advance_write_tensor_from_cb(command_context_t& c contig_pages_advanced = std::min(cmd_ctx.packet_size_in_pages - i, contig_pages_); write_payload_then_advance_read_address( - noc0_dest_noc_addr, cmd_ctx, cmd_specific_ctx, l1_read_addr, contig_pages_advanced); + noc0_dest_noc_addr, + cmd_ctx.packet_header_buffer_addr, + cmd_ctx.current_cmd_header, + cmd_ctx.fabric_connection, + l1_read_addr, + contig_pages_advanced * cmd_ctx.page_size); auto done_worker_slice = ttnn::ccl::v2::advance_worker_global_page( cmd_specific_ctx.curr_tile_id, // Updated internally @@ -824,6 +773,83 @@ FORCE_INLINE void try_advance_write_tensor_from_cb(command_context_t& c } #endif +static FORCE_INLINE ttnn::ccl::cmd::noc_transfer_info get_next_noc_transfer_in_burst(arg_idx_t& arg_idx) { + auto noc_yx_in_16bits_each = get_arg_val(arg_idx + 1); + noc_grid_index_t noc_x = static_cast(noc_yx_in_16bits_each & 0xFF); + noc_grid_index_t noc_y = static_cast((noc_yx_in_16bits_each >> 16) & 0xFF); + + uint32_t noc_transfer_size_bytes = get_arg_val(arg_idx + 2); + uint32_t bank_addr_offset = get_arg_val(arg_idx); + return {safe_get_noc_addr(noc_x, noc_y, bank_addr_offset), noc_transfer_size_bytes}; +} + +static FORCE_INLINE size_t get_args_consumed_by_noc_transfer_info_in_burst() { return 3; } + +FORCE_INLINE static ttnn::ccl::cmd::noc_transfer_info advance_to_next_noc_transaction_in_burst( + noc_transfer_burst_context& noc_burst_ctx, arg_idx_t& arg_idx) { + const auto noc_transfer_info = get_next_noc_transfer_in_burst(arg_idx); + arg_idx += get_args_consumed_by_noc_transfer_info_in_burst(); + + noc_burst_ctx.current_noc_transfer++; + return noc_transfer_info; +} + +FORCE_INLINE static void try_advance_noc_read_burst( + noc_transfer_burst_context& noc_burst_ctx, uint32_t cb_id, uint32_t packet_size_in_pages, arg_idx_t& arg_idx) { + if (!cb_pages_reservable_at_back(cb_id, packet_size_in_pages)) { + return; + } + + auto wrptr = get_write_ptr(cb_id); + ttnn::ccl::cmd::noc_transfer_info transfer_info; + size_t num_transfers_in_group = get_arg_val(arg_idx++); + for (size_t i = 0; i < num_transfers_in_group; i++) { + auto transfer_info = advance_to_next_noc_transaction_in_burst(noc_burst_ctx, arg_idx); + + // Add the offset to the base address tp resolve the full address + uint64_t src_noc_addr = noc_burst_ctx.bank_base_address + transfer_info.noc_addr; + + noc_async_read(src_noc_addr, wrptr, transfer_info.noc_transfer_size_bytes); + wrptr += transfer_info.noc_transfer_size_bytes; + } + ASSERT(noc_burst_ctx.current_noc_transfer <= noc_burst_ctx.num_transfers_total); + + noc_async_read_barrier(); + cb_push_back(cb_id, packet_size_in_pages); +} + +static void try_advance_noc_write_burst( + FabricConnectionManager& fabric_connection, + noc_transfer_burst_context& noc_burst_ctx, + uint32_t cb_id, + uint32_t packet_size_in_pages, + size_t packet_header_buffer_addr, + const ttnn::ccl::cmd::CclCommandHeader& current_cmd_header, + arg_idx_t& arg_idx) { + if (!cb_pages_available_at_front(cb_id, packet_size_in_pages)) { + return; + } + size_t cb_rdptr = get_read_ptr(cb_id); + size_t num_transfers_in_group = get_arg_val(arg_idx++); + for (size_t i = 0; i < num_transfers_in_group; i++) { + auto transfer_info = advance_to_next_noc_transaction_in_burst(noc_burst_ctx, arg_idx); + + // Add the offset to the base address tp resolve the full address + uint64_t dest_noc_addr = noc_burst_ctx.bank_base_address + transfer_info.noc_addr; + // Import from reference kernel + write_payload_then_advance_read_address( + dest_noc_addr, + packet_header_buffer_addr, + current_cmd_header, + fabric_connection, + cb_rdptr, + transfer_info.noc_transfer_size_bytes); + } + noc_async_writes_flushed(); + + cb_pop_front(cb_id, packet_size_in_pages); +} + template FORCE_INLINE void try_advance(command_context_t& cmd_ctx) { switch (cmd_ctx.current_cmd_header.code) { @@ -838,6 +864,25 @@ FORCE_INLINE void try_advance(command_context_t& cmd_ctx) { #endif break; + case ttnn::ccl::cmd::CclCommandCode::NOC_READ_BURST: + try_advance_noc_read_burst( + cmd_ctx.cmd_specific_ctx.noc_transfer_burst_ctx, + cmd_ctx.cb_id, + cmd_ctx.packet_size_in_pages, + cmd_ctx.arg_idx); + break; + + case ttnn::ccl::cmd::CclCommandCode::NOC_WRITE_BURST: + try_advance_noc_write_burst( + cmd_ctx.fabric_connection, + cmd_ctx.cmd_specific_ctx.noc_transfer_burst_ctx, + cmd_ctx.cb_id, + cmd_ctx.packet_size_in_pages, + cmd_ctx.packet_header_buffer_addr, + cmd_ctx.current_cmd_header, + cmd_ctx.arg_idx); + break; + case ttnn::ccl::cmd::CclCommandCode::ATOMIC_INC: [[fallthrough]]; case ttnn::ccl::cmd::CclCommandCode::RAW_INLINE_WRITE_BYTES: try_advance_inline_write_or_atomic_inc(cmd_ctx); @@ -872,6 +917,15 @@ FORCE_INLINE void try_advance(command_context_t& cmd_ctx) { cmd_ctx.complete_current_command(); } break; + + case ttnn::ccl::cmd::CclCommandCode::NOC_READ_BURST: [[fallthrough]]; + case ttnn::ccl::cmd::CclCommandCode::NOC_WRITE_BURST: + if (cmd_ctx.cmd_specific_ctx.noc_transfer_burst_ctx.current_noc_transfer == + cmd_ctx.cmd_specific_ctx.noc_transfer_burst_ctx.num_transfers_total) { + DPRINT << "noc_burst cmd cmpl\n"; + cmd_ctx.complete_current_command(); + } + break; default: ASSERT(false); break; }; } diff --git a/ttnn/cpp/ttnn/operations/ccl/common/kernels/command_processor.hpp b/ttnn/cpp/ttnn/operations/ccl/common/kernels/command_processor.hpp index 63596906d20..e37f4d1945d 100644 --- a/ttnn/cpp/ttnn/operations/ccl/common/kernels/command_processor.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/common/kernels/command_processor.hpp @@ -14,6 +14,7 @@ #include "dataflow_api.h" // for interleaved addrgen #include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/sharded_tensor_addr_gen.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/common/interpreter_backends/kernel_common/algorithms.hpp" using shape_t = ttnn::ccl::Shape4D; using ttnn::ccl::coord_t; @@ -66,16 +67,7 @@ FORCE_INLINE shape_t worker_wrapped_offset_to_coord(shape_t const& slice_shape, return shape_t(0, 0, y, worker_slice_offset.x - (y * slice_shape.x)); } -FORCE_INLINE std::size_t get_flat_index_from_shape(const Shape4D& shape, const Shape4D& index) { - std::size_t offset = index.x; - std::size_t inner_volume = shape.x; - offset += index.y * inner_volume; - inner_volume *= shape.y; - offset += index.z * inner_volume; - inner_volume *= shape.z; - offset += index.w * inner_volume; - return offset; -} + namespace v2 { /* diff --git a/ttnn/cpp/ttnn/operations/ccl/common/types/ccl_types_args_emitters.cpp b/ttnn/cpp/ttnn/operations/ccl/common/types/ccl_types_args_emitters.cpp index aaf889d66be..86a41cf65b2 100644 --- a/ttnn/cpp/ttnn/operations/ccl/common/types/ccl_types_args_emitters.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/common/types/ccl_types_args_emitters.cpp @@ -72,7 +72,7 @@ args_list_t emit_address_generator_compile_time_args(tt::tt_metal::Tensor const& TT_ASSERT(false); } -static std::pair shard_grid_from_shard_spec(const ShardSpec& shard_spec) { +std::pair shard_grid_from_shard_spec(const ShardSpec& shard_spec) { auto const& core_range = shard_spec.grid.bounding_box(); log_trace( tt::LogOp, diff --git a/ttnn/cpp/ttnn/operations/ccl/common/types/ccl_types_args_emitters.hpp b/ttnn/cpp/ttnn/operations/ccl/common/types/ccl_types_args_emitters.hpp index af6e2558a4d..5cd6a2124d6 100644 --- a/ttnn/cpp/ttnn/operations/ccl/common/types/ccl_types_args_emitters.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/common/types/ccl_types_args_emitters.hpp @@ -3,6 +3,7 @@ // SPDX-License-Identifier: Apache-2.0 #pragma once +#include "common/core_coord.hpp" #include "ttnn/cpp/ttnn/operations/ccl/common/types/ccl_types.hpp" #include @@ -11,6 +12,7 @@ namespace tt { namespace tt_metal { class Tensor; +class ShardSpec; inline namespace v0 { class Device; @@ -49,6 +51,8 @@ args_list_t emit_address_generator_runtime_args( tt::tt_metal::Device const* const d, tt::tt_metal::Tensor const& tensor); args_list_t emit_address_generator_compile_time_args(tt::tt_metal::Tensor const& tensor); +std::pair shard_grid_from_shard_spec(const tt::tt_metal::ShardSpec& shard_spec); + struct ShardedAddrGenArgBuilder { static bool shard_grid_is_transposed(tt::tt_metal::Tensor const& t); static std::vector emit_ct_args(tt::tt_metal::Tensor const& t); diff --git a/ttnn/cpp/ttnn/operations/ccl/common/uops/ccl_command.hpp b/ttnn/cpp/ttnn/operations/ccl/common/uops/ccl_command.hpp index 525bb7e5e77..6ccfa7339b8 100644 --- a/ttnn/cpp/ttnn/operations/ccl/common/uops/ccl_command.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/common/uops/ccl_command.hpp @@ -4,13 +4,14 @@ #pragma once +#include + #include #include #include #include #include "ttnn/cpp/ttnn/operations/ccl/common/types/ccl_types.hpp" - // For command dest type #include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header.hpp" @@ -60,6 +61,32 @@ struct CclCommandAtomicInc { uint32_t value = 1; uint32_t wrap_value = std::numeric_limits::max(); }; + +struct noc_transfer_info { + // When encoded into a command, the noc address contains the relative offset of the + // read/write from the base address, which must be stored elsewhere by the command + // interpreter/command stream. The base address is specified in the command + uint64_t noc_addr = 0; + size_t noc_transfer_size_bytes = 0; +}; + +struct HostNocTransferBurstGrouping { + size_t num_transfers_per_packet = 0; + std::vector transfer_infos; +}; +struct HostCclCommandNocTransferBurst { + size_t bank_base_address = 0; + uint32_t num_transfers_total = 0; + std::vector transfer_burst_groupings; +}; +struct DeviceCclCommandNocTransferBurst { + size_t bank_base_address = 0; + uint32_t num_transfers_total = 0; + + // Populated as the burst is being completed and command args are being decoded + uint8_t num_transfers_per_packet = 0; +}; + struct CclCommandInlineReadWrite { uint32_t value = 0; }; @@ -71,7 +98,8 @@ using CclCommandArgs = std::variant< CclCommandWaitValue, CclCommandAtomicInc, CclCommandInlineReadWrite, - CclCommandReadWrite>; + CclCommandReadWrite, + HostCclCommandNocTransferBurst>; enum SRC_DEST_TYPE : uint8_t { SRC = 0, DEST = 1 }; @@ -95,6 +123,16 @@ enum class CclCommandArgCode : uint8_t { // core descriptor commands SET_CORE_DESCRIPTOR_INFO = 9, + // Specifies how many noc transfers are specified in the + // noc transaction burst command + SET_NOC_TRANSFER_BURST_START_INFO = 10, + + // Specifies how many noc transfers are expected to be performed back to back + // and packed into a single (CB) packet (though conceivably we could pack this) + // into a common ethernet packet too for better utilization (provided that the) + // receiver does the proper unpacking + SET_NOC_TRANSFER_BURST_SIZE_PER_PACKET = 11, + INVALID = std::numeric_limits::max(), }; @@ -171,6 +209,14 @@ template <> struct command_arg_field { using type = uint32_t; }; +template <> +struct command_arg_field { + using type = DeviceCclCommandNocTransferBurst; +}; +template <> +struct command_arg_field { + using type = uint8_t; +}; template struct CclCommandArg {}; @@ -456,10 +502,14 @@ enum class CclCommandCoreDescriptorType : uint8_t { ADDRGEN = 0, LOCAL = 1, NOC_XY = 2, - RECTANGLE = 3 + RECTANGLE = 3, + + // used for noc bursts since the core is embedded in the noc command + NONE = 4 // Future types may include: list, rectangle_list, etc. }; struct CclCommandCoreDescriptorTypeAddrgen {}; +struct CclCommandCoreDescriptorTypeNone {}; struct CclCommandCoreDescriptorTypeLocal {}; struct CclCommandCoreDescriptorTypeNocXY { uint8_t x; @@ -493,7 +543,8 @@ using CclCommandCoreDescriptorArgs = std::variant< CclCommandCoreDescriptorTypeAddrgen, CclCommandCoreDescriptorTypeLocal, CclCommandCoreDescriptorTypeNocXY, - CclCommandCoreDescriptorTypeMcast>; + CclCommandCoreDescriptorTypeMcast, + CclCommandCoreDescriptorTypeNone>; // A command is composed of one or more arguments // This enum specifies the high level command @@ -512,14 +563,18 @@ enum class CclCommandCode : uint8_t { RAW_INLINE_WRITE_BYTES = 5, - // Behaviour reading/writing to CBs still a little unclear - // This mode isn't actually supported yet - RAW_READ_BYTES = 6, - RAW_WRITE_BYTES = 7, + NOC_READ_BURST = 6, + NOC_WRITE_BURST = 7, + + // Waits on semaphore values before performing reads. Every read waits for the target semaphore + // value before reading + FLOW_CONTROLLED_NOC_READ_BURST = 8, + NOC_WRITE_AND_ATOMIC_INC = 9, - INVALID = 8 + INVALID = 10 }; + enum CclCommandDestType : uint8_t { CHIP_UNICAST = tt::fabric::CHIP_UNICAST, CHIP_MULTICAST = tt::fabric::CHIP_MULTICAST, @@ -542,7 +597,6 @@ using LocalOnlyCommandDestArgs = DestTypeArgsNull; // Used only for host code paths using CclCommandDestArgs = std::variant; -namespace v2 {}; struct CclCommandHeader { CclCommandCode code : 6; @@ -559,8 +613,9 @@ struct CclCommandHeader { LocalOnlyCommandDestArgs local_only; } command_dest_args; - CclCommandHeader() : code(CclCommandCode::INVALID), dest_type(CclCommandDestType::CHIP_LOCAL_ONLY), arg_count(0) {} - CclCommandHeader(CclCommandCode code, CclCommandDestArgs const& args, uint8_t arg_count) : + CclCommandHeader() : + code(CclCommandCode::INVALID), dest_type(CclCommandDestType::CHIP_LOCAL_ONLY), arg_count(0) {} + CclCommandHeader(CclCommandCode code, const CclCommandDestArgs& args, uint8_t arg_count) : code(code), arg_count(arg_count) { if (std::holds_alternative(args)) { command_dest_args.unicast = std::get(args); @@ -573,36 +628,22 @@ struct CclCommandHeader { this->dest_type = CclCommandDestType::CHIP_LOCAL_ONLY; } } - CclCommandHeader(CclCommandCode code, MulticastCommandDestArgs const& multicast_args, uint8_t arg_count) : + CclCommandHeader(CclCommandCode code, const MulticastCommandDestArgs& multicast_args, uint8_t arg_count) : code(code), dest_type(CclCommandDestType::CHIP_MULTICAST), arg_count(arg_count) { this->command_dest_args.multicast = multicast_args; } - CclCommandHeader(CclCommandCode code, LocalOnlyCommandDestArgs const& local_only_args, uint8_t arg_count) : + CclCommandHeader(CclCommandCode code, const LocalOnlyCommandDestArgs& local_only_args, uint8_t arg_count) : code(code), dest_type(CclCommandDestType::CHIP_LOCAL_ONLY), arg_count(arg_count) { this->command_dest_args.local_only = local_only_args; } - static CclCommandHeader from_uint32(uint32_t cmd_header) { + static CclCommandHeader from_uint32_impl(uint32_t cmd_header) { CclCommandHeader decoded; reinterpret_cast(&decoded)[0] = cmd_header; return decoded; - // decoded.code = static_cast(cmd_header & 0xFF); - // decoded.dest_type = static_cast((cmd_header >> 6) & 0x3); - // switch (decoded.dest_type) { - // case CclCommandDestType::CHIP_UNICAST: - // decoded.command_dest_args.unicast = UnicastCommandDestArgs{static_cast((cmd_header >> 16) & - // 0xFF), static_cast((cmd_header >> 24) & 0x1)}; break; - // case CclCommandDestType::CHIP_MULTICAST: - // decoded.command_dest_args.multicast = MulticastCommandDestArgs{static_cast((cmd_header >> - // 16) & 0xFF), static_cast((cmd_header >> 24) & 0xFF)}; break; - // default: - // break; - // } - // decoded.arg_count = (cmd_header >> 8) & 0xF; - // return decoded; } - static uint32_t to_uint32(CclCommandHeader const& cmd_header) { + static uint32_t to_uint32(const CclCommandHeader& cmd_header) { uint32_t encoded = 0; encoded = (uint32_t)(cmd_header.code); encoded |= (cmd_header.dest_type << 6); @@ -621,11 +662,15 @@ struct CclCommandHeader { return encoded; } uint32_t to_uint32() const { return to_uint32(*this); } + static CclCommandHeader from_uint32(uint32_t cmd_header) { + return CclCommandHeader::from_uint32_impl(cmd_header); + } - UnicastCommandDestArgs const& get_unicast_dest_args() const { return command_dest_args.unicast; } - MulticastCommandDestArgs const& get_multicast_dest_args() const { return command_dest_args.multicast; } - LocalOnlyCommandDestArgs const& get_local_only_dest_args() const { return command_dest_args.local_only; } + const UnicastCommandDestArgs& get_unicast_dest_args() const { return command_dest_args.unicast; } + const MulticastCommandDestArgs& get_multicast_dest_args() const { return command_dest_args.multicast; } + const LocalOnlyCommandDestArgs& get_local_only_dest_args() const { return command_dest_args.local_only; } }; + static_assert(sizeof(CclCommandHeader) == sizeof(uint32_t)); } // namespace cmd diff --git a/ttnn/cpp/ttnn/operations/ccl/common/uops/ccl_host_commands.cpp b/ttnn/cpp/ttnn/operations/ccl/common/uops/ccl_host_commands.cpp index af3de7d0e46..b7e62b9a9ab 100644 --- a/ttnn/cpp/ttnn/operations/ccl/common/uops/ccl_host_commands.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/common/uops/ccl_host_commands.cpp @@ -376,6 +376,117 @@ CclHostLowLevelWorkerCommand fabric_unicast_absolute_address_semaphore_inc( ttnn::ccl::cmd::UnicastCommandDestArgs(unicast_args)); } + +// Noc Read/Write commands +// Densely packs as many transfers as possible into a single packet +static std::vector densely_pack_noc_transfers(tt::stl::Span const& transfer_infos, size_t cb_size_bytes) { + std::vector transfer_burst_groupings; + + size_t group_size_bytes = 0; + transfer_burst_groupings.push_back({}); + for (size_t i = 0; i < transfer_infos.size(); i++) { + group_size_bytes += transfer_infos[i].noc_transfer_size_bytes; + bool create_new_group = group_size_bytes >= cb_size_bytes; + if (create_new_group) { + transfer_burst_groupings.push_back({}); + group_size_bytes = 0; + } + + auto &group = transfer_burst_groupings.back(); + bool is_32B_aligned = (group_size_bytes & 0x1F) == 0; + if (!is_32B_aligned) { + group_size_bytes += 0x20 - (group_size_bytes & 0x1F); + } + + group.num_transfers_per_packet++; + group.transfer_infos.push_back(transfer_infos[i]); + } + + return transfer_burst_groupings; +} + +CclHostLowLevelWorkerCommand local_noc_read_burst_to_cb( + CclCommandAddrAbsoluteAddress const& bank_base_address, + tt::stl::Span const& transfer_infos, + size_t cb_size_bytes, + size_t cb_id +) { + auto transfer_burst_groupings = densely_pack_noc_transfers(transfer_infos, cb_size_bytes); + + return CclHostLowLevelWorkerCommand( + CclCommandCode::NOC_READ_BURST, + ttnn::ccl::cmd::CclCommandArgs(ttnn::ccl::cmd::HostCclCommandNocTransferBurst{bank_base_address.absolute_address, transfer_infos.size(), transfer_burst_groupings}), + ttnn::ccl::cmd::CclCommandAddrType::ABSOLUTE_ADDRESS, + ttnn::ccl::cmd::CclCommandAddrAbsoluteAddress{bank_base_address}, + ttnn::ccl::cmd::CclCommandAddrType::CIRCULAR_BUFFER_ID, + ttnn::ccl::cmd::CclCommandAddrCircularBufferId{cb_id} + ); +} + +CclHostLowLevelWorkerCommand local_noc_write_burst_from_cb( + CclCommandAddrAbsoluteAddress const& bank_base_address, + tt::stl::Span const& transfer_infos, + size_t cb_size_bytes, + size_t cb_id +) { + auto transfer_burst_groupings = densely_pack_noc_transfers(transfer_infos, cb_size_bytes); + + return CclHostLowLevelWorkerCommand( + CclCommandCode::NOC_WRITE_BURST, + ttnn::ccl::cmd::CclCommandArgs(ttnn::ccl::cmd::HostCclCommandNocTransferBurst{bank_base_address.absolute_address, transfer_infos.size(), transfer_burst_groupings}), + ttnn::ccl::cmd::CclCommandAddrType::CIRCULAR_BUFFER_ID, + ttnn::ccl::cmd::CclCommandAddrCircularBufferId{cb_id}, + ttnn::ccl::cmd::CclCommandAddrType::ABSOLUTE_ADDRESS, + ttnn::ccl::cmd::CclCommandAddrAbsoluteAddress{bank_base_address} + ); +} + +[[nodiscard]] CclHostLowLevelWorkerCommand fabric_unicast_noc_write_burst_from_cb( + CclCommandAddrAbsoluteAddress const& bank_base_address, + tt::stl::Span const& transfer_infos, + size_t cb_size_bytes, + size_t cb_id, + UnicastCommandDestArgs const& unicast_args +) { + auto transfer_burst_groupings = densely_pack_noc_transfers(transfer_infos, cb_size_bytes); + + return CclHostLowLevelWorkerCommand( + CclCommandCode::NOC_WRITE_BURST, + ttnn::ccl::cmd::CclCommandArgs(ttnn::ccl::cmd::HostCclCommandNocTransferBurst{bank_base_address.absolute_address, transfer_infos.size(), transfer_burst_groupings}), + ttnn::ccl::cmd::CclCommandAddrType::CIRCULAR_BUFFER_ID, + ttnn::ccl::cmd::CclCommandAddrCircularBufferId{cb_id}, + ttnn::ccl::cmd::CclCommandAddrType::ABSOLUTE_ADDRESS, + ttnn::ccl::cmd::CclCommandAddrAbsoluteAddress{bank_base_address}, + ttnn::ccl::cmd::CclCommandCoreDescriptorType::NONE, + ttnn::ccl::cmd::CclCommandCoreDescriptorTypeNone(), + ttnn::ccl::cmd::CclCommandDestType::CHIP_UNICAST, + ttnn::ccl::cmd::UnicastCommandDestArgs(unicast_args) + ); +} + +CclHostLowLevelWorkerCommand fabric_multicast_noc_write_burst_from_cb( + CclCommandAddrAbsoluteAddress const& bank_base_address, + tt::stl::Span const& transfer_infos, + size_t cb_size_bytes, + size_t cb_id, + MulticastCommandDestArgs const& multicast_args +) { + auto transfer_burst_groupings = densely_pack_noc_transfers(transfer_infos, cb_size_bytes); + + return CclHostLowLevelWorkerCommand( + CclCommandCode::NOC_WRITE_BURST, + ttnn::ccl::cmd::CclCommandArgs(ttnn::ccl::cmd::HostCclCommandNocTransferBurst{bank_base_address.absolute_address, transfer_infos.size(), transfer_burst_groupings}), + ttnn::ccl::cmd::CclCommandAddrType::CIRCULAR_BUFFER_ID, + ttnn::ccl::cmd::CclCommandAddrCircularBufferId{cb_id}, + ttnn::ccl::cmd::CclCommandAddrType::ABSOLUTE_ADDRESS, + ttnn::ccl::cmd::CclCommandAddrAbsoluteAddress{bank_base_address}, + ttnn::ccl::cmd::CclCommandCoreDescriptorType::NONE, + ttnn::ccl::cmd::CclCommandCoreDescriptorTypeNone(), + ttnn::ccl::cmd::CclCommandDestType::CHIP_MULTICAST, + ttnn::ccl::cmd::MulticastCommandDestArgs(multicast_args) + ); +} + } // namespace uops } // namespace ttnn::ccl::cmd diff --git a/ttnn/cpp/ttnn/operations/ccl/common/uops/ccl_host_commands.hpp b/ttnn/cpp/ttnn/operations/ccl/common/uops/ccl_host_commands.hpp index 8e8c22ea8b5..ce92d8a2485 100644 --- a/ttnn/cpp/ttnn/operations/ccl/common/uops/ccl_host_commands.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/common/uops/ccl_host_commands.hpp @@ -84,5 +84,38 @@ using semaphore_id_t = std::variant const& transfer_infos, + size_t cb_size_bytes, + size_t cb_id +); + +[[nodiscard]] CclHostLowLevelWorkerCommand local_noc_write_burst_from_cb( + CclCommandAddrAbsoluteAddress const& bank_base_address, + tt::stl::Span const& transfer_infos, + size_t cb_size_bytes, + size_t cb_id +); + +[[nodiscard]] CclHostLowLevelWorkerCommand fabric_unicast_noc_write_burst_from_cb( + CclCommandAddrAbsoluteAddress const& bank_base_address, + tt::stl::Span const& transfer_infos, + size_t cb_size_bytes, + size_t cb_id, + UnicastCommandDestArgs const& unicast_args +); + +[[nodiscard]] CclHostLowLevelWorkerCommand fabric_multicast_noc_write_burst_from_cb( + CclCommandAddrAbsoluteAddress const& bank_base_address, + tt::stl::Span const& transfer_infos, + size_t cb_size_bytes, + size_t cb_id, + MulticastCommandDestArgs const& multicast_args +); + + }; // namespace uops }; // namespace ttnn::ccl::cmd diff --git a/ttnn/cpp/ttnn/operations/ccl/common/uops/command_lowering.cpp b/ttnn/cpp/ttnn/operations/ccl/common/uops/command_lowering.cpp new file mode 100644 index 00000000000..3c511a6b46d --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/common/uops/command_lowering.cpp @@ -0,0 +1,224 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "ttnn/cpp/ttnn/operations/ccl/common/uops/command_lowering.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/common/interpreter_backends/kernel_common/algorithms.hpp" +#include "ttnn/operations/ccl/common/uops/ccl_command.hpp" +#include "ttnn/tensor/tensor_impl.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/sharded_tensor_addr_gen.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/common/types/ccl_types_args_emitters.hpp" + +namespace ttnn::ccl { + +// For page-aligned reads - never split pages across packets +template +void generate_noc_transfer_burst_for_tensor_slice( + const ttnn::ccl::v2::TensorSlice& tensor_slice, + ttnn::ccl::cmd::HostCclCommandNocTransferBurst& noc_transfer_burst_out, + const AddressGenerator& address_generator, + size_t page_size, + size_t packet_size_bytes) { + TT_FATAL(page_size > 0, "Internal error: page size is 0"); + + size_t packet_space_in_bytes_left = packet_size_bytes; + noc_transfer_burst_out.transfer_burst_groupings.push_back({}); + bool closed_out_last_group = false; + for (size_t w = 0; w < tensor_slice.tensor_slice_shape.w; w++) { + for (size_t z = 0; z < tensor_slice.tensor_slice_shape.z; z++) { + for (size_t y = 0; y < tensor_slice.tensor_slice_shape.y; y++) { + size_t pages_read = 0; + for (size_t x = 0; x < tensor_slice.tensor_slice_shape.x; x += pages_read) { + closed_out_last_group = false; + auto offset = ttnn::ccl::Shape4D{w, z, y, x} + tensor_slice.tensor_slice_offset; + auto& transfer_burst_grouping = noc_transfer_burst_out.transfer_burst_groupings.back(); + const size_t curr_page_idx = get_flat_index_from_shape(tensor_slice.tensor_shape, offset); + const auto& [noc_yx, page_index_into_shard, contig_pages_] = + address_generator.get_page_location_with_contiguous_pages_in_row_in_bank(curr_page_idx); + pages_read = std::min( + {tensor_slice.tensor_slice_shape.x - x, packet_space_in_bytes_left / page_size, contig_pages_}); + size_t transfer_size_in_bytes = pages_read * page_size; + + TT_FATAL(pages_read > 0, "Internal error: hit infinite loop indicating a logical error"); + noc_transfer_burst_out.num_transfers_total++; + transfer_burst_grouping.num_transfers_per_packet++; + packet_space_in_bytes_left -= transfer_size_in_bytes; + auto byte_offset_in_shard = page_index_into_shard * page_size; + uint64_t noc_addr_offset = (static_cast(noc_yx.noc_y) << 48) | + (static_cast(noc_yx.noc_x) << 32) | + static_cast(byte_offset_in_shard); + transfer_burst_grouping.transfer_infos.push_back( + ttnn::ccl::cmd::noc_transfer_info{noc_addr_offset, transfer_size_in_bytes}); + + if (packet_space_in_bytes_left < page_size) { + closed_out_last_group = true; + packet_space_in_bytes_left = packet_size_bytes; + bool last_w = w == tensor_slice.tensor_slice_shape.w - 1; + bool last_z = z == tensor_slice.tensor_slice_shape.z - 1; + bool last_y = y == tensor_slice.tensor_slice_shape.y - 1; + bool last_x = x + pages_read == tensor_slice.tensor_slice_shape.x; + if (!(last_w && last_z && last_y && last_x)) { + noc_transfer_burst_out.transfer_burst_groupings.push_back({}); + } + } + } + } + } + } +} + +void validate_lowered_noc_commands(const ttnn::ccl::cmd::HostCclCommandNocTransferBurst& noc_transfer_burst) { + TT_FATAL(noc_transfer_burst.transfer_burst_groupings.size() > 0, "Internal error: No transfer burst groupings"); + for (auto& transfer_burst_grouping : noc_transfer_burst.transfer_burst_groupings) { + TT_FATAL(transfer_burst_grouping.num_transfers_per_packet > 0, "Internal error: No transfers per packet"); + for (auto& transfer_info : transfer_burst_grouping.transfer_infos) { + TT_FATAL(transfer_info.noc_transfer_size_bytes > 0, "Internal error: No transfer size bytes"); + } + } +} + +ttnn::ccl::cmd::CclHostLowLevelWorkerCommand lower_tensor_slice_command_to_noc_commands( + const ttnn::ccl::cmd::CclHostLowLevelWorkerCommand& command, + const tt::tt_metal::Tensor& tensor, + size_t packet_size_bytes) { + using namespace tt::tt_metal::address_generators; + using namespace tt::tt_metal; + + TT_FATAL(tensor.is_sharded(), "Only tensor slices for sharded tensors are able to be lowered to noc reads/writes"); + + ttnn::ccl::cmd::HostCclCommandNocTransferBurst noc_transfer_burst; + noc_transfer_burst.bank_base_address = tensor.buffer()->address(); + + const auto& tensor_slice = std::get(command.command_args); + auto page_size = tensor.buffer()->page_size(); + + auto coord_lookup = tt::tt_metal::address_generators::VirtualCoordWormholeWorkerToNocLookup(); + + const auto& [pages_per_shard_y, pages_per_shard_x] = tensor.buffer()->shard_spec().shape_in_pages(); + const auto& [shard_grid_start, shard_grid_end] = ttnn::ccl::shard_grid_from_shard_spec(tensor.shard_spec().value()); + const bool shard_grid_transposed = ttnn::ccl::ShardedAddrGenArgBuilder::shard_grid_is_transposed(tensor); + // shard_grid_height (cores) + const size_t shard_grid_height = shard_grid_end.y - shard_grid_start.y + 1; + TT_FATAL( + shard_grid_height > 0, "Internal error. Computed shard_grid height == 0 to sharded addrgen, which is invalid"); + // shard_grid_width (cores) + const size_t shard_grid_width = shard_grid_end.x - shard_grid_start.x + 1; + TT_FATAL( + shard_grid_width > 0, "Internal error. Computed shard_grid width == 0 to sharded addrgen, which is invalid"); + // Only page aligned for now since tensor slice is page based at the moment + // Future work to migrate tensor slice to be element based and then at that + // point we can + switch (tensor.buffer()->buffer_layout()) { + case tt::tt_metal::TensorMemoryLayout::BLOCK_SHARDED: { + auto address_generator = build_sharded_addr_gen( + coord_lookup, + address_generators::DeviceShardSpecTypeGetter::type( + pages_per_shard_y, + pages_per_shard_x, + shard_grid_height, + shard_grid_width, + shard_grid_start.y, + shard_grid_start.x, + shard_grid_transposed), + noc_transfer_burst.bank_base_address, + page_size); + generate_noc_transfer_burst_for_tensor_slice( + tensor_slice, noc_transfer_burst, address_generator, page_size, packet_size_bytes); + break; + } + case tt::tt_metal::TensorMemoryLayout::HEIGHT_SHARDED: { + auto address_generator = build_sharded_addr_gen( + coord_lookup, + address_generators::DeviceShardSpecTypeGetter::type( + pages_per_shard_y, + pages_per_shard_x, + shard_grid_height, + shard_grid_width, + shard_grid_start.y, + shard_grid_start.x, + shard_grid_transposed), + noc_transfer_burst.bank_base_address, + page_size); + generate_noc_transfer_burst_for_tensor_slice( + tensor_slice, noc_transfer_burst, address_generator, page_size, packet_size_bytes); + break; + } + case tt::tt_metal::TensorMemoryLayout::WIDTH_SHARDED: { + auto address_generator = build_sharded_addr_gen( + coord_lookup, + address_generators::DeviceShardSpecTypeGetter::type( + pages_per_shard_y, + pages_per_shard_x, + shard_grid_height, + shard_grid_width, + shard_grid_start.y, + shard_grid_start.x, + shard_grid_transposed), + noc_transfer_burst.bank_base_address, + page_size); + generate_noc_transfer_burst_for_tensor_slice( + tensor_slice, noc_transfer_burst, address_generator, page_size, packet_size_bytes); + break; + } + default: TT_FATAL(false, "Unsupported buffer layout"); + } + + validate_lowered_noc_commands(noc_transfer_burst); + + std::stringstream ss; + ss << "Lowering noc commands: \n"; + ss << fmt::format( + "Base_addr: {}, burst_size: {}", + noc_transfer_burst.bank_base_address, + noc_transfer_burst.num_transfers_total) + << "\n"; + for (auto& transfer : noc_transfer_burst.transfer_burst_groupings) { + ss << fmt::format("\tGroup_size: {}", transfer.num_transfers_per_packet) << "\n"; + for (auto& transfer_info : transfer.transfer_infos) { + ss << fmt::format("\t\taddr: {}, size: {}", transfer_info.noc_addr, transfer_info.noc_transfer_size_bytes) + << "\n"; + } + } + log_trace(tt::LogOp, "{}", ss.str()); + + // Generate the new (lowered to noc read/write) command + ttnn::ccl::cmd::CclHostLowLevelWorkerCommand lowered_command = command; + switch (command.command_code) { + case ttnn::ccl::cmd::CclCommandCode::STREAM_CB_TO_TENSOR: + lowered_command.dest_addr_type = ttnn::ccl::cmd::CclCommandAddrType::NONE; + lowered_command.dest_addr_args = ttnn::ccl::cmd::CclCommandAddrArgs(); + lowered_command.command_code = ttnn::ccl::cmd::CclCommandCode::NOC_WRITE_BURST; + lowered_command.command_args = ttnn::ccl::cmd::HostCclCommandNocTransferBurst{noc_transfer_burst}; + break; + case ttnn::ccl::cmd::CclCommandCode::STREAM_TENSOR_TO_CB: + lowered_command.source_addr_type = ttnn::ccl::cmd::CclCommandAddrType::NONE; + lowered_command.source_addr_args = ttnn::ccl::cmd::CclCommandAddrArgs(); + lowered_command.command_code = ttnn::ccl::cmd::CclCommandCode::NOC_READ_BURST; + lowered_command.command_args = ttnn::ccl::cmd::HostCclCommandNocTransferBurst{noc_transfer_burst}; + break; + default: TT_FATAL(false, "Only STREAM_CB_TO_TENSOR and STREAM_TENSOR_TO_CB commands are supported"); + } + + return lowered_command; +} + +std::vector tensor_slice_commands_to_noc_commands( + const std::vector& command_stream, + const tt::tt_metal::Tensor& tensor, + size_t packet_size_bytes) { + std::vector lowered_command_stream; + for (auto& command : command_stream) { + switch (command.command_code) { + case ttnn::ccl::cmd::CclCommandCode::STREAM_CB_TO_TENSOR: [[fallthrough]]; + case ttnn::ccl::cmd::CclCommandCode::STREAM_TENSOR_TO_CB: + lowered_command_stream.push_back( + lower_tensor_slice_command_to_noc_commands(command, tensor, packet_size_bytes)); + break; + + default: lowered_command_stream.push_back(command); break; + } + } + return lowered_command_stream; +} + +} // namespace ttnn::ccl diff --git a/ttnn/cpp/ttnn/operations/ccl/common/uops/command_lowering.hpp b/ttnn/cpp/ttnn/operations/ccl/common/uops/command_lowering.hpp new file mode 100644 index 00000000000..30bbc152647 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/common/uops/command_lowering.hpp @@ -0,0 +1,22 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "ttnn/cpp/ttnn/operations/ccl/common/uops/ccl_host_commands.hpp" + +#include + +namespace tt::tt_metal { +class Tensor; +} + +namespace ttnn::ccl { + +struct tensor_command_map; +std::vector tensor_slice_commands_to_noc_commands( + const std::vector& command_stream, + const tt::tt_metal::Tensor& tensor, + size_t packet_size_bytes); +} // namespace ttnn::ccl diff --git a/ttnn/cpp/ttnn/operations/ccl/shared_with_host/sharded_tensor_addr_gen.hpp b/ttnn/cpp/ttnn/operations/ccl/shared_with_host/sharded_tensor_addr_gen.hpp index a51a6eff900..9df091f3ab3 100644 --- a/ttnn/cpp/ttnn/operations/ccl/shared_with_host/sharded_tensor_addr_gen.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/shared_with_host/sharded_tensor_addr_gen.hpp @@ -47,17 +47,20 @@ struct VirtualCoordWormholeWorkerToNocLookup : address_generators::WorkerToNocCoordLookup { VirtualCoordWormholeWorkerToNocLookup() : address_generators::WorkerToNocCoordLookup() {} noc_grid_index_t get_noc_x_from_worker_x(noc_grid_index_t worker_x) const { - return worker_x #if defined(KERNEL_BUILD) - + VIRTUAL_TENSIX_START_X + return worker_x + VIRTUAL_TENSIX_START_X; + #else + constexpr noc_grid_index_t HOST_PLACEHOLDER_VIRTUAL_TENSIX_START_X = 18; + return worker_x + HOST_PLACEHOLDER_VIRTUAL_TENSIX_START_X; #endif - ; } noc_grid_index_t get_noc_y_from_worker_y(noc_grid_index_t worker_y) const { - return worker_y #if defined(KERNEL_BUILD) - + VIRTUAL_TENSIX_START_Y + return worker_y + VIRTUAL_TENSIX_START_Y + #else + constexpr noc_grid_index_t HOST_PLACEHOLDER_VIRTUAL_TENSIX_START_Y = 18; + return worker_y + HOST_PLACEHOLDER_VIRTUAL_TENSIX_START_Y; #endif ; } diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_program.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_program.cpp index 495fc283cfd..1464bb66a70 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_program.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_program.cpp @@ -19,12 +19,13 @@ #include "ttnn/cpp/ttnn/operations/ccl/common/types/ccl_types_args_emitters.hpp" #include "ttnn/cpp/ttnn/operations/ccl/common/host/ccl_command_stream_builders.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/common/uops/command_lowering.hpp" + +#include "ttnn/cpp/ttnn/operations/ccl/common/host/ccl_worker_builder.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/common/host/command_backend_runtime_args_overrider.hpp" #include #include #include - -#include "ttnn/cpp/ttnn/operations/ccl/common/host/ccl_worker_builder.hpp" - #include using namespace tt::constants; @@ -116,6 +117,20 @@ std::tuple> choose_worker_cores( return {sender_worker_core_range, corerange_to_cores(sender_worker_core_range, std::nullopt, true)}; } +static bool can_command_stream_be_lowered_to_noc_commands(const Tensor& input_tensor) { + static constexpr size_t baseline_arg_count = 12; + // approximately... this is only very rough estimate until unlimited command stream length is enabled + static constexpr size_t args_per_noc_command = 4; + static constexpr size_t max_noc_commands = 256; + size_t page_num_elements = + input_tensor.layout() == Layout::TILE ? TILE_HEIGHT * TILE_WIDTH : input_tensor.padded_shape()[-1]; + size_t num_tensor_pages = input_tensor.padded_shape().volume() / page_num_elements; + + // Interleaved tensors are currently not iterable on host so we can't resolve the page locations + return input_tensor.is_sharded() && + (num_tensor_pages * args_per_noc_command + baseline_arg_count < max_noc_commands); +} + // For ring all-gather, we can send sub-sections of input tensor in opposite directions // For linear all-gather though, we must ensure we send full tensors in BOTH directions // (in other words, disable the "bidirectional" send flag) @@ -133,6 +148,7 @@ operation::ProgramWithCallbacks all_gather_async_multi_core_with_workers( bool enable_persistent_fabric_mode) { tt::tt_metal::Program program{}; const bool enable_async_output_tensor = false; + const bool lower_command_stream_to_noc_commands = can_command_stream_be_lowered_to_noc_commands(input_tensor); TT_FATAL(semaphore_opt.has_value(), "Semaphore is required for compile time"); @@ -254,6 +270,13 @@ operation::ProgramWithCallbacks all_gather_async_multi_core_with_workers( log_trace(tt::LogOp, "reader_tensor_slices[0] size: {}", reader_tensor_slices[0].size()); CoreCoord drain_sync_core; + // For now these are a little disconnected from the commands - they'll need to be unified and explicitly + // associated with each other but this is for bootstrapping the feature + constexpr size_t reader_tensor_command_map_idx = 0; + constexpr size_t writer_tensor_command_map_idx = 1; + std::unordered_map reader_rt_args_overrider_map; + std::unordered_map writer_rt_args_overrider_map; + for (std::size_t link = 0; link < num_links; link++) { CoreCoord core = {num_workers_per_link - 1, link}; if (link == 0) { @@ -300,6 +323,10 @@ operation::ProgramWithCallbacks all_gather_async_multi_core_with_workers( ttnn::ccl::cmd::uops::read_tensor_slice_to_cb_for_eventual_fabric_write( input_worker_slice_v2, src0_cb_index)); + if (lower_command_stream_to_noc_commands) { + reader_cmd_stream = + ttnn::ccl::tensor_slice_commands_to_noc_commands(reader_cmd_stream, input_tensor, packet_size_bytes); + } ttnn::ccl::worker_detail::generate_multi_input_command_stream_kernel_rt_args( program, worker_sender_reader_kernel_id, @@ -309,9 +336,12 @@ operation::ProgramWithCallbacks all_gather_async_multi_core_with_workers( num_pages_per_packet, {core}, reader_cmd_stream, - std::nullopt, - std::nullopt, - std::nullopt); + std::nullopt, // cmd stream 1 + std::nullopt, // fabric fwd connection + std::nullopt, // fabric bwd connection + std::nullopt, // tensor device override + std::vector{reader_tensor_command_map_idx}, // tensor indices + &reader_rt_args_overrider_map[core]); // WRITER COMMAND STREAM and RT ARGS std::vector writer_cmd_stream; @@ -353,6 +383,11 @@ operation::ProgramWithCallbacks all_gather_async_multi_core_with_workers( writer_cmd_stream.push_back(ttnn::ccl::cmd::uops::local_core_semaphore_set(&semaphore, 0)); } + if (lower_command_stream_to_noc_commands) { + writer_cmd_stream = + ttnn::ccl::tensor_slice_commands_to_noc_commands(writer_cmd_stream, output_tensor, packet_size_bytes); + } + // set the rt args ttnn::ccl::worker_detail::generate_multi_input_command_stream_kernel_rt_args( program, @@ -365,7 +400,10 @@ operation::ProgramWithCallbacks all_gather_async_multi_core_with_workers( writer_cmd_stream, std::nullopt, {forward_fabric_connection}, - {backward_fabric_connection}); + {backward_fabric_connection}, + std::nullopt, + std::vector{writer_tensor_command_map_idx}, // tensor indices + &writer_rt_args_overrider_map[core]); } if (!enable_persistent_fabric_mode) { @@ -373,7 +411,14 @@ operation::ProgramWithCallbacks all_gather_async_multi_core_with_workers( } auto override_runtime_arguments_callback = - [worker_sender_reader_kernel_id, worker_sender_writer_kernel_id, semaphore, sender_worker_cores]( + [worker_sender_reader_kernel_id, + reader_rt_args_overrider_map, + writer_rt_args_overrider_map, + reader_tensor_command_map_idx, + writer_tensor_command_map_idx, + worker_sender_writer_kernel_id, + semaphore, + sender_worker_cores]( const void* operation, Program& program, const std::vector& input_tensors, @@ -388,10 +433,12 @@ operation::ProgramWithCallbacks all_gather_async_multi_core_with_workers( for (const auto& core : sender_worker_cores) { // reader auto& worker_reader_sender_runtime_args = worker_reader_sender_runtime_args_by_core[core.x][core.y]; - worker_reader_sender_runtime_args.at(0) = input.buffer()->address(); + reader_rt_args_overrider_map.at(core).override_runtime_args( + reader_tensor_command_map_idx, input.buffer()->address(), worker_reader_sender_runtime_args); // writer auto& worker_writer_sender_runtime_args = worker_writer_sender_runtime_args_by_core[core.x][core.y]; - worker_writer_sender_runtime_args.at(0) = output.buffer()->address(); + writer_rt_args_overrider_map.at(core).override_runtime_args( + writer_tensor_command_map_idx, output.buffer()->address(), worker_writer_sender_runtime_args); } }; diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/device/reduce_scatter_async_program.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/device/reduce_scatter_async_program.cpp index ff9b052a315..b024e8fb8fc 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/device/reduce_scatter_async_program.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/device/reduce_scatter_async_program.cpp @@ -1839,6 +1839,7 @@ static void log_worker_command_streams(WorkerCommandStreams const& command_strea [](ttnn::ccl::cmd::CclCommandCoreDescriptorTypeLocal const& core) { return fmt::format("local"); }, [](ttnn::ccl::cmd::CclCommandCoreDescriptorTypeNocXY const& core) { return fmt::format("(x={},y={})", core.x, core.y); }, [](ttnn::ccl::cmd::CclCommandCoreDescriptorTypeMcast const& core) { return fmt::format("mcast"); }, + [](ttnn::ccl::cmd::CclCommandCoreDescriptorTypeNone const& core) { return fmt::format("NONE"); }, }, core); }; @@ -1878,10 +1879,14 @@ static void log_worker_command_streams(WorkerCommandStreams const& command_strea case ttnn::ccl::cmd::CclCommandCode::RAW_INLINE_WRITE_BYTES: return "RAW_INL_WR"; - case ttnn::ccl::cmd::CclCommandCode::RAW_READ_BYTES: - return "RAW_RD"; - case ttnn::ccl::cmd::CclCommandCode::RAW_WRITE_BYTES: - return "RAW_WR"; + case ttnn::ccl::cmd::CclCommandCode::NOC_READ_BURST: + return "NOC_RD_BURST"; + case ttnn::ccl::cmd::CclCommandCode::NOC_WRITE_BURST: + return "NOC_WR_BURST"; + case ttnn::ccl::cmd::CclCommandCode::FLOW_CONTROLLED_NOC_READ_BURST: + return "NOC_RD_BURST_FC"; + case ttnn::ccl::cmd::CclCommandCode::NOC_WRITE_AND_ATOMIC_INC: + return "NOC_WR_AND_AT_INC"; case ttnn::ccl::cmd::CclCommandCode::STREAM_EDM_TO_TENSOR: TT_THROW("Got an unsupported command in a command stream (STREAM_EDM_TO_TENSOR). This command is deprecated and unsupported by this infrastructure. This will lead to undefined and invalid behaviour");