From 4f5f417d67e145cb64f55aa51f18fdc95e21cb29 Mon Sep 17 00:00:00 2001 From: Sean Nijjar Date: Sat, 21 Dec 2024 13:58:47 -0500 Subject: [PATCH] Initial CCL V2 infra push - add cmd interpreter and reduce scatter, all-gather Without going to deep into the weeds, there were numerous reasons why CCLs needed to be fundamentally rewritten but to summarize some of the reasons: - Writing CCLs was not scalable from a development effort - Even within a single op (e.g. all-gather) we need to be able to support many topologies (ring, line, mesh, tree, tree of mesh, etc.) and use cases (BW bound, latency bound, high reliability vs lower reliability with potentially better perf) - CCLs need to be able to be fused with just about any ops without it being a Herculean effort - New concepts like "async tensor" need to be supported to account for performance artifacts like (dispatch) skew between chips and to effectively hide latency of various operations - (minor) support the new fabric projects with CCLs ### Initial test coverage - Gtests that provide basic coverage for the CCL Command interpreter running on the transitionary EDM fabric (both in persistent and non-persistent modes) - Gtests for reduce scatter and all-gather also added - Basic all gather pytests Future work will expand test coverage ### What's changed Lots to discuss here: - What's the command interpreter - How's it work - How do we build ops with it - What's new with these CCLs? The bulk of this information is or will be included in a much larger doc that will be circulated more widely in the coming weeks so a summary is provided below (if you want more details before the doc is provided, ask and I will point you to what's in progress): A new "command interpreter" kernel is provided which executes various different command types. Some commands map nearly directly to the low level noc API but others map to higher level operations. High Level Operation Example: - Stream Tensor Slice (from: CB/addrgen) (to:raw addr, CB, (fabric) addrgen) Low Level Command: - Wait for semaphore value - Send semaphore update - Raw Read/Write These commands are specifiable on host and there is a whole optimization story for performance but to provide the general idea, here's the primary functional code needed for all-gather as an example (code reorganized for purpose of PR example - not 1:1 to `all_gather_async_program.cpp`: ``` // Create a "reader kernel" command stream std::vector reader_cmd_stream; reader_cmd_stream.push_back(ttnn::ccl::cmd::uops::read_tensor_slice_to_cb(input_worker_slice_v2, src0_cb_index)); // Create a "writer kernel" command stream std::vector writer_cmd_stream; // 1, do mcast of the tensor slice to all the destinations writer_cmd_stream.push_back(ttnn::ccl::cmd::uops::fabric_write_cb_to_tensor_slice( output_worker_slice_v2, src0_cb_index, mcast_dest_args)); // Really, for all-gather, that's basically it - the rest of the code is to choose core placement and get info - l // ike which core(s) are fabric endpoints to connect to fabric, etc.) // Now pass the commands to the kernel(s) ttnn::ccl::worker_detail::generate_multi_input_command_stream_kernel_rt_args( program, worker_sender_reader_kernel_id, ..., reader_cmd_stream, std::nullopt, std::nullopt, std::nullopt); ttnn::ccl::worker_detail::generate_multi_input_command_stream_kernel_rt_args( program, worker_sender_writer_kernel_id, ..., writer_cmd_stream, std::nullopt, {forward_fabric_connection}, {backward_fabric_connection}); ``` With the above, operations such as fusion become far simpler (in some cases, trivial). For example, in the case of fusing an all-reduce with split-qkv heads operation for example (note that the output side of all-reduce is basically all-gather in an optimized ring implementation), the basic fusion operation is first identifying the split/slice boundaries of split-qkv (this could potentially be obtained from the op directly) and propagating those cut lines to all of the tensor slices of the producer (like the tensor slices in the commands shown above) and simply splitting those slices and setting the correct output tensors for each accordingly. Note that many commands can be added to each given command stream - all-gather is just very simple. Reduce scatter is an example of one that is more complicated. ### Expanding to other operations: Here are some simple examples #### Send/receive - Take the all-gather as example, and rather than specifying an mcast on the tensor write command: ``` writer_cmd_stream.push_back(ttnn::ccl::cmd::uops::fabric_write_cb_to_tensor_slice( output_worker_slice_v2, src0_cb_index, mcast_dest_args)) ``` you would unicast it to the desired destination (replace `mcast_dest_args`) If running in synchronous tensor mode, add a command interpreter kernel at the destination chip with a wait_val command to wait on a sem inc. Append a seminc to the sender command stream #### Broadcast Invoke all-gather above but just from one chip. If running in synchronous tensor mode, add a command interpreter kernel at all the destination chips with a wait_val command to wait on a sem inc. Append a fabric multicast seminc to the sender command stream. #### Reduce - Build a tree on the cluster - Each producer chip unicast sends to the next node towards the root of the tree, send sync signal to downstream - if not a leaf, perform partial reduction with your received data and your local data and forward to the next node toward the root - Add a wait val before accepting your input data - Root node can do any number of reductions to reduce the incoming data streams (ensuring to first sync on any input stream before consuming We do something similar to the above for reduce scatter ### Note on APIs These APIs are expected to be refined over time. In the mean-time, I have introduces the named "micro-ops" as commands to grant us some flexibilitiy in changing underlying command encodings (both on host and device). This will let us optimize and improve the "IR" over time without requiring constant op implementation updates. --------- Co-authored-by: Jack Cai --- tests/ttnn/unit_tests/gtests/CMakeLists.txt | 1 + ...c_erisc_datamover_sender_worker_sender.cpp | 57 +- .../fabric_worker_sender_multi_input.cpp | 218 ++ .../ccl/kernels/test_kernels.common.hpp | 46 + .../test_ccl_reduce_scatter_host_helpers.cpp | 4 +- .../gtests/ccl/test_ccl_tensor_slicers.cpp | 780 ++--- ...erisc_data_mover_loopback_with_workers.cpp | 2573 ++++++++++++++++- .../ccl/test_sharded_address_generators.cpp | 641 ++++ .../operations/ccl/test_new_all_gather.py | 601 ++++ .../ccl/test_reduce_scatter_async.py | 342 +++ ttnn/cpp/pybind11/operations/__init__.hpp | 8 +- ttnn/cpp/ttnn/operations/ccl/CMakeLists.txt | 4 + .../ccl/all_gather/all_gather_pybind.cpp | 5 - .../dataflow/worker_ring_gather_utils.hpp | 127 +- ttnn/cpp/ttnn/operations/ccl/ccl_common.cpp | 766 ++++- ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp | 171 ++ ttnn/cpp/ttnn/operations/ccl/ccl_pybind.cpp | 33 + ttnn/cpp/ttnn/operations/ccl/ccl_pybind.hpp | 16 + .../host/ccl_command_stream_builders.cpp | 181 ++ .../host/ccl_command_stream_builders.hpp | 40 + .../ccl/common/host/ccl_worker_builder.cpp | 1570 ++++++++++ .../ccl/common/host/ccl_worker_builder.hpp | 172 ++ .../ccl/common/kernels/ccl_send_reader.cpp | 203 ++ .../kernels/ccl_send_reader_two_input.cpp | 1012 +++++++ .../ccl/common/kernels/ccl_send_utils.hpp | 330 +++ .../ccl/common/kernels/ccl_send_writer.cpp | 274 ++ .../common/kernels/ccl_wait_completion.cpp | 67 + .../ccl/common/kernels/command_processor.hpp | 136 + .../operations/ccl/common/types/ccl_types.hpp | 7 + .../common/types/ccl_types_args_emitters.cpp | 4 + .../ccl/common/types/ccl_types_device.hpp | 1 + .../ccl/common/uops/ccl_command.hpp | 459 ++- .../ccl/common/uops/ccl_command_device.hpp | 14 +- .../ccl/common/uops/ccl_host_commands.cpp | 381 +++ .../ccl/common/uops/ccl_host_commands.hpp | 88 + .../ccl/erisc_datamover_builder.cpp | 495 +++- .../ccl/erisc_datamover_builder.hpp | 143 +- .../ccl/kernel_common/worker_edm_utils.hpp | 6 +- .../ccl/kernels/edm/edm_handshake.hpp | 15 +- .../edm_fabric/edm_fabric_worker_adapters.hpp | 227 +- .../edm_fabric/fabric_edm_packet_header.hpp | 63 +- .../fabric_edm_packet_transmission.hpp | 61 +- .../edm_fabric/fabric_erisc_datamover.cpp | 133 +- .../fabric_erisc_datamover_channels.hpp | 1 + .../host/reduce_scatter_worker_builder.cpp | 7 +- .../host/reduce_scatter_worker_builder.hpp | 1 - .../reduce_scatter/reduce_scatter_pybind.cpp | 6 +- .../hetergeneous_data_structs.hpp | 146 + .../sharded_tensor_addr_gen.hpp | 58 +- .../experimental/ccl/CMakeLists.txt | 9 + .../ccl/all_gather_async/all_gather_async.cpp | 55 + .../ccl/all_gather_async/all_gather_async.hpp | 46 + .../all_gather_async_pybind.cpp | 139 + .../all_gather_async_pybind.hpp | 13 + .../device/all_gather_async_op.cpp | 334 +++ .../device/all_gather_async_op.hpp | 144 + .../device/all_gather_async_program.cpp | 410 +++ .../ccl/ccl_experimental_pybind.cpp | 22 + .../ccl/ccl_experimental_pybind.hpp | 16 + .../device/reduce_scatter_async_op.cpp | 374 +++ .../device/reduce_scatter_async_op.hpp | 150 + .../device/reduce_scatter_async_program.cpp | 2172 ++++++++++++++ .../reduce_scatter_async/reduce_scatter.cpp | 32 + .../reduce_scatter_async/reduce_scatter.hpp | 38 + .../reduce_scatter_pybind.cpp | 103 + .../reduce_scatter_pybind.hpp | 13 + .../experimental/experimental_pybind.cpp | 9 +- ttnn/ttnn/__init__.py | 2 + ttnn/ttnn/operations/ccl.py | 2 + 69 files changed, 15490 insertions(+), 1287 deletions(-) create mode 100644 tests/ttnn/unit_tests/gtests/ccl/kernels/fabric_worker_sender_multi_input.cpp create mode 100644 tests/ttnn/unit_tests/gtests/ccl/kernels/test_kernels.common.hpp create mode 100644 tests/ttnn/unit_tests/gtests/ccl/test_sharded_address_generators.cpp create mode 100644 tests/ttnn/unit_tests/operations/ccl/test_new_all_gather.py create mode 100644 tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_async.py create mode 100644 ttnn/cpp/ttnn/operations/ccl/ccl_pybind.cpp create mode 100644 ttnn/cpp/ttnn/operations/ccl/ccl_pybind.hpp create mode 100644 ttnn/cpp/ttnn/operations/ccl/common/host/ccl_command_stream_builders.cpp create mode 100644 ttnn/cpp/ttnn/operations/ccl/common/host/ccl_command_stream_builders.hpp create mode 100644 ttnn/cpp/ttnn/operations/ccl/common/host/ccl_worker_builder.cpp create mode 100644 ttnn/cpp/ttnn/operations/ccl/common/host/ccl_worker_builder.hpp create mode 100644 ttnn/cpp/ttnn/operations/ccl/common/kernels/ccl_send_reader.cpp create mode 100644 ttnn/cpp/ttnn/operations/ccl/common/kernels/ccl_send_reader_two_input.cpp create mode 100644 ttnn/cpp/ttnn/operations/ccl/common/kernels/ccl_send_utils.hpp create mode 100644 ttnn/cpp/ttnn/operations/ccl/common/kernels/ccl_send_writer.cpp create mode 100644 ttnn/cpp/ttnn/operations/ccl/common/kernels/ccl_wait_completion.cpp create mode 100644 ttnn/cpp/ttnn/operations/ccl/common/kernels/command_processor.hpp create mode 100644 ttnn/cpp/ttnn/operations/ccl/common/uops/ccl_host_commands.cpp create mode 100644 ttnn/cpp/ttnn/operations/ccl/common/uops/ccl_host_commands.hpp create mode 100644 ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/all_gather_async.cpp create mode 100644 ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/all_gather_async.hpp create mode 100644 ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/all_gather_async_pybind.cpp create mode 100644 ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/all_gather_async_pybind.hpp create mode 100644 ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.cpp create mode 100644 ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.hpp create mode 100644 ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_program.cpp create mode 100644 ttnn/cpp/ttnn/operations/experimental/ccl/ccl_experimental_pybind.cpp create mode 100644 ttnn/cpp/ttnn/operations/experimental/ccl/ccl_experimental_pybind.hpp create mode 100644 ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/device/reduce_scatter_async_op.cpp create mode 100644 ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/device/reduce_scatter_async_op.hpp create mode 100644 ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/device/reduce_scatter_async_program.cpp create mode 100644 ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/reduce_scatter.cpp create mode 100644 ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/reduce_scatter.hpp create mode 100644 ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/reduce_scatter_pybind.cpp create mode 100644 ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/reduce_scatter_pybind.hpp diff --git a/tests/ttnn/unit_tests/gtests/CMakeLists.txt b/tests/ttnn/unit_tests/gtests/CMakeLists.txt index 573e382348b..f104ce6d7fe 100644 --- a/tests/ttnn/unit_tests/gtests/CMakeLists.txt +++ b/tests/ttnn/unit_tests/gtests/CMakeLists.txt @@ -16,6 +16,7 @@ set(TTNN_CCL_UNIT_TESTS_SRC ${CMAKE_CURRENT_SOURCE_DIR}/ccl/test_ccl_commands.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ccl/test_ccl_helpers.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ccl/test_ccl_tensor_slicers.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/ccl/test_sharded_address_generators.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ccl/test_ccl_reduce_scatter_host_helpers.cpp ) diff --git a/tests/ttnn/unit_tests/gtests/ccl/kernels/fabric_erisc_datamover_sender_worker_sender.cpp b/tests/ttnn/unit_tests/gtests/ccl/kernels/fabric_erisc_datamover_sender_worker_sender.cpp index 4d60b4243f7..39380695040 100644 --- a/tests/ttnn/unit_tests/gtests/ccl/kernels/fabric_erisc_datamover_sender_worker_sender.cpp +++ b/tests/ttnn/unit_tests/gtests/ccl/kernels/fabric_erisc_datamover_sender_worker_sender.cpp @@ -7,6 +7,7 @@ #include "dataflow_api.h" #include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header.hpp" #include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/edm_fabric_worker_adapters.hpp" +#include "tests/ttnn/unit_tests/gtests/ccl/kernels/test_kernels.common.hpp" struct unicast_mode { uint8_t distance; @@ -53,8 +54,7 @@ void kernel_main() { const uint32_t eth_sender_noc_x = get_arg_val(arg_idx++); const uint32_t eth_sender_noc_y = get_arg_val(arg_idx++); const uint32_t num_buffers_per_edm_channel = get_arg_val(arg_idx++); - size_t edm_connection_handshake_addr = - get_semaphore(get_arg_val(arg_idx++)); + size_t edm_connection_handshake_id = get_arg_val(arg_idx++); size_t edm_worker_location_info_addr = get_arg_val(arg_idx++); size_t edm_buffer_size_bytes = get_arg_val(arg_idx++); size_t dest_addr = get_arg_val(arg_idx++); @@ -62,10 +62,12 @@ void kernel_main() { reinterpret_cast(get_semaphore(get_arg_val(arg_idx++))); *last_message_semaphore_address = 0; auto worker_buffer_index_semaphore_addr = get_semaphore(get_arg_val(arg_idx++)); + bool connected_to_persistent_fabric = get_arg_val(arg_idx++) != 0; + // TODO: move to semaphore auto edm_buffer_index_sem_id = get_arg_val(arg_idx++); ASSERT(edm_buffer_index_sem_id < 8); - auto edm_buffer_index_address = get_semaphore(edm_buffer_index_sem_id); + auto edm_buffer_index_id = edm_buffer_index_sem_id; ASSERT(worker_buffer_index_semaphore_addr != reinterpret_cast(writer_send_sem_addr)); ASSERT(worker_buffer_index_semaphore_addr != reinterpret_cast(last_message_semaphore_address)); @@ -77,20 +79,22 @@ void kernel_main() { config.unicast.distance = static_cast(get_arg_val(arg_idx++)); } - const InterleavedAddrGen dest_addr_gen = {.bank_base_address = dest_addr, .page_size = page_size}; + const InterleavedAddrGen dest_addr_gen = { + .bank_base_address = dest_addr, .page_size = page_size}; ASSERT(num_buffers_per_channel > 0); auto sender = tt::fabric::WorkerToFabricEdmSender( + connected_to_persistent_fabric, eth_sender_noc_x, eth_sender_noc_y, eth_l1_base_addr, num_buffers_per_channel, eth_sender_l1_sem_id, - edm_connection_handshake_addr, + edm_connection_handshake_id, edm_worker_location_info_addr, edm_buffer_size_bytes, - edm_buffer_index_address, + edm_buffer_index_id, writer_send_sem_addr, worker_buffer_index_semaphore_addr); @@ -154,10 +158,8 @@ void kernel_main() { auto& packet_header = *reinterpret_cast(a_packet_header_addr); ASSERT(*last_message_semaphore_address == 0); - packet_header.reserved = 0xE; - packet_header.reserved2 = 0xFFFF; packet_header.to_atomic_inc(); - packet_header.to_chip_unicast(tt::fabric::UnicastRoutingCommandHeader{1}); + packet_header.to_chip_unicast(tt::fabric::UnicastRoutingCommandHeader{2}); packet_header.to_noc_unicast_atomic_inc(tt::fabric::NocUnicastAtomicIncCommandHeader( reinterpret_cast(last_message_semaphore_address), 1, 32, my_x[0], my_y[0])); @@ -167,40 +169,9 @@ void kernel_main() { noc_semaphore_wait(last_message_semaphore_address, 1); } - bool closed = false; - size_t num_endpoints_to_terminate = get_arg_val(arg_idx++); - for (size_t i = 0; i < num_endpoints_to_terminate; i++) { - size_t edm_noc_x = get_arg_val(arg_idx++); - size_t edm_noc_y = get_arg_val(arg_idx++); - size_t distance = get_arg_val(arg_idx++); - size_t termination_addr = get_arg_val(arg_idx++); - - if (!closed && distance == 0) { - closed = true; - sender.close(); - } - if (distance == 0) { - noc_inline_dw_write( - get_noc_addr(edm_noc_x, edm_noc_y, termination_addr), - tt::fabric::TerminationSignal::GRACEFULLY_TERMINATE); - } else { - auto& packet_header = *reinterpret_cast(a_packet_header_addr); - reinterpret_cast(a_packet_header_addr)[sizeof(tt::fabric::PacketHeader) >> 2] = - tt::fabric::TerminationSignal::GRACEFULLY_TERMINATE; - sender.wait_for_empty_write_slot(); - packet_header.to_write() - .to_chip_unicast(tt::fabric::UnicastRoutingCommandHeader{static_cast(distance - 1)}) - .to_noc_unicast(tt::fabric::NocUnicastCommandHeader{ - termination_addr, - sizeof(tt::fabric::PacketHeader) + sizeof(uint32_t), - static_cast(edm_noc_x), - static_cast(edm_noc_y)}); - sender.send_payload_blocking_from_address( - a_packet_header_addr, packet_header.get_payload_size_including_header()); - noc_async_writes_flushed(); - } - } - if (!closed) { + bool closed_fabric_connection = terminate_fabric_endpoints_farthest_to_nearest(sender, a_packet_header_addr, arg_idx); + + if (!closed_fabric_connection) { sender.close(); } } diff --git a/tests/ttnn/unit_tests/gtests/ccl/kernels/fabric_worker_sender_multi_input.cpp b/tests/ttnn/unit_tests/gtests/ccl/kernels/fabric_worker_sender_multi_input.cpp new file mode 100644 index 00000000000..f699132dbca --- /dev/null +++ b/tests/ttnn/unit_tests/gtests/ccl/kernels/fabric_worker_sender_multi_input.cpp @@ -0,0 +1,218 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include "dataflow_api.h" +#include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/edm_fabric_worker_adapters.hpp" +#include "tests/ttnn/unit_tests/gtests/ccl/kernels/test_kernels.common.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/common/kernels/ccl_send_utils.hpp" + +struct unicast_mode { + uint8_t distance; +}; +struct mcast_mode { + uint8_t distance; + uint8_t range; +}; + +union transmit_config { + unicast_mode unicast; + mcast_mode mcast; +}; + +enum class ReadMode { + // Fully drain one input CB before advancing to the next + FULLY_ORDERED, + // Drain both inputs in at some specified ratio (e.g. if 5:3 then 3 pages from CB0 then 5 from CB1) + RATIOD_FORWARDING, + // Read pages from either CB as soon as they are available + ARBITRARILY_ORDERED +}; + +static constexpr size_t NORMALIZED_NOC_INDEX = 0; + +template +auto forward_to_fabric_from_cb( + size_t total_pages_to_send, + tt::fabric::WorkerToFabricEdmSender &sender, + uint32_t cb_id, + transmit_config const& config, + size_t page_size, + const AddrGen &dest_addr_gen, + size_t num_pages_per_send, + size_t current_page + ) { + + // for (uint32_t p = 0; p < total_pages_to_send; p += num_pages_per_send) { + uint32_t pages_to_send = std::min(num_pages_per_send, total_pages_to_send - current_page); + sender.wait_for_empty_write_slot(); + + // bit of a hack to extract X/Y + const auto dest_noc_address = get_noc_addr(current_page, dest_addr_gen, 0, NORMALIZED_NOC_INDEX); + const auto [dest_worker_noc, dest_addr] = get_noc_address_components(dest_noc_address); + const size_t packet_size = page_size + sizeof(tt::fabric::PacketHeader); + + auto packet_addr = get_read_ptr(cb_id); + auto &packet_header = *reinterpret_cast(packet_addr); + if constexpr (mcast_mode) { + packet_header.to_write() + .to_chip_multicast(tt::fabric::MulticastRoutingCommandHeader{config.mcast.distance, config.mcast.range}) + .to_noc_unicast(tt::fabric::NocUnicastCommandHeader{ + dest_addr, + (pages_to_send * page_size) + sizeof(tt::fabric::PacketHeader), + static_cast(dest_worker_noc.x), + static_cast(dest_worker_noc.y)}); + } else { + packet_header.to_write() + .to_chip_unicast(tt::fabric::UnicastRoutingCommandHeader{config.unicast.distance}) + .to_noc_unicast(tt::fabric::NocUnicastCommandHeader{ + dest_addr, + (pages_to_send * page_size) + sizeof(tt::fabric::PacketHeader), + static_cast(dest_worker_noc.x), + static_cast(dest_worker_noc.y)}); + } + + uint64_t buffer_address = sender.edm_buffer_addr + (*sender.buffer_index_ptr * (sender.buffer_size_bytes + sizeof(eth_channel_sync_t))); + sender.send_payload_blocking_from_address(packet_addr, packet_size); + noc_async_writes_flushed(); + // } +} + +template +void non_blocking_read_and_forward(size_t ¤t_page_in, uint32_t cb_id, const AddrGen &dest_addr_gen, tt::fabric::WorkerToFabricEdmSender &sender, transmit_config const& config, uint32_t page_size, uint32_t total_pages_to_send, uint32_t num_pages_per_send) { + uint32_t pages_to_send = std::min(num_pages_per_send, total_pages_to_send - current_page_in); + if (!cb_pages_available_at_front(cb_id, pages_to_send)) { + return; + } + + current_page_in += num_pages_per_send; + cb_wait_front(cb_id, pages_to_send); + forward_to_fabric_from_cb( + total_pages_to_send, + sender, + cb_id, + config, + page_size, + dest_addr_gen, + num_pages_per_send, + current_page_in + ); + cb_pop_front(cb_id, pages_to_send); +} + +// Worker core - Data Movement Writer -> Sends to Erisc Data Mover (sender side). +// -> takes input from local cb and pushes to erisc L1 +void kernel_main() { + + // Test doesn't support multiple pages per send yet since we are writing + // to interleaved which will never have subsequent pages on the same core + // (and hence, able to share a packet header) + constexpr uint32_t num_pages_per_send = 1;//get_compile_time_arg_val(0); + constexpr uint32_t total_pages_to_send = get_compile_time_arg_val(1); + constexpr uint32_t page_size = get_compile_time_arg_val(2); + constexpr uint32_t num_buffers_per_channel = get_compile_time_arg_val(3); + constexpr bool dest0_is_dram = get_compile_time_arg_val(4) != 0; + constexpr bool dest1_is_dram = get_compile_time_arg_val(5) != 0; + constexpr ReadMode read_mode = static_cast(get_compile_time_arg_val(6)); + + transmit_config config; + size_t arg_idx = 0; + auto sender = tt::fabric::WorkerToFabricEdmSender::build_from_args(arg_idx); + volatile uint32_t* const last_message_semaphore_address = reinterpret_cast(get_semaphore(get_arg_val(arg_idx++))); + size_t output_buffer0_addr = get_arg_val(arg_idx++); + size_t output_buffer1_addr = get_arg_val(arg_idx++); + config.unicast.distance = static_cast(get_arg_val(arg_idx++)); + + size_t read_ratio0 = (read_mode == ReadMode::ARBITRARILY_ORDERED) ? 0 : + (read_mode == ReadMode::FULLY_ORDERED) ? total_pages_to_send : + get_arg_val(arg_idx++); + size_t read_ratio1 = (read_mode == ReadMode::ARBITRARILY_ORDERED) ? 0 : + (read_mode == ReadMode::FULLY_ORDERED) ? total_pages_to_send : + get_arg_val(arg_idx++); + + + *last_message_semaphore_address = 0; + const InterleavedAddrGen dest_addr_gen0 = { + .bank_base_address = output_buffer0_addr, .page_size = page_size}; + const InterleavedAddrGen dest_addr_gen1 = { + .bank_base_address = output_buffer1_addr, .page_size = page_size}; + + ASSERT(num_buffers_per_channel > 0); + + sender.open(); + + constexpr uint32_t cb_id_in0 = tt::CB::c_in0; + constexpr uint32_t cb_id_in1 = tt::CB::c_in0; + + // We need to normalize all noc addresses to be for a consistent noc ID + // so the remote sender core can correctly send the packet. In the future + // we can decide if it's better for the noc index to be embedded in the packet + // header (for now we don't do that) + constexpr size_t NORMALIZED_NOC_INDEX = 0; + + cb_wait_front(cb_id_in0, 1); + auto a_packet_header_addr = get_read_ptr(cb_id_in0); + + if constexpr (read_mode == ReadMode::FULLY_ORDERED || read_mode == ReadMode::RATIOD_FORWARDING) { + + size_t current_page_in0 = 0; + size_t current_page_in1 = 0; + while (current_page_in0 < total_pages_to_send || current_page_in1 < total_pages_to_send) { + for (size_t read = 0; read < read_ratio0 && current_page_in0 < total_pages_to_send; current_page_in0 += num_pages_per_send, read++) { + uint32_t pages_to_send = std::min(num_pages_per_send, total_pages_to_send - current_page_in0); + cb_wait_front(cb_id_in0, pages_to_send); + non_blocking_read_and_forward(current_page_in0, cb_id_in0, dest_addr_gen0, sender, config, page_size, total_pages_to_send, num_pages_per_send); + cb_pop_front(cb_id_in0, pages_to_send); + } + + for (size_t read = 0; read < read_ratio1 && current_page_in1 < total_pages_to_send; current_page_in1 += num_pages_per_send, read++) { + uint32_t pages_to_send = std::min(num_pages_per_send, total_pages_to_send - current_page_in1); + cb_wait_front(cb_id_in1, pages_to_send); + non_blocking_read_and_forward(current_page_in1, cb_id_in1, dest_addr_gen1, sender, config, page_size, total_pages_to_send, num_pages_per_send); + cb_pop_front(cb_id_in1, pages_to_send); + } + } + + } else if constexpr (read_mode == ReadMode::ARBITRARILY_ORDERED) { + size_t current_page_in0 = 0; + size_t current_page_in1 = 0; + while (current_page_in0 < total_pages_to_send || current_page_in1 < total_pages_to_send) { + if (current_page_in0 < total_pages_to_send) { + non_blocking_read_and_forward(current_page_in0, cb_id_in0, dest_addr_gen0, sender, config, page_size, total_pages_to_send, num_pages_per_send); + } + if (current_page_in1 < total_pages_to_send) { + non_blocking_read_and_forward(current_page_in1, cb_id_in1, dest_addr_gen1, sender, config, page_size, total_pages_to_send, num_pages_per_send); + } + } + } + + sender.wait_for_empty_write_slot(); + + constexpr size_t kLoopbackNumHopsToMyChip = 2; + auto &packet_header = *reinterpret_cast(a_packet_header_addr); + ASSERT(*last_message_semaphore_address == 0); + packet_header.reserved = 0xE; + packet_header.reserved2 = 0xFFFF; + packet_header.to_atomic_inc(); + packet_header.to_chip_unicast(tt::fabric::UnicastRoutingCommandHeader{kLoopbackNumHopsToMyChip}); + packet_header.to_noc_unicast_atomic_inc(tt::fabric::NocUnicastAtomicIncCommandHeader( + reinterpret_cast(last_message_semaphore_address), + 1, + 32, + my_x[0], + my_y[0] + )); + + sender.send_payload_blocking_from_address(a_packet_header_addr, packet_header.get_payload_size_including_header()); + + noc_semaphore_wait(last_message_semaphore_address, 1); + + bool closed_fabric_connection = terminate_fabric_endpoints_farthest_to_nearest(sender, a_packet_header_addr, arg_idx); + + if (!closed_fabric_connection) { + sender.close(); + } +} diff --git a/tests/ttnn/unit_tests/gtests/ccl/kernels/test_kernels.common.hpp b/tests/ttnn/unit_tests/gtests/ccl/kernels/test_kernels.common.hpp new file mode 100644 index 00000000000..53c102f6098 --- /dev/null +++ b/tests/ttnn/unit_tests/gtests/ccl/kernels/test_kernels.common.hpp @@ -0,0 +1,46 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/edm_fabric_worker_adapters.hpp" + +bool terminate_fabric_endpoints_farthest_to_nearest ( + tt::fabric::WorkerToFabricEdmSender &sender, + size_t a_packet_header_addr, + size_t arg_idx) { + + bool closed = false; + size_t num_endpoints_to_terminate = get_arg_val(arg_idx++); + for (size_t i = 0; i < num_endpoints_to_terminate; i++) { + size_t edm_noc_x = get_arg_val(arg_idx++); + size_t edm_noc_y = get_arg_val(arg_idx++); + size_t distance = get_arg_val(arg_idx++); + size_t termination_addr = get_arg_val(arg_idx++); + + if (!closed && distance == 0) { + closed = true; + sender.close(); + } + if (distance == 0) { + noc_inline_dw_write(get_noc_addr(edm_noc_x, edm_noc_y, termination_addr), tt::fabric::TerminationSignal::GRACEFULLY_TERMINATE); + } else { + auto &packet_header = *reinterpret_cast(a_packet_header_addr); + reinterpret_cast(a_packet_header_addr)[sizeof(tt::fabric::PacketHeader) >> 2] = tt::fabric::TerminationSignal::GRACEFULLY_TERMINATE; + sender.wait_for_empty_write_slot(); + packet_header.to_write() + .to_chip_unicast(tt::fabric::UnicastRoutingCommandHeader{static_cast(distance)}) + .to_noc_unicast(tt::fabric::NocUnicastCommandHeader{ + termination_addr, + sizeof(tt::fabric::PacketHeader) + sizeof(uint32_t), + static_cast(edm_noc_x), + static_cast(edm_noc_y) + }); + sender.send_payload_blocking_from_address(a_packet_header_addr, packet_header.get_payload_size_including_header()); + noc_async_writes_flushed(); + } + } + + return closed; +} diff --git a/tests/ttnn/unit_tests/gtests/ccl/test_ccl_reduce_scatter_host_helpers.cpp b/tests/ttnn/unit_tests/gtests/ccl/test_ccl_reduce_scatter_host_helpers.cpp index a940a6a5426..0e105e9a777 100644 --- a/tests/ttnn/unit_tests/gtests/ccl/test_ccl_reduce_scatter_host_helpers.cpp +++ b/tests/ttnn/unit_tests/gtests/ccl/test_ccl_reduce_scatter_host_helpers.cpp @@ -10,6 +10,8 @@ #include "ttnn/cpp/ttnn/operations/ccl/common/uops/ccl_command.hpp" #include "ttnn/cpp/ttnn/operations/ccl/common/types/ccl_types.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/common/host/ccl_worker_builder.hpp" + #include #include @@ -38,7 +40,7 @@ TEST(LineReduceScatter, EmitCclSendSliceSequenceCommands_8Slices_1x1x32x2048Tens std::vector args; ASSERT_EQ(slices.size(), 8); - ttnn::ccl::reduce_scatter_detail::emit_ccl_send_slice_sequence_commands(slices, args); + ttnn::ccl::worker_detail::emit_ccl_send_slice_sequence_commands(slices, args); const std::size_t args_per_command_header = 1; const std::size_t args_per_command_arg_header = 1; diff --git a/tests/ttnn/unit_tests/gtests/ccl/test_ccl_tensor_slicers.cpp b/tests/ttnn/unit_tests/gtests/ccl/test_ccl_tensor_slicers.cpp index f89f66fe646..3f5628fb75e 100644 --- a/tests/ttnn/unit_tests/gtests/ccl/test_ccl_tensor_slicers.cpp +++ b/tests/ttnn/unit_tests/gtests/ccl/test_ccl_tensor_slicers.cpp @@ -2,640 +2,186 @@ // // SPDX-License-Identifier: Apache-2.0 -#include -#include "gtest/gtest.h" -#include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/sharded_tensor_addr_gen.hpp" - -static constexpr std::array worker_to_routing_x_wormhole = {1, 2, 3, 4, 6, 7, 8, 9}; - -static constexpr std::array worker_to_routing_y_wormhole = {1, 2, 3, 4, 5, 7, 8, 9, 10, 11}; - -namespace tt { -namespace tt_metal { - -struct UnharvestedWormholeWorkerToNocLookup - : address_generators::WorkerToNocCoordLookup { - static constexpr std::array worker_to_routing_x = {1, 2, 3, 4, 6, 7, 8, 9}; - static constexpr std::array worker_to_routing_y = {1, 2, 3, 4, 5, 7, 8, 9, 10, 11}; - - noc_grid_index_t get_noc_x_from_worker_x(noc_grid_index_t worker_x) const { - // ASSERT worker_x < worker_to_routing_x_wormhole.size() - return worker_to_routing_x[worker_x]; - } - - noc_grid_index_t get_noc_y_from_worker_y(noc_grid_index_t worker_y) const { - // ASSERT worker_y < worker_to_routing_y_wormhole.size() - return worker_to_routing_y[worker_y]; - } -}; - -static void run_width_sharded_tensor_slice_indexer_get_page_location_test( - address_generators::WidthShardedAddressGenerator< - UnharvestedWormholeWorkerToNocLookup, - address_generators::DeviceWidthShardSpec>& addrgen, - - std::size_t pages_per_shard_y, - std::size_t pages_per_shard_x, - - std::size_t shard_grid_height, - std::size_t shard_grid_width, - - std::size_t worker_shard_cores_start_y, - std::size_t worker_shard_cores_start_x, - bool is_shard_grid_transposed) { - std::size_t page_id = 0; - // Takes a long time to sweep really large numbers so instead stride through the range. - // Really the only reason to test larger numbers is to catch overflow issues with smaller - // number types that may be carried around in the addrgen structs - std::size_t py_increment = pages_per_shard_y > 32 ? 7 : 1; - std::size_t px_increment = pages_per_shard_x > 32 ? 7 : 1; - - if (!is_shard_grid_transposed) { - for (std::size_t py = 0; py < pages_per_shard_y; py++) { - for (std::size_t y_logical = worker_shard_cores_start_y; - y_logical < worker_shard_cores_start_y + shard_grid_height; - y_logical++) { - for (std::size_t x_logical = worker_shard_cores_start_x; - x_logical < worker_shard_cores_start_x + shard_grid_width; - x_logical++) { - for (std::size_t px = 0; px < pages_per_shard_x; px++) { - if (px % px_increment == 0 && py % py_increment == 0 || - (py == (pages_per_shard_y - 1) || px != (pages_per_shard_x - 1))) { - auto const& result = addrgen.get_page_location(page_id); - ASSERT_EQ(result.core_location.noc_x, worker_to_routing_x_wormhole.at(x_logical)); - ASSERT_EQ(result.core_location.noc_y, worker_to_routing_y_wormhole.at(y_logical)); - ASSERT_EQ(result.page_offset, px + (py * pages_per_shard_x)); - - auto const& result2 = - addrgen.get_page_location_with_contiguous_pages_in_row_in_bank(page_id); - ASSERT_EQ(result2.core_location.noc_x, result.core_location.noc_x); - ASSERT_EQ(result2.core_location.noc_y, result.core_location.noc_y); - ASSERT_EQ(result2.page_offset, result.page_offset); - ASSERT_EQ(result2.contig_pages_in_row, pages_per_shard_x - px); - } - - page_id++; - } - } - } - } - } else { - for (std::size_t py = 0; py < pages_per_shard_y; py++) { - for (std::size_t x_logical = worker_shard_cores_start_x; - x_logical < worker_shard_cores_start_x + shard_grid_width; - x_logical++) { - for (std::size_t y_logical = worker_shard_cores_start_y; - y_logical < worker_shard_cores_start_y + shard_grid_height; - y_logical++) { - for (std::size_t px = 0; px < pages_per_shard_x; px++) { - if (px % px_increment == 0 && py % py_increment == 0 || - (py == (pages_per_shard_y - 1) || px != (pages_per_shard_x - 1))) { - auto const& result = addrgen.get_page_location(page_id); - ASSERT_EQ(result.core_location.noc_x, worker_to_routing_x_wormhole.at(x_logical)); - ASSERT_EQ(result.core_location.noc_y, worker_to_routing_y_wormhole.at(y_logical)); - ASSERT_EQ(result.page_offset, px + (py * pages_per_shard_x)); - - auto const& result2 = - addrgen.get_page_location_with_contiguous_pages_in_row_in_bank(page_id); - ASSERT_EQ(result2.core_location.noc_x, result.core_location.noc_x); - ASSERT_EQ(result2.core_location.noc_y, result.core_location.noc_y); - ASSERT_EQ(result2.page_offset, result.page_offset); - ASSERT_EQ(result2.contig_pages_in_row, pages_per_shard_x - px); - } - page_id++; - } - } - } - } - } -} - -static void run_width_sharded_tensor_slice_indexer_get_page_location_test( - std::size_t pages_per_shard_y, - std::size_t pages_per_shard_x, - - std::size_t shard_grid_height, - std::size_t shard_grid_width, - - std::size_t worker_shard_cores_start_y, - std::size_t worker_shard_cores_start_x, - - bool is_shard_grid_transposed) { - const std::size_t global_num_pages = pages_per_shard_y * pages_per_shard_x * shard_grid_width * shard_grid_height; - - auto addrgen = address_generators:: - WidthShardedAddressGenerator( - UnharvestedWormholeWorkerToNocLookup(), - address_generators::DeviceShardSpecTypeGetter::type( - pages_per_shard_y, - pages_per_shard_x, - shard_grid_height, - shard_grid_width, - worker_shard_cores_start_y, - worker_shard_cores_start_x, - is_shard_grid_transposed), - 1024, - 0x0); - - run_width_sharded_tensor_slice_indexer_get_page_location_test( - addrgen, - pages_per_shard_y, - pages_per_shard_x, - - shard_grid_height, - shard_grid_width, - - worker_shard_cores_start_y, - worker_shard_cores_start_x, - - is_shard_grid_transposed); -} - -TEST(CclWidthShardedTensorSliceIndexer_Wormhole, basic_test_case) { - static constexpr std::size_t pages_per_shard_y = 1; - static constexpr std::size_t pages_per_shard_x = 8; - - static constexpr std::size_t shard_grid_height = 2; - static constexpr std::size_t shard_grid_width = 1; - - static constexpr std::size_t worker_shard_cores_start_y = 0; - static constexpr std::size_t worker_shard_cores_start_x = 0; - - bool is_shard_grid_transposed = false; - - run_width_sharded_tensor_slice_indexer_get_page_location_test( - pages_per_shard_y, - pages_per_shard_x, - - shard_grid_height, - shard_grid_width, - - worker_shard_cores_start_y, - worker_shard_cores_start_x, - - is_shard_grid_transposed); -} - -TEST(CclWidthShardedTensorSliceIndexer_Wormhole, SweepWormhole) { - std::size_t max_worker_rows = 10; - std::size_t max_worker_cols = 8; - - for (auto pages_per_shard_y : {1, 2, 5, 8, 256}) { - for (auto pages_per_shard_x : {1, 2, 5, 8, 256}) { - for (auto shard_grid_offset_logical_y : {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}) { - for (auto shard_grid_offset_logical_x : {0, 1, 2, 3, 4, 5, 6, 7}) { - for (std::size_t shard_grid_height = 1; - shard_grid_height < (max_worker_rows - shard_grid_offset_logical_y); - shard_grid_height++) { - for (std::size_t shard_grid_width = 1; - shard_grid_width < (max_worker_cols - shard_grid_offset_logical_x); - shard_grid_width++) { - for (bool transpose_shard_grid : {false, true}) { - run_width_sharded_tensor_slice_indexer_get_page_location_test( - pages_per_shard_y, - pages_per_shard_x, - shard_grid_height, - shard_grid_width, - shard_grid_offset_logical_y, - shard_grid_offset_logical_x, - transpose_shard_grid); - } - } - } - } - } - } - } -} - -static void run_height_sharded_tensor_slice_indexer_get_page_location_test( - address_generators::HeightShardedAddressGenerator< - UnharvestedWormholeWorkerToNocLookup, - address_generators::DeviceHeightShardSpec>& addrgen, - std::size_t pages_per_shard_y, - std::size_t pages_per_shard_x, - - std::size_t shard_grid_height, - std::size_t shard_grid_width, - - std::size_t worker_shard_cores_start_y, - std::size_t worker_shard_cores_start_x, - - bool is_shard_grid_transposed) { - std::size_t page_id = 0; - - // Takes a long time to sweep really large numbers so instead stride through the range. - // Really the only reason to test larger numbers is to catch overflow issues with smaller - // number types that may be carried around in the addrgen structs - std::size_t py_increment = pages_per_shard_y > 32 ? 7 : 1; - std::size_t px_increment = pages_per_shard_x > 32 ? 7 : 1; - - if (!is_shard_grid_transposed) { - for (std::size_t x_logical = worker_shard_cores_start_x; - x_logical < worker_shard_cores_start_x + shard_grid_width; - x_logical++) { - for (std::size_t y_logical = worker_shard_cores_start_y; - y_logical < worker_shard_cores_start_y + shard_grid_height; - y_logical++) { - for (std::size_t py = 0; py < pages_per_shard_y; py++) { - for (std::size_t px = 0; px < pages_per_shard_x; px++) { - if (px % px_increment == 0 && py % py_increment == 0 || - (py == (pages_per_shard_y - 1) || px != (pages_per_shard_x - 1))) { - auto const& result = addrgen.get_page_location(page_id); - ASSERT_EQ(result.core_location.noc_x, worker_to_routing_x_wormhole.at(x_logical)); - ASSERT_EQ(result.core_location.noc_y, worker_to_routing_y_wormhole.at(y_logical)); - ASSERT_EQ(result.page_offset, px + (py * pages_per_shard_x)); - } - - page_id++; - } - } - } - } - } else { - for (std::size_t y_logical = worker_shard_cores_start_y; - y_logical < worker_shard_cores_start_y + shard_grid_height; - y_logical++) { - for (std::size_t x_logical = worker_shard_cores_start_x; - x_logical < worker_shard_cores_start_x + shard_grid_width; - x_logical++) { - for (std::size_t py = 0; py < pages_per_shard_y; py++) { - for (std::size_t px = 0; px < pages_per_shard_x; px++) { - if (px % px_increment == 0 && py % py_increment == 0 || - (py == (pages_per_shard_y - 1) || px != (pages_per_shard_x - 1))) { - auto const& result = addrgen.get_page_location(page_id); - ASSERT_EQ(result.core_location.noc_x, worker_to_routing_x_wormhole.at(x_logical)); - ASSERT_EQ(result.core_location.noc_y, worker_to_routing_y_wormhole.at(y_logical)); - ASSERT_EQ(result.page_offset, px + (py * pages_per_shard_x)); - } - page_id++; - } - } - } - } - } -} +#include "ttnn/cpp/ttnn/operations/ccl/common/host/ccl_command_stream_builders.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/common/uops/ccl_command.hpp" -static void run_height_sharded_tensor_slice_indexer_get_page_location_test( - std::size_t pages_per_shard_y, - std::size_t pages_per_shard_x, +#include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" - std::size_t shard_grid_height, - std::size_t shard_grid_width, - - std::size_t worker_shard_cores_start_y, - std::size_t worker_shard_cores_start_x, - - bool is_shard_grid_transposed) { - const std::size_t global_num_pages = pages_per_shard_y * pages_per_shard_x * shard_grid_width * shard_grid_height; - - auto addrgen = address_generators:: - HeightShardedAddressGenerator( - UnharvestedWormholeWorkerToNocLookup(), - address_generators::DeviceShardSpecTypeGetter::type( - pages_per_shard_y, - pages_per_shard_x, - shard_grid_height, - shard_grid_width, - worker_shard_cores_start_y, - worker_shard_cores_start_x, - is_shard_grid_transposed), - 1024, - 0x0); +#include "gtest/gtest.h" - run_height_sharded_tensor_slice_indexer_get_page_location_test( - addrgen, - pages_per_shard_y, - pages_per_shard_x, - shard_grid_height, - shard_grid_width, - worker_shard_cores_start_y, - worker_shard_cores_start_x, +#include - is_shard_grid_transposed); +TEST( + CclTensorSlicer_SliceWorkerSplitting, + page_based_1worker_TensorShape_1_1_4_1__SliceShape_1_1_1_1__SliceOffset_0_0_3_0__Workers1) { + const auto worker_slices = ttnn::ccl::cmd::builder::split_tensor_slice_across_workers_wrapped_page_aligned( + ttnn::ccl::v2::TensorSlice{ + {1, 1, 4, 1}, // tensor_shape + {1, 1, 1, 1}, // tensor slice shape + {0, 0, 3, 0}, // tensor slice offset + {1, 1, 1, 1}, + {0, 0, 0, 0}}, + 1); + + ASSERT_EQ(worker_slices.size(), 1); + ASSERT_EQ(worker_slices[0].tensor_slice_shape.w, 1); + ASSERT_EQ(worker_slices[0].tensor_slice_shape.z, 1); + ASSERT_EQ(worker_slices[0].tensor_slice_shape.y, 1); + ASSERT_EQ(worker_slices[0].tensor_slice_shape.x, 1); + + ASSERT_EQ(worker_slices[0].tensor_slice_offset.w, 0); + ASSERT_EQ(worker_slices[0].tensor_slice_offset.z, 0); + ASSERT_EQ(worker_slices[0].tensor_slice_offset.y, 3); + ASSERT_EQ(worker_slices[0].tensor_slice_offset.x, 0); + + ASSERT_EQ(worker_slices[0].worker_slice_shape.w, 1); + ASSERT_EQ(worker_slices[0].worker_slice_shape.z, 1); + ASSERT_EQ(worker_slices[0].worker_slice_shape.y, 1); + ASSERT_EQ(worker_slices[0].worker_slice_shape.x, 1); + + ASSERT_EQ(worker_slices[0].worker_slice_offset.w, 0); + ASSERT_EQ(worker_slices[0].worker_slice_offset.z, 0); + ASSERT_EQ(worker_slices[0].worker_slice_offset.y, 0); + ASSERT_EQ(worker_slices[0].worker_slice_offset.x, 0); } -TEST(CclHeightShardedTensorSliceIndexer_Wormhole, basic_test_case) { - static constexpr std::size_t pages_per_shard_y = 8; - static constexpr std::size_t pages_per_shard_x = 1; - - static constexpr std::size_t shard_grid_height = 1; - static constexpr std::size_t shard_grid_width = 2; - - static constexpr std::size_t worker_shard_cores_start_y = 0; - static constexpr std::size_t worker_shard_cores_start_x = 0; - - bool is_shard_grid_transposed = false; - - run_height_sharded_tensor_slice_indexer_get_page_location_test( - pages_per_shard_y, - pages_per_shard_x, - - shard_grid_height, - shard_grid_width, - - worker_shard_cores_start_y, - worker_shard_cores_start_x, - - is_shard_grid_transposed); +TEST( + CclTensorSlicer_SliceWorkerSplitting, + page_based_1worker_TensorShape_1_1_4_1__SliceShape_1_1_1_1__SliceOffset_0_0_0_0__Workers1) { + const auto worker_slices = ttnn::ccl::cmd::builder::split_tensor_slice_across_workers_wrapped_page_aligned( + ttnn::ccl::v2::TensorSlice{ + {1, 1, 4, 1}, // tensor_shape + {1, 1, 1, 1}, // tensor slice shape + {0, 0, 0, 0}, // tensor slice offset + {1, 1, 4, 1}, + {0, 0, 0, 0}}, + 1); + + ASSERT_EQ(worker_slices.size(), 1); + ASSERT_EQ(worker_slices[0].tensor_slice_shape.w, 1); + ASSERT_EQ(worker_slices[0].tensor_slice_shape.z, 1); + ASSERT_EQ(worker_slices[0].tensor_slice_shape.y, 1); + ASSERT_EQ(worker_slices[0].tensor_slice_shape.x, 1); + + ASSERT_EQ(worker_slices[0].tensor_slice_offset.w, 0); + ASSERT_EQ(worker_slices[0].tensor_slice_offset.z, 0); + ASSERT_EQ(worker_slices[0].tensor_slice_offset.y, 0); + ASSERT_EQ(worker_slices[0].tensor_slice_offset.x, 0); + + ASSERT_EQ(worker_slices[0].worker_slice_shape.w, 1); + ASSERT_EQ(worker_slices[0].worker_slice_shape.z, 1); + ASSERT_EQ(worker_slices[0].worker_slice_shape.y, 1); + ASSERT_EQ(worker_slices[0].worker_slice_shape.x, 1); + + ASSERT_EQ(worker_slices[0].worker_slice_offset.w, 0); + ASSERT_EQ(worker_slices[0].worker_slice_offset.z, 0); + ASSERT_EQ(worker_slices[0].worker_slice_offset.y, 0); + ASSERT_EQ(worker_slices[0].worker_slice_offset.x, 0); } -TEST(CclHeightShardedTensorSliceIndexer_Wormhole, SweepWormhole) { - std::size_t max_worker_rows = 10; - std::size_t max_worker_cols = 8; - - for (auto pages_per_shard_y : {1, 2, 5, 8, 256}) { - for (auto pages_per_shard_x : {1, 2, 5, 8, 256}) { - for (auto shard_grid_offset_logical_y : {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}) { - for (auto shard_grid_offset_logical_x : {0, 1, 2, 3, 4, 5, 6, 7}) { - for (std::size_t shard_grid_height = 1; - shard_grid_height < (max_worker_rows - shard_grid_offset_logical_y); - shard_grid_height++) { - for (std::size_t shard_grid_width = 1; - shard_grid_width < (max_worker_cols - shard_grid_offset_logical_x); - shard_grid_width++) { - for (bool transpose_shard_grid : {false, true}) { - run_height_sharded_tensor_slice_indexer_get_page_location_test( - pages_per_shard_y, - pages_per_shard_x, - shard_grid_height, - shard_grid_width, - shard_grid_offset_logical_y, - shard_grid_offset_logical_x, - transpose_shard_grid); - } - } - } - } - } - } - } +static size_t get_flat_index_from_shape( + const ttnn::ccl::Shape4D& shape, const ttnn::ccl::Shape4D& index) { + std::size_t offset = index.x; + std::size_t inner_volume = shape.x; + offset += index.y * inner_volume; + inner_volume *= shape.y; + offset += index.z * inner_volume; + inner_volume *= shape.z; + offset += index.w * inner_volume; + return offset; } -static void run_block_sharded_tensor_slice_indexer_get_page_location_test( - address_generators::BlockShardedAddressGenerator< - UnharvestedWormholeWorkerToNocLookup, - address_generators::DeviceBlockShardSpec>& addrgen, - std::size_t pages_per_shard_y, - std::size_t pages_per_shard_x, - - std::size_t shard_grid_height, - std::size_t shard_grid_width, - - std::size_t worker_shard_cores_start_y, - std::size_t worker_shard_cores_start_x, - - bool is_shard_grid_transposed) { - std::size_t page_id = 0; - - // Takes a long time to sweep really large numbers so instead stride through the range. - // Really the only reason to test larger numbers is to catch overflow issues with smaller - // number types that may be carried around in the addrgen structs - std::size_t py_increment = pages_per_shard_y > 32 ? 7 : 1; - std::size_t px_increment = pages_per_shard_x > 32 ? 7 : 1; - - if (!is_shard_grid_transposed) { - for (std::size_t y_logical = worker_shard_cores_start_y; - y_logical < worker_shard_cores_start_y + shard_grid_height; - y_logical++) { - for (std::size_t py = 0; py < pages_per_shard_y; py++) { - for (std::size_t x_logical = worker_shard_cores_start_x; - x_logical < worker_shard_cores_start_x + shard_grid_width; - x_logical++) { - for (std::size_t px = 0; px < pages_per_shard_x; px++) { - if (px % px_increment == 0 && py % py_increment == 0 || - (py == (pages_per_shard_y - 1) || px != (pages_per_shard_x - 1))) { - auto const& result = addrgen.get_page_location(page_id); - ASSERT_EQ(result.core_location.noc_x, worker_to_routing_x_wormhole.at(x_logical)); - ASSERT_EQ(result.core_location.noc_y, worker_to_routing_y_wormhole.at(y_logical)); - ASSERT_EQ(result.page_offset, px + (py * pages_per_shard_x)); - - auto const& result2 = - addrgen.get_page_location_with_contiguous_pages_in_row_in_bank(page_id); - ASSERT_EQ(result2.core_location.noc_x, result.core_location.noc_x); - ASSERT_EQ(result2.core_location.noc_y, result.core_location.noc_y); - ASSERT_EQ(result2.page_offset, result.page_offset); - ASSERT_EQ(result2.contig_pages_in_row, pages_per_shard_x - px); - } - - page_id++; - } - } - } - } - } else { - ASSERT_EQ(true, false); //"Transposed grid not supported in testing yet" - } -} - -static void run_block_sharded_tensor_slice_indexer_get_page_location_test( - std::size_t pages_per_shard_y, - std::size_t pages_per_shard_x, - - std::size_t shard_grid_height, - std::size_t shard_grid_width, - - std::size_t worker_shard_cores_start_y, - std::size_t worker_shard_cores_start_x, - - bool is_shard_grid_transposed) { - const std::size_t global_num_pages = pages_per_shard_y * pages_per_shard_x * shard_grid_width * shard_grid_height; - - auto addrgen = address_generators:: - BlockShardedAddressGenerator( - UnharvestedWormholeWorkerToNocLookup(), - address_generators::DeviceShardSpecTypeGetter::type( - pages_per_shard_y, - pages_per_shard_x, - shard_grid_height, - shard_grid_width, - worker_shard_cores_start_y, - worker_shard_cores_start_x, - is_shard_grid_transposed), - 1024, - 0x0); - - run_block_sharded_tensor_slice_indexer_get_page_location_test( - addrgen, - pages_per_shard_y, - pages_per_shard_x, - shard_grid_height, - shard_grid_width, - worker_shard_cores_start_y, - worker_shard_cores_start_x, - is_shard_grid_transposed); -} - -TEST(CclBlockShardedTensorSliceIndexer_Wormhole, basic_test_case) { - static constexpr std::size_t pages_per_shard_y = 8; - static constexpr std::size_t pages_per_shard_x = 2; - - static constexpr std::size_t shard_grid_height = 3; - static constexpr std::size_t shard_grid_width = 2; - - static constexpr std::size_t worker_shard_cores_start_y = 0; - static constexpr std::size_t worker_shard_cores_start_x = 0; - - bool is_shard_grid_transposed = false; - - run_block_sharded_tensor_slice_indexer_get_page_location_test( - pages_per_shard_y, - pages_per_shard_x, - - shard_grid_height, - shard_grid_width, - - worker_shard_cores_start_y, - worker_shard_cores_start_x, - - is_shard_grid_transposed); -} - -TEST(CclBlockShardedTensorSliceIndexer_Wormhole, SweepWormhole) { - std::size_t max_worker_rows = 10; - std::size_t max_worker_cols = 8; - - for (auto pages_per_shard_y : {1, 2, 5, 8, 256}) { - for (auto pages_per_shard_x : {1, 2, 5, 8, 256}) { - for (auto shard_grid_offset_logical_y : {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}) { - for (auto shard_grid_offset_logical_x : {0, 1, 2, 3, 4, 5, 6, 7}) { - for (std::size_t shard_grid_height = 1; - shard_grid_height < (max_worker_rows - shard_grid_offset_logical_y); - shard_grid_height++) { - for (std::size_t shard_grid_width = 1; - shard_grid_width < (max_worker_cols - shard_grid_offset_logical_x); - shard_grid_width++) { - for (bool transpose_shard_grid : - {false}) { // true: transpose mode not yet supported for block sharded indexer - run_block_sharded_tensor_slice_indexer_get_page_location_test( - pages_per_shard_y, - pages_per_shard_x, - shard_grid_height, - shard_grid_width, - shard_grid_offset_logical_y, - shard_grid_offset_logical_x, - transpose_shard_grid); +TEST( + TensorIterationSweep, + advance_worker_global_page__Shape_1_4_4_72__SliceShape_1_4_1_72__SliceOffset_0_0_3_0__WorkerStartPage_0__Stride_1) { + uint32_t stride = 1; + ttnn::ccl::Shape4D tensor_shape{1, 4, 4, 72}; + ttnn::ccl::Shape4D tensor_slice_shape{1, 4, 1, 72}; + const ttnn::ccl::Shape4D tensor_slice_offset_base{0, 0, 3, 0}; + ttnn::ccl::Shape4D tensor_slice_offset_current{0, 0, 3, 0}; + ttnn::ccl::Shape4D start_offset_worker_slice{0, 0, 0, 0}; + size_t worker_slice_volume = tensor_slice_shape.volume(); + uint32_t curr_page_idx = + get_flat_index_from_shape(tensor_shape, tensor_slice_offset_base + start_offset_worker_slice); + uint32_t offset_into_worker_slice = 0; + + tensor_slice_offset_current.w = tensor_slice_offset_base.w; + for (size_t w = 0; w < tensor_slice_shape.w; w++) { + bool last_w = w == tensor_slice_shape.w - 1; + tensor_slice_offset_current.z = tensor_slice_offset_base.z; + for (size_t z = 0; z < tensor_slice_shape.z; z++) { + bool last_z = z == tensor_slice_shape.z - 1; + tensor_slice_offset_current.y = tensor_slice_offset_base.y; + for (size_t y = 0; y < tensor_slice_shape.y; y++) { + bool last_y = y == tensor_slice_shape.y - 1; + tensor_slice_offset_current.x = tensor_slice_offset_base.x; + for (size_t x = 0; x < tensor_slice_shape.x; x++) { + bool last_x = x == tensor_slice_shape.x - 1; + bool end_of_worker_slice = ttnn::ccl::v2::advance_worker_global_page( + curr_page_idx, + offset_into_worker_slice, // local to the worker chunk + start_offset_worker_slice, // local to the tensor slice + + worker_slice_volume, // worker chunk shape + tensor_slice_shape, // tensor slice shape (per device) + tensor_slice_offset_base, + + tensor_shape, // full tensor shape + + stride); + if (tensor_slice_offset_current.x == (tensor_slice_offset_base.x + tensor_slice_shape.x - 1)) { + if (tensor_slice_offset_current.y == (tensor_slice_offset_base.y + tensor_slice_shape.y - 1)) { + if (tensor_slice_offset_current.z == + (tensor_slice_offset_base.z + tensor_slice_shape.z - 1)) { + tensor_slice_offset_current.w = + (tensor_slice_offset_current.w + 1) % tensor_slice_shape.w; + tensor_slice_offset_current.z = tensor_slice_offset_base.z; + } else { + tensor_slice_offset_current.z = tensor_slice_offset_current.z + 1; } + tensor_slice_offset_current.y = tensor_slice_offset_base.y; + } else { + tensor_slice_offset_current.y = tensor_slice_offset_current.y + 1; } + tensor_slice_offset_current.x = tensor_slice_offset_base.x; + } else { + tensor_slice_offset_current.x = tensor_slice_offset_current.x + 1; } + bool last_page = last_w && last_z && last_y && last_x; + ASSERT_TRUE( + last_page || + (curr_page_idx == get_flat_index_from_shape(tensor_shape, tensor_slice_offset_current))); } } } } } -TEST(CclShardedTensorAddrGenBuilder, TestBuildWidthSharded) { - static constexpr std::size_t pages_per_shard_y = 1; - static constexpr std::size_t pages_per_shard_x = 8; - - static constexpr std::size_t shard_grid_height = 2; - static constexpr std::size_t shard_grid_width = 1; - - static constexpr std::size_t worker_shard_cores_start_y = 0; - static constexpr std::size_t worker_shard_cores_start_x = 0; - - bool is_shard_grid_transposed = false; - auto addrgen = build_sharded_addr_gen( - UnharvestedWormholeWorkerToNocLookup(), - address_generators::DeviceShardSpecTypeGetter::type( - pages_per_shard_y, - pages_per_shard_x, - shard_grid_height, - shard_grid_width, - worker_shard_cores_start_y, - worker_shard_cores_start_x, - is_shard_grid_transposed), - 1024, - 0x0); - - run_width_sharded_tensor_slice_indexer_get_page_location_test( - addrgen, - pages_per_shard_y, - pages_per_shard_x, - - shard_grid_height, - shard_grid_width, - - worker_shard_cores_start_y, - worker_shard_cores_start_x, - - is_shard_grid_transposed); -} -TEST(CclShardedTensorAddrGenBuilder, TestBuildHeightSharded) { - static constexpr std::size_t pages_per_shard_y = 8; - static constexpr std::size_t pages_per_shard_x = 1; - - static constexpr std::size_t shard_grid_height = 1; - static constexpr std::size_t shard_grid_width = 2; - - static constexpr std::size_t worker_shard_cores_start_y = 0; - static constexpr std::size_t worker_shard_cores_start_x = 0; - - bool is_shard_grid_transposed = false; - auto addrgen = build_sharded_addr_gen( - UnharvestedWormholeWorkerToNocLookup(), - address_generators::DeviceShardSpecTypeGetter::type( - pages_per_shard_y, - pages_per_shard_x, - shard_grid_height, - shard_grid_width, - worker_shard_cores_start_y, - worker_shard_cores_start_x, - is_shard_grid_transposed), - 1024, - 0x0); - - run_height_sharded_tensor_slice_indexer_get_page_location_test( - addrgen, - pages_per_shard_y, - pages_per_shard_x, - - shard_grid_height, - shard_grid_width, - - worker_shard_cores_start_y, - worker_shard_cores_start_x, - - is_shard_grid_transposed); +TEST( + TensorIteration, + advance_worker_global_page__Shape_1_4_4_72__SliceShape_1_4_1_72__OffsetIntoWorkerSlice_71__CurrPageId_287_last_page_on_plane__Stride_1) { + uint32_t stride = 1; + uint32_t curr_page_idx = 287; + uint32_t offset_into_worker_slice = 71; + ttnn::ccl::Shape4D tensor_shape{1, 4, 4, 72}; + ttnn::ccl::Shape4D tensor_slice_shape{1, 4, 1, 72}; + ttnn::ccl::Shape4D tensor_slice_offset{0, 0, 3, 0}; + ttnn::ccl::Shape4D start_offset_worker_slice{0, 0, 0, 0}; + ttnn::ccl::Shape4D worker_slice_shape{1, 1, 1, 288}; + + auto old_page_id = curr_page_idx; + bool end_of_worker_slice = ttnn::ccl::v2::advance_worker_global_page( + curr_page_idx, + offset_into_worker_slice, // local to the worker chunk + start_offset_worker_slice, // local to the tensor slice + + worker_slice_shape.volume(), // worker chunk shape + tensor_slice_shape, // tensor slice shape (per device) + tensor_slice_offset, + + tensor_shape, // full tensor shape + + stride); + ASSERT_EQ(curr_page_idx, old_page_id + 1 + 72 * 3); + ASSERT_EQ(offset_into_worker_slice, 72); } -TEST(CclShardedTensorAddrGenBuilder, TestBuildBlockSharded) { - static constexpr std::size_t pages_per_shard_y = 8; - static constexpr std::size_t pages_per_shard_x = 2; - - static constexpr std::size_t shard_grid_height = 3; - static constexpr std::size_t shard_grid_width = 2; - - static constexpr std::size_t worker_shard_cores_start_y = 0; - static constexpr std::size_t worker_shard_cores_start_x = 0; - - bool is_shard_grid_transposed = false; - - auto addrgen = build_sharded_addr_gen( - UnharvestedWormholeWorkerToNocLookup(), - address_generators::DeviceShardSpecTypeGetter::type( - pages_per_shard_y, - pages_per_shard_x, - shard_grid_height, - shard_grid_width, - worker_shard_cores_start_y, - worker_shard_cores_start_x, - is_shard_grid_transposed), - 1024, - 0x0); - - run_block_sharded_tensor_slice_indexer_get_page_location_test( - addrgen, - pages_per_shard_y, - pages_per_shard_x, - - shard_grid_height, - shard_grid_width, - - worker_shard_cores_start_y, - worker_shard_cores_start_x, - - is_shard_grid_transposed); -} - -} // namespace tt_metal -} // namespace tt diff --git a/tests/ttnn/unit_tests/gtests/ccl/test_fabric_erisc_data_mover_loopback_with_workers.cpp b/tests/ttnn/unit_tests/gtests/ccl/test_fabric_erisc_data_mover_loopback_with_workers.cpp index db2910cb801..13e2166bae6 100644 --- a/tests/ttnn/unit_tests/gtests/ccl/test_fabric_erisc_data_mover_loopback_with_workers.cpp +++ b/tests/ttnn/unit_tests/gtests/ccl/test_fabric_erisc_data_mover_loopback_with_workers.cpp @@ -3,43 +3,76 @@ // // SPDX-License-Identifier: Apache-2.0 -#include -#include -#include -#include - -#include "umd/device/types/arch.h" -#include "gtest/gtest.h" -// #include "tt_backend_api_types.hpp" +#include "common/logger.hpp" +#include "sub_device/sub_device_types.hpp" #include "tt_metal/common/core_coord.hpp" -#include "tt_metal/common/math.hpp" #include "tt_metal/detail/tt_metal.hpp" #include "tt_metal/host_api.hpp" #include "tt_metal/impl/kernels/kernel.hpp" -#include "tt_metal/test_utils/comparison.hpp" #include "tt_metal/test_utils/df/df.hpp" #include "tt_metal/test_utils/env_vars.hpp" -#include "tt_metal/test_utils/print_helpers.hpp" -#include "tt_metal/test_utils/stimulus.hpp" +#include "ttnn/common/constants.hpp" #include "ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp" #include "ttnn/cpp/ttnn/operations/ccl/erisc_datamover_builder.hpp" #include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/common/host/ccl_worker_builder.hpp" +#include "ttnn/operations/ccl/common/uops/ccl_host_commands.hpp" +#include "ttnn/cpp/ttnn/operations/creation.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/common/uops/ccl_command.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/common/types/ccl_types_args_emitters.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/common/host/ccl_worker_builder.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/common/host/ccl_command_stream_builders.hpp" + +#include "tt_metal/distributed/mesh_device.hpp" +#include "tt_metal/distributed/mesh_device_view.hpp" + +#include "tt_metal/impl/tile/tile.hpp" + +#include "umd/device/types/arch.h" +#include "umd/device/types/cluster_descriptor_types.h" +#include "gtest/gtest.h" + +#include +#include +#include +#include +#include using namespace tt; using namespace tt::test_utils; using namespace tt::test_utils::df; +enum TwoInputReaderKernelWriteMode { LOCAL_WRITEBACK, FABRIC_UNICAST, FABRIC_MULTICAST }; + +static constexpr size_t TEST_WORKERS_SUBDEVICE_INDEX = 0; +static constexpr size_t TEST_EDM_FABRIC_SUBDEVICE_INDEX = 1; + +using subdevice_managers_t = std::unordered_map; +struct SubdeviceInfo { + std::unordered_map sub_device_managers; + std::unordered_map worker_subdevice_id; + std::unordered_map fabric_subdevice_id; +}; + +using tt::tt_metal::distributed::MeshDevice; +using tt::tt_metal::distributed::MeshDeviceConfig; +using tt::tt_metal::distributed::MeshDeviceView; +using tt::tt_metal::distributed::MeshShape; class T3000TestDevice { public: T3000TestDevice() : device_open(false) { + auto slow_dispatch = getenv("TT_METAL_SLOW_DISPATCH_MODE"); + if (slow_dispatch) { + TT_THROW("This suite can only be run without TT_METAL_SLOW_DISPATCH_MODE set"); + } arch_ = tt::get_arch_from_string(tt::test_utils::get_umd_arch_name()); num_devices_ = tt::tt_metal::GetNumAvailableDevices(); - if (arch_ == tt::ARCH::WORMHOLE_B0 and tt::tt_metal::GetNumAvailableDevices() >= 4 and - tt::tt_metal::GetNumPCIeDevices() >= 1) { + if (arch_ == tt::ARCH::WORMHOLE_B0 and num_devices_ == 8 and tt::tt_metal::GetNumPCIeDevices() == 4) { + mesh_device_ = MeshDevice::create(MeshDeviceConfig(MeshShape{2, 4})); + std::vector ids(num_devices_, 0); std::iota(ids.begin(), ids.end(), 0); - devices_ = tt::tt_metal::detail::CreateDevices(ids); } else { TT_THROW("This suite can only be run on T3000 Wormhole devices"); @@ -54,14 +87,12 @@ class T3000TestDevice { void TearDown() { device_open = false; - for (auto [device_id, device_ptr] : devices_) { - tt::tt_metal::CloseDevice(device_ptr); - } + mesh_device_->close_devices(); } - std::map devices_; tt::ARCH arch_; size_t num_devices_; + std::shared_ptr mesh_device_; private: bool device_open; @@ -71,9 +102,9 @@ struct BankedConfig { size_t num_pages; size_t size_bytes; size_t page_size_bytes; - BufferType input_buffer_type; // = BufferType::L1; - BufferType output_buffer_type; // = BufferType::L1; - tt::DataFormat l1_data_format; // = tt::DataFormat::Float16_b; + BufferType input_buffer_type; + BufferType output_buffer_type; + tt::DataFormat l1_data_format; }; struct KernelXY { @@ -86,44 +117,75 @@ struct KernelXY { enum Correctness { Correct, Incorrect }; struct EthLinkBuilder { - ttnn::ccl::FabricEriscDatamoverBuilder sender_edm_builder; // chip_0_edm_builder, - ttnn::ccl::FabricEriscDatamoverBuilder receiver_edm_builder; // chip_0_edm_builder, + ttnn::ccl::FabricEriscDatamoverBuilder sender_edm_builder; + ttnn::ccl::FabricEriscDatamoverBuilder receiver_edm_builder; tt_xy_pair sender_core; tt_xy_pair receiver_core; - // size_t downstream_edm_buffer_index_semaphore_id; }; -Correctness run_output_check( - std::vector const& all_zeros, - std::vector const& inputs, - const std::shared_ptr& output_buffer) { +template +Correctness run_output_check(CONTAINER_T const& inputs, CONTAINER_T output_buffer) { constexpr bool debug_mode = true; - std::vector readback_data_vec(all_zeros.size(), 0); // init to 0 data for easier debug - tt_metal::detail::ReadFromBuffer(output_buffer, readback_data_vec); log_info(tt::LogTest, "Checking outputs"); - if (readback_data_vec.size() != inputs.size()) { - log_error(tt::LogTest, "Output size mismatch: expected {} got {}", inputs.size(), readback_data_vec.size()); - return Correctness::Incorrect; - } - bool pass = (readback_data_vec == inputs); - if (not pass) { - log_error("Output mismatch"); - if (debug_mode) { - std::size_t num_printed_mismatches = 0; - for (size_t i = 0; i < readback_data_vec.size() && num_printed_mismatches < 64; i++) { - if (readback_data_vec[i] != inputs[i]) { - log_error("[{}]: expected {} got {}", i, inputs[i], readback_data_vec[i]); - num_printed_mismatches++; + bool pass = true; + + std::size_t num_printed_mismatches = 0; + for (size_t i = 0; i < inputs.size() && num_printed_mismatches < 64; i++) { + if (output_buffer[i] != inputs[i]) { + if (debug_mode) { + if (pass) { + log_error("Output mismatch"); } + log_error("[{}]: expected {} got {}", i, inputs[i], output_buffer[i]); + num_printed_mismatches++; } - log_error("... (remaining mismatches omitted)"); + pass = false; } } - return Correctness::Correct; + if (num_printed_mismatches > 0) { + log_error("... (remaining mismatches omitted)"); + } + + log_info(tt::LogTest, "Output check: {}", pass ? "PASS" : "FAIL"); + return pass ? Correctness::Correct : Correctness::Incorrect; +}; + +static SubdeviceInfo create_subdevices(std::vector const& devices) { + SubdeviceInfo subdevice_info; + std::unordered_map sub_device_manager_ids; + for (auto device : devices) { + const auto& tensix_sub_device = + tt_metal::SubDevice(std::array{device->worker_cores(HalProgrammableCoreType::TENSIX, SubDeviceId{0})}); + const auto& eth_sub_device = tt_metal::SubDevice( + std::array{CoreRangeSet(), device->worker_cores(HalProgrammableCoreType::ACTIVE_ETH, SubDeviceId{0})}); + subdevice_info.sub_device_managers.insert( + {device->id(), device->create_sub_device_manager({tensix_sub_device, eth_sub_device}, 0)}); + device->load_sub_device_manager(subdevice_info.sub_device_managers.at(device->id())); + subdevice_info.worker_subdevice_id.insert( + {device->id(), device->get_sub_device_ids().at(TEST_WORKERS_SUBDEVICE_INDEX)}); + subdevice_info.fabric_subdevice_id.insert( + {device->id(), device->get_sub_device_ids().at(TEST_EDM_FABRIC_SUBDEVICE_INDEX)}); + } + + return subdevice_info; +} + +Correctness run_output_check( + std::vector const& all_zeros, + std::vector const& inputs, + std::shared_ptr& output_buffer) { + constexpr bool debug_mode = true; + std::vector readback_data_vec(all_zeros.size(), 0); // init to 0 data for easier debug + + tt_metal::detail::ReadFromBuffer(output_buffer, readback_data_vec); + return run_output_check(inputs, readback_data_vec); }; -void run_programs(std::vector& programs, std::vector const& devices) { +void run_programs( + std::vector& programs, + std::vector const& devices, + std::optional> const& sub_device_ids = std::nullopt) { EXPECT_EQ(programs.size(), devices.size()); const size_t num_programs = programs.size(); try { @@ -152,7 +214,11 @@ void run_programs(std::vector& programs, std::vector const& de log_debug(tt::LogTest, "Calling Finish"); for (size_t i = 0; i < num_programs; i++) { - tt_metal::Finish(devices.at(i)->command_queue()); + if (sub_device_ids.has_value()) { + tt_metal::Finish(devices.at(i)->command_queue(), {sub_device_ids.value().at(devices.at(i)->id())}); + } else { + tt_metal::Finish(devices.at(i)->command_queue()); + } } } } @@ -251,6 +317,7 @@ void generate_sender_worker_kernels( dram_output_buffer_base_addr, local_worker_last_message_semaphore_id, worker_buffer_index_semaphore_id, + worker_fabric_connection.persistent_fabric ? 1 : 0, worker_fabric_connection.buffer_index_semaphore_id}; if (std::holds_alternative(mode)) { @@ -260,20 +327,7 @@ void generate_sender_worker_kernels( sender_worker_writer_runtime_args.push_back(std::get(mode).distance); } - sender_worker_writer_runtime_args.push_back(edm_termination_infos.size()); - for (auto const& info : edm_termination_infos) { - sender_worker_writer_runtime_args.push_back(info.edm_noc_x); - sender_worker_writer_runtime_args.push_back(info.edm_noc_y); - sender_worker_writer_runtime_args.push_back(info.distance); - sender_worker_writer_runtime_args.push_back(info.termination_addr); - log_trace( - tt::LogTest, - "EDM termination info: x={}, y={}, distance={}, termination_addr={}", - info.edm_noc_x, - info.edm_noc_y, - info.distance, - info.termination_addr); - } + get_runtime_args_for_edm_termination_infos(edm_termination_infos, sender_worker_writer_runtime_args); uint32_t src0_cb_index = CBIndex::c_0; log_trace(tt::LogTest, "\tSenderWriter CT Args"); @@ -323,14 +377,15 @@ bool RunLoopbackTest( const uint32_t page_size, const uint32_t num_pages_total, bool src_is_dram, - bool dest_is_dram) { + bool dest_is_dram, + std::vector& programs, + ttnn::ccl::FabricEriscDatamoverBuilder& chip_0_edm_builder, + std::optional& subdevice_managers, + bool enable_persistent_fabric) { + auto& sender_program = programs.at(0); std::size_t page_plus_header_size = page_size + sizeof(tt::fabric::PacketHeader); std::size_t tensor_size_bytes = num_pages_total * page_size; - std::vector programs(2); - auto& sender_program = programs.at(0); - auto& receiver_program = programs.at(1); - std::vector worker_cores = {CoreCoord(0, 0)}; auto local_worker_fabric_semaphore_id = tt::tt_metal::CreateSemaphore(sender_program, worker_cores.at(0), 0); @@ -366,17 +421,8 @@ bool RunLoopbackTest( //////////////////////////////////////////////////////////////////////////// static constexpr std::size_t edm_buffer_size = 4096 + PACKET_HEADER_SIZE_BYTES; - const chip_id_t local_chip_id = 0; - const chip_id_t remote_chip_id = 1; - auto const& edm_config = ttnn::ccl::FabricEriscDatamoverConfig(edm_buffer_size, 1, 2); - auto chip_0_edm_builder = ttnn::ccl::FabricEriscDatamoverBuilder::build( - sender_device, sender_program, eth_sender_core, local_chip_id, remote_chip_id, edm_config); - auto chip0_worker_fabric_connection = chip_0_edm_builder.build_connection_to_worker_channel(); - auto chip_1_edm_builder = ttnn::ccl::FabricEriscDatamoverBuilder::build( - receiver_device, receiver_program, eth_receiver_core, remote_chip_id, local_chip_id, edm_config); - // Create the loopback connection on the second device - chip_1_edm_builder.connect_to_downstream_edm(chip_1_edm_builder); + auto chip0_worker_fabric_connection = chip_0_edm_builder.build_connection_to_worker_channel(); //////////////////////////////////////////////////////////////////////////// // Build Workers //////////////////////////////////////////////////////////////////////////// @@ -386,22 +432,27 @@ bool RunLoopbackTest( auto const& worker_core = worker_cores.at(0); log_trace(tt::LogTest, "Worker {}. On Core x={},y={}", 0, worker_core.x, worker_core.y); - std::vector const& edm_termination_infos = { - {1, - sender_device->ethernet_core_from_logical_core(eth_receiver_core).x, - sender_device->ethernet_core_from_logical_core(eth_receiver_core).y, - ttnn::ccl::FabricEriscDatamoverConfig::termination_signal_address}, - {0, - sender_device->ethernet_core_from_logical_core(eth_sender_core).x, - sender_device->ethernet_core_from_logical_core(eth_sender_core).y, - ttnn::ccl::FabricEriscDatamoverConfig::termination_signal_address}}; - + std::vector const& edm_termination_infos = + enable_persistent_fabric ? std::vector{} + : std::vector{ + {1, + sender_device->ethernet_core_from_logical_core(eth_receiver_core).x, + sender_device->ethernet_core_from_logical_core(eth_receiver_core).y, + ttnn::ccl::FabricEriscDatamoverConfig::termination_signal_address}, + {0, + sender_device->ethernet_core_from_logical_core(eth_sender_core).x, + sender_device->ethernet_core_from_logical_core(eth_sender_core).y, + ttnn::ccl::FabricEriscDatamoverConfig::termination_signal_address}}; + + TT_ASSERT( + (enable_persistent_fabric && edm_termination_infos.size() == 0) || + (!enable_persistent_fabric && edm_termination_infos.size() > 0)); generate_sender_worker_kernels( sender_program, sender_device, worker_core, chip0_worker_fabric_connection, - unicast_send{1}, + unicast_send{2}, // 2 hops because we are looping back to ourselves edm_buffer_size, page_plus_header_size, num_pages_total, @@ -416,30 +467,445 @@ bool RunLoopbackTest( edm_termination_infos); //////////////////////////////////////////////////////////////////////////// - // Build EDMs + // Compile and Execute Application //////////////////////////////////////////////////////////////////////////// - auto local_edm_kernel = - ttnn::ccl::generate_edm_kernel(sender_program, sender_device, chip_0_edm_builder, eth_sender_core, NOC::NOC_0); + std::vector devices = {sender_device}; + if (!enable_persistent_fabric) { + devices.push_back(receiver_device); + } + log_trace(tt::LogTest, "{} programs, {} devices", programs.size(), devices.size()); + run_programs( + programs, + devices, + subdevice_managers.has_value() ? subdevice_managers.value().worker_subdevice_id + : std::optional>{std::nullopt}); + log_info(tt::LogTest, "Reading back outputs"); - auto remote_edm_kernel = ttnn::ccl::generate_edm_kernel( - receiver_program, receiver_device, chip_1_edm_builder, eth_receiver_core, NOC::NOC_0); + bool pass = true; + constexpr bool enable_check = true; + if constexpr (enable_check) { + pass &= run_output_check(all_zeros, inputs, local_output_buffer) == Correctness::Correct; + } + return pass; +} + +void generate_multi_input_test_worker_reader_kernel( + Program& program, + std::vector const& cb_indices, + std::vector const& tensors, + Device* device, + uint32_t page_size, + CoreRangeSet const& worker_core_range, + uint32_t num_pages_per_edm_buffer, + ttnn::ccl::v2::TensorSlice const& in0_command_tensor_slice, + ttnn::ccl::v2::TensorSlice const& in1_command_tensor_slice, + ttnn::ccl::cmd::CclCommandCode command_type, + DataMovementConfig const& datamovement_kernel_config, + std::optional const& chip0_worker_forward_fabric_connection, + std::optional const& chip0_worker_backward_fabric_connection, + std::optional const& optional_teardown_sequence, + ttnn::ccl::cmd::CclCommandDestArgs const& dest_args) { + bool fabric_enabled = std::holds_alternative(dest_args) || + std::holds_alternative(dest_args); + using namespace ttnn::ccl::cmd::uops; + using namespace ttnn::ccl::cmd; + log_trace( + tt::LogTest, + "Generating multi input test worker reader kernel for command type: {}", + static_cast(command_type)); + + TT_FATAL( + command_type == ttnn::ccl::cmd::CclCommandCode::STREAM_TENSOR_TO_CB || + command_type == ttnn::ccl::cmd::CclCommandCode::STREAM_CB_TO_TENSOR, + "Unsupported tensor IO command type"); + + TT_ASSERT(tensors.size() > 0 && tensors.size() <= 2); + TT_ASSERT(cb_indices.size() == tensors.size()); + + auto sender_worker_reader_kernel = ttnn::ccl::worker_detail::generate_multi_command_stream_kernel_ct_args( + program, cb_indices, tensors, worker_core_range, datamovement_kernel_config); + + std::vector ccl_command_stream0; + std::vector ccl_command_stream1; + + // Add the main tensor slice commands + if (command_type == ttnn::ccl::cmd::CclCommandCode::STREAM_TENSOR_TO_CB) { + log_trace(tt::LogTest, "Adding local noc read"); + if (fabric_enabled) { + ccl_command_stream0.push_back( + read_tensor_slice_to_cb_for_eventual_fabric_write(in0_command_tensor_slice, cb_indices.at(0))); + ccl_command_stream1.push_back( + read_tensor_slice_to_cb_for_eventual_fabric_write(in1_command_tensor_slice, cb_indices.at(1))); + } else { + ccl_command_stream0.push_back(read_tensor_slice_to_cb(in0_command_tensor_slice, cb_indices.at(0))); + ccl_command_stream1.push_back(read_tensor_slice_to_cb(in1_command_tensor_slice, cb_indices.at(1))); + } + } else { + if (std::holds_alternative(dest_args)) { + log_trace(tt::LogTest, "Adding local noc write"); + ccl_command_stream0.push_back(local_write_cb_to_tensor_slice(in0_command_tensor_slice, cb_indices.at(0))); + ccl_command_stream1.push_back(local_write_cb_to_tensor_slice(in1_command_tensor_slice, cb_indices.at(1))); + } else { + if (std::holds_alternative(dest_args)) { + log_trace( + tt::LogTest, + "Adding fabric unicast write command. Distance: {}. Forward: {}", + std::get(dest_args).distance_in_hops, + std::get(dest_args).is_forward_direction); + ccl_command_stream0.push_back(fabric_write_cb_to_tensor_slice( + in0_command_tensor_slice, + cb_indices.at(0), + UnicastCommandDestArgs{std::get(dest_args)})); + ccl_command_stream1.push_back(fabric_write_cb_to_tensor_slice( + in1_command_tensor_slice, + cb_indices.at(1), + UnicastCommandDestArgs{std::get(dest_args)})); + } else if (std::holds_alternative(dest_args)) { + log_trace( + tt::LogTest, + "Adding fabric multicast write command. Forward: {}. Backward: {}", + std::get(dest_args).num_targets_forward_direction, + std::get(dest_args).num_targets_backward_direction); + ccl_command_stream0.push_back(fabric_write_cb_to_tensor_slice( + in0_command_tensor_slice, + cb_indices.at(0), + MulticastCommandDestArgs{std::get(dest_args)})); + ccl_command_stream1.push_back(fabric_write_cb_to_tensor_slice( + in1_command_tensor_slice, + cb_indices.at(1), + MulticastCommandDestArgs{std::get(dest_args)})); + } else { + log_trace(tt::LogTest, "WTF? Should have been caught earlier"); + TT_FATAL(true, "Unsupported dest args type"); + } + } + } + + // Now, because we are bringing up/tearing down the fabric per op with this program, we need to queue up the + // commands to teardown the fabric + // We need to make sure only one of the command streams is sending out the termination signals, and we + // need to make sure it only does that after the other command stream is done - so what we do is + // make the termination command stream wait for a semaphore value (locally) that the other command stream + // will set after it has finished. + if (optional_teardown_sequence.has_value()) { + std::ranges::copy(optional_teardown_sequence.value(), std::back_inserter(ccl_command_stream0)); + } + + ttnn::ccl::worker_detail::generate_multi_input_command_stream_kernel_rt_args( + program, + sender_worker_reader_kernel, + tensors, + {page_size, page_size}, + device, + num_pages_per_edm_buffer, // TODO: get from fabric + worker_core_range, + ccl_command_stream0, + ccl_command_stream1, + chip0_worker_forward_fabric_connection, + chip0_worker_backward_fabric_connection); +} + +void generate_multi_input_test_worker_kernels_for_local_tensor_write( + Program& program, + Device* device, + Tensor& input_tensor0, + Tensor& input_tensor1, + Tensor& output_tensor0, + Tensor& output_tensor1, + size_t first_cb_index, + size_t second_cb_index, + CoreCoord const& worker_core, + const uint32_t page_plus_header_size, + const uint32_t num_pages_per_edm_buffer, + ttnn::ccl::v2::TensorSlice const& in0_tensor_slice, + ttnn::ccl::v2::TensorSlice const& in1_tensor_slice, + ttnn::ccl::v2::TensorSlice const& out0_tensor_slice, + ttnn::ccl::v2::TensorSlice const& out1_tensor_slice, + std::optional const& optional_teardown_sequence, + std::optional& chip0_worker_forward_fabric_connection, + std::optional& chip0_worker_backward_fabric_connection, + ttnn::ccl::cmd::CclCommandDestArgs const& dest_args) { + // Just want a dummy DF + tt::DataFormat df = (page_plus_header_size - PACKET_HEADER_SIZE_BYTES) == 1024 ? tt::DataFormat::Bfp8 + : (page_plus_header_size - PACKET_HEADER_SIZE_BYTES) == 2048 ? tt::DataFormat::Float16 + : tt::DataFormat::Float32; + + { + tt_metal::CircularBufferConfig cb_src0_config = + tt_metal::CircularBufferConfig(2 * num_pages_per_edm_buffer * page_plus_header_size, {{first_cb_index, df}}) + .set_page_size(first_cb_index, page_plus_header_size); + CBHandle cb0 = CreateCircularBuffer(program, worker_core, cb_src0_config); + } + { + tt_metal::CircularBufferConfig cb_src1_config = + tt_metal::CircularBufferConfig( + 2 * num_pages_per_edm_buffer * page_plus_header_size, {{second_cb_index, df}}) + .set_page_size(second_cb_index, page_plus_header_size); + CBHandle cb1 = CreateCircularBuffer(program, worker_core, cb_src1_config); + } + + generate_multi_input_test_worker_reader_kernel( + program, + {first_cb_index, second_cb_index}, + {&input_tensor0, &input_tensor1}, + device, + page_plus_header_size - PACKET_HEADER_SIZE_BYTES, + CoreRangeSet({CoreRange(worker_core)}), + num_pages_per_edm_buffer, + in0_tensor_slice, + in1_tensor_slice, + ttnn::ccl::cmd::CclCommandCode::STREAM_TENSOR_TO_CB, + tt_metal::ReaderDataMovementConfig{}, + std::nullopt, + std::nullopt, + std::nullopt, + dest_args); + + generate_multi_input_test_worker_reader_kernel( + program, + {first_cb_index, second_cb_index}, + {&output_tensor0, &output_tensor1}, + device, + page_plus_header_size - PACKET_HEADER_SIZE_BYTES, + CoreRangeSet({CoreRange(worker_core)}), + num_pages_per_edm_buffer, + out0_tensor_slice, + out1_tensor_slice, + ttnn::ccl::cmd::CclCommandCode::STREAM_CB_TO_TENSOR, + tt_metal::WriterDataMovementConfig{}, + chip0_worker_forward_fabric_connection, + chip0_worker_backward_fabric_connection, + optional_teardown_sequence, + dest_args); +} + +bool RunLocalTestWithMultiInputReaders( + std::vector const& devices, + std::vector& programs, + std::optional& line_fabric, + + Tensor& input_tensor0, + Tensor& input_tensor1, + Tensor& output_tensor0, + Tensor& output_tensor1, + std::vector input0_tensors, // Device + std::vector input1_tensors, // Device + std::vector output0_tensors, // Device + std::vector output1_tensors, // Device + + ttnn::ccl::v2::TensorSlice const& in0_tensor_slice, + ttnn::ccl::v2::TensorSlice const& in1_tensor_slice, + ttnn::ccl::v2::TensorSlice const& out0_tensor_slice, + ttnn::ccl::v2::TensorSlice const& out1_tensor_slice, + + const uint32_t page_size, + TwoInputReaderKernelWriteMode test_mode, + ttnn::ccl::cmd::CclCommandDestArgs const& dest_args, + std::optional& subdevice_managers, + bool enable_persistent_fabric) { + const bool fabric_enabled = test_mode != TwoInputReaderKernelWriteMode::LOCAL_WRITEBACK; + tt_metal::Device* device = devices.at(0); + for (size_t i = 0; i < devices.size(); i++) { + log_info(tt::LogTest, "Device[{}] ID: {}", i, devices.at(i)->id()); + } + auto program_ptrs = std::vector(); + program_ptrs.reserve(devices.size()); + std::ranges::transform(programs, std::back_inserter(program_ptrs), [](auto& p) { return &p; }); + + size_t output_tensor_dest_device_index = 0; + if (fabric_enabled) { + if (std::holds_alternative(dest_args)) { + log_info( + tt::LogTest, + "Unicast command dest args. Distance in hops: {}", + std::get(dest_args).distance_in_hops); + output_tensor_dest_device_index = + std::get(dest_args).distance_in_hops; + TT_ASSERT(output_tensor_dest_device_index != 0, "Output tensor destination device index must be non-zero"); + TT_ASSERT(test_mode == TwoInputReaderKernelWriteMode::FABRIC_UNICAST); + } else if (std::holds_alternative(dest_args)) { + log_info( + tt::LogTest, + "Multicast command dest args. Number of targets forward direction: {}", + std::get(dest_args).num_targets_forward_direction); + output_tensor_dest_device_index = + std::get(dest_args).num_targets_forward_direction; + TT_ASSERT(output_tensor_dest_device_index != 0, "Output tensor destination device index must be non-zero"); + TT_ASSERT(test_mode == TwoInputReaderKernelWriteMode::FABRIC_MULTICAST); + } + } else { + log_info(tt::LogTest, "No fabric enabled"); + TT_ASSERT( + std::holds_alternative(dest_args), "Local command dest args expected"); + } + + std::size_t page_plus_header_size = page_size + sizeof(tt::fabric::PacketHeader); + + auto first_cb_index = tt::CB::c_in0; + auto second_cb_index = tt::CB::c_in1; + + auto output_tensor_dest_device = devices.at(output_tensor_dest_device_index); + TT_ASSERT(input_tensor0.get_logical_shape()[-2] != 1); + + bool is_fabric_mcast = std::holds_alternative(dest_args); + + auto input_tensor0_device = input0_tensors.at(0); + auto input_tensor1_device = input1_tensors.at(0); + auto output_tensor0_device = output0_tensors.at(output_tensor_dest_device_index); + auto output_tensor1_device = output1_tensors.at(output_tensor_dest_device_index); + log_info(tt::LogTest, "input_tensor0_device->address(): {}", input_tensor0_device.buffer()->address()); + log_info(tt::LogTest, "input_tensor1_device->address(): {}", input_tensor1_device.buffer()->address()); + log_info( + tt::LogTest, + "output_tensor0_device->address(): {} on device {}", + output_tensor0_device.buffer()->address(), + output_tensor_dest_device->id()); + log_info( + tt::LogTest, + "output_tensor1_device->address(): {} on device {}", + output_tensor1_device.buffer()->address(), + output_tensor_dest_device->id()); + + //////////////////////////////////////////////////////////////////////////// + // Build Workers + //////////////////////////////////////////////////////////////////////////// + auto const& worker_core = CoreCoord(0, 0); + + const size_t num_pages_per_edm_buffer = 2; + + std::optional chip0_worker_forward_fabric_connection = + fabric_enabled ? line_fabric->uniquely_connect_worker(devices[0], ttnn::ccl::EdmLineFabricOpInterface::FORWARD) + : std::optional{std::nullopt}; + + // always at start of line for now + std::optional> edm_termination_infos = + (!fabric_enabled || enable_persistent_fabric) + ? std::optional>{std::nullopt} + : line_fabric->generate_ordered_termination_info_farthest_to_nearest(); + std::optional chip0_worker_backward_fabric_connection = std::nullopt; + + std::optional sync_details; + std::optional teardown_worker_core; + std::optional teardown_command_stream; + if (fabric_enabled && !enable_persistent_fabric) { + teardown_worker_core = worker_core; + + sync_details = ttnn::ccl::SyncModeSpec{}; + sync_details->core = teardown_worker_core.value(); + sync_details->add_signal(tt::tt_metal::CreateSemaphore(programs.at(0), teardown_worker_core.value(), 0), 1); + teardown_command_stream = {ttnn::ccl::cmd::uops::local_core_semaphore_inc(sync_details->sem_ids.at(0), 1)}; + TT_FATAL(edm_termination_infos.has_value(), "EDM termination infos must be set if fabric is enabled"); + ttnn::ccl::cmd::CclHostLowLevelCommandSequence teardown_commands; + + teardown_commands = ttnn::ccl::worker_detail::build_ccl_cmd_proc_teardown_commands( + programs.at(0), + device, + nullptr, // forward device - in this test, we have a single source doing all teardown + devices.size(), + 0, + edm_termination_infos.value(), + sync_details.value(), + line_fabric.value()); + std::ranges::copy(teardown_commands, std::back_inserter(teardown_command_stream.value())); + } + + generate_multi_input_test_worker_kernels_for_local_tensor_write( + programs.at(0), + device, + input_tensor0_device, + input_tensor1_device, + output_tensor0_device, + output_tensor1_device, + first_cb_index, + second_cb_index, + worker_core, + page_plus_header_size, + num_pages_per_edm_buffer, + in0_tensor_slice, + in1_tensor_slice, + out0_tensor_slice, + out1_tensor_slice, + teardown_command_stream, + chip0_worker_forward_fabric_connection, + chip0_worker_backward_fabric_connection, + dest_args); + + if (!enable_persistent_fabric) { + log_info(tt::LogTest, "Building EDM kernels"); + line_fabric->build_kernels(); + } + + log_info(tt::LogTest, "persistent_fabric: {}", enable_persistent_fabric); + log_info(tt::LogTest, "subdevice_managers.has_value(): {}", subdevice_managers.has_value()); //////////////////////////////////////////////////////////////////////////// // Compile and Execute Application //////////////////////////////////////////////////////////////////////////// - run_programs(programs, {sender_device, receiver_device}); - log_info(tt::LogTest, "Reading back outputs"); + run_programs( + programs, + enable_persistent_fabric ? std::vector{devices[0]} : devices, + subdevice_managers.has_value() ? subdevice_managers.value().worker_subdevice_id + : std::optional>{std::nullopt} + + ); + log_info(tt::LogTest, "Finished"); bool pass = true; constexpr bool enable_check = true; if constexpr (enable_check) { - pass &= run_output_check(all_zeros, inputs, local_output_buffer) == Correctness::Correct; + log_info(tt::LogTest, "Reading back outputs"); + std::vector out_tensor_worker_subdevice_id = + subdevice_managers.has_value() ? std::vector{subdevice_managers->worker_subdevice_id.at( + devices.at(output_tensor_dest_device_index)->id())} + : std::vector{}; + auto output0_cpu = output_tensor0_device.cpu(true, ttnn::DefaultQueueId, out_tensor_worker_subdevice_id); + auto output1_cpu = output_tensor1_device.cpu(true, ttnn::DefaultQueueId, out_tensor_worker_subdevice_id); + + auto in_tensor_worker_subdevice_id = + subdevice_managers.has_value() + ? std::vector{subdevice_managers->worker_subdevice_id.at(devices.at(0)->id())} + : std::vector{}; + auto in0_tensor_copyback_cpu = + input_tensor0_device.cpu(true, ttnn::DefaultQueueId, in_tensor_worker_subdevice_id); + auto in1_tensor_copyback_cpu = + input_tensor1_device.cpu(true, ttnn::DefaultQueueId, in_tensor_worker_subdevice_id); + + auto in0_tensor_copyback = tt::tt_metal::owned_buffer::get_as(in0_tensor_copyback_cpu); + auto in1_tensor_copyback = tt::tt_metal::owned_buffer::get_as(in1_tensor_copyback_cpu); + + auto in0_tensor_data = tt::tt_metal::owned_buffer::get_as(input_tensor0); + auto in1_tensor_data = tt::tt_metal::owned_buffer::get_as(input_tensor1); + auto out0_tensor_data = tt::tt_metal::owned_buffer::get_as(output0_cpu); + auto out1_tensor_data = tt::tt_metal::owned_buffer::get_as(output1_cpu); + + bool input0_copyback_check_passed = + run_output_check(in0_tensor_data, in0_tensor_copyback) == Correctness::Correct; + bool input1_copyback_check_passed = + run_output_check(in1_tensor_data, in1_tensor_copyback) == Correctness::Correct; + TT_FATAL(input0_copyback_check_passed, "Input 0 copyback check failed"); + TT_FATAL(input1_copyback_check_passed, "Input 1 copyback check failed"); + + log_info(tt::LogTest, "Comparing outputs"); + pass &= run_output_check(in0_tensor_data, out0_tensor_data) == Correctness::Correct; + if (pass) { + log_info(tt::LogTest, "Output check passed for output 0"); + } else { + log_error(tt::LogTest, "Output check failed for output 0"); + } + pass &= run_output_check(in1_tensor_data, out1_tensor_data) == Correctness::Correct; + if (pass) { + log_info(tt::LogTest, "Output check passed for output 1"); + } else { + log_error(tt::LogTest, "Output check failed for output 1"); + } } + return pass; } bool RunLineFabricTest( std::vector devices, + std::vector& programs, const size_t mcast_first_chip, const size_t mcast_last_chip, @@ -447,19 +913,20 @@ bool RunLineFabricTest( const uint32_t page_size, const uint32_t num_pages_total, bool src_is_dram, - bool dest_is_dram) { + bool dest_is_dram, + + std::optional& subdevice_managers, + ttnn::ccl::EdmLineFabricOpInterface& line_fabric, + bool enable_persistent_fabric) { std::size_t page_plus_header_size = page_size + sizeof(tt::fabric::PacketHeader); std::size_t tensor_size_bytes = num_pages_total * page_size; static constexpr std::size_t edm_buffer_size = 4096 + PACKET_HEADER_SIZE_BYTES; const size_t local_chip_id = 0; const size_t remote_chip_id = 1; - auto programs = std::vector(devices.size()); auto program_ptrs = std::vector(devices.size()); std::transform(programs.begin(), programs.end(), program_ptrs.begin(), [](auto& program) { return &program; }); - auto line_fabric = ttnn::ccl::EdmLineFabricOpInterface(devices, program_ptrs, 1); - std::vector worker_cores = {CoreCoord(0, 0)}; // Generate inputs @@ -480,8 +947,12 @@ bool RunLineFabricTest( std::vector all_zeros(inputs.size(), 0); // output buffers - TT_ASSERT(mcast_first_chip <= mcast_last_chip, "mcast_first_chip must be less than or equal to mcast_last_chip"); - TT_ASSERT(mcast_last_chip < devices.size(), "mcast_last_chip must be less than the number of devices"); + TT_ASSERT( + enable_persistent_fabric || mcast_first_chip <= mcast_last_chip, + "mcast_first_chip must be less than or equal to mcast_last_chip"); + TT_ASSERT( + enable_persistent_fabric || mcast_last_chip < devices.size(), + "mcast_last_chip must be less than the number of devices"); std::vector> output_buffers; output_buffers.reserve(devices.size()); for (size_t i = 0; i < devices.size(); i++) { @@ -516,7 +987,9 @@ bool RunLineFabricTest( auto const& worker_core = worker_cores.at(0); log_trace(tt::LogTest, "Worker {}. On Core x={},y={}", 0, worker_core.x, worker_core.y); - const auto edm_termination_infos = line_fabric.generate_ordered_termination_info_farthest_to_nearest(); + const auto edm_termination_infos = enable_persistent_fabric + ? std::vector{} + : line_fabric.generate_ordered_termination_info_farthest_to_nearest(); auto chip0_worker_fabric_connection = line_fabric.uniquely_connect_worker(devices[0], ttnn::ccl::EdmLineFabricOpInterface::FORWARD); @@ -528,7 +1001,7 @@ bool RunLineFabricTest( devices[0], worker_core, chip0_worker_fabric_connection, - mcast_send{mcast_first_chip - 1, mcast_last_chip - mcast_first_chip}, + mcast_send{mcast_first_chip, mcast_last_chip - mcast_first_chip + 1}, edm_buffer_size, page_plus_header_size, num_pages_total, @@ -545,13 +1018,19 @@ bool RunLineFabricTest( //////////////////////////////////////////////////////////////////////////// // Build EDM Kernels //////////////////////////////////////////////////////////////////////////// - line_fabric.build_kernels(); + if (!enable_persistent_fabric) { + line_fabric.build_kernels(); + } //////////////////////////////////////////////////////////////////////////// // Compile and Execute Application //////////////////////////////////////////////////////////////////////////// - run_programs(programs, devices); + run_programs( + programs, + devices, + subdevice_managers.has_value() ? subdevice_managers.value().worker_subdevice_id + : std::optional>{std::nullopt}); log_info(tt::LogTest, "Reading back outputs"); bool pass = true; @@ -572,6 +1051,66 @@ bool RunLineFabricTest( return pass; } +void persistent_fabric_teardown_sequence( + std::vector const& devices, + std::optional& subdevice_managers, + ttnn::ccl::EdmLineFabricOpInterface& line_fabric, + tt::fabric::TerminationSignal termination_mode = tt::fabric::TerminationSignal::GRACEFULLY_TERMINATE) { + log_info("Tearing down fabric"); + + // Wait for workers to finish + auto d0_worker_subdevice = devices[0]->get_sub_device_ids()[TEST_WORKERS_SUBDEVICE_INDEX]; + tt_metal::Finish(devices[0]->command_queue(), {subdevice_managers->worker_subdevice_id.at(devices[0]->id())}); + + // Teardown the fabric + line_fabric.teardown_from_host(termination_mode); + + // wait for fabric teardown to finish + std::ranges::for_each(devices, [&](Device* d) { + tt_metal::Finish(d->command_queue(), {subdevice_managers->fabric_subdevice_id.at(d->id())}); + }); +} + +void setup_test_with_persistent_fabric( + std::vector const& devices, + std::vector& programs, + std::optional& subdevice_managers, + std::optional>& fabric_programs, + std::vector& fabric_program_ptrs, + std::optional& line_fabric, + bool enable_persistent_fabric, + std::optional num_links = std::nullopt) { + if (enable_persistent_fabric) { + log_info(tt::LogTest, "Enabling persistent fabric"); + fabric_programs = std::vector(devices.size()); + subdevice_managers = create_subdevices(devices); + std::transform( + fabric_programs->begin(), fabric_programs->end(), std::back_inserter(fabric_program_ptrs), [](auto& p) { + return &p; + }); + } else { + std::transform( + programs.begin(), programs.end(), std::back_inserter(fabric_program_ptrs), [](auto& p) { return &p; }); + } + + line_fabric = ttnn::ccl::EdmLineFabricOpInterface( + devices, fabric_program_ptrs, enable_persistent_fabric, num_links.value_or(1)); + + if (enable_persistent_fabric) { + TT_FATAL(fabric_programs.has_value(), "Fabric programs must be set if fabric is enabled"); + TT_FATAL(devices.size() == fabric_programs->size(), "Number of devices must match number of programs"); + + log_info(tt::LogTest, "Building EDM kernels"); + line_fabric->build_kernels(); + for (size_t i = 0; i < devices.size(); i++) { + tt::tt_metal::detail::CompileProgram(devices[i], fabric_programs->at(i)); + } + for (size_t i = 0; i < devices.size(); i++) { + tt_metal::EnqueueProgram(devices[i]->command_queue(), fabric_programs->at(i), false); + } + } +} + // RESUME HERE AND IMPLEMENT MCAST TEST int TestLineFabricEntrypoint( const size_t mcast_first_chip, @@ -579,7 +1118,8 @@ int TestLineFabricEntrypoint( const uint32_t page_size, const uint32_t num_pages_total, const bool src_is_dram, - const bool dest_is_dram) { + const bool dest_is_dram, + bool enable_persistent_fabric) { // argv[0]: program // argv[1]: buffer_size_bytes // argv[2]: num_loops @@ -587,7 +1127,7 @@ int TestLineFabricEntrypoint( auto arch = tt::get_arch_from_string(tt::test_utils::get_umd_arch_name()); auto num_devices = tt::tt_metal::GetNumAvailableDevices(); if (num_devices < 4) { - log_info("This test can only be run on N300 devices"); + log_info("This test can only be run on T3000 devices"); return 0; } if (arch == tt::ARCH::GRAYSKULL) { @@ -596,32 +1136,59 @@ int TestLineFabricEntrypoint( } T3000TestDevice test_fixture; + auto view = test_fixture.mesh_device_->get_view(); // build a line of devices std::vector devices = { - test_fixture.devices_.at(0), - test_fixture.devices_.at(1), - test_fixture.devices_.at(2), - test_fixture.devices_.at(3)}; - - bool success = false; - try { - success = RunLineFabricTest( - devices, - // fabric_hops, - - mcast_first_chip, - mcast_last_chip, - - page_size, - num_pages_total, - src_is_dram, - dest_is_dram); - - } catch (std::exception& e) { - log_error("Caught exception: {}", e.what()); - test_fixture.TearDown(); - return -1; + view.get_device(0, 0), view.get_device(0, 1), view.get_device(0, 2), view.get_device(0, 3)}; + std::vector programs(enable_persistent_fabric ? 1 : devices.size()); + std::optional subdevice_managers = std::nullopt; + std::optional> fabric_programs; + std::vector fabric_program_ptrs; + std::optional line_fabric; + setup_test_with_persistent_fabric( + devices, + programs, + subdevice_managers, + fabric_programs, + fabric_program_ptrs, + line_fabric, + enable_persistent_fabric); + + auto launch_workers = [&](std::vector& _programs) -> bool { + bool success = false; + try { + success = RunLineFabricTest( + enable_persistent_fabric ? std::vector{devices[0]} : devices, + _programs, + // fabric_hops, + + mcast_first_chip, + mcast_last_chip, + + page_size, + num_pages_total, + src_is_dram, + dest_is_dram, + + subdevice_managers, + line_fabric.value(), + enable_persistent_fabric); + + } catch (std::exception& e) { + log_error("Caught exception: {}", e.what()); + test_fixture.TearDown(); + return false; + } + return success; + }; + bool success = launch_workers(programs); + + if (enable_persistent_fabric) { + std::vector second_run_programs(1); + success = launch_workers(second_run_programs); + persistent_fabric_teardown_sequence( + devices, subdevice_managers, line_fabric.value(), tt::fabric::TerminationSignal::IMMEDIATELY_TERMINATE); } test_fixture.TearDown(); @@ -630,15 +1197,20 @@ int TestLineFabricEntrypoint( } int TestLoopbackEntrypoint( - const uint32_t page_size, const uint32_t num_pages_total, const bool src_is_dram, const bool dest_is_dram) { + const uint32_t page_size, + const uint32_t num_pages_total, + const bool src_is_dram, + const bool dest_is_dram, + bool enable_persistent_fabric) { // argv[0]: program // argv[1]: buffer_size_bytes // argv[2]: num_loops + std::optional subdevice_managers = std::nullopt; auto arch = tt::get_arch_from_string(tt::test_utils::get_umd_arch_name()); auto num_devices = tt::tt_metal::GetNumAvailableDevices(); if (num_devices < 4) { - log_info("This test can only be run on N300 devices"); + log_info("This test can only be run on T3000 devices"); return 0; } if (arch == tt::ARCH::GRAYSKULL) { @@ -647,8 +1219,10 @@ int TestLoopbackEntrypoint( } T3000TestDevice test_fixture; + auto view = test_fixture.mesh_device_->get_view(); - const auto& device_0 = test_fixture.devices_.at(0); + const auto& device_0 = view.get_device(0, 0); + const auto& device_1 = view.get_device(0, 1); auto const& active_eth_cores = device_0->get_active_ethernet_cores(true); auto eth_sender_core_iter = active_eth_cores.begin(); @@ -662,10 +1236,58 @@ int TestLoopbackEntrypoint( std::tie(device_id, eth_receiver_core) = device_0->get_connected_ethernet_core(*eth_sender_core_iter); eth_sender_core = *eth_sender_core_iter; eth_sender_core_iter++; - } while (device_id != 1); - TT_ASSERT(device_id == 1); - const auto& device_1 = test_fixture.devices_.at(device_id); + } while (device_id != device_1->id()); + TT_ASSERT(device_id == device_1->id()); + // const auto& device_1 = test_fixture.mesh_device_->get_device(device_id); + std::vector programs(enable_persistent_fabric ? 1 : 2); + std::optional> fabric_programs; + auto& sender_program = programs.at(0); + if (enable_persistent_fabric) { + log_info(tt::LogTest, "Enabling persistent fabric"); + fabric_programs = std::vector(2); + subdevice_managers = create_subdevices({device_0, device_1}); + } + + auto& fabric_sender_program = enable_persistent_fabric ? fabric_programs->at(0) : sender_program; + auto& fabric_receiver_program = enable_persistent_fabric ? fabric_programs->at(1) : programs.at(1); + Device* sender_device = device_0; + Device* receiver_device = device_1; + + static constexpr std::size_t edm_buffer_size = 4096 + PACKET_HEADER_SIZE_BYTES; + const chip_id_t local_chip_id = 0; + const chip_id_t remote_chip_id = 1; + auto const& edm_config = ttnn::ccl::FabricEriscDatamoverConfig(edm_buffer_size, 1, 2); + auto chip_0_edm_builder = ttnn::ccl::FabricEriscDatamoverBuilder::build( + sender_device, + fabric_sender_program, + eth_sender_core, + local_chip_id, + remote_chip_id, + edm_config, + enable_persistent_fabric); + auto chip_1_edm_builder = ttnn::ccl::FabricEriscDatamoverBuilder::build( + receiver_device, + fabric_receiver_program, + eth_receiver_core, + remote_chip_id, + local_chip_id, + edm_config, + enable_persistent_fabric); + // Create the loopback connection on the second device + chip_1_edm_builder.connect_to_downstream_edm(chip_1_edm_builder); + auto local_edm_kernel = ttnn::ccl::generate_edm_kernel( + fabric_sender_program, sender_device, chip_0_edm_builder, eth_sender_core, NOC::NOC_0); + auto remote_edm_kernel = ttnn::ccl::generate_edm_kernel( + fabric_receiver_program, receiver_device, chip_1_edm_builder, eth_receiver_core, NOC::NOC_0); + + if (enable_persistent_fabric) { + tt::tt_metal::detail::CompileProgram(sender_device, fabric_sender_program); + tt::tt_metal::detail::CompileProgram(receiver_device, fabric_receiver_program); + tt_metal::EnqueueProgram(sender_device->command_queue(), fabric_sender_program, false); + tt_metal::EnqueueProgram(receiver_device->command_queue(), fabric_receiver_program, false); + } + log_trace(tt::LogTest, "{} programs ", programs.size()); bool success = false; try { success = RunLoopbackTest( @@ -678,18 +1300,202 @@ int TestLoopbackEntrypoint( page_size, num_pages_total, src_is_dram, - dest_is_dram); + dest_is_dram, + programs, + chip_0_edm_builder, + subdevice_managers, + enable_persistent_fabric); } catch (std::exception& e) { log_error("Caught exception: {}", e.what()); test_fixture.TearDown(); return -1; } + if (enable_persistent_fabric) { + // Run the test twice with a single fabric invocation + + std::vector second_programs(1); + try { + success = RunLoopbackTest( + device_0, + device_1, + + eth_sender_core, + eth_receiver_core, + + page_size, + num_pages_total, + src_is_dram, + dest_is_dram, + second_programs, + chip_0_edm_builder, + subdevice_managers, + enable_persistent_fabric); + } catch (std::exception& e) { + log_error("Caught exception: {}", e.what()); + test_fixture.TearDown(); + return -1; + } + // Wait for worker programs to finish + + auto d0_worker_subdevice = device_0->get_sub_device_ids()[TEST_WORKERS_SUBDEVICE_INDEX]; + auto d1_worker_subdevice = device_1->get_sub_device_ids()[TEST_WORKERS_SUBDEVICE_INDEX]; + auto d0_fabric_subdevice = device_0->get_sub_device_ids()[TEST_EDM_FABRIC_SUBDEVICE_INDEX]; + auto d1_fabric_subdevice = device_1->get_sub_device_ids()[TEST_EDM_FABRIC_SUBDEVICE_INDEX]; + // Teardown the fabric + tt_metal::Finish(sender_device->command_queue(), {d0_worker_subdevice}); + // tt_metal::Finish(receiver_device->command_queue(), {d1_worker_subdevice}); + + // Notify fabric of teardown + chip_1_edm_builder.teardown_from_host(receiver_device); + chip_0_edm_builder.teardown_from_host(sender_device); + + // wait for fabric finish + tt_metal::Finish(sender_device->command_queue(), {d0_fabric_subdevice}); + tt_metal::Finish(receiver_device->command_queue(), {d1_fabric_subdevice}); + } + test_fixture.TearDown(); return success ? 0 : -1; } +bool TestMultiInputReaderKernel( + size_t fabric_num_devices, + Tensor& input_tensor0, + MemoryConfig const& input_tensor0_mem_config, + Tensor& input_tensor1, + MemoryConfig const& input_tensor1_mem_config, + Tensor& output_tensor0, + MemoryConfig const& output_tensor0_mem_config, + Tensor& output_tensor1, + MemoryConfig const& output_tensor1_mem_config, + + ttnn::ccl::v2::TensorSlice const& in0_tensor_slice, + ttnn::ccl::v2::TensorSlice const& in1_tensor_slice, + ttnn::ccl::v2::TensorSlice const& out0_tensor_slice, + ttnn::ccl::v2::TensorSlice const& out1_tensor_slice, + + const uint32_t page_size, + + TwoInputReaderKernelWriteMode test_mode, + ttnn::ccl::cmd::CclCommandDestArgs const& dest_args, + bool enable_persistent_fabric) { + auto arch = tt::get_arch_from_string(tt::test_utils::get_umd_arch_name()); + auto num_devices = tt::tt_metal::GetNumAvailableDevices(); + if (num_devices < 4) { + log_info("This test can only be run on T3000 devices"); + return true; + } + if (arch == tt::ARCH::GRAYSKULL) { + log_info("Test must be run on WH"); + return true; + } + T3000TestDevice test_fixture; + + TT_FATAL( + !enable_persistent_fabric || test_mode != TwoInputReaderKernelWriteMode::LOCAL_WRITEBACK, + "Test configuration issue. Set local writeback mode with persistent fabric"); + + auto view = test_fixture.mesh_device_->get_view(); + + std::vector devices; + devices.reserve(fabric_num_devices); + for (size_t i = 0; i < fabric_num_devices; i++) { + devices.push_back(view.get_device(0, i)); + } + + std::vector programs(enable_persistent_fabric ? 1 : devices.size()); + std::optional subdevice_managers = std::nullopt; + std::optional> fabric_programs; + std::vector fabric_program_ptrs; + std::optional line_fabric; + setup_test_with_persistent_fabric( + devices, + programs, + subdevice_managers, + fabric_programs, + fabric_program_ptrs, + line_fabric, + enable_persistent_fabric); + + std::vector input0_tensors_device; + std::vector input1_tensors_device; + std::vector output0_tensors_device; + std::vector output1_tensors_device; + + // All this garbage is to make sure the test sets up buffer addresses correctly so we can safely + // multicast to a consistent destination address + for (size_t i = 0; i < devices.size(); i++) { + std::vector subdevice_target = + subdevice_managers.has_value() + ? std::vector{subdevice_managers->worker_subdevice_id.at(devices[i]->id())} + : std::vector{}; + input0_tensors_device.push_back( + input_tensor0.to(devices.at(i), input_tensor0_mem_config, ttnn::DefaultQueueId, subdevice_target)); + input1_tensors_device.push_back( + input_tensor1.to(devices.at(i), input_tensor1_mem_config, ttnn::DefaultQueueId, subdevice_target)); + output0_tensors_device.push_back( + output_tensor0.to(devices.at(i), output_tensor0_mem_config, ttnn::DefaultQueueId, subdevice_target)); + output1_tensors_device.push_back( + output_tensor1.to(devices.at(i), output_tensor1_mem_config, ttnn::DefaultQueueId, subdevice_target)); + } + TT_FATAL( + !enable_persistent_fabric || subdevice_managers.has_value(), + "Subdevice managers must be set if fabric is enabled"); + auto launch_ccl_command_interpreter_workers = [&](std::vector& _programs) { + return RunLocalTestWithMultiInputReaders( + devices, + _programs, + line_fabric, + + input_tensor0, + input_tensor1, + output_tensor0, + output_tensor1, + + input0_tensors_device, + input1_tensors_device, + output0_tensors_device, + output1_tensors_device, + + in0_tensor_slice, + in1_tensor_slice, + out0_tensor_slice, + out1_tensor_slice, + + page_size, + test_mode, + dest_args, + subdevice_managers, + enable_persistent_fabric); + }; + + auto pass = launch_ccl_command_interpreter_workers(programs); + if (enable_persistent_fabric) { + std::vector second_run_programs(1); + // It looks suspicious that we are dropping the first result but there are two reasons we do this + // 1) We really only care that we can run back to back safely + // 2) The first run will end up racing with host and copy-back because there is no + // receiver on the destination that can signal to us when we are done. We need to add this + // to the test to make it more robust but that is future work + pass = launch_ccl_command_interpreter_workers(second_run_programs); + pass = true; + + // Due to race between host and device some packets are in flight by the time host sends shutdown signals so + // some get shutdown in between any packets in the pipeline. This can only be fixed by having a "drainer" op to + // make sure it receives all writes before exiting + persistent_fabric_teardown_sequence( + devices, subdevice_managers, line_fabric.value(), tt::fabric::TerminationSignal::IMMEDIATELY_TERMINATE); + + log_info(tt::LogTest, "Finished"); + for (auto d : devices) { + tt_metal::Synchronize(d, ttnn::DefaultQueueId); + } + } + return pass; +} + //////////////////////////////////////////////////////////////////// /// MESSAGE COUNT TERMINATION MODE //////////////////////////////////////////////////////////////////// @@ -700,7 +1506,7 @@ TEST(WorkerFabricEdmDatapath, FabricEDMLoopback_With_Workers_SingleMessage) { const bool src_is_dram = true; const bool dest_is_dram = true; - auto result = TestLoopbackEntrypoint(page_size, num_pages_total, src_is_dram, dest_is_dram); + auto result = TestLoopbackEntrypoint(page_size, num_pages_total, src_is_dram, dest_is_dram, false); ASSERT_EQ(result, 0); } @@ -711,7 +1517,7 @@ TEST(WorkerFabricEdmDatapath, FabricEDMLoopback_With_Workers_2_messages) { const bool src_is_dram = true; const bool dest_is_dram = true; - auto result = TestLoopbackEntrypoint(page_size, num_pages_total, src_is_dram, dest_is_dram); + auto result = TestLoopbackEntrypoint(page_size, num_pages_total, src_is_dram, dest_is_dram, false); ASSERT_EQ(result, 0); } // Will wrapp sender but not receiver buffers @@ -721,7 +1527,7 @@ TEST(WorkerFabricEdmDatapath, FabricEDMLoopback_With_Workers_10_messages) { const bool src_is_dram = true; const bool dest_is_dram = true; - auto result = TestLoopbackEntrypoint(page_size, num_pages_total, src_is_dram, dest_is_dram); + auto result = TestLoopbackEntrypoint(page_size, num_pages_total, src_is_dram, dest_is_dram, false); ASSERT_EQ(result, 0); } @@ -732,21 +1538,108 @@ TEST(WorkerFabricEdmDatapath, FabricEDMLoopback_With_Workers_20_messages) { const bool src_is_dram = true; const bool dest_is_dram = true; - auto result = TestLoopbackEntrypoint(page_size, num_pages_total, src_is_dram, dest_is_dram); + auto result = TestLoopbackEntrypoint(page_size, num_pages_total, src_is_dram, dest_is_dram, false); ASSERT_EQ(result, 0); } TEST(WorkerFabricEdmDatapath, FabricEDMLoopback_With_Workers) { const uint32_t page_size = 2048; - const uint32_t num_pages_total = 100000; + const uint32_t num_pages_total = 10000; + const bool src_is_dram = true; + const bool dest_is_dram = true; + + auto result = TestLoopbackEntrypoint(page_size, num_pages_total, src_is_dram, dest_is_dram, false); + ASSERT_EQ(result, 0); +} + +// ------------------------- +// Persistent Fabric +// ------------------------- + +TEST(WorkerFabricEdmDatapath, FabricEDMLoopback_With_Workers_SingleMessage_PersistentFabric) { + const uint32_t page_size = 2048; + const uint32_t num_pages_total = 1; + const bool src_is_dram = true; + const bool dest_is_dram = true; + + auto result = TestLoopbackEntrypoint(page_size, num_pages_total, src_is_dram, dest_is_dram, true); + ASSERT_EQ(result, 0); +} + +// Will wrapp sender but not receiver buffers +TEST(WorkerFabricEdmDatapath, FabricEDMLoopback_With_Workers_2_messages_PersistentFabric) { + const uint32_t page_size = 2048; + const uint32_t num_pages_total = 2; + const bool src_is_dram = true; + const bool dest_is_dram = true; + + auto result = TestLoopbackEntrypoint(page_size, num_pages_total, src_is_dram, dest_is_dram, true); + ASSERT_EQ(result, 0); +} +// Will wrapp sender but not receiver buffers +TEST(WorkerFabricEdmDatapath, FabricEDMLoopback_With_Workers_10_messages_PersistentFabric) { + const uint32_t page_size = 2048; + const uint32_t num_pages_total = 10; + const bool src_is_dram = true; + const bool dest_is_dram = true; + + auto result = TestLoopbackEntrypoint(page_size, num_pages_total, src_is_dram, dest_is_dram, true); + ASSERT_EQ(result, 0); +} + +// Will wrapp sender and receiver buffers +TEST(WorkerFabricEdmDatapath, FabricEDMLoopback_With_Workers_20_messages_PersistentFabric) { + const uint32_t page_size = 2048; + const uint32_t num_pages_total = 20; + const bool src_is_dram = true; + const bool dest_is_dram = true; + + auto result = TestLoopbackEntrypoint(page_size, num_pages_total, src_is_dram, dest_is_dram, true); + ASSERT_EQ(result, 0); +} + +TEST(WorkerFabricEdmDatapath, FabricEDMLoopback_With_Workers_PersistentFabric) { + const uint32_t page_size = 2048; + const uint32_t num_pages_total = 10000; + const bool src_is_dram = true; + const bool dest_is_dram = true; + + auto result = TestLoopbackEntrypoint(page_size, num_pages_total, src_is_dram, dest_is_dram, true); + ASSERT_EQ(result, 0); +} + +//////////////////////////////// + +TEST(WorkerFabricEdmDatapath, LineFabricMcast_SingleMessage_SingleSource) { + const uint32_t page_size = 2048; + const uint32_t num_pages_total = 1; + const bool src_is_dram = true; + const bool dest_is_dram = true; + const size_t mcast_first_chip = 1; + const size_t mcast_last_chip = 3; + + auto result = TestLineFabricEntrypoint( + mcast_first_chip, mcast_last_chip, page_size, num_pages_total, src_is_dram, dest_is_dram, false); + + ASSERT_EQ(result, 0); +} + +// Non-functional on harvested parts. Needs testing on unharvested parts. +TEST(WorkerFabricEdmDatapath, LineFabricMcast_ManyMessages_SingleSource) { + const uint32_t page_size = 2048; + const uint32_t num_pages_total = 10000; const bool src_is_dram = true; const bool dest_is_dram = true; + const size_t mcast_first_chip = 1; + const size_t mcast_last_chip = 3; + + auto result = TestLineFabricEntrypoint( + mcast_first_chip, mcast_last_chip, page_size, num_pages_total, src_is_dram, dest_is_dram, false); - auto result = TestLoopbackEntrypoint(page_size, num_pages_total, src_is_dram, dest_is_dram); ASSERT_EQ(result, 0); } -TEST(WorkerFabricEdmDatapath, DISABLED_LineFabricMcast_SingleMessage_SingleSource) { +TEST(WorkerFabricEdmDatapath, LineFabricMcast_SingleMessage_SingleSource_PersistentFabric) { const uint32_t page_size = 2048; const uint32_t num_pages_total = 1; const bool src_is_dram = true; @@ -755,13 +1648,13 @@ TEST(WorkerFabricEdmDatapath, DISABLED_LineFabricMcast_SingleMessage_SingleSourc const size_t mcast_last_chip = 3; auto result = TestLineFabricEntrypoint( - mcast_first_chip, mcast_last_chip, page_size, num_pages_total, src_is_dram, dest_is_dram); + mcast_first_chip, mcast_last_chip, page_size, num_pages_total, src_is_dram, dest_is_dram, true); ASSERT_EQ(result, 0); } // Non-functional on harvested parts. Needs testing on unharvested parts. -TEST(WorkerFabricEdmDatapath, DISABLED_LineFabricMcast_ManyMessages_SingleSource) { +TEST(WorkerFabricEdmDatapath, LineFabricMcast_ManyMessages_SingleSource_PersistentFabric) { const uint32_t page_size = 2048; const uint32_t num_pages_total = 10000; const bool src_is_dram = true; @@ -770,9 +1663,1391 @@ TEST(WorkerFabricEdmDatapath, DISABLED_LineFabricMcast_ManyMessages_SingleSource const size_t mcast_last_chip = 3; auto result = TestLineFabricEntrypoint( - mcast_first_chip, mcast_last_chip, page_size, num_pages_total, src_is_dram, dest_is_dram); + mcast_first_chip, mcast_last_chip, page_size, num_pages_total, src_is_dram, dest_is_dram, true); ASSERT_EQ(result, 0); } -// EnablePersistentKernelCache +#include "ttnn/cpp/ttnn/operations/ccl/common/host/ccl_worker_builder.hpp" + +//////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////// +//// LOCAL CHIP TENSOR READ?WRITE (2 INPUT) +//////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////// + +ttnn::ccl::Shape4D shape_to_shape_in_tiles(ttnn::Shape const& shape) { + auto logical_shape = shape.logical_shape(); + logical_shape[-2] /= tt::constants::TILE_HEIGHT; + logical_shape[-1] /= tt::constants::TILE_WIDTH; + EXPECT_TRUE(logical_shape.size() == 4); + ttnn::ccl::Shape4D shape_in_tiles = { + logical_shape[0], logical_shape[1], logical_shape[2], logical_shape[3]}; + return shape_in_tiles; +} + +bool RunMultiInputReaderTestPropagateFullTensorIn( + ttnn::Shape const& tensor_shape, + Layout const& layout, + MemoryConfig const& in0_memory_config, + MemoryConfig const& in1_memory_config, + MemoryConfig const& out0_memory_config, + MemoryConfig const& out1_memory_config, + TwoInputReaderKernelWriteMode test_writeback_mode) { + auto logical_shape = tensor_shape.logical_shape(); + auto num_elems = std::reduce(logical_shape.cbegin(), logical_shape.cend(), 1, std::multiplies()); + Tensor input_tensor0 = ttnn::arange(0, num_elems, 1, DataType::UINT32).reshape(tensor_shape).to(layout); + Tensor input_tensor1 = ttnn::arange(num_elems, 2 * num_elems, 1, DataType::UINT32).reshape(tensor_shape).to(layout); + Tensor output_tensor0 = ttnn::ones(tensor_shape, DataType::UINT32, layout).reshape(tensor_shape); + Tensor output_tensor1 = ttnn::ones(tensor_shape, DataType::UINT32, layout).reshape(tensor_shape); + input_tensor0.set_tensor_spec(TensorSpec( + logical_shape, TensorLayout(DataType::UINT32, PageConfig(layout, tt_metal::Tile()), in0_memory_config))); + input_tensor1.set_tensor_spec(TensorSpec( + logical_shape, TensorLayout(DataType::UINT32, PageConfig(layout, tt_metal::Tile()), in1_memory_config))); + output_tensor0.set_tensor_spec(TensorSpec( + logical_shape, TensorLayout(DataType::UINT32, PageConfig(layout, tt_metal::Tile()), out0_memory_config))); + output_tensor1.set_tensor_spec(TensorSpec( + logical_shape, TensorLayout(DataType::UINT32, PageConfig(layout, tt_metal::Tile()), out1_memory_config))); + + size_t page_size = tile_size(DataFormat::RawUInt32); + + ttnn::ccl::Shape4D tensor_shape_in_pages = shape_to_shape_in_tiles(tensor_shape); + ttnn::ccl::Shape4D tensor_slice_shape_in_pages = tensor_shape_in_pages; + ttnn::ccl::Shape4D tensor_slice_offset = {0, 0, 0, 0}; + ttnn::ccl::Shape4D worker_slice_shape = tensor_shape_in_pages; + ttnn::ccl::Shape4D worker_slice_offset = {0, 0, 0, 0}; + + ttnn::ccl::v2::TensorSlice tensor_slice{ + tensor_shape_in_pages, + tensor_slice_shape_in_pages, + tensor_slice_offset, + worker_slice_shape, + worker_slice_offset}; + + auto const in0_tensor_slice = tensor_slice; + auto const in1_tensor_slice = tensor_slice; + auto const out0_tensor_slice = tensor_slice; + auto const out1_tensor_slice = tensor_slice; + + auto pass = TestMultiInputReaderKernel( + 1, + input_tensor0, + in0_memory_config, + input_tensor1, + in1_memory_config, + output_tensor0, + out0_memory_config, + output_tensor1, + out1_memory_config, + + in0_tensor_slice, + in1_tensor_slice, + out0_tensor_slice, + out1_tensor_slice, + + page_size, + test_writeback_mode, + ttnn::ccl::cmd::LocalOnlyCommandDestArgs{}, + false); + + return pass; +} + +TEST(WorkerCclCommandProcessingKernelLocalMode, MultiInputReader_SinglePageTile) { + auto pass = RunMultiInputReaderTestPropagateFullTensorIn( + ttnn::Shape{1, 1, 32, 32}, + Layout::TILE, + MemoryConfig(TensorMemoryLayout::INTERLEAVED, BufferType::DRAM), + MemoryConfig(TensorMemoryLayout::INTERLEAVED, BufferType::DRAM), + MemoryConfig(TensorMemoryLayout::INTERLEAVED, BufferType::DRAM), + MemoryConfig(TensorMemoryLayout::INTERLEAVED, BufferType::DRAM), + TwoInputReaderKernelWriteMode::LOCAL_WRITEBACK); + ASSERT_TRUE(pass); +} + +TEST(WorkerCclCommandProcessingKernelLocalMode, MultiInputReader_MultiPage0) { + auto pass = RunMultiInputReaderTestPropagateFullTensorIn( + ttnn::Shape{1, 1, 32, 64}, + Layout::TILE, + MemoryConfig(TensorMemoryLayout::INTERLEAVED, BufferType::DRAM), + MemoryConfig(TensorMemoryLayout::INTERLEAVED, BufferType::DRAM), + MemoryConfig(TensorMemoryLayout::INTERLEAVED, BufferType::DRAM), + MemoryConfig(TensorMemoryLayout::INTERLEAVED, BufferType::DRAM), + TwoInputReaderKernelWriteMode::LOCAL_WRITEBACK); + ASSERT_TRUE(pass); +} + +TEST(WorkerCclCommandProcessingKernelLocalMode, MultiInputReader_MultiPage0_Sharded) { + ttnn::Shape tensor_shape = {1, 1, 32, 64}; + auto logical_shape = tensor_shape.logical_shape(); + auto mem_config = MemoryConfig( + TensorMemoryLayout::WIDTH_SHARDED, + BufferType::L1, + ShardSpec( + CoreRangeSet{std::set{CoreRange{CoreCoord{0, 0}, CoreCoord{0, 0}}}}, + {logical_shape[0] * logical_shape[1] * logical_shape[2], logical_shape[3]}, + ShardOrientation::ROW_MAJOR, + false, + ShardMode::LOGICAL)); + auto pass = RunMultiInputReaderTestPropagateFullTensorIn( + tensor_shape, + Layout::TILE, + mem_config, + mem_config, + mem_config, + mem_config, + TwoInputReaderKernelWriteMode::LOCAL_WRITEBACK); + ASSERT_TRUE(pass); +} +TEST(WorkerCclCommandProcessingKernelLocalMode, MultiInputReader_MultiPage0_Sharded1) { + ttnn::Shape tensor_shape = {1, 1, 32, 128}; + auto logical_shape = tensor_shape.logical_shape(); + auto mem_config = MemoryConfig( + TensorMemoryLayout::WIDTH_SHARDED, + BufferType::L1, + ShardSpec( + CoreRangeSet{std::set{CoreRange{CoreCoord{0, 0}, CoreCoord{0, 0}}}}, + {logical_shape[0] * logical_shape[1] * logical_shape[2], logical_shape[3]}, + ShardOrientation::ROW_MAJOR, + false, + ShardMode::LOGICAL)); + auto pass = RunMultiInputReaderTestPropagateFullTensorIn( + tensor_shape, + Layout::TILE, + mem_config, + mem_config, + mem_config, + mem_config, + TwoInputReaderKernelWriteMode::LOCAL_WRITEBACK); + ASSERT_TRUE(pass); +} +TEST(WorkerCclCommandProcessingKernelLocalMode, MultiInputReader_MultiPage0_Sharded2) { + ttnn::Shape tensor_shape = {1, 1, 32, 128}; + auto logical_shape = tensor_shape.logical_shape(); + auto mem_config = MemoryConfig( + TensorMemoryLayout::WIDTH_SHARDED, + BufferType::L1, + ShardSpec( + CoreRangeSet{std::set{CoreRange{CoreCoord{0, 0}, CoreCoord{3, 0}}}}, + {logical_shape[0] * logical_shape[1] * logical_shape[2], logical_shape[3] / 4}, + ShardOrientation::ROW_MAJOR, + false, + ShardMode::LOGICAL)); + auto pass = RunMultiInputReaderTestPropagateFullTensorIn( + tensor_shape, + Layout::TILE, + mem_config, + mem_config, + mem_config, + mem_config, + TwoInputReaderKernelWriteMode::LOCAL_WRITEBACK); + ASSERT_TRUE(pass); +} +TEST(WorkerCclCommandProcessingKernelLocalMode, MultiInputReader_MultiPage0_Sharded3) { + ttnn::Shape tensor_shape = {1, 1, 32, 8192}; + auto logical_shape = tensor_shape.logical_shape(); + size_t ncores_x = 8; + size_t ncores_y = 4; + auto mem_config = MemoryConfig( + TensorMemoryLayout::WIDTH_SHARDED, + BufferType::L1, + ShardSpec( + CoreRangeSet{std::set{CoreRange{CoreCoord{0, 0}, CoreCoord{ncores_x - 1, ncores_y - 1}}}}, + {logical_shape[0] * logical_shape[1] * logical_shape[2], logical_shape[3] / (ncores_x * ncores_y)}, + ShardOrientation::ROW_MAJOR, + false, + ShardMode::LOGICAL)); + auto pass = RunMultiInputReaderTestPropagateFullTensorIn( + tensor_shape, + Layout::TILE, + mem_config, + mem_config, + mem_config, + mem_config, + TwoInputReaderKernelWriteMode::LOCAL_WRITEBACK); + ASSERT_TRUE(pass); +} +TEST(WorkerCclCommandProcessingKernelLocalMode, MultiInputReader_MultiPage0_Sharded4) { + ttnn::Shape tensor_shape = {1, 1, 32, 1024}; + auto logical_shape = tensor_shape.logical_shape(); + size_t ncores_x = 8; + size_t ncores_y = 4; + auto mem_config = MemoryConfig( + TensorMemoryLayout::WIDTH_SHARDED, + BufferType::L1, + ShardSpec( + CoreRangeSet{std::set{CoreRange{CoreCoord{0, 0}, CoreCoord{ncores_x - 1, ncores_y - 1}}}}, + {logical_shape[0] * logical_shape[1] * logical_shape[2], logical_shape[3] / (ncores_x * ncores_y)}, + ShardOrientation::ROW_MAJOR, + false, + ShardMode::LOGICAL)); + auto pass = RunMultiInputReaderTestPropagateFullTensorIn( + tensor_shape, + Layout::TILE, + mem_config, + mem_config, + mem_config, + mem_config, + TwoInputReaderKernelWriteMode::LOCAL_WRITEBACK); + ASSERT_TRUE(pass); +} + +TEST(WorkerCclCommandProcessingKernelLocalMode, MultiInputReader_MultiPage0_Sharded_WithReshard0) { + ttnn::Shape tensor_shape = {1, 1, 32, 128}; + auto logical_shape = tensor_shape.logical_shape(); + Layout const layout = Layout::TILE; + auto input_mem_config = MemoryConfig( + TensorMemoryLayout::WIDTH_SHARDED, + BufferType::L1, + ShardSpec( + CoreRangeSet{std::set{CoreRange{CoreCoord{0, 0}, CoreCoord{0, 0}}}}, + {logical_shape[0] * logical_shape[1] * logical_shape[2], logical_shape[3]}, + ShardOrientation::ROW_MAJOR, + false, + ShardMode::LOGICAL)); + auto output_mem_config = MemoryConfig( + TensorMemoryLayout::WIDTH_SHARDED, + BufferType::L1, + ShardSpec( + CoreRangeSet{std::set{CoreRange{CoreCoord{0, 0}, CoreCoord{3, 0}}}}, + {logical_shape[0] * logical_shape[1] * logical_shape[2], logical_shape[3] / 4}, + ShardOrientation::ROW_MAJOR, + false, + ShardMode::LOGICAL)); + auto pass = RunMultiInputReaderTestPropagateFullTensorIn( + tensor_shape, + Layout::TILE, + input_mem_config, + input_mem_config, + output_mem_config, + output_mem_config, + TwoInputReaderKernelWriteMode::LOCAL_WRITEBACK); + ASSERT_TRUE(pass); +} + +TEST(WorkerCclCommandProcessingKernelLocalMode, MultiInputReader_MultiPage0_Sharded_WithReshard0_UniquePerStream) { + ttnn::Shape tensor_shape = {1, 1, 32, 128}; + auto logical_shape = tensor_shape.logical_shape(); + Layout const layout = Layout::TILE; + size_t in_shard_grid_x = 1; + size_t in_shard_grid_y = 1; + size_t out_shard_grid_x = 4; + size_t out_shard_grid_y = 1; + auto mem_config0 = MemoryConfig( + TensorMemoryLayout::WIDTH_SHARDED, + BufferType::L1, + ShardSpec( + CoreRangeSet{ + std::set{CoreRange{CoreCoord{0, 0}, CoreCoord{in_shard_grid_x - 1, in_shard_grid_y - 1}}}}, + {logical_shape[0] * logical_shape[1] * logical_shape[2], + logical_shape[3] / (in_shard_grid_x * in_shard_grid_y)}, + ShardOrientation::ROW_MAJOR, + false, + ShardMode::LOGICAL)); + auto mem_config1 = MemoryConfig( + TensorMemoryLayout::WIDTH_SHARDED, + BufferType::L1, + ShardSpec( + CoreRangeSet{ + std::set{CoreRange{CoreCoord{0, 0}, CoreCoord{out_shard_grid_x - 1, out_shard_grid_y - 1}}}}, + {logical_shape[0] * logical_shape[1] * logical_shape[2], + logical_shape[3] / (out_shard_grid_x * out_shard_grid_y)}, + ShardOrientation::ROW_MAJOR, + false, + ShardMode::LOGICAL)); + auto pass = RunMultiInputReaderTestPropagateFullTensorIn( + tensor_shape, + Layout::TILE, + mem_config0, + mem_config1, + mem_config1, + mem_config0, + TwoInputReaderKernelWriteMode::LOCAL_WRITEBACK); + ASSERT_TRUE(pass); +} + +// Copying even slightly large tensors exposes issues in underlying tensor code +// that isn't under test here +TEST(WorkerCclCommandProcessingKernelLocalMode, MultiInputReader_MultiPage1) { + ttnn::Shape tensor_shape = {1, 1, 256, 256}; + auto pass = RunMultiInputReaderTestPropagateFullTensorIn( + tensor_shape, + Layout::TILE, + MemoryConfig(TensorMemoryLayout::INTERLEAVED, BufferType::DRAM), + MemoryConfig(TensorMemoryLayout::INTERLEAVED, BufferType::DRAM), + MemoryConfig(TensorMemoryLayout::INTERLEAVED, BufferType::DRAM), + MemoryConfig(TensorMemoryLayout::INTERLEAVED, BufferType::DRAM), + TwoInputReaderKernelWriteMode::LOCAL_WRITEBACK); + ASSERT_TRUE(pass); +} + +// TODO: update the test infra to be able to properly compare tensors if we are only +// doing a slice of the larger tensor + +// //////////////////////////////////////////////////////////////////// +// //////////////////////////////////////////////////////////////////// +// //// FABRIC UNICAST TENSOR WRITE (2 INPUT) +// //////////////////////////////////////////////////////////////////// +// //////////////////////////////////////////////////////////////////// + +TEST(WorkerCclCommandProcessingKernelFabricUnicastMode, MultiInputReader_SinglePageTile_OneHop) { + ttnn::Shape tensor_shape = {1, 1, 32, 32}; + constexpr size_t distance_dest_device = 1; + constexpr size_t num_devices = 4; + auto logical_shape = tensor_shape.logical_shape(); + Layout const layout = Layout::TILE; + MemoryConfig const in0_memory_config = MemoryConfig(TensorMemoryLayout::INTERLEAVED, BufferType::DRAM); + MemoryConfig const in1_memory_config = MemoryConfig(TensorMemoryLayout::INTERLEAVED, BufferType::DRAM); + MemoryConfig const out0_memory_config = MemoryConfig(TensorMemoryLayout::INTERLEAVED, BufferType::DRAM); + MemoryConfig const out1_memory_config = MemoryConfig(TensorMemoryLayout::INTERLEAVED, BufferType::DRAM); + + auto num_elems = std::reduce(logical_shape.cbegin(), logical_shape.cend(), 1, std::multiplies()); + Tensor input_tensor0 = ttnn::arange(0, num_elems, 1, DataType::UINT32).reshape(tensor_shape).to(layout); + Tensor input_tensor1 = ttnn::arange(num_elems, 2 * num_elems, 1, DataType::UINT32).reshape(tensor_shape).to(layout); + Tensor output_tensor0 = ttnn::ones(tensor_shape.value, DataType::UINT32, layout).reshape(tensor_shape); + Tensor output_tensor1 = ttnn::ones(tensor_shape.value, DataType::UINT32, layout).reshape(tensor_shape); + + input_tensor0.set_tensor_spec(TensorSpec( + logical_shape, TensorLayout(DataType::UINT32, PageConfig(layout, tt_metal::Tile()), in0_memory_config))); + input_tensor1.set_tensor_spec(TensorSpec( + logical_shape, TensorLayout(DataType::UINT32, PageConfig(layout, tt_metal::Tile()), in1_memory_config))); + output_tensor0.set_tensor_spec(TensorSpec( + logical_shape, TensorLayout(DataType::UINT32, PageConfig(layout, tt_metal::Tile()), out0_memory_config))); + output_tensor1.set_tensor_spec(TensorSpec( + logical_shape, TensorLayout(DataType::UINT32, PageConfig(layout, tt_metal::Tile()), out1_memory_config))); + + size_t page_size = tile_size(DataFormat::RawUInt32); + + ttnn::ccl::Shape4D tensor_shape_in_pages = shape_to_shape_in_tiles(tensor_shape); + ttnn::ccl::Shape4D tensor_slice_shape_in_pages = tensor_shape_in_pages; + ttnn::ccl::Shape4D tensor_slice_offset = {0, 0, 0, 0}; + ttnn::ccl::Shape4D worker_slice_shape = tensor_shape_in_pages; + ttnn::ccl::Shape4D worker_slice_offset = {0, 0, 0, 0}; + + ttnn::ccl::v2::TensorSlice tensor_slice{ + tensor_shape_in_pages, + tensor_slice_shape_in_pages, + tensor_slice_offset, + worker_slice_shape, + worker_slice_offset}; + + auto const in0_tensor_slice = tensor_slice; + auto const in1_tensor_slice = tensor_slice; + auto const out0_tensor_slice = tensor_slice; + auto const out1_tensor_slice = tensor_slice; + + ttnn::ccl::cmd::CclCommandDestArgs dest_args = ttnn::ccl::cmd::UnicastCommandDestArgs{distance_dest_device, true}; + auto pass = TestMultiInputReaderKernel( + num_devices, + input_tensor0, + in0_memory_config, + input_tensor1, + in1_memory_config, + output_tensor0, + out0_memory_config, + output_tensor1, + out1_memory_config, + + in0_tensor_slice, + in1_tensor_slice, + out0_tensor_slice, + out1_tensor_slice, + + page_size, + TwoInputReaderKernelWriteMode::FABRIC_UNICAST, + dest_args, + false); + + ASSERT_TRUE(pass); +} + +TEST(WorkerCclCommandProcessingKernelFabricUnicastMode, MultiInputReader_SinglePageTile_OneHop_PersistentFabric) { + ttnn::Shape tensor_shape = {1, 1, 32, 32}; + constexpr size_t distance_dest_device = 1; + constexpr size_t num_devices = 4; + auto logical_shape = tensor_shape.logical_shape(); + Layout const layout = Layout::TILE; + MemoryConfig const in0_memory_config = MemoryConfig(TensorMemoryLayout::INTERLEAVED, BufferType::DRAM); + MemoryConfig const in1_memory_config = MemoryConfig(TensorMemoryLayout::INTERLEAVED, BufferType::DRAM); + MemoryConfig const out0_memory_config = MemoryConfig(TensorMemoryLayout::INTERLEAVED, BufferType::DRAM); + MemoryConfig const out1_memory_config = MemoryConfig(TensorMemoryLayout::INTERLEAVED, BufferType::DRAM); + + auto num_elems = std::reduce(logical_shape.cbegin(), logical_shape.cend(), 1, std::multiplies()); + Tensor input_tensor0 = ttnn::arange(0, num_elems, 1, DataType::UINT32).reshape(tensor_shape).to(layout); + Tensor input_tensor1 = ttnn::arange(num_elems, 2 * num_elems, 1, DataType::UINT32).reshape(tensor_shape).to(layout); + Tensor output_tensor0 = ttnn::ones(tensor_shape.value, DataType::UINT32, layout).reshape(tensor_shape); + Tensor output_tensor1 = ttnn::ones(tensor_shape.value, DataType::UINT32, layout).reshape(tensor_shape); + + input_tensor0.set_tensor_spec(TensorSpec( + logical_shape, TensorLayout(DataType::UINT32, PageConfig(layout, tt_metal::Tile()), in0_memory_config))); + input_tensor1.set_tensor_spec(TensorSpec( + logical_shape, TensorLayout(DataType::UINT32, PageConfig(layout, tt_metal::Tile()), in1_memory_config))); + output_tensor0.set_tensor_spec(TensorSpec( + logical_shape, TensorLayout(DataType::UINT32, PageConfig(layout, tt_metal::Tile()), out0_memory_config))); + output_tensor1.set_tensor_spec(TensorSpec( + logical_shape, TensorLayout(DataType::UINT32, PageConfig(layout, tt_metal::Tile()), out1_memory_config))); + + size_t page_size = tile_size(DataFormat::RawUInt32); + + ttnn::ccl::Shape4D tensor_shape_in_pages = shape_to_shape_in_tiles(tensor_shape); + ttnn::ccl::Shape4D tensor_slice_shape_in_pages = tensor_shape_in_pages; + ttnn::ccl::Shape4D tensor_slice_offset = {0, 0, 0, 0}; + ttnn::ccl::Shape4D worker_slice_shape = tensor_shape_in_pages; + ttnn::ccl::Shape4D worker_slice_offset = {0, 0, 0, 0}; + + ttnn::ccl::v2::TensorSlice tensor_slice{ + tensor_shape_in_pages, + tensor_slice_shape_in_pages, + tensor_slice_offset, + worker_slice_shape, + worker_slice_offset}; + + auto const in0_tensor_slice = tensor_slice; + auto const in1_tensor_slice = tensor_slice; + auto const out0_tensor_slice = tensor_slice; + auto const out1_tensor_slice = tensor_slice; + + ttnn::ccl::cmd::CclCommandDestArgs dest_args = ttnn::ccl::cmd::UnicastCommandDestArgs{distance_dest_device, true}; + auto pass = TestMultiInputReaderKernel( + num_devices, + input_tensor0, + in0_memory_config, + input_tensor1, + in1_memory_config, + output_tensor0, + out0_memory_config, + output_tensor1, + out1_memory_config, + + in0_tensor_slice, + in1_tensor_slice, + out0_tensor_slice, + out1_tensor_slice, + + page_size, + TwoInputReaderKernelWriteMode::FABRIC_UNICAST, + dest_args, + true); + + ASSERT_TRUE(pass); +} + +// //////////////////////////////////////////////////////////////////// +// //////////////////////////////////////////////////////////////////// +// //// FABRIC MCAST TENSOR WRITE (2 INPUT) +// //////////////////////////////////////////////////////////////////// +// //////////////////////////////////////////////////////////////////// + +void RunFabricMcastFullTensorPropagateTest( + ttnn::Shape const& tensor_shape, size_t distance_dest_device, size_t num_devices, bool enable_persistent_fabric) { + auto logical_shape = tensor_shape.logical_shape(); + Layout const layout = Layout::TILE; + MemoryConfig const in0_memory_config = MemoryConfig(TensorMemoryLayout::INTERLEAVED, BufferType::DRAM); + MemoryConfig const in1_memory_config = MemoryConfig(TensorMemoryLayout::INTERLEAVED, BufferType::DRAM); + MemoryConfig const out0_memory_config = MemoryConfig(TensorMemoryLayout::INTERLEAVED, BufferType::DRAM); + MemoryConfig const out1_memory_config = MemoryConfig(TensorMemoryLayout::INTERLEAVED, BufferType::DRAM); + + auto num_elems = std::reduce(logical_shape.cbegin(), logical_shape.cend(), 1, std::multiplies()); + Tensor input_tensor1 = ttnn::arange(num_elems, 2 * num_elems, 1, DataType::UINT32).reshape(tensor_shape).to(layout); + Tensor input_tensor0 = ttnn::arange(0, num_elems, 1, DataType::UINT32).reshape(tensor_shape).to(layout); + Tensor output_tensor1 = ttnn::ones(tensor_shape.value, DataType::UINT32, layout).reshape(tensor_shape); + Tensor output_tensor0 = ttnn::ones(tensor_shape.value, DataType::UINT32, layout).reshape(tensor_shape); + input_tensor0.set_tensor_spec(TensorSpec( + logical_shape, TensorLayout(DataType::UINT32, PageConfig(layout, tt_metal::Tile()), in0_memory_config))); + input_tensor1.set_tensor_spec(TensorSpec( + logical_shape, TensorLayout(DataType::UINT32, PageConfig(layout, tt_metal::Tile()), in1_memory_config))); + output_tensor0.set_tensor_spec(TensorSpec( + logical_shape, TensorLayout(DataType::UINT32, PageConfig(layout, tt_metal::Tile()), out0_memory_config))); + output_tensor1.set_tensor_spec(TensorSpec( + logical_shape, TensorLayout(DataType::UINT32, PageConfig(layout, tt_metal::Tile()), out1_memory_config))); + ASSERT_EQ(input_tensor0.get_logical_shape(), tensor_shape.logical_shape()); + ASSERT_EQ(input_tensor1.get_logical_shape(), tensor_shape.logical_shape()); + ASSERT_EQ(output_tensor0.get_logical_shape(), tensor_shape.logical_shape()); + ASSERT_EQ(output_tensor1.get_logical_shape(), tensor_shape.logical_shape()); + + size_t page_size = tile_size(DataFormat::RawUInt32); + + ttnn::ccl::Shape4D tensor_shape_in_pages = shape_to_shape_in_tiles(tensor_shape); + ttnn::ccl::Shape4D tensor_slice_shape_in_pages = tensor_shape_in_pages; + ttnn::ccl::Shape4D tensor_slice_offset = {0, 0, 0, 0}; + ttnn::ccl::Shape4D worker_slice_shape = tensor_shape_in_pages; + ttnn::ccl::Shape4D worker_slice_offset = {0, 0, 0, 0}; + + ttnn::ccl::v2::TensorSlice tensor_slice{ + tensor_shape_in_pages, + tensor_slice_shape_in_pages, + tensor_slice_offset, + worker_slice_shape, + worker_slice_offset}; + + auto const in0_tensor_slice = tensor_slice; + auto const in1_tensor_slice = tensor_slice; + auto const out0_tensor_slice = tensor_slice; + auto const out1_tensor_slice = tensor_slice; + + ttnn::ccl::cmd::CclCommandDestArgs dest_args = ttnn::ccl::cmd::MulticastCommandDestArgs{distance_dest_device, 0}; + auto pass = TestMultiInputReaderKernel( + num_devices, + input_tensor0, + in0_memory_config, + input_tensor1, + in1_memory_config, + output_tensor0, + out0_memory_config, + output_tensor1, + out1_memory_config, + + in0_tensor_slice, + in1_tensor_slice, + out0_tensor_slice, + out1_tensor_slice, + + page_size, + TwoInputReaderKernelWriteMode::FABRIC_MULTICAST, + dest_args, + enable_persistent_fabric); + + ASSERT_TRUE(pass); +} + +TEST(WorkerCclCommandProcessingKernelFabricMulticastMode, MultiInputReader_SinglePageTile_SingleHop) { + ttnn::Shape tensor_shape = {1, 1, 32, 32}; + constexpr size_t distance_dest_device = 1; + constexpr size_t num_devices = 4; + RunFabricMcastFullTensorPropagateTest(tensor_shape, distance_dest_device, num_devices, false); +} +TEST(WorkerCclCommandProcessingKernelFabricMulticastMode, MultiInputReader_SinglePageTile_TwoHop) { + ttnn::Shape tensor_shape = {1, 1, 32, 32}; + constexpr size_t distance_dest_device = 2; + constexpr size_t num_devices = 4; + RunFabricMcastFullTensorPropagateTest(tensor_shape, distance_dest_device, num_devices, false); +} +TEST(WorkerCclCommandProcessingKernelFabricMulticastMode, MultiInputReader_SinglePageTile_ThreeHop) { + ttnn::Shape tensor_shape = {1, 1, 32, 32}; + constexpr size_t distance_dest_device = 3; + constexpr size_t num_devices = 4; + RunFabricMcastFullTensorPropagateTest(tensor_shape, distance_dest_device, num_devices, false); +} + +TEST(WorkerCclCommandProcessingKernelFabricMulticastMode, MultiInputReader_4PageTile_SingleHop) { + ttnn::Shape tensor_shape = {1, 1, 32, 128}; + constexpr size_t distance_dest_device = 1; + constexpr size_t num_devices = 4; + RunFabricMcastFullTensorPropagateTest(tensor_shape, distance_dest_device, num_devices, false); +} +TEST(WorkerCclCommandProcessingKernelFabricMulticastMode, DMultiInputReader_4PageTile_TwoHop) { + ttnn::Shape tensor_shape = {1, 1, 128, 32}; + constexpr size_t distance_dest_device = 2; + constexpr size_t num_devices = 4; + RunFabricMcastFullTensorPropagateTest(tensor_shape, distance_dest_device, num_devices, false); +} +TEST(WorkerCclCommandProcessingKernelFabricMulticastMode, MultiInputReader_4PageTile_ThreeHop) { + ttnn::Shape tensor_shape = {1, 1, 64, 64}; + constexpr size_t distance_dest_device = 3; + constexpr size_t num_devices = 4; + RunFabricMcastFullTensorPropagateTest(tensor_shape, distance_dest_device, num_devices, false); +} +TEST(WorkerCclCommandProcessingKernelFabricMulticastMode, MultiInputReader_lotsPageTile_ThreeHop) { + ttnn::Shape tensor_shape = {1, 1, 64, 16384}; + constexpr size_t distance_dest_device = 3; + constexpr size_t num_devices = 4; + RunFabricMcastFullTensorPropagateTest(tensor_shape, distance_dest_device, num_devices, false); +} + +TEST(WorkerCclCommandProcessingKernelFabricMulticastMode, MultiInputReader_SinglePageTile_SingleHop_PersistentFabric) { + ttnn::Shape tensor_shape = {1, 1, 32, 32}; + constexpr size_t distance_dest_device = 1; + constexpr size_t num_devices = 4; + RunFabricMcastFullTensorPropagateTest(tensor_shape, distance_dest_device, num_devices, true); +} + +TEST(WorkerCclCommandProcessingKernelFabricMulticastMode, MultiInputReader_SinglePageTile_TwoHop_PersistentFabric) { + ttnn::Shape tensor_shape = {1, 1, 32, 32}; + constexpr size_t distance_dest_device = 2; + constexpr size_t num_devices = 4; + RunFabricMcastFullTensorPropagateTest(tensor_shape, distance_dest_device, num_devices, true); +} +TEST(WorkerCclCommandProcessingKernelFabricMulticastMode, MultiInputReader_SinglePageTile_ThreeHop_PersistentFabric) { + ttnn::Shape tensor_shape = {1, 1, 32, 32}; + constexpr size_t distance_dest_device = 3; + constexpr size_t num_devices = 4; + RunFabricMcastFullTensorPropagateTest(tensor_shape, distance_dest_device, num_devices, true); +} + +TEST(WorkerCclCommandProcessingKernelFabricMulticastMode, MultiInputReader_4PageTile_SingleHop_PersistentFabric) { + ttnn::Shape tensor_shape = {1, 1, 32, 128}; + constexpr size_t distance_dest_device = 1; + constexpr size_t num_devices = 4; + RunFabricMcastFullTensorPropagateTest(tensor_shape, distance_dest_device, num_devices, true); +} +TEST(WorkerCclCommandProcessingKernelFabricMulticastMode, DMultiInputReader_4PageTile_TwoHop_PersistentFabric) { + ttnn::Shape tensor_shape = {1, 1, 128, 32}; + constexpr size_t distance_dest_device = 2; + constexpr size_t num_devices = 4; + RunFabricMcastFullTensorPropagateTest(tensor_shape, distance_dest_device, num_devices, true); +} +TEST(WorkerCclCommandProcessingKernelFabricMulticastMode, MultiInputReader_4PageTile_ThreeHop_PersistentFabric) { + ttnn::Shape tensor_shape = {1, 1, 64, 64}; + constexpr size_t distance_dest_device = 3; + constexpr size_t num_devices = 4; + RunFabricMcastFullTensorPropagateTest(tensor_shape, distance_dest_device, num_devices, true); +} +TEST(WorkerCclCommandProcessingKernelFabricMulticastMode, MultiInputReader_lotsPageTile_ThreeHop_PersistentFabric) { + ttnn::Shape tensor_shape = {1, 1, 64, 16384}; + constexpr size_t distance_dest_device = 3; + constexpr size_t num_devices = 4; + RunFabricMcastFullTensorPropagateTest(tensor_shape, distance_dest_device, num_devices, true); +} + +bool RunPipelinedWorkersTest( + + ttnn::Shape tensor_shape, + const size_t split_dim, + + // In this test we will have n stages with anywhere from 1 to 8 workers per stage (this will be configurable) + const size_t num_stages, + std::vector num_workers_per_stage, + const size_t slices_per_stage, + const tt::DataFormat data_format, + const size_t page_size_bytes, + const size_t cb_packet_size_in_pages, + const size_t num_packets_per_cb, + auto layout, + + std::vector> worker_chunk_read_order, + std::vector mem_configs) { + auto arch = tt::get_arch_from_string(tt::test_utils::get_umd_arch_name()); + auto num_devices = tt::tt_metal::GetNumAvailableDevices(); + if (num_devices < 4) { + log_info("This test can only be run on T3000 devices"); + return true; + } + if (arch == tt::ARCH::GRAYSKULL) { + log_info("Test must be run on WH"); + return true; + } + + auto logical_shape = tensor_shape.logical_shape(); + auto const cb_index = tt::CB::c_in0; + + auto programs = std::vector(1); + Program& program = programs[0]; + + T3000TestDevice test_fixture; + auto view = test_fixture.mesh_device_->get_view(); + + Device* device = view.get_device(0, 0); + ; + + // General setup is as follows: + // Worker 1 reads input tensor as a sequence of slices - it forwards to an output tensor and after each slice, it + // writes a semaphore increment to some known semaphore address on the destination worker so the destination worker + // knows it's safe to read that slice. + // HOWEVER. the reader will be programmed to read the chunks in a different order than they were written, this way + // we can identify synchronization related bugs (e.g. if sender semaphore increments before writes flush) + + TT_FATAL(num_workers_per_stage.size() == num_stages, "Must have a read order for each stage"); + TT_FATAL(worker_chunk_read_order.size() == num_stages, "Must have a read order for each stage"); + for (size_t i = 0; i < num_stages; ++i) { + TT_FATAL(worker_chunk_read_order[i].size() == slices_per_stage, "Must have a read order for each slice"); + } + + // Validate the test setup + TT_FATAL(num_stages > 1, "Must have at least 2 stages"); + TT_FATAL(num_stages < 8, "Must have at most 8 stages"); + for (size_t i = 0; i < num_stages; ++i) { + TT_FATAL(num_workers_per_stage[i] > 0, "Must have at least 1 worker per stage"); + TT_FATAL(num_workers_per_stage[i] < 8, "Must have at most 8 workers per stage"); + } + + std::vector tensor_specs; + tensor_specs.reserve(num_stages + 1); + for (size_t i = 0; i < num_stages + 1; ++i) { + tensor_specs.push_back(TensorSpec( + logical_shape, TensorLayout(DataType::UINT32, PageConfig(layout, tt_metal::Tile()), mem_configs[i]))); + } + + // Allocate the tensors - pull to function + const size_t num_tensors = num_stages + 1; + std::vector host_tensors; + std::vector device_tensors; + host_tensors.reserve(num_tensors); + device_tensors.reserve(num_tensors); + auto num_elems = std::reduce(logical_shape.cbegin(), logical_shape.cend(), 1, std::multiplies()); + host_tensors.push_back(ttnn::arange(0, num_elems, 1, DataType::UINT32).reshape(tensor_shape).to(layout)); + for (size_t i = 1; i < num_tensors; ++i) { + host_tensors.push_back(ttnn::ones(tensor_shape.value, DataType::UINT32, layout).reshape(tensor_shape)); + } + TT_FATAL(mem_configs.size() == num_tensors, "Must have a memory config for each tensor"); + for (size_t i = 0; i < num_tensors; i++) { + host_tensors[i].set_tensor_spec(tensor_specs[i]); + device_tensors.push_back(host_tensors[i].to(device, mem_configs[i])); + log_info("Tensor[{}] allocated starting at address {}", i, device_tensors[i].buffer()->address()); + } + TT_ASSERT(device_tensors.size() == num_tensors); + TT_ASSERT(device_tensors.size() == host_tensors.size()); + + // MAIN STUFF + + // Initial setup like worker core assignment, chunk read order, etc. + + std::vector pipeline_stage_worker_cores = {}; + for (size_t i = 0; i < num_stages; ++i) { + pipeline_stage_worker_cores.push_back( + CoreRangeSet(CoreRange(CoreCoord(0, i), CoreCoord(num_workers_per_stage[i] - 1, i)))); + } + CoreRangeSet all_workers_cores = CoreRangeSet(); + for (size_t i = 0; i < num_stages; ++i) { + } + + // Create circular buffers + for (size_t stage = 0; stage < num_stages; stage++) { + const size_t cb_packet_size_in_pages = 4; + const size_t num_packets_per_cb = 4; + tt_metal::CircularBufferConfig cb_config = + tt_metal::CircularBufferConfig( + cb_packet_size_in_pages * num_packets_per_cb * page_size_bytes, {{cb_index, data_format}}) + .set_page_size(cb_index, page_size_bytes); + CBHandle sender_workers_cb = CreateCircularBuffer(program, pipeline_stage_worker_cores[stage], cb_config); + } + + // Generate the reader semaphores + std::vector> input_tensor_semaphores; + input_tensor_semaphores.reserve(num_stages); + for (size_t stage = 0; stage < num_stages; stage++) { + input_tensor_semaphores.push_back({}); + for (size_t j = 0; j < slices_per_stage; j++) { + input_tensor_semaphores[stage].push_back(CreateSemaphore(program, pipeline_stage_worker_cores[stage], 0)); + } + } + + constexpr size_t num_command_streams = 1; + std::vector reader_kernels; + std::vector writer_kernels; + // Create the kernel handles for each pipeline stage + for (size_t stage = 0; stage < num_stages; stage++) { + auto reader_kernel = ttnn::ccl::worker_detail::generate_multi_command_stream_kernel_ct_args( + program, + {tt::CB::c_in0}, + {&device_tensors[stage]}, + pipeline_stage_worker_cores[stage], + tt_metal::ReaderDataMovementConfig{}, + num_command_streams); + reader_kernels.push_back(reader_kernel); + auto writer_kernel = ttnn::ccl::worker_detail::generate_multi_command_stream_kernel_ct_args( + program, + {tt::CB::c_in0}, + {&device_tensors[stage + 1]}, + pipeline_stage_worker_cores[stage], + tt_metal::WriterDataMovementConfig{}, + num_command_streams); + writer_kernels.push_back(writer_kernel); + } + + // Generate the tensor slices for each tensor/worker + std::vector> tensor_slices; + tensor_slices.reserve(num_stages + 1); + for (size_t t = 0; t < num_tensors; t++) { + tensor_slices.push_back( + ttnn::ccl::cmd::builder::generate_tensor_slices(slices_per_stage, device_tensors[t], split_dim)); + } + std::vector>> per_stage_worker_reader_tensor_slices; + std::vector>> per_stage_worker_writer_tensor_slices; + per_stage_worker_reader_tensor_slices.reserve(num_tensors); + per_stage_worker_writer_tensor_slices.reserve(num_tensors); + for (size_t stage = 0; stage < num_stages; stage++) { + per_stage_worker_reader_tensor_slices.push_back( + ttnn::ccl::cmd::builder::split_tensor_slices_across_workers_page_aligned( + num_workers_per_stage[stage], tensor_slices[stage])); + // We could compute this once and reuse it but I am generating it twice so I can have size mismatches + per_stage_worker_writer_tensor_slices.push_back( + ttnn::ccl::cmd::builder::split_tensor_slices_across_workers_page_aligned( + num_workers_per_stage[stage], tensor_slices[stage + 1])); + TT_FATAL( + per_stage_worker_reader_tensor_slices.back().size() == num_workers_per_stage[stage], + "Mismatch in tensor slices. Got {} but expected {}", + per_stage_worker_reader_tensor_slices.back().size(), + num_workers_per_stage[stage]); + TT_FATAL( + per_stage_worker_writer_tensor_slices.back().size() == num_workers_per_stage[stage], + "Mismatch in tensor slices. Got {} but expected {}", + per_stage_worker_writer_tensor_slices.back().size(), + num_workers_per_stage[stage]); + } + + // Build the command stream for each stage/worker + // Seminc example + // - local_core_semaphore_inc(second_command_stream_done_semaphore_id, 1); + // semwait example + // - local_semaphore_wait(second_command_stream_done_semaphore_id, 1) + // read tensor slice to cb example + // - read_tensor_slice_to_cb(in0_command_tensor_slice, cb_indices.at(0)) + // write tensor slice to cb example + // - build_write_tensor_slice_to_cb(out0_command_tensor_slice, cb_indices.at(0)) + TT_FATAL(per_stage_worker_reader_tensor_slices.size() == num_stages, "Mismatch in tensor slices"); + for (size_t stage = 0; stage < num_stages; stage++) { + bool last_stage = stage == num_stages - 1; + bool first_stage = stage == 0; + + const auto worker_cores = corerange_to_cores(pipeline_stage_worker_cores[stage]); + TT_FATAL(worker_cores.size() == num_workers_per_stage[stage], "Mismatch in worker cores"); + std::optional> next_worker_cores = + !last_stage ? corerange_to_cores(pipeline_stage_worker_cores[stage + 1]) + : std::optional>(std::nullopt); + + TT_FATAL( + per_stage_worker_reader_tensor_slices[stage].size() == num_workers_per_stage[stage], + "Mismatch in tensor slices"); + TT_FATAL( + per_stage_worker_writer_tensor_slices[stage].size() == num_workers_per_stage[stage], + "Mismatch in tensor slices"); + for (size_t worker = 0; worker < num_workers_per_stage[stage]; worker++) { + std::vector reader_cmd_stream; + std::vector writer_cmd_stream; + TT_FATAL( + per_stage_worker_reader_tensor_slices[stage][worker].size() == slices_per_stage, + "Mismatch in tensor slices"); + TT_FATAL( + per_stage_worker_writer_tensor_slices[stage][worker].size() == slices_per_stage, + "Mismatch in tensor slices"); + for (size_t slice_logical = 0; slice_logical < slices_per_stage; slice_logical++) { + const auto slice_actual = worker_chunk_read_order[stage][slice_logical]; + // reader + if (!first_stage) { + reader_cmd_stream.push_back(ttnn::ccl::cmd::uops::local_semaphore_wait( + input_tensor_semaphores[stage][slice_actual], num_workers_per_stage[stage - 1])); + } + reader_cmd_stream.push_back(ttnn::ccl::cmd::uops::read_tensor_slice_to_cb( + per_stage_worker_reader_tensor_slices[stage][worker][slice_actual], cb_index)); + log_info(tt::LogTest, "Worker {} reading/writing slice {}", worker, slice_actual); + + // writer + writer_cmd_stream.push_back(ttnn::ccl::cmd::uops::local_write_cb_to_tensor_slice( + per_stage_worker_writer_tensor_slices[stage][worker][slice_actual], cb_index)); + if (not last_stage) { + for (auto next_worker_xy : next_worker_cores.value()) { + log_info( + tt::LogTest, + "Stage {} Worker {} noc seminc to core (logical) x={},y={}", + stage, + worker, + next_worker_xy.x, + next_worker_xy.y); + writer_cmd_stream.push_back(ttnn::ccl::cmd::uops::local_chip_noc_semaphore_inc( + device->worker_core_from_logical_core(next_worker_xy).x, + device->worker_core_from_logical_core(next_worker_xy).y, + input_tensor_semaphores[stage + 1][slice_actual], + 1)); + } + } + } + ttnn::ccl::worker_detail::generate_multi_input_command_stream_kernel_rt_args( + program, + reader_kernels[stage], + {&device_tensors[stage]}, + {page_size_bytes}, + device, + cb_packet_size_in_pages, + {worker_cores.at(worker)}, + reader_cmd_stream, + std::nullopt, + std::nullopt, + std::nullopt); + ttnn::ccl::worker_detail::generate_multi_input_command_stream_kernel_rt_args( + program, + writer_kernels[stage], + {&device_tensors[stage + 1]}, + {page_size_bytes}, + device, + cb_packet_size_in_pages, + {worker_cores.at(worker)}, + writer_cmd_stream, + std::nullopt, + std::nullopt, + std::nullopt); + } + } + + run_programs(programs, {device}); + + bool pass = true; + constexpr bool enable_check = true; + if constexpr (enable_check) { + log_info(tt::LogTest, "Reading back outputs"); + auto input_cpu = device_tensors[0].cpu(); + auto final_out_cpu = device_tensors.back().cpu(); + + auto in_tensor_copyback = tt::tt_metal::owned_buffer::get_as(input_cpu); + auto out_tensor_copyback = tt::tt_metal::owned_buffer::get_as(final_out_cpu); + + auto in_tensor_data = tt::tt_metal::owned_buffer::get_as(host_tensors[0]); + + bool input_copyback_check_passed = run_output_check(in_tensor_data, in_tensor_copyback) == Correctness::Correct; + TT_FATAL(input_copyback_check_passed, "Input 0 copyback check failed"); + + log_info(tt::LogTest, "Comparing outputs"); + + pass &= run_output_check(in_tensor_data, out_tensor_copyback) == Correctness::Correct; + if (pass) { + log_info(tt::LogTest, "Output check passed for output 0"); + } else { + log_error(tt::LogTest, "Output check failed for output 0"); + } + } + + return pass; +} + +TEST(WorkerCclCommandProcessingKernels, ChainOfCommandProcessorsWithVaryingDataReadOrders_LocalOnly0) { + ttnn::Shape tensor_shape = {1, 1, 64, 16384}; + auto logical_shape = tensor_shape.logical_shape(); + const size_t split_dim = 3; + + // In this test we will have n stages with anywhere from 1 to 8 workers per stage (this will be configurable) + constexpr size_t num_stages = 4; + const size_t slices_per_stage = 4; + const size_t cb_packet_size_in_pages = 4; + const size_t num_packets_per_cb = 4; + auto layout = Layout::TILE; + const tt::DataFormat data_format = tt::DataFormat::RawUInt32; + const size_t page_size_bytes = tile_size(DataFormat::RawUInt32); + std::vector num_workers_per_stage = {1, 1, 1, 1}; + + std::vector> worker_chunk_read_order = { + {0, 1, 2, 3}, // first input + {3, 2, 1, 0}, // read in reverse order + {2, 0, 3, 1}, // read in non-sequential order + {1, 2, 3, 0} // read in non-sequential order + }; + std::vector mem_configs{ + MemoryConfig(TensorMemoryLayout::INTERLEAVED, BufferType::DRAM), + MemoryConfig(TensorMemoryLayout::INTERLEAVED, BufferType::DRAM), + MemoryConfig(TensorMemoryLayout::INTERLEAVED, BufferType::DRAM), + MemoryConfig(TensorMemoryLayout::INTERLEAVED, BufferType::DRAM), + MemoryConfig(TensorMemoryLayout::INTERLEAVED, BufferType::DRAM)}; + + auto pass = RunPipelinedWorkersTest( + + tensor_shape, + split_dim, + + // In this test we will have n stages with anywhere from 1 to 8 workers per stage (this will be configurable) + num_stages, + num_workers_per_stage, + slices_per_stage, + data_format, + page_size_bytes, + cb_packet_size_in_pages, + num_packets_per_cb, + layout, + + worker_chunk_read_order, + mem_configs); + + ASSERT_TRUE(pass); +} +TEST(WorkerCclCommandProcessingKernels, ChainOfCommandProcessorsWithVaryingDataReadOrders_LocalOnly1) { + ttnn::Shape tensor_shape = {1, 1, 64, 128}; + auto logical_shape = tensor_shape.logical_shape(); + const size_t split_dim = 3; + + // In this test we will have n stages with anywhere from 1 to 8 workers per stage (this will be configurable) + constexpr size_t num_stages = 4; + const size_t slices_per_stage = 4; + const size_t cb_packet_size_in_pages = 4; + const size_t num_packets_per_cb = 4; + auto layout = Layout::TILE; + const tt::DataFormat data_format = tt::DataFormat::RawUInt32; + const size_t page_size_bytes = tile_size(DataFormat::RawUInt32); + std::vector num_workers_per_stage = {1, 1, 1, 1}; + + std::vector> worker_chunk_read_order = { + {0, 1, 2, 3}, // first input + {3, 2, 1, 0}, // read in reverse order + {2, 0, 3, 1}, // read in non-sequential order + {1, 2, 3, 0} // read in non-sequential order + }; + std::vector mem_configs{ + MemoryConfig(TensorMemoryLayout::INTERLEAVED, BufferType::DRAM), + MemoryConfig(TensorMemoryLayout::INTERLEAVED, BufferType::DRAM), + MemoryConfig(TensorMemoryLayout::INTERLEAVED, BufferType::DRAM), + MemoryConfig(TensorMemoryLayout::INTERLEAVED, BufferType::DRAM), + MemoryConfig(TensorMemoryLayout::INTERLEAVED, BufferType::DRAM)}; + + auto pass = RunPipelinedWorkersTest( + + tensor_shape, + split_dim, + + // In this test we will have n stages with anywhere from 1 to 8 workers per stage (this will be configurable) + num_stages, + num_workers_per_stage, + slices_per_stage, + data_format, + page_size_bytes, + cb_packet_size_in_pages, + num_packets_per_cb, + layout, + + worker_chunk_read_order, + mem_configs); + + ASSERT_TRUE(pass); +} +TEST(WorkerCclCommandProcessingKernels, ChainOfCommandProcessorsWithVaryingDataReadOrders_LocalOnly2) { + ttnn::Shape tensor_shape = {1, 1, 64, 8192}; + auto logical_shape = tensor_shape.logical_shape(); + const size_t split_dim = 3; + + // In this test we will have n stages with anywhere from 1 to 8 workers per stage (this will be configurable) + constexpr size_t num_stages = 4; + const size_t slices_per_stage = 2; + const size_t cb_packet_size_in_pages = 4; + const size_t num_packets_per_cb = 4; + auto layout = Layout::TILE; + const tt::DataFormat data_format = tt::DataFormat::RawUInt32; + const size_t page_size_bytes = tile_size(DataFormat::RawUInt32); + std::vector num_workers_per_stage = {1, 1, 1, 1}; + + std::vector> worker_chunk_read_order = { + {0, 1}, // first input + {1, 0}, // read in reverse order + {1, 0}, // read in non-sequential order + {0, 1} // read in non-sequential order + }; + std::vector mem_configs{ + MemoryConfig(TensorMemoryLayout::INTERLEAVED, BufferType::DRAM), + MemoryConfig(TensorMemoryLayout::INTERLEAVED, BufferType::DRAM), + MemoryConfig(TensorMemoryLayout::INTERLEAVED, BufferType::DRAM), + MemoryConfig(TensorMemoryLayout::INTERLEAVED, BufferType::DRAM), + MemoryConfig(TensorMemoryLayout::INTERLEAVED, BufferType::DRAM)}; + + auto pass = RunPipelinedWorkersTest( + + tensor_shape, + split_dim, + + // In this test we will have n stages with anywhere from 1 to 8 workers per stage (this will be configurable) + num_stages, + num_workers_per_stage, + slices_per_stage, + data_format, + page_size_bytes, + cb_packet_size_in_pages, + num_packets_per_cb, + layout, + + worker_chunk_read_order, + mem_configs); + + ASSERT_TRUE(pass); +} + +// Hits issues with input tensor copy-back +TEST( + WorkerCclCommandProcessingKernels, + DISABLED_ChainOfCommandProcessorsWithVaryingDataReadOrders_LocalOnly_SmallSweep) { + std::vector tensor_shapes = { + {1, 1, 64, 8192}, {1, 4, 64, 768}, {4, 1, 64, 768}, {4, 4, 64, 768}, {1, 1, 64, 768}, {5, 3, 64, 768}}; + + const size_t split_dim = 3; + + // In this test we will have n stages with anywhere from 1 to 8 workers per stage (this will be configurable) + constexpr size_t num_stages = 4; + const std::vector slices_per_stage_sweep = {2, 3, 4}; + const size_t cb_packet_size_in_pages = 4; + const size_t num_packets_per_cb = 4; + auto layout = Layout::TILE; + const tt::DataFormat data_format = tt::DataFormat::RawUInt32; + const size_t page_size_bytes = tile_size(DataFormat::RawUInt32); + std::vector> num_workers_per_stage_sweep = { + {1, 1, 1, 1}, {2, 2, 2, 2}, {3, 3, 3, 3}, {4, 4, 4, 4}}; + + std::vector>> worker_chunk_read_order = { + {{}}, + { + {0}, + {0}, + {0}, + {0}, + }, + { + {0, 1}, + {1, 0}, + {1, 0}, + {0, 1}, + }, + { + {2, 0, 1}, + {1, 0, 2}, + {0, 1, 2}, + {2, 1, 0}, + }, + { + {0, 1, 2, 3}, // first input + {3, 2, 1, 0}, // read in reverse order + {2, 0, 3, 1}, // read in non-sequential order + {1, 2, 3, 0} // read in non-sequential order + }}; + std::vector> mem_configs_sweep = { + { + MemoryConfig(TensorMemoryLayout::INTERLEAVED, BufferType::DRAM), + MemoryConfig(TensorMemoryLayout::INTERLEAVED, BufferType::DRAM), + MemoryConfig(TensorMemoryLayout::INTERLEAVED, BufferType::DRAM), + MemoryConfig(TensorMemoryLayout::INTERLEAVED, BufferType::DRAM), + MemoryConfig(TensorMemoryLayout::INTERLEAVED, BufferType::DRAM), + }, + {MemoryConfig(TensorMemoryLayout::INTERLEAVED, BufferType::L1), + MemoryConfig(TensorMemoryLayout::INTERLEAVED, BufferType::L1), + MemoryConfig(TensorMemoryLayout::INTERLEAVED, BufferType::L1), + MemoryConfig(TensorMemoryLayout::INTERLEAVED, BufferType::L1), + MemoryConfig(TensorMemoryLayout::INTERLEAVED, BufferType::L1)}, + {MemoryConfig(TensorMemoryLayout::INTERLEAVED, BufferType::L1), + MemoryConfig(TensorMemoryLayout::INTERLEAVED, BufferType::DRAM), + MemoryConfig(TensorMemoryLayout::INTERLEAVED, BufferType::L1), + MemoryConfig(TensorMemoryLayout::INTERLEAVED, BufferType::L1), + MemoryConfig(TensorMemoryLayout::INTERLEAVED, BufferType::DRAM)}, + {MemoryConfig(TensorMemoryLayout::INTERLEAVED, BufferType::L1), + MemoryConfig(TensorMemoryLayout::INTERLEAVED, BufferType::DRAM), + MemoryConfig(TensorMemoryLayout::INTERLEAVED, BufferType::L1), + MemoryConfig(TensorMemoryLayout::INTERLEAVED, BufferType::L1), + MemoryConfig(TensorMemoryLayout::INTERLEAVED, BufferType::DRAM)}, + }; + + for (auto& tensor_shape : tensor_shapes) { + for (auto& num_workers_per_stage : num_workers_per_stage_sweep) { + for (size_t slices_per_stage : slices_per_stage_sweep) { + for (auto& mem_configs : mem_configs_sweep) { + log_info( + tt::LogTest, + "tensor shape {} and workers stage {} slices_per_stage {}", + tensor_shape, + num_workers_per_stage, + slices_per_stage); + auto pass = RunPipelinedWorkersTest( + + tensor_shape, + split_dim, + + // In this test we will have n stages with anywhere from 1 to 8 workers per stage (this will be + // configurable) + num_stages, + num_workers_per_stage, + slices_per_stage, + data_format, + page_size_bytes, + cb_packet_size_in_pages, + num_packets_per_cb, + layout, + + worker_chunk_read_order[slices_per_stage], + mem_configs); + + ASSERT_TRUE(pass); + } + } + } + } +} + +#include "ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/device/reduce_scatter_async_op.hpp" +#include "tt_metal/common/bfloat16.hpp" +TEST(CclAsyncOp, ReduceScatterSmall_PersistentFabric) { + const size_t dim = 3; + const size_t num_links = 1; + constexpr auto layout = Layout::TILE; + // DEVICES setup + auto arch = tt::get_arch_from_string(tt::test_utils::get_umd_arch_name()); + constexpr size_t test_expected_num_devices = 4; + if (tt::tt_metal::GetNumAvailableDevices() < test_expected_num_devices) { + log_info("This test can only be run on T3000 devices"); + return; + } + if (arch == tt::ARCH::GRAYSKULL) { + log_info("Test must be run on WH"); + return; + } + T3000TestDevice test_fixture; + auto view = test_fixture.mesh_device_->get_view(); + + // build a line of devices + std::vector devices = { + view.get_device(0, 1), view.get_device(1, 1), view.get_device(1, 2), view.get_device(0, 2)}; + const size_t num_devices = devices.size(); + TT_FATAL( + test_expected_num_devices == num_devices, + "Expected {} devices but got {}", + test_expected_num_devices, + num_devices); + const ttnn::Shape input_shape = ttnn::Shape{1, 1, 32, 32 * num_devices}; + const MemoryConfig in_memory_config = MemoryConfig(TensorMemoryLayout::INTERLEAVED, BufferType::DRAM); + auto const logical_shape = input_shape.logical_shape(); + const auto num_elems = logical_shape.volume(); + + // INPUT TENSOR setup + size_t page_size = tile_size(DataFormat::Float16); + std::vector device_input_tensors; + for (size_t i = 0; i < num_devices; i++) { + // host_input_tensors.push_back(ttnn::numpy::random::uniform(bfloat16(-1.0f), bfloat16(1.0f) , + // {logical_shape[0],logical_shape[1],logical_shape[2],logical_shape[3]}, layout).to(devices[i])); + auto t = ttnn::arange(0, num_elems, 1, DataType::BFLOAT16).reshape(input_shape).to(layout); + t.set_tensor_spec(TensorSpec( + logical_shape, TensorLayout(DataType::BFLOAT16, PageConfig(layout, tt_metal::Tile()), in_memory_config))); + + device_input_tensors.push_back(t.to(devices[i])); + } + // Need to make it a mesh tensor for use with the op + const Tensor input_mesh_tensor = ttnn::distributed::aggregate_as_tensor(device_input_tensors, AllGatherTensor{}); + + // FABRIC setup + const bool enable_persistent_fabric = true; + + std::vector dummy_worker_programs; + std::optional subdevice_managers = std::nullopt; + std::optional> fabric_programs; + std::vector fabric_program_ptrs; + std::optional fabric_handle; + setup_test_with_persistent_fabric( + devices, + dummy_worker_programs, + subdevice_managers, + fabric_programs, + fabric_program_ptrs, + fabric_handle, + enable_persistent_fabric, + num_links); + + auto output_tensor = ttnn::operations::experimental::ccl::reduce_scatter( + input_mesh_tensor, + dim, + ttnn::operations::reduction::ReduceType::Sum, + operation::DEFAULT_OUTPUT_MEMORY_CONFIG, + ttnn::ccl::Topology::Linear, + num_links, + subdevice_managers->worker_subdevice_id.at(devices[0]->id()), + true, + fabric_handle); + + // wait for op completion + log_info(tt::LogTest, "Waiting for Op finish"); + std::ranges::for_each(devices, [&](Device* d) { + tt_metal::Finish(d->command_queue(), {subdevice_managers->worker_subdevice_id.at(d->id())}); + }); + log_info(tt::LogTest, "Main op done"); + + log_info(tt::LogTest, "Fabric teardown"); + persistent_fabric_teardown_sequence( + devices, subdevice_managers, fabric_handle.value(), tt::fabric::TerminationSignal::GRACEFULLY_TERMINATE); + + log_info(tt::LogTest, "Waiting for teardown completion"); + for (auto d : devices) { + tt_metal::Synchronize(d, ttnn::DefaultQueueId); + } + log_info(tt::LogTest, "Finished"); +} + +#include "ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.hpp" +void run_all_gather_with_persistent_fabric(const size_t dim, const size_t num_links, ttnn::Shape const& input_shape) { + log_info(tt::LogTest, "entering test"); + constexpr auto layout = Layout::TILE; + // DEVICES setuip + auto arch = tt::get_arch_from_string(tt::test_utils::get_umd_arch_name()); + constexpr size_t test_expected_num_devices = 4; + if (tt::tt_metal::GetNumAvailableDevices() < test_expected_num_devices) { + log_info("This test can only be run on T3000 devices"); + return; + } + if (arch == tt::ARCH::GRAYSKULL) { + log_info("Test must be run on WH"); + return; + } + T3000TestDevice test_fixture; + auto view = test_fixture.mesh_device_->get_view(); + + // build a line of devices + std::vector devices = { + view.get_device(0, 0), view.get_device(0, 1), view.get_device(0, 2), view.get_device(0, 3)}; + const size_t num_devices = devices.size(); + TT_FATAL( + test_expected_num_devices == num_devices, + "Expected {} devices but got {}", + test_expected_num_devices, + num_devices); + const MemoryConfig in_memory_config = MemoryConfig(TensorMemoryLayout::INTERLEAVED, BufferType::DRAM); + auto const logical_shape = input_shape.logical_shape(); + const auto num_elems = logical_shape.volume(); + + // INPUT TENSOR setup + log_info(tt::LogTest, "setting up input tensors"); + size_t page_size = tile_size(DataFormat::Float16); + std::vector device_input_tensors; + for (size_t i = 0; i < num_devices; i++) { + auto t = ttnn::arange(0, num_elems, 1).reshape(input_shape).to(layout); + t.set_tensor_spec(TensorSpec( + logical_shape, TensorLayout(DataType::BFLOAT16, PageConfig(layout, tt_metal::Tile()), in_memory_config))); + + device_input_tensors.push_back(t.to(devices[i])); + } + // Need to make it a mesh tensor for use with the op + const Tensor input_mesh_tensor = ttnn::distributed::aggregate_as_tensor(device_input_tensors, AllGatherTensor{}); + + // FABRIC setup + const bool enable_persistent_fabric = true; + + std::vector dummy_worker_programs; + std::optional subdevice_managers = std::nullopt; + std::optional> fabric_programs; + std::vector fabric_program_ptrs; + std::optional fabric_handle; + setup_test_with_persistent_fabric( + devices, + dummy_worker_programs, + subdevice_managers, + fabric_programs, + fabric_program_ptrs, + fabric_handle, + enable_persistent_fabric, + num_links); + log_info(tt::LogTest, "Lauching op"); + + auto output_tensor = ttnn::operations::experimental::ccl::all_gather_async( + input_mesh_tensor, + dim, + num_links, + operation::DEFAULT_OUTPUT_MEMORY_CONFIG, + ttnn::ccl::Topology::Linear, + SubDeviceId(0), + true); + + // wait for op completion + log_info(tt::LogTest, "Waiting for Op finish"); + std::ranges::for_each(devices, [&](Device* d) { + tt_metal::Finish(d->command_queue(), {subdevice_managers->worker_subdevice_id.at(d->id())}); + }); + log_info(tt::LogTest, "Main op done"); + + log_info(tt::LogTest, "Fabric teardown"); + persistent_fabric_teardown_sequence( + devices, subdevice_managers, fabric_handle.value(), tt::fabric::TerminationSignal::IMMEDIATELY_TERMINATE); + + log_info(tt::LogTest, "Waiting for teardown completion"); + for (auto d : devices) { + tt_metal::Synchronize(d, ttnn::DefaultQueueId); + } + log_info(tt::LogTest, "Finished"); +} + +TEST(CclAsyncOp, AllGather_PersistentFabric_Dim3_Links1_Shape1_1_32_128) { + run_all_gather_with_persistent_fabric(3, 1, ttnn::Shape{1, 1, 32, 128}); +} +TEST(CclAsyncOp, AllGather_PersistentFabric_Dim3_Links1_Shape1_1_32_8192) { + run_all_gather_with_persistent_fabric(3, 1, ttnn::Shape{1, 1, 32, 8192}); +} +// Mesh device setup seems to not provide the correct configuration for multi-link? To be investigated +TEST(CclAsyncOp, DISABLED_AllGather_PersistentFabric_Dim3_Links2_Shape1_1_32_128) { + run_all_gather_with_persistent_fabric(3, 2, ttnn::Shape{1, 1, 32, 128}); +} +// Mesh device setup seems to not provide the correct configuration for multi-link? To be investigated +TEST(CclAsyncOp, DISABLED_AllGather_PersistentFabric_Dim3_Links2_Shape1_1_32_8192) { + run_all_gather_with_persistent_fabric(3, 2, ttnn::Shape{1, 1, 32, 8192}); +} diff --git a/tests/ttnn/unit_tests/gtests/ccl/test_sharded_address_generators.cpp b/tests/ttnn/unit_tests/gtests/ccl/test_sharded_address_generators.cpp new file mode 100644 index 00000000000..1f56107bc7c --- /dev/null +++ b/tests/ttnn/unit_tests/gtests/ccl/test_sharded_address_generators.cpp @@ -0,0 +1,641 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include "gtest/gtest.h" +#include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/sharded_tensor_addr_gen.hpp" + +static constexpr std::array worker_to_routing_x_wormhole = {1, 2, 3, 4, 6, 7, 8, 9}; + +static constexpr std::array worker_to_routing_y_wormhole = {1, 2, 3, 4, 5, 7, 8, 9, 10, 11}; + +namespace tt { +namespace tt_metal { + +struct UnharvestedWormholeWorkerToNocLookup + : address_generators::WorkerToNocCoordLookup { + static constexpr std::array worker_to_routing_x = {1, 2, 3, 4, 6, 7, 8, 9}; + static constexpr std::array worker_to_routing_y = {1, 2, 3, 4, 5, 7, 8, 9, 10, 11}; + + noc_grid_index_t get_noc_x_from_worker_x(noc_grid_index_t worker_x) const { + // ASSERT worker_x < worker_to_routing_x_wormhole.size() + return worker_to_routing_x[worker_x]; + } + + noc_grid_index_t get_noc_y_from_worker_y(noc_grid_index_t worker_y) const { + // ASSERT worker_y < worker_to_routing_y_wormhole.size() + return worker_to_routing_y[worker_y]; + } +}; + +static void run_width_sharded_tensor_slice_indexer_get_page_location_test( + address_generators::WidthShardedAddressGenerator< + UnharvestedWormholeWorkerToNocLookup, + address_generators::DeviceWidthShardSpec>& addrgen, + + std::size_t pages_per_shard_y, + std::size_t pages_per_shard_x, + + std::size_t shard_grid_height, + std::size_t shard_grid_width, + + std::size_t worker_shard_cores_start_y, + std::size_t worker_shard_cores_start_x, + bool is_shard_grid_transposed) { + std::size_t page_id = 0; + // Takes a long time to sweep really large numbers so instead stride through the range. + // Really the only reason to test larger numbers is to catch overflow issues with smaller + // number types that may be carried around in the addrgen structs + std::size_t py_increment = pages_per_shard_y > 32 ? 7 : 1; + std::size_t px_increment = pages_per_shard_x > 32 ? 7 : 1; + + if (!is_shard_grid_transposed) { + for (std::size_t py = 0; py < pages_per_shard_y; py++) { + for (std::size_t y_logical = worker_shard_cores_start_y; + y_logical < worker_shard_cores_start_y + shard_grid_height; + y_logical++) { + for (std::size_t x_logical = worker_shard_cores_start_x; + x_logical < worker_shard_cores_start_x + shard_grid_width; + x_logical++) { + for (std::size_t px = 0; px < pages_per_shard_x; px++) { + if (px % px_increment == 0 && py % py_increment == 0 || + (py == (pages_per_shard_y - 1) || px != (pages_per_shard_x - 1))) { + const auto& result = addrgen.get_page_location(page_id); + ASSERT_EQ(result.core_location.noc_x, worker_to_routing_x_wormhole.at(x_logical)); + ASSERT_EQ(result.core_location.noc_y, worker_to_routing_y_wormhole.at(y_logical)); + ASSERT_EQ(result.page_offset, px + (py * pages_per_shard_x)); + + const auto& result2 = + addrgen.get_page_location_with_contiguous_pages_in_row_in_bank(page_id); + ASSERT_EQ(result2.core_location.noc_x, result.core_location.noc_x); + ASSERT_EQ(result2.core_location.noc_y, result.core_location.noc_y); + ASSERT_EQ(result2.page_offset, result.page_offset); + ASSERT_EQ(result2.contig_pages_in_row, pages_per_shard_x - px); + } + + page_id++; + } + } + } + } + } else { + for (std::size_t py = 0; py < pages_per_shard_y; py++) { + for (std::size_t x_logical = worker_shard_cores_start_x; + x_logical < worker_shard_cores_start_x + shard_grid_width; + x_logical++) { + for (std::size_t y_logical = worker_shard_cores_start_y; + y_logical < worker_shard_cores_start_y + shard_grid_height; + y_logical++) { + for (std::size_t px = 0; px < pages_per_shard_x; px++) { + if (px % px_increment == 0 && py % py_increment == 0 || + (py == (pages_per_shard_y - 1) || px != (pages_per_shard_x - 1))) { + const auto& result = addrgen.get_page_location(page_id); + ASSERT_EQ(result.core_location.noc_x, worker_to_routing_x_wormhole.at(x_logical)); + ASSERT_EQ(result.core_location.noc_y, worker_to_routing_y_wormhole.at(y_logical)); + ASSERT_EQ(result.page_offset, px + (py * pages_per_shard_x)); + + const auto& result2 = + addrgen.get_page_location_with_contiguous_pages_in_row_in_bank(page_id); + ASSERT_EQ(result2.core_location.noc_x, result.core_location.noc_x); + ASSERT_EQ(result2.core_location.noc_y, result.core_location.noc_y); + ASSERT_EQ(result2.page_offset, result.page_offset); + ASSERT_EQ(result2.contig_pages_in_row, pages_per_shard_x - px); + } + page_id++; + } + } + } + } + } +} + +static void run_width_sharded_tensor_slice_indexer_get_page_location_test( + std::size_t pages_per_shard_y, + std::size_t pages_per_shard_x, + + std::size_t shard_grid_height, + std::size_t shard_grid_width, + + std::size_t worker_shard_cores_start_y, + std::size_t worker_shard_cores_start_x, + + bool is_shard_grid_transposed) { + const std::size_t global_num_pages = pages_per_shard_y * pages_per_shard_x * shard_grid_width * shard_grid_height; + + auto addrgen = address_generators:: + WidthShardedAddressGenerator( + UnharvestedWormholeWorkerToNocLookup(), + address_generators::DeviceShardSpecTypeGetter::type( + pages_per_shard_y, + pages_per_shard_x, + shard_grid_height, + shard_grid_width, + worker_shard_cores_start_y, + worker_shard_cores_start_x, + is_shard_grid_transposed), + 1024, + 0x0); + + run_width_sharded_tensor_slice_indexer_get_page_location_test( + addrgen, + pages_per_shard_y, + pages_per_shard_x, + + shard_grid_height, + shard_grid_width, + + worker_shard_cores_start_y, + worker_shard_cores_start_x, + + is_shard_grid_transposed); +} + +TEST(CclWidthShardedTensorSliceIndexer_Wormhole, basic_test_case) { + static constexpr std::size_t pages_per_shard_y = 1; + static constexpr std::size_t pages_per_shard_x = 8; + + static constexpr std::size_t shard_grid_height = 2; + static constexpr std::size_t shard_grid_width = 1; + + static constexpr std::size_t worker_shard_cores_start_y = 0; + static constexpr std::size_t worker_shard_cores_start_x = 0; + + bool is_shard_grid_transposed = false; + + run_width_sharded_tensor_slice_indexer_get_page_location_test( + pages_per_shard_y, + pages_per_shard_x, + + shard_grid_height, + shard_grid_width, + + worker_shard_cores_start_y, + worker_shard_cores_start_x, + + is_shard_grid_transposed); +} + +TEST(CclWidthShardedTensorSliceIndexer_Wormhole, SweepWormhole) { + std::size_t max_worker_rows = 10; + std::size_t max_worker_cols = 8; + + for (auto pages_per_shard_y : {1, 2, 5, 8, 256}) { + for (auto pages_per_shard_x : {1, 2, 5, 8, 256}) { + for (auto shard_grid_offset_logical_y : {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}) { + for (auto shard_grid_offset_logical_x : {0, 1, 2, 3, 4, 5, 6, 7}) { + for (std::size_t shard_grid_height = 1; + shard_grid_height < (max_worker_rows - shard_grid_offset_logical_y); + shard_grid_height++) { + for (std::size_t shard_grid_width = 1; + shard_grid_width < (max_worker_cols - shard_grid_offset_logical_x); + shard_grid_width++) { + for (bool transpose_shard_grid : {false, true}) { + run_width_sharded_tensor_slice_indexer_get_page_location_test( + pages_per_shard_y, + pages_per_shard_x, + shard_grid_height, + shard_grid_width, + shard_grid_offset_logical_y, + shard_grid_offset_logical_x, + transpose_shard_grid); + } + } + } + } + } + } + } +} + +static void run_height_sharded_tensor_slice_indexer_get_page_location_test( + address_generators::HeightShardedAddressGenerator< + UnharvestedWormholeWorkerToNocLookup, + address_generators::DeviceHeightShardSpec>& addrgen, + std::size_t pages_per_shard_y, + std::size_t pages_per_shard_x, + + std::size_t shard_grid_height, + std::size_t shard_grid_width, + + std::size_t worker_shard_cores_start_y, + std::size_t worker_shard_cores_start_x, + + bool is_shard_grid_transposed) { + std::size_t page_id = 0; + + // Takes a long time to sweep really large numbers so instead stride through the range. + // Really the only reason to test larger numbers is to catch overflow issues with smaller + // number types that may be carried around in the addrgen structs + std::size_t py_increment = pages_per_shard_y > 32 ? 7 : 1; + std::size_t px_increment = pages_per_shard_x > 32 ? 7 : 1; + + if (!is_shard_grid_transposed) { + for (std::size_t x_logical = worker_shard_cores_start_x; + x_logical < worker_shard_cores_start_x + shard_grid_width; + x_logical++) { + for (std::size_t y_logical = worker_shard_cores_start_y; + y_logical < worker_shard_cores_start_y + shard_grid_height; + y_logical++) { + for (std::size_t py = 0; py < pages_per_shard_y; py++) { + for (std::size_t px = 0; px < pages_per_shard_x; px++) { + if (px % px_increment == 0 && py % py_increment == 0 || + (py == (pages_per_shard_y - 1) || px != (pages_per_shard_x - 1))) { + const auto& result = addrgen.get_page_location(page_id); + ASSERT_EQ(result.core_location.noc_x, worker_to_routing_x_wormhole.at(x_logical)); + ASSERT_EQ(result.core_location.noc_y, worker_to_routing_y_wormhole.at(y_logical)); + ASSERT_EQ(result.page_offset, px + (py * pages_per_shard_x)); + } + + page_id++; + } + } + } + } + } else { + for (std::size_t y_logical = worker_shard_cores_start_y; + y_logical < worker_shard_cores_start_y + shard_grid_height; + y_logical++) { + for (std::size_t x_logical = worker_shard_cores_start_x; + x_logical < worker_shard_cores_start_x + shard_grid_width; + x_logical++) { + for (std::size_t py = 0; py < pages_per_shard_y; py++) { + for (std::size_t px = 0; px < pages_per_shard_x; px++) { + if (px % px_increment == 0 && py % py_increment == 0 || + (py == (pages_per_shard_y - 1) || px != (pages_per_shard_x - 1))) { + const auto& result = addrgen.get_page_location(page_id); + ASSERT_EQ(result.core_location.noc_x, worker_to_routing_x_wormhole.at(x_logical)); + ASSERT_EQ(result.core_location.noc_y, worker_to_routing_y_wormhole.at(y_logical)); + ASSERT_EQ(result.page_offset, px + (py * pages_per_shard_x)); + } + page_id++; + } + } + } + } + } +} + +static void run_height_sharded_tensor_slice_indexer_get_page_location_test( + std::size_t pages_per_shard_y, + std::size_t pages_per_shard_x, + + std::size_t shard_grid_height, + std::size_t shard_grid_width, + + std::size_t worker_shard_cores_start_y, + std::size_t worker_shard_cores_start_x, + + bool is_shard_grid_transposed) { + const std::size_t global_num_pages = pages_per_shard_y * pages_per_shard_x * shard_grid_width * shard_grid_height; + + auto addrgen = address_generators:: + HeightShardedAddressGenerator( + UnharvestedWormholeWorkerToNocLookup(), + address_generators::DeviceShardSpecTypeGetter::type( + pages_per_shard_y, + pages_per_shard_x, + shard_grid_height, + shard_grid_width, + worker_shard_cores_start_y, + worker_shard_cores_start_x, + is_shard_grid_transposed), + 1024, + 0x0); + + run_height_sharded_tensor_slice_indexer_get_page_location_test( + addrgen, + pages_per_shard_y, + pages_per_shard_x, + shard_grid_height, + shard_grid_width, + worker_shard_cores_start_y, + worker_shard_cores_start_x, + + is_shard_grid_transposed); +} + +TEST(CclHeightShardedTensorSliceIndexer_Wormhole, basic_test_case) { + static constexpr std::size_t pages_per_shard_y = 8; + static constexpr std::size_t pages_per_shard_x = 1; + + static constexpr std::size_t shard_grid_height = 1; + static constexpr std::size_t shard_grid_width = 2; + + static constexpr std::size_t worker_shard_cores_start_y = 0; + static constexpr std::size_t worker_shard_cores_start_x = 0; + + bool is_shard_grid_transposed = false; + + run_height_sharded_tensor_slice_indexer_get_page_location_test( + pages_per_shard_y, + pages_per_shard_x, + + shard_grid_height, + shard_grid_width, + + worker_shard_cores_start_y, + worker_shard_cores_start_x, + + is_shard_grid_transposed); +} + +TEST(CclHeightShardedTensorSliceIndexer_Wormhole, SweepWormhole) { + std::size_t max_worker_rows = 10; + std::size_t max_worker_cols = 8; + + for (auto pages_per_shard_y : {1, 2, 5, 8, 256}) { + for (auto pages_per_shard_x : {1, 2, 5, 8, 256}) { + for (auto shard_grid_offset_logical_y : {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}) { + for (auto shard_grid_offset_logical_x : {0, 1, 2, 3, 4, 5, 6, 7}) { + for (std::size_t shard_grid_height = 1; + shard_grid_height < (max_worker_rows - shard_grid_offset_logical_y); + shard_grid_height++) { + for (std::size_t shard_grid_width = 1; + shard_grid_width < (max_worker_cols - shard_grid_offset_logical_x); + shard_grid_width++) { + for (bool transpose_shard_grid : {false, true}) { + run_height_sharded_tensor_slice_indexer_get_page_location_test( + pages_per_shard_y, + pages_per_shard_x, + shard_grid_height, + shard_grid_width, + shard_grid_offset_logical_y, + shard_grid_offset_logical_x, + transpose_shard_grid); + } + } + } + } + } + } + } +} + +static void run_block_sharded_tensor_slice_indexer_get_page_location_test( + address_generators::BlockShardedAddressGenerator< + UnharvestedWormholeWorkerToNocLookup, + address_generators::DeviceBlockShardSpec>& addrgen, + std::size_t pages_per_shard_y, + std::size_t pages_per_shard_x, + + std::size_t shard_grid_height, + std::size_t shard_grid_width, + + std::size_t worker_shard_cores_start_y, + std::size_t worker_shard_cores_start_x, + + bool is_shard_grid_transposed) { + std::size_t page_id = 0; + + // Takes a long time to sweep really large numbers so instead stride through the range. + // Really the only reason to test larger numbers is to catch overflow issues with smaller + // number types that may be carried around in the addrgen structs + std::size_t py_increment = pages_per_shard_y > 32 ? 7 : 1; + std::size_t px_increment = pages_per_shard_x > 32 ? 7 : 1; + + if (!is_shard_grid_transposed) { + for (std::size_t y_logical = worker_shard_cores_start_y; + y_logical < worker_shard_cores_start_y + shard_grid_height; + y_logical++) { + for (std::size_t py = 0; py < pages_per_shard_y; py++) { + for (std::size_t x_logical = worker_shard_cores_start_x; + x_logical < worker_shard_cores_start_x + shard_grid_width; + x_logical++) { + for (std::size_t px = 0; px < pages_per_shard_x; px++) { + if (px % px_increment == 0 && py % py_increment == 0 || + (py == (pages_per_shard_y - 1) || px != (pages_per_shard_x - 1))) { + const auto& result = addrgen.get_page_location(page_id); + ASSERT_EQ(result.core_location.noc_x, worker_to_routing_x_wormhole.at(x_logical)); + ASSERT_EQ(result.core_location.noc_y, worker_to_routing_y_wormhole.at(y_logical)); + ASSERT_EQ(result.page_offset, px + (py * pages_per_shard_x)); + + const auto& result2 = + addrgen.get_page_location_with_contiguous_pages_in_row_in_bank(page_id); + ASSERT_EQ(result2.core_location.noc_x, result.core_location.noc_x); + ASSERT_EQ(result2.core_location.noc_y, result.core_location.noc_y); + ASSERT_EQ(result2.page_offset, result.page_offset); + ASSERT_EQ(result2.contig_pages_in_row, pages_per_shard_x - px); + } + + page_id++; + } + } + } + } + } else { + ASSERT_EQ(true, false); //"Transposed grid not supported in testing yet" + } +} + +static void run_block_sharded_tensor_slice_indexer_get_page_location_test( + std::size_t pages_per_shard_y, + std::size_t pages_per_shard_x, + + std::size_t shard_grid_height, + std::size_t shard_grid_width, + + std::size_t worker_shard_cores_start_y, + std::size_t worker_shard_cores_start_x, + + bool is_shard_grid_transposed) { + const std::size_t global_num_pages = pages_per_shard_y * pages_per_shard_x * shard_grid_width * shard_grid_height; + + auto addrgen = address_generators:: + BlockShardedAddressGenerator( + UnharvestedWormholeWorkerToNocLookup(), + address_generators::DeviceShardSpecTypeGetter::type( + pages_per_shard_y, + pages_per_shard_x, + shard_grid_height, + shard_grid_width, + worker_shard_cores_start_y, + worker_shard_cores_start_x, + is_shard_grid_transposed), + 1024, + 0x0); + + run_block_sharded_tensor_slice_indexer_get_page_location_test( + addrgen, + pages_per_shard_y, + pages_per_shard_x, + shard_grid_height, + shard_grid_width, + worker_shard_cores_start_y, + worker_shard_cores_start_x, + is_shard_grid_transposed); +} + +TEST(CclBlockShardedTensorSliceIndexer_Wormhole, basic_test_case) { + static constexpr std::size_t pages_per_shard_y = 8; + static constexpr std::size_t pages_per_shard_x = 2; + + static constexpr std::size_t shard_grid_height = 3; + static constexpr std::size_t shard_grid_width = 2; + + static constexpr std::size_t worker_shard_cores_start_y = 0; + static constexpr std::size_t worker_shard_cores_start_x = 0; + + bool is_shard_grid_transposed = false; + + run_block_sharded_tensor_slice_indexer_get_page_location_test( + pages_per_shard_y, + pages_per_shard_x, + + shard_grid_height, + shard_grid_width, + + worker_shard_cores_start_y, + worker_shard_cores_start_x, + + is_shard_grid_transposed); +} + +TEST(CclBlockShardedTensorSliceIndexer_Wormhole, SweepWormhole) { + std::size_t max_worker_rows = 10; + std::size_t max_worker_cols = 8; + + for (auto pages_per_shard_y : {1, 2, 5, 8, 256}) { + for (auto pages_per_shard_x : {1, 2, 5, 8, 256}) { + for (auto shard_grid_offset_logical_y : {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}) { + for (auto shard_grid_offset_logical_x : {0, 1, 2, 3, 4, 5, 6, 7}) { + for (std::size_t shard_grid_height = 1; + shard_grid_height < (max_worker_rows - shard_grid_offset_logical_y); + shard_grid_height++) { + for (std::size_t shard_grid_width = 1; + shard_grid_width < (max_worker_cols - shard_grid_offset_logical_x); + shard_grid_width++) { + for (bool transpose_shard_grid : + {false}) { // true: transpose mode not yet supported for block sharded indexer + run_block_sharded_tensor_slice_indexer_get_page_location_test( + pages_per_shard_y, + pages_per_shard_x, + shard_grid_height, + shard_grid_width, + shard_grid_offset_logical_y, + shard_grid_offset_logical_x, + transpose_shard_grid); + } + } + } + } + } + } + } +} + +TEST(CclShardedTensorAddrGenBuilder, TestBuildWidthSharded) { + static constexpr std::size_t pages_per_shard_y = 1; + static constexpr std::size_t pages_per_shard_x = 8; + + static constexpr std::size_t shard_grid_height = 2; + static constexpr std::size_t shard_grid_width = 1; + + static constexpr std::size_t worker_shard_cores_start_y = 0; + static constexpr std::size_t worker_shard_cores_start_x = 0; + + bool is_shard_grid_transposed = false; + auto addrgen = build_sharded_addr_gen( + UnharvestedWormholeWorkerToNocLookup(), + address_generators::DeviceShardSpecTypeGetter::type( + pages_per_shard_y, + pages_per_shard_x, + shard_grid_height, + shard_grid_width, + worker_shard_cores_start_y, + worker_shard_cores_start_x, + is_shard_grid_transposed), + 1024, + 0x0); + + run_width_sharded_tensor_slice_indexer_get_page_location_test( + addrgen, + pages_per_shard_y, + pages_per_shard_x, + + shard_grid_height, + shard_grid_width, + + worker_shard_cores_start_y, + worker_shard_cores_start_x, + + is_shard_grid_transposed); +} +TEST(CclShardedTensorAddrGenBuilder, TestBuildHeightSharded) { + static constexpr std::size_t pages_per_shard_y = 8; + static constexpr std::size_t pages_per_shard_x = 1; + + static constexpr std::size_t shard_grid_height = 1; + static constexpr std::size_t shard_grid_width = 2; + + static constexpr std::size_t worker_shard_cores_start_y = 0; + static constexpr std::size_t worker_shard_cores_start_x = 0; + + bool is_shard_grid_transposed = false; + auto addrgen = build_sharded_addr_gen( + UnharvestedWormholeWorkerToNocLookup(), + address_generators::DeviceShardSpecTypeGetter::type( + pages_per_shard_y, + pages_per_shard_x, + shard_grid_height, + shard_grid_width, + worker_shard_cores_start_y, + worker_shard_cores_start_x, + is_shard_grid_transposed), + 1024, + 0x0); + + run_height_sharded_tensor_slice_indexer_get_page_location_test( + addrgen, + pages_per_shard_y, + pages_per_shard_x, + + shard_grid_height, + shard_grid_width, + + worker_shard_cores_start_y, + worker_shard_cores_start_x, + + is_shard_grid_transposed); +} +TEST(CclShardedTensorAddrGenBuilder, TestBuildBlockSharded) { + static constexpr std::size_t pages_per_shard_y = 8; + static constexpr std::size_t pages_per_shard_x = 2; + + static constexpr std::size_t shard_grid_height = 3; + static constexpr std::size_t shard_grid_width = 2; + + static constexpr std::size_t worker_shard_cores_start_y = 0; + static constexpr std::size_t worker_shard_cores_start_x = 0; + + bool is_shard_grid_transposed = false; + + auto addrgen = build_sharded_addr_gen( + UnharvestedWormholeWorkerToNocLookup(), + address_generators::DeviceShardSpecTypeGetter::type( + pages_per_shard_y, + pages_per_shard_x, + shard_grid_height, + shard_grid_width, + worker_shard_cores_start_y, + worker_shard_cores_start_x, + is_shard_grid_transposed), + 1024, + 0x0); + + run_block_sharded_tensor_slice_indexer_get_page_location_test( + addrgen, + pages_per_shard_y, + pages_per_shard_x, + + shard_grid_height, + shard_grid_width, + + worker_shard_cores_start_y, + worker_shard_cores_start_x, + + is_shard_grid_transposed); +} + +} // namespace tt_metal +} // namespace tt diff --git a/tests/ttnn/unit_tests/operations/ccl/test_new_all_gather.py b/tests/ttnn/unit_tests/operations/ccl/test_new_all_gather.py new file mode 100644 index 00000000000..052a0863b63 --- /dev/null +++ b/tests/ttnn/unit_tests/operations/ccl/test_new_all_gather.py @@ -0,0 +1,601 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import pytest +from loguru import logger +import ttnn +from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_equal, comp_pcc +from models.utility_functions import skip_for_grayskull +from tests.ttnn.unit_tests.operations.ccl.test_reduce_scatter_async import ( + create_and_load_sub_device_manager_with_fabric_interface, + teardown_fabric_interface, +) + + +def is_unsupported_case(input_shape, dim, mem_config, num_devices, num_links, input_dtype, layout): + if layout == ttnn.ROW_MAJOR_LAYOUT and input_dtype == ttnn.bfloat8_b: + return True, "Invalid combination" + + if input_shape[dim] % num_devices != 0 or (dim == 3 and input_shape[dim] // num_devices % 32 != 0): + return True, "Unsupported test case" + + ## Check that we can readback results + fast_dispatch_page_size_limit = 55 * 1024 + elem_size = 2 if input_dtype == ttnn.bfloat16 else 1 + if layout == ttnn.ROW_MAJOR_LAYOUT and (input_shape[dim] * elem_size) > fast_dispatch_page_size_limit: + # Fast dispatch currently can't breakup readback of large pages into multiple smaller pages and is + # limited to ~55K pages. + return True, "Fast dispatch can't support reading back this page size in one shot" + + # Check that we can fit in L1 (if L1 config) + tensor_size_bytes = elem_size + for i in input_shape: + tensor_size_bytes *= i + num_l1_banks = 64 + if mem_config.buffer_type == ttnn.BufferType.L1 and tensor_size_bytes > num_l1_banks * 50 * 1024: + return True, "L1 buffer can't support large tensor sizes" + + # Check that each chip has a non-zero amount of data available + min_sized_chunks_on_dim = input_shape[dim] + if dim == 3: + min_sized_chunks_on_dim //= 32 + if dim == 2: + if layout == ttnn.TILE_LAYOUT: + min_sized_chunks_on_dim //= 32 + if min_sized_chunks_on_dim < num_devices: + return ( + True, + f"Input shape {input_shape} incompatible with {num_devices} on dim {dim} because some chips will have no tensor", + ) + + if input_shape == [8, 8, 256, 384] and dim == 1 and layout == ttnn.TILE_LAYOUT and input_dtype == ttnn.bfloat8_b: + return True, "Known failure" + + return False, "" + + +def run_with_trace( + mesh_device, + all_gather_topology, + input_tensor_mesh, + dim, + num_links, + output_mem_config, + num_iter=20, + subdevice_id=None, +): + # Compile Run + logger.info("Compiling model") + tt_out_tensor = ttnn.experimental.all_gather_async( + input_tensor_mesh, + dim, + num_links=num_links, + memory_config=output_mem_config, + topology=all_gather_topology, + subdevice_id=subdevice_id, + create_semaphore_handles=True, + ) + for d in mesh_device.get_devices(): + ttnn.synchronize_device(d) + + # Capture trace + logger.info("Capturing trace") + trace_id = ttnn.begin_trace_capture(mesh_device, cq_id=0) + for i in range(num_iter): + tt_out_tensor = ttnn.experimental.all_gather_async( + input_tensor_mesh, + dim, + num_links=num_links, + memory_config=output_mem_config, + topology=all_gather_topology, + subdevice_id=subdevice_id, + create_semaphore_handles=False, + ) + ttnn.end_trace_capture(mesh_device, trace_id, cq_id=0) + for d in mesh_device.get_devices(): + ttnn.synchronize_device(d) + + # Run the op + logger.info("Starting Trace perf test...") + ttnn.execute_trace(mesh_device, trace_id, blocking=False) + ttnn.release_trace(mesh_device, trace_id) + for d in mesh_device.get_devices(): + ttnn.synchronize_device(d) + + return tt_out_tensor + + +def run_all_gather_impl( + mesh_device, + num_devices, + output_shape, + dim, + num_links, + input_dtype, + layout, + use_program_cache, + function_level_defaults, + all_gather_topology, + num_iters=1, + enable_async=False, + trace_mode=False, + rand_tensor=True, + mem_config=None, + input_shard_shape=None, + shard_grid=None, + tensor_mem_layout=None, + use_cluster_axis_api=False, + cluster_axis=None, + create_persistent_fabric=True, + teardown_persistent_fabric=True, +): + enable_persistent_fabric = True + if num_iters < 1: + pytest.fail("num_iters must be >= 1") + # Use Async mode based on test input config + mesh_device.enable_async(enable_async) + + if enable_async: + logger.info(f"Using Async Mode for All Gather Op Dispatch") + + logger.info(f"Output shape: {output_shape}") + logger.info(f"dim: {dim}") + logger.info(f"input_shard_shape: {input_shard_shape}") + logger.info(f"shard_grid: {shard_grid}") + + ### For sharded all gather only + if bool(input_shard_shape) != bool(shard_grid) and bool(tensor_mem_layout) != bool(shard_grid): + pytest.fail( + "Both input_shard_shape, shard_grid, and tensor_mem_layout must be provided together or all must be None" + ) + if input_shard_shape and shard_grid: + input_shard_spec = ttnn.ShardSpec( + shard_grid, + input_shard_shape, + ttnn.ShardOrientation.ROW_MAJOR, + False, + ) + input_mem_config = ttnn.MemoryConfig( + tensor_mem_layout, buffer_type=ttnn.BufferType.L1, shard_spec=input_shard_spec + ) + output_shard_shape = list(input_shard_shape) + if dim == 3: + output_shard_shape[1] *= num_devices + else: + output_shard_shape[0] *= num_devices + output_shard_spec = ttnn.ShardSpec( + shard_grid, + output_shard_shape, + ttnn.ShardOrientation.ROW_MAJOR, + False, + ) + output_mem_config = ttnn.MemoryConfig( + tensor_mem_layout, buffer_type=ttnn.BufferType.L1, shard_spec=output_shard_spec + ) + else: + assert mem_config is not None + input_mem_config = mem_config + output_mem_config = mem_config + ### + + if rand_tensor: + output_tensor = torch.rand(output_shape).bfloat16() + else: + output_tensor = torch.zeros(output_shape) + tile_id = 1 + for w in range(output_shape[0]): + for z in range(output_shape[1]): + for y in range(0, output_shape[2], 32): + for x in range(0, output_shape[3], 32): + output_tensor[w, z, y : y + 32, x : x + 32] = tile_id + tile_id += 1 + + input_tensors = torch.chunk(output_tensor, num_devices, dim) + tt_input_tensors = [] + for i, t in enumerate(input_tensors): + tt_input_tensors.append( + ttnn.Tensor(t, input_dtype).to(layout).to(mesh_device.get_devices()[i], input_mem_config) + ) + logger.info(f"using device {mesh_device.get_devices()[i].id()}") + + input_tensor_mesh = ttnn.aggregate_as_tensor(tt_input_tensors) + + compute_grid_size = mesh_device.compute_with_storage_grid_size() + worker_sub_device = ttnn.SubDevice( + [ + ttnn.CoreRangeSet( + {ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(compute_grid_size.x - 1, compute_grid_size.y - 1))} + ) + ] + ) + worker_sub_device_id = ttnn.SubDeviceId(0) + if create_persistent_fabric: + mesh_sub_device_manager_id = create_and_load_sub_device_manager_with_fabric_interface( + mesh_device, [worker_sub_device], 0, 0, enable_persistent_fabric + ) + + if trace_mode: + tt_out_tensor = run_with_trace( + mesh_device, + all_gather_topology, + input_tensor_mesh, + dim, + num_links, + output_mem_config, + num_iter=num_iters, + subdevice_id=worker_sub_device_id, + ) + else: + for i in range(num_iters): + if use_cluster_axis_api: + tt_out_tensor = ttnn.experimental.all_gather_async( + input_tensor_mesh, + dim, + cluster_axis=cluster_axis, + mesh_device=mesh_device, + memory_config=output_mem_config, + topology=all_gather_topology, + subdevice_id=worker_sub_device_id, + enable_persistent_fabric_mode=enable_persistent_fabric, + num_preferred_links=num_links, + create_semaphore_handles=True, + ) + + else: + tt_out_tensor = ttnn.experimental.all_gather_async( + input_tensor_mesh, + dim, + num_links=num_links, + memory_config=output_mem_config, + topology=all_gather_topology, + subdevice_id=worker_sub_device_id, + enable_persistent_fabric_mode=enable_persistent_fabric, + ) + + logger.info(f"Waiting for op {i}") + for d in mesh_device.get_devices(): + ttnn.synchronize_device(d, sub_device_ids=[worker_sub_device_id]) + logger.info(f"Done iteration {i}") + + if enable_persistent_fabric and teardown_persistent_fabric: + teardown_fabric_interface(mesh_device) + + for i, t in enumerate(ttnn.get_device_tensors(tt_out_tensor)): + tt_output_tensor = t.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch() + logger.info(f"Checking for device {t.device().id()}") + + if input_dtype == ttnn.bfloat16: + eq, output = comp_equal(tt_output_tensor, output_tensor) + else: + eq, output = comp_pcc(tt_output_tensor, output_tensor) + if not eq: + logger.error(f"output mismatch for tensor {i}") + assert eq, f"{i} FAILED: {output}" + + +# Enumerate the post-commit cases explicitly +@skip_for_grayskull("Requires eth connected devices to run") +@pytest.mark.parametrize( + "num_devices, num_links, output_shape, dim, layout", + [ + (4, 1, [1, 1, 64, 512], 3, ttnn.TILE_LAYOUT), + # (4, 1, [1, 1, 32, 32768], 3, ttnn.TILE_LAYOUT), + # (4, 1, [1, 1, 2048, 16384], 3, ttnn.TILE_LAYOUT), + ], +) +@pytest.mark.parametrize( + "input_dtype", + [ + ttnn.bfloat16, + ], +) +@pytest.mark.parametrize( + "mem_config", + [ + ttnn.MemoryConfig(buffer_type=ttnn.BufferType.DRAM), + ], +) +@pytest.mark.parametrize("num_iters", [1]) +@pytest.mark.parametrize("enable_async", [False]) +def test_all_gather( + t3k_mesh_device, + # pcie_mesh_device, + num_devices, + output_shape, + dim, + num_links, + input_dtype, + layout, + mem_config, + num_iters, + use_program_cache, + function_level_defaults, + enable_async, +): + run_all_gather_impl( + t3k_mesh_device, + num_devices, + output_shape, + dim, + num_links, + input_dtype, + layout, + use_program_cache, + function_level_defaults, + all_gather_topology=ttnn.Topology.Ring, + num_iters=num_iters, + enable_async=enable_async, + rand_tensor=True, + mem_config=mem_config, + ) + + +# Enumerate the post-commit cases explicitly +@skip_for_grayskull("Requires eth connected devices to run") +@pytest.mark.parametrize( + "num_devices, output_shape, dim, layout, input_shard_shape, shard_grid, tensor_mem_layout", + [ + ( + 2, + [1, 1, 32, 256], + 3, + ttnn.TILE_LAYOUT, + (32, 32), + ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(0, 3))}), + ttnn.TensorMemoryLayout.WIDTH_SHARDED, + ), + ( + 2, + [1, 1, 32, 256], + 3, + ttnn.TILE_LAYOUT, + (32, 64), + ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(0, 1))}), + ttnn.TensorMemoryLayout.WIDTH_SHARDED, + ), + ( + 2, + [1, 1, 32, 256], + 3, + ttnn.TILE_LAYOUT, + (32, 128), + ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(0, 0))}), + ttnn.TensorMemoryLayout.WIDTH_SHARDED, + ), + ( + 2, + [1, 1, 64, 256], + 2, + ttnn.TILE_LAYOUT, + (32, 128), + ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(0, 1))}), + ttnn.TensorMemoryLayout.WIDTH_SHARDED, + ), + ( + 2, + [1, 4, 32, 256], + 3, + ttnn.TILE_LAYOUT, + (32, 128), + ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(0, 3))}), + ttnn.TensorMemoryLayout.HEIGHT_SHARDED, + ), + ], +) +@pytest.mark.parametrize("num_links", [1]) +@pytest.mark.parametrize( + "input_dtype", + [ + ttnn.bfloat16, + ], +) +@pytest.mark.parametrize("num_iters", [1]) +@pytest.mark.parametrize("enable_async", [False]) +def test_all_gather_sharded( + t3k_mesh_device, + # pcie_mesh_device, + num_devices, + output_shape, + dim, + num_links, + input_dtype, + layout, + num_iters, + use_program_cache, + function_level_defaults, + enable_async, + input_shard_shape, + shard_grid, + tensor_mem_layout, +): + run_all_gather_impl( + t3k_mesh_device, + num_devices, + output_shape, + dim, + num_links, + input_dtype, + layout, + use_program_cache, + function_level_defaults, + all_gather_topology=ttnn.Topology.Ring, + num_iters=num_iters, + enable_async=enable_async, + rand_tensor=True, + input_shard_shape=input_shard_shape, + shard_grid=shard_grid, + tensor_mem_layout=tensor_mem_layout, + ) + + +# # Enumerate the post-commit cases explicitly +# @skip_for_grayskull("Requires eth connected devices to run") +# @pytest.mark.parametrize( +# "row_num_devices, row_num_links, row_output_shape, row_gather_dim, row_tensor_mem_layout", +# [ +# (4, 1, [1, 1, 64, 1024], 3, ttnn.TILE_LAYOUT), +# # (4, 1, [1, 1, 32, 32768], 3, ttnn.TILE_LAYOUT), +# # (4, 1, [1, 1, 2048, 16384], 3, ttnn.TILE_LAYOUT), +# ], +# ) +# @pytest.mark.parametrize( +# "col_num_devices, col_num_links, col_output_shape, col_gather_dim, col_tensor_mem_layout", +# [ +# (8, 1, [1, 1, 64, 1024], 3, ttnn.TILE_LAYOUT), +# ], +# ) +# @pytest.mark.parametrize( +# "input_dtype", +# [ +# ttnn.bfloat16, +# ], +# ) +# @pytest.mark.parametrize( +# "mem_config", +# [ +# ttnn.MemoryConfig(buffer_type=ttnn.BufferType.DRAM), +# ], +# ) +# @pytest.mark.parametrize("num_iters", [1]) +# @pytest.mark.parametrize("enable_async", [False]) +# def test_back_to_back_row_and_col_all_gathers_on_galaxy_mesh_fabric( +# tg_mesh_device, +# num_devices, +# row_num_devices, row_num_links, row_output_shape, row_gather_dim, row_tensor_mem_layout, +# col_num_devices, col_num_links, col_output_shape, col_gather_dim, col_tensor_mem_layout, +# mem_config, +# input_dtype, +# num_iters, +# use_program_cache, +# function_level_defaults, +# enable_async, +# tensor_mem_layout, +# ): + +# run_all_gather_impl( +# tg_mesh_device, +# row_num_devices, +# row_output_shape, +# row_gather_dim, +# row_num_links, +# input_dtype, +# use_program_cache, +# function_level_defaults, +# mem_config=mem_config, +# all_gather_topology=ttnn.Topology.Linear, +# num_iters=num_iters, +# enable_async=enable_async, +# rand_tensor=True, +# tensor_mem_layout=row_tensor_mem_layout, +# use_cluster_axis_api=True, +# cluster_axis=1, +# create_persistent_fabric=True, +# teardown_persistent_fabric=False +# ) + +# run_all_gather_impl( +# tg_mesh_device, +# col_num_devices, +# col_output_shape, +# col_gather_dim, +# col_num_links, +# input_dtype, +# use_program_cache, +# function_level_defaults, +# mem_config=mem_config, +# all_gather_topology=ttnn.Topology.Linear, +# num_iters=num_iters, +# enable_async=enable_async, +# rand_tensor=True, +# tensor_mem_layout=col_tensor_mem_layout, +# use_cluster_axis_api=True, +# cluster_axis=0, +# create_persistent_fabric=False, +# teardown_persistent_fabric=True +# ) + + +# # Enumerate the post-commit cases explicitly +# @skip_for_grayskull("Requires eth connected devices to run") +# @pytest.mark.parametrize( +# "row_num_devices, row_num_links, row_output_shape, row_gather_dim, row_layout", +# [ +# (4, 1, [1, 1, 64, 1024], 3, ttnn.TILE_LAYOUT), +# ], +# ) +# @pytest.mark.parametrize( +# "col_num_devices, col_num_links, col_output_shape, col_gather_dim, col_layout", +# [ +# (2, 1, [1, 1, 64, 1024], 3, ttnn.TILE_LAYOUT), +# ], +# ) +# @pytest.mark.parametrize( +# "input_dtype", +# [ +# ttnn.bfloat16, +# ], +# ) +# @pytest.mark.parametrize( +# "mem_config", +# [ +# ttnn.MemoryConfig(buffer_type=ttnn.BufferType.DRAM), +# ], +# ) + +# @pytest.mark.parametrize("num_iters", [1]) +# @pytest.mark.parametrize("enable_async", [False]) +# def test_back_to_back_row_and_col_all_gathers_on_t3k_mesh_fabric( +# t3k_mesh_device, +# row_num_devices, row_num_links, row_output_shape, row_gather_dim, row_layout, +# col_num_devices, col_num_links, col_output_shape, col_gather_dim, col_layout, +# mem_config, +# input_dtype, +# num_iters, +# use_program_cache, +# function_level_defaults, +# enable_async, +# ): + +# run_all_gather_impl( +# t3k_mesh_device, +# row_num_devices, +# row_output_shape, +# row_gather_dim, +# row_num_links, +# input_dtype, +# row_layout, +# use_program_cache, +# function_level_defaults, +# mem_config=mem_config, +# all_gather_topology=ttnn.Topology.Linear, +# num_iters=num_iters, +# enable_async=enable_async, +# rand_tensor=True, +# use_cluster_axis_api=True, +# cluster_axis=1, +# create_persistent_fabric=True, +# teardown_persistent_fabric=False +# ) + +# run_all_gather_impl( +# t3k_mesh_device, +# col_num_devices, +# col_output_shape, +# col_gather_dim, +# col_num_links, +# input_dtype, +# col_layout, +# use_program_cache, +# function_level_defaults, +# mem_config=mem_config, +# all_gather_topology=ttnn.Topology.Linear, +# num_iters=num_iters, +# enable_async=enable_async, +# rand_tensor=True, +# use_cluster_axis_api=True, +# cluster_axis=0, +# create_persistent_fabric=False, +# teardown_persistent_fabric=True +# ) diff --git a/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_async.py b/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_async.py new file mode 100644 index 00000000000..0bb22b6a104 --- /dev/null +++ b/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_async.py @@ -0,0 +1,342 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import pytest +from loguru import logger +import ttnn +from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_pcc +from models.utility_functions import skip_for_grayskull + + +def create_and_load_sub_device_manager_with_fabric_interface( + mesh_device, worker_sub_devices, ccl_worker_sub_device_id, local_allocator_size, enable_persistent_fabric=True +): + assert ccl_worker_sub_device_id < len(worker_sub_devices) + mesh_sub_device_manager_id, fabric_subdevice_id = mesh_device.create_sub_device_manager_with_fabric( + worker_sub_devices, local_allocator_size + ) + # fabric sub-device id can also be queried from device, no need to explicitly pass it in + mesh_device.load_sub_device_manager(mesh_sub_device_manager_id) + if enable_persistent_fabric: + ttnn.initialize_edm_fabric(mesh_device) + return mesh_sub_device_manager_id + + +def teardown_fabric_interface(mesh_device): + ttnn.teardown_edm_fabric(mesh_device) + for device_id in mesh_device.get_device_ids(): + ttnn.synchronize_device(mesh_device.get_device(device_id)) + + +def is_unsupported_case(input_shape, dim, math_op, mem_config, num_devices, num_links, input_dtype, layout): + elem_size = 2 if input_dtype == ttnn.bfloat16 else 1 + tensor_size_bytes = elem_size + for i in input_shape: + tensor_size_bytes *= i + num_l1_banks = 64 + if mem_config.buffer_type == ttnn.BufferType.L1 and tensor_size_bytes > num_l1_banks * 50 * 1024: + return True, "L1 buffer can't support large tensor sizes" + + # if input_dtype == ttnn.bfloat8_b and tuple(input_shape) == (1, 1, 2048, 1024) and dim == 3: + # return True, "Known failure with bfp8_b data format" + + return False, "" + + +def run_with_trace( + t3k_mesh_device, + input_tensor_mesh, + dim, + num_links, + math_op, + output_mem_config, + num_iters=40, + topology=ttnn.Topology.Ring, + subdevice_id=None, +): + # Compile Run + logger.info("Compiling model") + output_tensor_mesh = ttnn.reduce_scatter_async( + input_tensor_mesh, + dim=dim, + math_op=math_op, + num_links=num_links, + memory_config=output_mem_config, + topology=topology, + subdevice_id=subdevice_id, + create_semaphore_handles=True, + ) + for device_id in t3k_mesh_device.get_device_ids(): + ttnn.synchronize_device(t3k_mesh_device.get_device(device_id)) + + # Capture trace + logger.info("Capturing trace") + trace_id = ttnn.begin_trace_capture(t3k_mesh_device, cq_id=0) + for i in range(num_iters): + output_tensor_mesh = ttnn.reduce_scatter_async( + input_tensor_mesh, + dim=dim, + math_op=math_op, + num_links=num_links, + memory_config=output_mem_config, + topology=topology, + subdevice_id=subdevice_id, + create_semaphore_handles=False, + ) + ttnn.end_trace_capture(t3k_mesh_device, trace_id, cq_id=0) + for device_id in t3k_mesh_device.get_device_ids(): + ttnn.synchronize_device(t3k_mesh_device.get_device(device_id)) + + # Run the op + logger.info("Starting Trace perf test...") + ttnn.execute_trace(t3k_mesh_device, trace_id, blocking=False) + ttnn.release_trace(t3k_mesh_device, trace_id) + for device_id in t3k_mesh_device.get_device_ids(): + ttnn.synchronize_device(t3k_mesh_device.get_device(device_id)) + + return output_tensor_mesh + + +def run_reduce_scatter_test( + mesh_device, + num_devices, + per_chip_output_shape, + dim, + num_links, + math_op, + input_dtype, + layout, + mem_config, + use_program_cache, + function_level_defaults, + enable_async=True, + num_iters=1, + topology=ttnn.Topology.Ring, + trace_mode=False, +): + enable_persistent_fabric = True + if len(mesh_device.get_device_ids()) < num_devices: + pytest.skip( + f"Not enough devices on machine to implement test case. Wanted {num_devices} but found {len(mesh_device.get_device_ids())}" + ) + + debug = False + + (is_known_failure, message) = is_unsupported_case( + per_chip_output_shape, dim, math_op, mem_config, num_devices, num_links, input_dtype, layout + ) + if is_known_failure: + pytest.skip(f"Skipping unsupported case {message}.") + + mesh_device.enable_async(enable_async) + if enable_async: + logger.info(f"Using Async Mode for Reduce Scatter Op Dispatch") + + logger.info(f"Per chip output shape: {per_chip_output_shape}, devices: {num_devices}, dim: {dim}") + + # Generate input tensors + canonical_input_shape = per_chip_output_shape.copy() + canonical_input_shape[dim] *= num_devices + tt_input_tensors = [] + + numel = canonical_input_shape[0] * canonical_input_shape[1] * canonical_input_shape[2] * canonical_input_shape[3] + input_tensors = [ + torch.rand(canonical_input_shape).bfloat16() if not debug else torch.ones(canonical_input_shape).bfloat16() + for _ in range(num_devices) + ] + if debug: + tile_id = 0 + for w in range(input_tensors[-1].shape[0]): + for z in range(input_tensors[-1].shape[1]): + for y in range(0, input_tensors[-1].shape[2], 32): + for x in range(0, input_tensors[-1].shape[3], 32): + for yy in range(32): + for xx in range(32): + input_tensors[-1][w, z, y + yy, x + xx] = tile_id + # input_tensors[-1][w,z,y:y+32,x:x+32] = tile_id + tile_id += 1 + for i, canonical_input_tensor in enumerate(input_tensors): + logger.info(f"Creating input tensor on device {mesh_device.get_device_ids()[i]}") + tt_input_tensors.append( + ttnn.Tensor(canonical_input_tensor, input_dtype) + .to(layout) + .to(mesh_device.get_device(mesh_device.get_device_ids()[i]), mem_config) + ) + + assert len(tt_input_tensors) == num_devices + + input_tensor_mesh = ttnn.aggregate_as_tensor(tt_input_tensors) + + compute_grid_size = mesh_device.compute_with_storage_grid_size() + worker_sub_device = ttnn.SubDevice( + [ + ttnn.CoreRangeSet( + {ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(compute_grid_size.x - 1, compute_grid_size.y - 1))} + ) + ] + ) + worker_sub_device_id = ttnn.SubDeviceId(0) + mesh_sub_device_manager_id = create_and_load_sub_device_manager_with_fabric_interface( + mesh_device, [worker_sub_device], 0, 0, enable_persistent_fabric + ) + + # Run the op + if trace_mode: + output_tensor_mesh = run_with_trace( + mesh_device, + input_tensor_mesh, + dim, + num_links, + math_op, + mem_config, + num_iters=num_iters, + topology=topology, + subdevice_id=ttnn.SubDeviceId(0), + ) + else: + for i in range(num_iters): + output_tensor_mesh = ttnn.reduce_scatter_async( + input_tensor_mesh, + dim=dim, + math_op=math_op, + num_links=num_links, + memory_config=mem_config, + topology=topology, + subdevice_id=worker_sub_device_id, + ) + + logger.info(f"Waiting for op {i}") + for device_id in mesh_device.get_device_ids(): + ttnn.synchronize_device(mesh_device.get_device(device_id), sub_device_ids=[worker_sub_device_id]) + logger.info(f"Done iteration {i}") + + teardown_fabric_interface(mesh_device) + # Compute golden + # TODO: Make it model how reduce scatter actually works for numerical correctness/ordering + golden_canonical_out_tensor = torch.zeros(canonical_input_shape).bfloat16() + for i, t in enumerate(input_tensors): + golden_canonical_out_tensor = torch.add(golden_canonical_out_tensor, t).bfloat16() + + golden_output_tensors = torch.chunk(golden_canonical_out_tensor, num_devices, dim) + + tt_out_tensors = ttnn.get_device_tensors(output_tensor_mesh) + logger.info(f"Compare") + # Compare + assert len(golden_output_tensors) == len(tt_out_tensors) + mismatch = False + for i, t in enumerate(tt_out_tensors): + logger.info(f"DEVICE {i}") + logger.info(f"Checking output from device {t.device().id()}") + tt_output_tensor = t.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch() + eq, output = comp_pcc(tt_output_tensor, golden_output_tensors[i]) + mismatch = mismatch or not eq + if not eq: + logger.error(f"output mismatch for tensor {i}. Mesh device ID: {mesh_device.get_devices()[i].id()}") + if debug: + logger.info(f"FINAL OUTPUT TENSOR {tt_output_tensor}") + mismatch_tensor_shape = [ + tt_output_tensor.shape[0], + tt_output_tensor.shape[1], + tt_output_tensor.shape[2] // 32, + tt_output_tensor.shape[3] // 32, + ] + mismatch_tensor = torch.zeros(mismatch_tensor_shape).bfloat16() + for w in range(tt_output_tensor.shape[0]): + for z in range(tt_output_tensor.shape[1]): + for y in range(0, tt_output_tensor.shape[2], 32): + for x in range(0, tt_output_tensor.shape[3], 32): + if tt_output_tensor[w, z, y, x] != golden_output_tensors[i][w, z, y, x]: + mismatch_tensor[w, z, y // 32, x // 32] = 1 + logger.error( + f"mismatch at {w}, {z}, {y}, {x}: {tt_output_tensor[w, z, y, x]} != {golden_output_tensors[i][w, z, y, x]}" + ) + logger.error(f"MISMATCH TENSOR {mismatch_tensor}") + + else: + logger.info(f"output match for tensor {i}") + assert not mismatch, f"{i} FAILED: {output}" + + +# ~2:45 extra time in the current state +@skip_for_grayskull("Requires eth connected devices to run") +@pytest.mark.timeout(120) +@pytest.mark.parametrize( + "num_devices, num_links", + [ + (4, 1), + ], +) +@pytest.mark.parametrize( + "per_chip_output_shape, dim, layout", + [ + # ([1, 1, 32, 32], 3, ttnn.TILE_LAYOUT), + # ([1, 1, 32, 32], 3, ttnn.TILE_LAYOUT), + # ([1, 1, 32, 32 * 2], 3, ttnn.TILE_LAYOUT), + # ([1, 1, 64, 32], 3, ttnn.TILE_LAYOUT), + # ([1, 1, 64, 64], 3, ttnn.TILE_LAYOUT), + # ([1, 1, 128, 128], 0, ttnn.TILE_LAYOUT), + # ([1, 1, 128, 128], 1, ttnn.TILE_LAYOUT), + # ([1, 1, 128, 128], 2, ttnn.TILE_LAYOUT), + # ([1, 1, 128, 128], 3, ttnn.TILE_LAYOUT), + # ([1, 1, 32, 32], 2, ttnn.TILE_LAYOUT), + # ([1, 1, 32, 64], 2, ttnn.TILE_LAYOUT), + # ([1, 1, 32, 32 * 4], 3, ttnn.TILE_LAYOUT), + # ([1, 1, 128, 4096], 3, ttnn.TILE_LAYOUT), + ([1, 4, 32, 2304], 2, ttnn.TILE_LAYOUT), + # ([1, 2, 224, 32 * 8], 3, ttnn.TILE_LAYOUT), + # ([1, 8, 1024, 1024], 3, ttnn.TILE_LAYOUT), + # ([1, 4, 2048, 1024], 3, ttnn.TILE_LAYOUT), + # ([1, 1, 128, 8192], 3, ttnn.TILE_LAYOUT), + ], +) +@pytest.mark.parametrize( + "input_dtype", + [ + ttnn.bfloat16, + ], +) +@pytest.mark.parametrize( + "mem_config", + [ + ttnn.MemoryConfig(buffer_type=ttnn.BufferType.DRAM), + ], +) +@pytest.mark.parametrize("math_op", [ttnn.ReduceType.Sum]) +@pytest.mark.parametrize("enable_async", [False]) +@pytest.mark.parametrize("trace_mode", [False]) +@pytest.mark.parametrize("device_params", [{"trace_region_size": 27648}], indirect=True) +def test_line_reduce_scatter_async_post_commit( + t3k_mesh_device, + num_devices, + per_chip_output_shape, + dim, + num_links, + math_op, + input_dtype, + layout, + mem_config, + use_program_cache, + function_level_defaults, + enable_async, + trace_mode, + num_iters=1, +): + run_reduce_scatter_test( + t3k_mesh_device, + num_devices, + per_chip_output_shape, + dim, + num_links, + math_op, + input_dtype, + layout, + mem_config, + use_program_cache, + function_level_defaults, + num_iters=num_iters, + enable_async=enable_async, + topology=ttnn.Topology.Linear, + trace_mode=trace_mode, + ) diff --git a/ttnn/cpp/pybind11/operations/__init__.hpp b/ttnn/cpp/pybind11/operations/__init__.hpp index b2bcb41671b..b4522d99b5b 100644 --- a/ttnn/cpp/pybind11/operations/__init__.hpp +++ b/ttnn/cpp/pybind11/operations/__init__.hpp @@ -11,9 +11,7 @@ #include "pybind11/operations/core.hpp" #include "pybind11/operations/creation.hpp" #include "ttnn/operations/bernoulli/bernoulli_pybind.hpp" -#include "ttnn/operations/ccl/all_gather/all_gather_pybind.hpp" -#include "ttnn/operations/ccl/reduce_scatter/reduce_scatter_pybind.hpp" -#include "ttnn/operations/ccl/barrier/barrier_pybind.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/ccl_pybind.hpp" #include "ttnn/operations/conv/conv_pybind.hpp" #include "ttnn/operations/data_movement/data_movement_pybind.hpp" #include "ttnn/operations/eltwise/binary/binary_pybind.hpp" @@ -93,9 +91,7 @@ void py_module(py::module& module) { complex_unary_backward::py_module(m_complex_unary_backward); auto m_ccl = module.def_submodule("ccl", "collective communication operations"); - ccl::py_bind_all_gather(m_ccl); - ccl::py_bind_reduce_scatter(m_ccl); - ccl::py_bind_barrier(m_ccl); + ccl::py_module(m_ccl); auto m_creation = module.def_submodule("creation", "creation operations"); creation::py_module(m_creation); diff --git a/ttnn/cpp/ttnn/operations/ccl/CMakeLists.txt b/ttnn/cpp/ttnn/operations/ccl/CMakeLists.txt index 148d928be91..37b365f248a 100644 --- a/ttnn/cpp/ttnn/operations/ccl/CMakeLists.txt +++ b/ttnn/cpp/ttnn/operations/ccl/CMakeLists.txt @@ -6,7 +6,11 @@ set(CCL_TTNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/ccl_host_datastructures.cpp ${CMAKE_CURRENT_SOURCE_DIR}/common/types/ccl_types_args_emitters.cpp ${CMAKE_CURRENT_SOURCE_DIR}/common/uops/ccl_command.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/common/uops/ccl_host_commands.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/common/host/ccl_worker_builder.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/common/host/ccl_command_stream_builders.cpp # CCL Ops + ${CMAKE_CURRENT_SOURCE_DIR}/ccl_pybind.cpp ${CMAKE_CURRENT_SOURCE_DIR}/all_gather/all_gather.cpp ${CMAKE_CURRENT_SOURCE_DIR}/all_gather/all_gather_pybind.cpp ${CMAKE_CURRENT_SOURCE_DIR}/all_gather/device/all_gather_op.cpp diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_pybind.cpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_pybind.cpp index 210ac9d243a..b8404b749aa 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_pybind.cpp @@ -9,7 +9,6 @@ #include "ttnn/cpp/pybind11/decorators.hpp" #include "ttnn/operations/ccl/all_gather/all_gather.hpp" -#include "ttnn/operations/ccl/ccl_host_datastructures.hpp" #include "ttnn/distributed/types.hpp" namespace ttnn::operations::ccl { @@ -18,10 +17,6 @@ namespace detail { template void bind_all_gather(pybind11::module& module, const ccl_operation_t& operation, const char* doc) { - py::enum_(module, "Topology") - .value("Ring", ttnn::ccl::Topology::Ring) - .value("Linear", ttnn::ccl::Topology::Linear); - bind_registered_operation( module, operation, diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_ring_gather_utils.hpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_ring_gather_utils.hpp index 75606f8b23c..675de4fef24 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_ring_gather_utils.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_ring_gather_utils.hpp @@ -8,6 +8,9 @@ #include "ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_utils.hpp" #include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" #include "ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_adapters.hpp" +#include "tt_metal/impl/buffers/buffer_constants.hpp" +#include "ttnn/cpp/ttnn/tensor/enum_types.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/common/types/ccl_types.hpp" using ttnn::ccl::ShardType; using ttnn::ccl::UNINITIALIZED_VALUE_U16; @@ -39,7 +42,6 @@ FORCE_INLINE void write_and_send_chunk_write_to_tensor_segment( uint64_t dst_noc_addr = get_noc_addr(output_page_idx, d); noc_async_write(l1_read_addr, dst_noc_addr, page_size); #elif defined SHARDED_MEM_LAYOUT - // TODO: Make d.get_noc_addr work on host + device auto const&[noc_yx, page_offset] = d.get_page_location(output_page_idx); uint64_t dst_noc_addr = get_noc_addr(static_cast(noc_yx.noc_x), noc_yx.noc_y, d.bank_base_address + (page_offset * d.page_size) + 0); ASSERT(false); // untested && unimplemented @@ -54,8 +56,6 @@ FORCE_INLINE void write_and_send_chunk_write_to_tensor_segment( #ifdef INTERLEAVED_MEM_LAYOUT noc_async_write_tile(output_page_idx, d, l1_read_addr); #elif defined SHARDED_MEM_LAYOUT - // TODO: Make d.get_noc_addr work on host + device - // auto const&[noc_yx, page_offset] = d.get_page_location(output_page_idx); auto [noc_yx, page_offset, contig_pages_] = d.get_page_location_with_contiguous_pages_in_row_in_bank(output_page_idx); contig_pages = std::min(pages_remaining, std::min(contig_pages_, num_cols - col_idx)); uint64_t dst_noc_addr = get_noc_addr(static_cast(noc_yx.noc_x), noc_yx.noc_y, d.bank_base_address + (page_offset * d.page_size) + 0); @@ -173,7 +173,6 @@ FORCE_INLINE void write_chunk( uint64_t dst_noc_addr = get_noc_addr(output_page_idx, d); noc_async_write(l1_read_addr, dst_noc_addr, page_size); #elif defined SHARDED_MEM_LAYOUT - // TODO: Make d.get_noc_addr work on host + device auto const&[noc_yx, page_offset] = d.get_page_location(output_page_idx); uint64_t dst_noc_addr = get_noc_addr(static_cast(noc_yx.noc_x), noc_yx.noc_y, d.bank_base_address + (page_offset * d.page_size) + 0); ASSERT(false); // untested && unimplemented @@ -236,7 +235,6 @@ FORCE_INLINE void read_chunk_from_input_tensor( #ifdef INTERLEAVED_MEM_LAYOUT noc_async_read_tile(input_page_idx, s, local_l1_read_addr); #elif defined SHARDED_MEM_LAYOUT - // TODO: Make d.get_noc_addr work on host + device auto const&[noc_yx, page_offset, contig_pages_] = s.get_page_location_with_contiguous_pages_in_row_in_bank(input_page_idx); contig_pages = std::min(pages_remaining, contig_pages_); uint64_t src_noc_addr = get_noc_addr(static_cast(noc_yx.noc_x), static_cast(noc_yx.noc_y), s.bank_base_address + (page_offset * s.page_size) + 0); @@ -274,7 +272,6 @@ FORCE_INLINE void read_chunk_from_output_tensor( uint64_t src_noc_addr = get_noc_addr(input_page_idx, s); noc_async_read(src_noc_addr, local_l1_read_addr, page_size); #elif defined SHARDED_MEM_LAYOUT - // TODO: Make d.get_noc_addr work on host + device auto const&[noc_yx, page_offset] = s.get_page_location(input_page_idx); uint64_t src_noc_addr = get_noc_addr(static_cast(noc_yx.noc_x), noc_yx.noc_y, s.bank_base_address + (page_offset * s.page_size) + 0); ASSERT(false); // unimplemented @@ -290,7 +287,6 @@ FORCE_INLINE void read_chunk_from_output_tensor( #ifdef INTERLEAVED_MEM_LAYOUT noc_async_read_tile(input_page_idx, s, local_l1_read_addr); #elif defined SHARDED_MEM_LAYOUT - // TODO: Make d.get_noc_addr work on host + device auto [noc_yx, page_offset, contig_pages_] = s.get_page_location_with_contiguous_pages_in_row_in_bank(input_page_idx); contig_pages = std::min(pages_remaining, std::min(contig_pages_, num_cols - col_idx)); uint64_t src_noc_addr = get_noc_addr(static_cast(noc_yx.noc_x), static_cast(noc_yx.noc_y), s.bank_base_address + (page_offset * s.page_size) + 0); @@ -345,7 +341,6 @@ FORCE_INLINE void read_chunk_from_output_tensor_v2( #ifdef INTERLEAVED_MEM_LAYOUT noc_async_read_tile(curr_page_idx, s, local_l1_read_addr); #elif defined SHARDED_MEM_LAYOUT - // TODO: Make d.get_noc_addr work on host + device auto const&[noc_yx, page_offset] = s.get_page_location(curr_page_idx); uint64_t src_noc_addr = get_noc_addr(static_cast(noc_yx.noc_x), noc_yx.noc_y, s.bank_base_address + (page_offset * s.page_size) + 0); noc_async_read(src_noc_addr, local_l1_read_addr, page_size); @@ -401,7 +396,6 @@ FORCE_INLINE void write_chunk_v2( #ifdef INTERLEAVED_MEM_LAYOUT noc_async_write_tile(curr_page_idx, d, l1_read_addr); #elif defined SHARDED_MEM_LAYOUT - // TODO: Make d.get_noc_addr work on host + device auto const&[noc_yx, page_offset] = d.get_page_location(curr_page_idx); uint64_t dst_noc_addr = get_noc_addr(static_cast(noc_yx.noc_x), noc_yx.noc_y, d.bank_base_address + (page_offset * d.page_size) + 0); noc_async_write(l1_read_addr, dst_noc_addr, page_size); @@ -463,7 +457,6 @@ FORCE_INLINE void read_wrapped_chunk_from_output_tensor_to_address( noc_async_read_tile(curr_page_idx, s, local_l1_read_addr); // common with `write_chunk_v2` #elif defined SHARDED_MEM_LAYOUT - // TODO: Make d.get_noc_addr work on host + device auto const&[noc_yx, page_offset, contig_pages_] = s.get_page_location_with_contiguous_pages_in_row_in_bank(curr_page_idx); /* * num_pages - i: check if we are outside the number of pages remaining @@ -531,11 +524,122 @@ FORCE_INLINE void read_wrapped_chunk_from_output_tensor( cb_push_back(cb_id, num_pages); } + + +template +FORCE_INLINE std::pair get_noc_addr_and_contiguous_pages( + uint32_t curr_page_idx, + const uint32_t offset_into_worker_slice, + const ttnn::ccl::Shape4D& offset_worker_slice, + const AddrGen& address_generator, + const ttnn::ccl::Shape4D& tensor_slice_shape, + uint8_t noc_id = noc_index) { + if constexpr (TENSOR_LAYOUT == tt::tt_metal::TensorMemoryLayout::INTERLEAVED) { + static constexpr uint32_t offset = 0; + uint64_t dst_noc_addr = get_noc_addr(curr_page_idx, address_generator, offset, noc_id); + return {dst_noc_addr, 1}; + } else { + static_assert( + TENSOR_LAYOUT == tt::tt_metal::TensorMemoryLayout::WIDTH_SHARDED || + TENSOR_LAYOUT == tt::tt_metal::TensorMemoryLayout::HEIGHT_SHARDED || + TENSOR_LAYOUT == tt::tt_metal::TensorMemoryLayout::BLOCK_SHARDED); + if constexpr (MEM_LAYOUT == tt::tt_metal::Layout::ROW_MAJOR) { + ASSERT(false); // unimplemented + return {0, 1}; + } else { + static_assert(MEM_LAYOUT == tt::tt_metal::Layout::TILE); + auto const&[noc_yx, page_offset, contig_pages_] = address_generator.get_page_location_with_contiguous_pages_in_row_in_bank(curr_page_idx); + /* + * Shared with `read_wrapped_chunk_from_output_tensor` + */ + uint32_t flattened_offset_worker_slice = ttnn::ccl::v2::flattened_index(tensor_slice_shape, offset_worker_slice); + uint32_t contig_until_edge_of_tensor_slice = tensor_slice_shape.x - ((flattened_offset_worker_slice + offset_into_worker_slice) % tensor_slice_shape.x); + + size_t contig_pages = std::min(contig_pages_, contig_until_edge_of_tensor_slice); + uint64_t dst_noc_addr = get_noc_addr(static_cast(noc_yx.noc_x), noc_yx.noc_y, address_generator.bank_base_address + (page_offset * address_generator.page_size) + 0, noc_id); + return {dst_noc_addr, contig_pages}; + } + } +} + +template +FORCE_INLINE std::pair get_noc_addr_and_contiguous_pages_for_fabric_write( + uint32_t curr_page_idx, + const uint32_t offset_into_worker_slice, + const ttnn::ccl::Shape4D& offset_worker_slice, + const AddrGen& address_generator, + const ttnn::ccl::Shape4D& tensor_slice_shape) { + return get_noc_addr_and_contiguous_pages( + curr_page_idx, offset_into_worker_slice, offset_worker_slice, address_generator, tensor_slice_shape, 0); +} + +namespace v2 { + template +FORCE_INLINE void write_wrapped_chunk( + uint32_t& curr_page_idx, + uint32_t& offset_into_worker_slice, + const ttnn::ccl::Shape4D& offset_worker_slice, + const ttnn::ccl::Shape4D& worker_slice_shape, + + // In tiles for tile layout + const ttnn::ccl::Shape4D& tensor_shape, + const ttnn::ccl::Shape4D& tensor_slice_shape, + uint32_t cb_id, + const AddrGen& d, + const uint32_t num_pages, + const uint32_t page_size, + bool& last_page_of_worker) { + + uint32_t l1_read_addr = get_read_ptr(cb_id); + + int32_t contig_pages = 1; + for (uint32_t i = 0; i < num_pages; i+= contig_pages) { + contig_pages = 1; + if constexpr (MEM_LAYOUT == tt::tt_metal::Layout::ROW_MAJOR) { + if constexpr (TENSOR_LAYOUT == tt::tt_metal::TensorMemoryLayout::INTERLEAVED) { + uint64_t dst_noc_addr = get_noc_addr(curr_page_idx, d); + noc_async_write(l1_read_addr, dst_noc_addr, page_size); + ASSERT(false); // unimplemented + } else { + ASSERT(false); // unimplemented + } + } else if constexpr (MEM_LAYOUT == tt::tt_metal::Layout::TILE) { + if constexpr (TENSOR_LAYOUT == tt::tt_metal::TensorMemoryLayout::INTERLEAVED) { + noc_async_write_tile(curr_page_idx, d, l1_read_addr); + } else { + auto const&[noc_yx, page_offset, contig_pages_] = d.get_page_location_with_contiguous_pages_in_row_in_bank(curr_page_idx); + /* + * Shared with `read_wrapped_chunk_from_output_tensor` + */ + uint32_t flattened_offset_worker_slice = ttnn::ccl::v2::flattened_index(tensor_slice_shape, offset_worker_slice); + uint32_t contig_edge_of_tensor_slice = tensor_slice_shape.x - ((flattened_offset_worker_slice + offset_into_worker_slice) % tensor_slice_shape.x); + + contig_pages = std::min(num_pages - i, std::min(contig_pages_, contig_edge_of_tensor_slice)); + uint64_t dst_noc_addr = get_noc_addr(static_cast(noc_yx.noc_x), noc_yx.noc_y, d.bank_base_address + (page_offset * d.page_size) + 0); + noc_async_write(l1_read_addr, dst_noc_addr, page_size * contig_pages); + } + } + // Update the curr_page_idx based on how the worker chunks + tensor slice is laid out in global tensor + bool end_of_worker_slice_row = ttnn::ccl::v2::advance_worker_global_page( + curr_page_idx, // Updated internally + offset_into_worker_slice, + offset_worker_slice, + worker_slice_shape, + tensor_slice_shape, + tensor_shape, + contig_pages + ); + + l1_read_addr += page_size * contig_pages; + } +} +} + template FORCE_INLINE void write_wrapped_chunk( uint32_t& curr_page_idx, uint32_t& offset_into_worker_slice, - ttnn::ccl::coord_t& offset_worker_slice, + const ttnn::ccl::coord_t& offset_worker_slice, const ttnn::ccl::coord_t& worker_slice_shape, // In tiles for tile layout @@ -567,7 +671,6 @@ FORCE_INLINE void write_wrapped_chunk( noc_async_write_tile(curr_page_idx, d, l1_read_addr); // Common with `read_chunk_from_output_tensor_v2` #elif defined SHARDED_MEM_LAYOUT - // TODO: Make d.get_noc_addr work on host + device auto const&[noc_yx, page_offset, contig_pages_] = d.get_page_location_with_contiguous_pages_in_row_in_bank(curr_page_idx); /* * Shared with `read_wrapped_chunk_from_output_tensor` diff --git a/ttnn/cpp/ttnn/operations/ccl/ccl_common.cpp b/ttnn/cpp/ttnn/operations/ccl/ccl_common.cpp index 85b94a0505f..ed896ea266e 100644 --- a/ttnn/cpp/ttnn/operations/ccl/ccl_common.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/ccl_common.cpp @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "ccl_common.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp" #include #include @@ -13,27 +13,66 @@ namespace ttnn { namespace ccl { -std::tuple, std::optional> get_device_index_and_sender_receiver_ids( - const Tensor& input_tensor, - const std::vector& devices, - const ttnn::ccl::Topology& topology) { +void SyncModeSpec::add_signal(uint32_t sem_id, uint32_t wait_count) { + this->sem_ids.push_back(sem_id); + this->wait_counts.push_back(wait_count); + this->num_signals++; +} + +LineTopology::LineTopology(size_t line_size, size_t line_index) : _line_size(line_size), _line_index(line_index) {} + +bool LineTopology::is_first_device_in_line(ttnn::ccl::EdmLineFabricOpInterface::Direction direction) const { + if (direction == ttnn::ccl::EdmLineFabricOpInterface::Direction::FORWARD) { + return _line_index == 0; + } else { + TT_ASSERT(direction == ttnn::ccl::EdmLineFabricOpInterface::Direction::BACKWARD); + return _line_index == _line_size - 1; + } +} +bool LineTopology::is_last_device_in_line(ttnn::ccl::EdmLineFabricOpInterface::Direction direction) const { + if (direction == ttnn::ccl::EdmLineFabricOpInterface::Direction::BACKWARD) { + return _line_index == 0; + } else { + TT_ASSERT(direction == ttnn::ccl::EdmLineFabricOpInterface::Direction::FORWARD); + return _line_index == _line_size - 1; + } +} + +bool LineTopology::is_at_end_of_line() const { return _line_index == 0 || _line_index == _line_size - 1; } + +size_t LineTopology::line_size() const { return _line_size; } + +size_t LineTopology::line_index() const { return _line_index; } + +size_t LineTopology::get_distance_to_end_of_line(ttnn::ccl::EdmLineFabricOpInterface::Direction direction) const { + if (direction == ttnn::ccl::EdmLineFabricOpInterface::Direction::FORWARD) { + return (_line_size - _line_index) - 1; + } else { + return _line_index; + } +} + +ttnn::ccl::Topology LineTopology::topology() const { return ttnn::ccl::Topology::Linear; } +std::tuple, std::optional> get_device_index_and_sender_receiver_ids( + const Tensor& input_tensor, const std::vector& devices, const ttnn::ccl::Topology& topology) { uint32_t num_devices = devices.size(); bool is_linear = topology == ttnn::ccl::Topology::Linear; - uint32_t device_index = 0; // Initialize device index + uint32_t device_index = 0; // Initialize device index for (uint32_t i = 0; i < num_devices; ++i) { if (devices.at(i) == input_tensor.device()) { device_index = i; bool is_last_chip_in_clockwise_direction = is_linear && i == (num_devices - 1); bool is_last_chip_in_counter_clockwise_direction = is_linear && i == 0; - std::optional receiver_device_id = is_last_chip_in_clockwise_direction ? - std::nullopt : - std::optional(devices.at((i + 1) % num_devices)->id()); + std::optional receiver_device_id = + is_last_chip_in_clockwise_direction ? std::nullopt + : std::optional(devices.at((i + 1) % num_devices)->id()); - std::optional sender_device_id = is_last_chip_in_counter_clockwise_direction ? - std::nullopt : - std::optional(devices.at((i + num_devices - 1) % num_devices)->id()); + std::optional sender_device_id = + is_last_chip_in_counter_clockwise_direction + ? std::nullopt + : std::optional(devices.at((i + num_devices - 1) % num_devices)->id()); return {device_index, sender_device_id, receiver_device_id}; } @@ -50,7 +89,11 @@ RingTopology::RingTopology( uint32_t num_links, uint32_t ring_size, uint32_t ring_index) : - device(device), num_links(num_links), ring_size(ring_size), ring_index(ring_index), is_linear(topology == Topology::Linear) { + device(device), + num_links(num_links), + ring_size(ring_size), + ring_index(ring_index), + is_linear(topology == Topology::Linear) { eth_sender_cores.reserve(num_links); eth_receiver_cores.reserve(num_links); @@ -71,8 +114,7 @@ RingTopology::RingTopology( auto const& sockets = device->get_ethernet_sockets(receiver_device); auto eth_sender_core = sockets.at(sender_socket_idx); eth_sender_cores.push_back(eth_sender_core); - log_trace( - tt::LogOp, "\teth_sender_core on link {}: (x={},y={})", l, eth_sender_core.x, eth_sender_core.y); + log_trace(tt::LogOp, "\teth_sender_core on link {}: (x={},y={})", l, eth_sender_core.x, eth_sender_core.y); } if (!is_linear || ring_index != 0) { uint32_t sender_device = sender_device_id.value(); @@ -80,11 +122,7 @@ RingTopology::RingTopology( auto eth_receiver_core = sockets.at(receiver_socket_idx); eth_receiver_cores.push_back(eth_receiver_core); log_trace( - tt::LogOp, - "\teth_receiver_core on link {}: (x={},y={})", - l, - eth_receiver_core.x, - eth_receiver_core.y); + tt::LogOp, "\teth_receiver_core on link {}: (x={},y={})", l, eth_receiver_core.x, eth_receiver_core.y); } if (receiver_device_id == sender_device_id) { @@ -111,10 +149,10 @@ CclOpTensorConfig::CclOpTensorConfig(Tensor const& tensor) : df(tt::tt_metal::datatype_to_dataformat_converter(tensor.get_dtype())) { if (tensor.get_layout() == Layout::TILE) { this->tile = tensor.get_tensor_spec().tile(); - this->page_size =this->tile.get_tile_size(this->df); + this->page_size = this->tile.get_tile_size(this->df); this->tile_size = this->tile.get_tile_hw(); } else { - this->tile = Tile({32,32}); + this->tile = Tile({32, 32}); this->page_size = tensor.buffer()->page_size(); this->tile_size = 1024; } @@ -125,16 +163,14 @@ Tile CclOpTensorConfig::get_tile() const { return this->tile; } uint32_t CclOpTensorConfig::get_buffer_start_address() const { return this->buffer_start_address; } - -CclOpInterleavedTensorConfig::CclOpInterleavedTensorConfig(Tensor const& input_tensor) : CclOpTensorConfig(input_tensor) {} - +CclOpInterleavedTensorConfig::CclOpInterleavedTensorConfig(Tensor const& input_tensor) : + CclOpTensorConfig(input_tensor) {} CclOpShardedTensorConfig::CclOpShardedTensorConfig(Tensor const& tensor) : CclOpTensorConfig(tensor), shard_spec(tensor.shard_spec().value()) {} ShardSpec const& CclOpShardedTensorConfig::get_shard_spec() const { return this->shard_spec; } - std::unique_ptr CclOpTensorConfig::build_all_gather_tensor_config(Tensor const& tensor) { if (tensor.is_sharded()) { return std::make_unique(tensor); @@ -143,11 +179,8 @@ std::unique_ptr CclOpTensorConfig::build_all_gather_tensor_co } } - - - void generate_edm_kernels_for_ring_or_linear_topology( - tt::tt_metal::Program& program, + tt::tt_metal::Program& program, Device const* device, RingTopology const& topology_config, std::vector const& clockwise_edm_builders, @@ -275,14 +308,18 @@ ccl::EriscDatamoverBuilder create_erisc_datamover_builder( uint32_t edm_buffer_addr = ccl::EriscDatamoverConfig::get_buffers_base_address(num_channels); TT_ASSERT(edm_sem_addr > 0); TT_ASSERT(edm_buffer_addr > 0); - const uint32_t channel_buffer_size = ccl::EriscDatamoverConfig::compute_buffer_size(num_channels, num_buffers_per_channel, page_size); + const uint32_t channel_buffer_size = + ccl::EriscDatamoverConfig::compute_buffer_size(num_channels, num_buffers_per_channel, page_size); for (std::size_t c = 0; c < num_channels; ++c) { edm_sem_addresses.at(c) = edm_sem_addr; edm_sem_addr += ccl::EriscDatamoverConfig::semaphore_size; TT_ASSERT(edm_buffer_addr % EriscDatamoverConfig::get_eth_word_size() == 0); edm_buffer_addresses.at(c) = edm_buffer_addr; log_trace(tt::LogOp, " edm_buffer_addresses({}) = {}", c, edm_buffer_addr); - edm_buffer_addr += num_buffers_per_channel * (channel_buffer_size + (ccl::EriscDatamoverConfig::enable_merged_payload_and_channel_sync ? ccl::EriscDatamoverConfig::get_eth_channel_sync_size_bytes() : 0)); + edm_buffer_addr += num_buffers_per_channel * + (channel_buffer_size + (ccl::EriscDatamoverConfig::enable_merged_payload_and_channel_sync + ? ccl::EriscDatamoverConfig::get_eth_channel_sync_size_bytes() + : 0)); TT_ASSERT((c == 0) || (edm_buffer_addresses.back() != edm_buffer_addresses.front())); TT_ASSERT((c == 0) || (edm_sem_addresses.back() != edm_sem_addresses.front())); } @@ -323,14 +360,15 @@ RingReduceScatterBaseTensorSlicer::RingReduceScatterBaseTensor this->num_rows = std::accumulate(input_shape.begin() + slice_dim, input_shape.end() - 1, 1, std::multiplies()); this->row_offset = - std::accumulate( - output_shape.begin() + slice_dim, output_shape.end() - 1, 1, std::multiplies()) - + std::accumulate(output_shape.begin() + slice_dim, output_shape.end() - 1, 1, std::multiplies()) - num_rows; } else { auto input_tile = input_tensor.get_tensor_spec().tile(); const uint32_t num_tiles_x = input_tensor.get_legacy_shape()[-1] / input_tile.get_width(); uint32_t num_tiles_y = (input_tensor.get_legacy_shape()[-2] / input_tile.get_height()); - for (std::size_t i = 0; input_tensor.get_legacy_shape().rank() > 2 && i < input_tensor.get_legacy_shape().rank() - 2; i++) { + for (std::size_t i = 0; + input_tensor.get_legacy_shape().rank() > 2 && i < input_tensor.get_legacy_shape().rank() - 2; + i++) { num_tiles_y *= input_tensor.get_legacy_shape()[i]; } TT_ASSERT(num_tiles_x >= ring_size); @@ -374,11 +412,12 @@ RingReduceScatterBaseTensorSlicer::RingReduceScatterBaseTensor this->flattened_tensor_shape = tt_xy_pair{ input_tensor.get_legacy_shape()[3] / input_tile.get_width(), (input_tensor.get_legacy_shape()[0] * input_tensor.get_legacy_shape()[1] * - input_tensor.get_legacy_shape()[2]) / - input_tile.get_height()}; + input_tensor.get_legacy_shape()[2]) / + input_tile.get_height()}; } - this->worker_slice_offsets = DERIVED_SLICER_T::compute_worker_slice_offsets(this->worker_slice_shapes, this->tensor_slice_shape); + this->worker_slice_offsets = + DERIVED_SLICER_T::compute_worker_slice_offsets(this->worker_slice_shapes, this->tensor_slice_shape); TT_ASSERT(this->worker_slice_offsets.size() == this->worker_slice_shapes.size()); } @@ -390,10 +429,16 @@ RingReduceScatterTensorSlicer::RingReduceScatterTensorSlicer( uint32_t ring_size, uint32_t total_num_workers, uint32_t max_slice_size_in_bytes, - uint32_t half_cb_n_pages): - RingReduceScatterBaseTensorSlicer - (input_tensor, output_tensor, slice_dim, ring_index, ring_size, total_num_workers, max_slice_size_in_bytes, half_cb_n_pages) {}; - + uint32_t half_cb_n_pages) : + RingReduceScatterBaseTensorSlicer( + input_tensor, + output_tensor, + slice_dim, + ring_index, + ring_size, + total_num_workers, + max_slice_size_in_bytes, + half_cb_n_pages){}; RingReduceScatterWrappedTensorSlicer::RingReduceScatterWrappedTensorSlicer( Tensor const& input_tensor, @@ -403,9 +448,16 @@ RingReduceScatterWrappedTensorSlicer::RingReduceScatterWrappedTensorSlicer( uint32_t ring_size, uint32_t total_num_workers, uint32_t max_slice_size_in_bytes, - uint32_t half_cb_n_pages): - RingReduceScatterBaseTensorSlicer - (input_tensor, output_tensor, slice_dim, ring_index, ring_size, total_num_workers, max_slice_size_in_bytes, half_cb_n_pages) {}; + uint32_t half_cb_n_pages) : + RingReduceScatterBaseTensorSlicer( + input_tensor, + output_tensor, + slice_dim, + ring_index, + ring_size, + total_num_workers, + max_slice_size_in_bytes, + half_cb_n_pages){}; std::vector RingReduceScatterTensorSlicer::compute_worker_slice_offsets( std::vector const& worker_slice_shapes, tt_xy_pair const& tensor_slice_shape) { @@ -444,7 +496,6 @@ static std::vector compute_worker_slice_offsets_for_wrapped_tensor_s std::uint32_t flattened_idx = 0; for (tt_xy_pair const& worker_slice_shape : worker_slice_shapes) { - // Convert from flat to (x, y) coordinates std::size_t offset_x = flattened_idx % tensor_slice_shape.x; std::size_t offset_y = flattened_idx / tensor_slice_shape.x; @@ -462,11 +513,12 @@ static std::vector compute_worker_slice_offsets_for_wrapped_tensor_s std::vector RingReduceScatterWrappedTensorSlicer::compute_worker_slice_offsets( std::vector const& worker_slice_shapes, tt_xy_pair const& tensor_slice_shape) { - return compute_worker_slice_offsets_for_wrapped_tensor_slicer(worker_slice_shapes, tensor_slice_shape); + return compute_worker_slice_offsets_for_wrapped_tensor_slicer(worker_slice_shapes, tensor_slice_shape); } template -std::vector RingReduceScatterBaseTensorSlicer::create_worker_slice_shapes_for_row_major_layout( +std::vector +RingReduceScatterBaseTensorSlicer::create_worker_slice_shapes_for_row_major_layout( tt_xy_pair const& tensor_slice_shape_in_elems, uint32_t num_workers, uint32_t max_slice_size_in_elements) { std::vector worker_slice_shapes; worker_slice_shapes.reserve(num_workers); @@ -490,8 +542,7 @@ std::vector RingReduceScatterBaseTensorSlicer::cre // For now we don't support row splitting but we will in the future const uint32_t min_rows_per_worker = tensor_slice_shape_in_elems.y / num_workers; const uint32_t num_workers_with_max_rows = tensor_slice_shape_in_elems.y % num_workers; - const uint32_t max_rows_per_worker = - num_workers_with_max_rows != 0 ? min_rows_per_worker + 1 : min_rows_per_worker; + const uint32_t max_rows_per_worker = num_workers_with_max_rows != 0 ? min_rows_per_worker + 1 : min_rows_per_worker; for (uint32_t w = 0; w < num_workers_with_max_rows; w++) { worker_slice_shapes.emplace_back(tensor_slice_shape_in_elems.x, max_rows_per_worker); num_elems_accounted_for += tensor_slice_shape_in_elems.x * max_rows_per_worker; @@ -510,12 +561,11 @@ std::vector RingReduceScatterBaseTensorSlicer::cre } std::vector RingReduceScatterTensorSlicer::create_worker_slice_shapes_for_tile_layout( - tt::tt_metal::LegacyShape const& tensor_shape, - tt_xy_pair const& tensor_slice_shape_in_tiles, - uint32_t num_workers, - uint32_t max_slice_size_in_pages, - uint32_t half_cb_n_pages) -{ + tt::tt_metal::LegacyShape const& tensor_shape, + tt_xy_pair const& tensor_slice_shape_in_tiles, + uint32_t num_workers, + uint32_t max_slice_size_in_pages, + uint32_t half_cb_n_pages) { log_trace(tt::LogOp, "\tmax_slice_size_in_pages={}", max_slice_size_in_pages); TT_ASSERT(max_slice_size_in_pages > 0); std::vector worker_slice_shapes; @@ -538,8 +588,6 @@ std::vector RingReduceScatterTensorSlicer::create_worker_slice_shape std::size_t max_slice_size_in_tiles = max_slice_size_in_pages; // Add padding for filler pages - - TT_ASSERT(max_slice_size_in_tiles > 0); std::size_t max_width_in_tiles = std::min(max_slice_size_in_tiles, tensor_slice_shape_in_tiles.x); std::size_t max_height_in_tiles = std::min(max_slice_size_in_tiles, tensor_slice_shape_in_tiles.y); @@ -586,13 +634,11 @@ std::vector RingReduceScatterTensorSlicer::create_worker_slice_shape // 3. Row with min num workers and max columns wide per worker (first part of rows with min num workers) // 4. Row with min num workers and min columns wide per worker (second part of rows with min num workers) // Depending on specific numbers, some of the above "quadrants" might be 0 sized - const uint32_t max_workers_row_min_cols_per_worker = - tensor_slice_shape_in_tiles.x / max_num_workers_per_row; - const uint32_t max_workers_row_max_col_worker_count = - tensor_slice_shape_in_tiles.x % max_num_workers_per_row; + const uint32_t max_workers_row_min_cols_per_worker = tensor_slice_shape_in_tiles.x / max_num_workers_per_row; + const uint32_t max_workers_row_max_col_worker_count = tensor_slice_shape_in_tiles.x % max_num_workers_per_row; const uint32_t max_workers_row_max_cols_per_worker = max_workers_row_max_col_worker_count != 0 - ? max_workers_row_min_cols_per_worker + 1 - : max_workers_row_min_cols_per_worker; + ? max_workers_row_min_cols_per_worker + 1 + : max_workers_row_min_cols_per_worker; TT_ASSERT(max_workers_row_min_cols_per_worker > 0); TT_ASSERT(max_workers_row_max_cols_per_worker >= max_workers_row_min_cols_per_worker); for (uint32_t w_r = 0; w_r < num_rows_with_max_workers; w_r++) { @@ -606,13 +652,11 @@ std::vector RingReduceScatterTensorSlicer::create_worker_slice_shape } } - const uint32_t min_workers_row_min_cols_per_worker = - tensor_slice_shape_in_tiles.x / min_num_workers_per_row; - const uint32_t min_workers_row_max_col_worker_count = - tensor_slice_shape_in_tiles.x % min_num_workers_per_row; + const uint32_t min_workers_row_min_cols_per_worker = tensor_slice_shape_in_tiles.x / min_num_workers_per_row; + const uint32_t min_workers_row_max_col_worker_count = tensor_slice_shape_in_tiles.x % min_num_workers_per_row; const uint32_t min_workers_row_max_cols_per_worker = min_workers_row_max_col_worker_count != 0 - ? min_workers_row_min_cols_per_worker + 1 - : min_workers_row_min_cols_per_worker; + ? min_workers_row_min_cols_per_worker + 1 + : min_workers_row_min_cols_per_worker; for (uint32_t w_r = num_rows_with_max_workers; w_r < tensor_slice_shape_in_tiles.y; w_r++) { for (uint32_t w_c = 0; w_c < min_workers_row_max_cols_per_worker; w_c++) { @@ -631,8 +675,7 @@ std::vector RingReduceScatterTensorSlicer::create_worker_slice_shape // the max size and then apply that shape to all workers slice shapes tt_xy_pair largest_worker_slice_shape = {0, 0}; for (auto const& worker_slice_shape : worker_slice_shapes) { - if (largest_worker_slice_shape.x * largest_worker_slice_shape.y < - worker_slice_shape.x * worker_slice_shape.y) { + if (largest_worker_slice_shape.x * largest_worker_slice_shape.y < worker_slice_shape.x * worker_slice_shape.y) { largest_worker_slice_shape = worker_slice_shape; } } @@ -646,7 +689,8 @@ std::vector RingReduceScatterTensorSlicer::create_worker_slice_shape largest_worker_slice_shape.y = 1; } - bool do_truncation = ((largest_worker_slice_shape.x * largest_worker_slice_shape.y) > max_slice_size_in_tiles) || has_gt_1_depth_size; + bool do_truncation = ((largest_worker_slice_shape.x * largest_worker_slice_shape.y) > max_slice_size_in_tiles) || + has_gt_1_depth_size; if (do_truncation) { log_trace(tt::LogOp, "Truncating worker slice shapes to fit max slice size in tiles"); } @@ -660,7 +704,8 @@ std::vector RingReduceScatterTensorSlicer::create_worker_slice_shape return tt::round_up(worker_slice_shape.x * worker_slice_shape.y, half_cb_n_pages); }; - while (get_padded_worker_slice_size_in_tiles(largest_worker_slice_shape, half_cb_n_pages) > max_slice_size_in_tiles) { + while (get_padded_worker_slice_size_in_tiles(largest_worker_slice_shape, half_cb_n_pages) > + max_slice_size_in_tiles) { log_trace(tt::LogOp, "Loop Head"); // truncate the largest dim first uint32_t delta = (largest_worker_slice_shape.x * largest_worker_slice_shape.y) - max_slice_size_in_tiles; @@ -670,9 +715,9 @@ std::vector RingReduceScatterTensorSlicer::create_worker_slice_shape uint32_t rows_removed_if_y_truncated = std::max(1, largest_worker_slice_shape.y / delta); uint32_t tiles_removed_if_y_truncated = rows_removed_if_y_truncated * largest_worker_slice_shape.x; uint32_t difference_x = tiles_removed_if_x_truncated > delta ? tiles_removed_if_x_truncated - delta - : delta - tiles_removed_if_x_truncated; + : delta - tiles_removed_if_x_truncated; uint32_t difference_y = tiles_removed_if_y_truncated > delta ? tiles_removed_if_y_truncated - delta - : delta - tiles_removed_if_y_truncated; + : delta - tiles_removed_if_y_truncated; log_trace(tt::LogOp, "-- cols_removed_if_x_truncated: {}", cols_removed_if_x_truncated); log_trace(tt::LogOp, "-- tiles_removed_if_x_truncated: {}", tiles_removed_if_x_truncated); log_trace(tt::LogOp, "-- rows_removed_if_y_truncated: {}", rows_removed_if_y_truncated); @@ -697,7 +742,10 @@ std::vector RingReduceScatterTensorSlicer::create_worker_slice_shape largest_worker_slice_shape.x, largest_worker_slice_shape.y); if (!(largest_worker_slice_shape.x * largest_worker_slice_shape.y > 0)) { - log_warning(tt::LogOp, "Computing worker slice shape for reduce scatter resulted in 0 sized slice. Defaulting to 1x1 page per worker, which is likely to lead to suboptimal performance"); + log_warning( + tt::LogOp, + "Computing worker slice shape for reduce scatter resulted in 0 sized slice. Defaulting to 1x1 page per " + "worker, which is likely to lead to suboptimal performance"); largest_worker_slice_shape.x = 1; largest_worker_slice_shape.y = 1; } @@ -707,8 +755,7 @@ std::vector RingReduceScatterTensorSlicer::create_worker_slice_shape } } - TT_ASSERT( - num_tiles_accounted_for == total_num_tiles, "All tiles must be accounted for in the worker slice shapes"); + TT_ASSERT(num_tiles_accounted_for == total_num_tiles, "All tiles must be accounted for in the worker slice shapes"); TT_ASSERT(worker_slice_shapes.size() == num_workers, "Worker slice shapes must match the number of workers"); std::for_each( worker_slice_shapes.begin(), @@ -720,12 +767,11 @@ std::vector RingReduceScatterTensorSlicer::create_worker_slice_shape } std::vector RingReduceScatterWrappedTensorSlicer::create_worker_slice_shapes_for_tile_layout( - tt::tt_metal::LegacyShape const& tensor_shape, - tt_xy_pair const& tensor_slice_shape_in_tiles, - uint32_t num_workers, - uint32_t max_slice_size_in_pages, - uint32_t half_cb_n_pages) -{ + tt::tt_metal::LegacyShape const& tensor_shape, + tt_xy_pair const& tensor_slice_shape_in_tiles, + uint32_t num_workers, + uint32_t max_slice_size_in_pages, + uint32_t half_cb_n_pages) { log_trace(tt::LogOp, "\tmax_slice_size_in_pages={}", max_slice_size_in_pages); TT_ASSERT(max_slice_size_in_pages > 0); std::vector worker_slice_shapes; @@ -748,13 +794,14 @@ std::vector RingReduceScatterWrappedTensorSlicer::create_worker_slic std::size_t max_slice_size_in_tiles = max_slice_size_in_pages; // Assign slices by assuming that the input tensor is flattened into a 1D Shape - std::size_t optim_worker_slice_len_tiles = ceil(total_num_tiles / num_workers); // Ceil so that the remainder worker will have a smaller slice + std::size_t optim_worker_slice_len_tiles = + ceil(total_num_tiles / num_workers); // Ceil so that the remainder worker will have a smaller slice - if (max_slice_size_in_tiles < optim_worker_slice_len_tiles) { // Each worker will have a full slice + if (max_slice_size_in_tiles < optim_worker_slice_len_tiles) { // Each worker will have a full slice for (uint32_t w = 0; w < num_workers; ++w) { worker_slice_shapes.emplace_back(max_slice_size_in_tiles, 1); } - } else { // Each worker will only have one slice + } else { // Each worker will only have one slice uint32_t remainder_worker_len_tiles = total_num_tiles % optim_worker_slice_len_tiles; for (uint32_t w = 0; w < num_workers; ++w) { @@ -769,10 +816,9 @@ std::vector RingReduceScatterWrappedTensorSlicer::create_worker_slic return worker_slice_shapes; } - /* - * @brief: Given a tensor shape, evenly break it into pieces along a given dimension and generate the slices accordingly. - * This can be fed into a CCL Send command generator + * @brief: Given a tensor shape, evenly break it into pieces along a given dimension and generate the slices + * accordingly. This can be fed into a CCL Send command generator */ std::vector generate_slice_sequence_on_dim( TensorSlice::ords_t tensor_shape, @@ -781,11 +827,12 @@ std::vector generate_slice_sequence_on_dim( std::size_t num_slices, std::int64_t start_slice_index, std::int64_t end_slice_index_exclusive, - std::size_t worker_index -) { - static_assert(std::is_same_v, "generate_slice_sequence_on_dim not yet implemented for type not of tt_xy_pair"); - // We don't support 4D shapes in the CCL kernels yet, which are needed for proper reduction/concatenation in some cases - // so for now we subtract the outer dims from the fracture_dim since we only support 2D at the moment. + std::size_t worker_index) { + static_assert( + std::is_same_v, + "generate_slice_sequence_on_dim not yet implemented for type not of tt_xy_pair"); + // We don't support 4D shapes in the CCL kernels yet, which are needed for proper reduction/concatenation in some + // cases so for now we subtract the outer dims from the fracture_dim since we only support 2D at the moment. if (fracture_dim == 3) { fracture_dim -= 2; } else { @@ -799,21 +846,170 @@ std::vector generate_slice_sequence_on_dim( auto dim_size = fracture_dim == 1 ? tensor_shape.x : tensor_shape.y; TT_ASSERT(dim_size % num_slices == 0); auto slice_size_on_dim = dim_size / num_slices; - auto slice_shape = fracture_dim == 0 ? tt_xy_pair{tensor_shape.x, slice_size_on_dim} : tt_xy_pair{slice_size_on_dim, tensor_shape.y}; + auto slice_shape = fracture_dim == 0 ? tt_xy_pair{tensor_shape.x, slice_size_on_dim} + : tt_xy_pair{slice_size_on_dim, tensor_shape.y}; auto dim_start_offset = start_slice_index * slice_size_on_dim; - TensorSlice::ords_t tensor_slice_offset = fracture_dim == 0 ? tt_xy_pair{0, dim_start_offset} : tt_xy_pair{dim_start_offset, 0}; + TensorSlice::ords_t tensor_slice_offset = + fracture_dim == 0 ? tt_xy_pair{0, dim_start_offset} : tt_xy_pair{dim_start_offset, 0}; - bool forward_direction = start_slice_index > end_slice_index_exclusive; // only for debug + bool forward_direction = start_slice_index > end_slice_index_exclusive; // only for debug auto incr = start_slice_index < end_slice_index_exclusive ? 1 : -1; if (forward_direction) { log_trace(tt::LogOp, "slice_size_on_dim {}", slice_size_on_dim); log_trace(tt::LogOp, "worker_index {}", worker_index); } - auto worker_slice_start_offset = /*fracture_dim == 0 ? TensorSlice::ords_t{0, worker_index * worker_slice_shape.y} :*/ TensorSlice::ords_t{worker_index * worker_slice_shape.x, 0}; + auto worker_slice_start_offset = + /*fracture_dim == 0 ? TensorSlice::ords_t{0, worker_index * worker_slice_shape.y} :*/ TensorSlice::ords_t{ + worker_index * worker_slice_shape.x, 0}; + + auto generate_slice = [forward_direction, + incr, + &slices, + &tensor_shape, + &slice_shape, + &worker_slice_shape, + tensor_slice_offset, + &worker_slice_start_offset, + fracture_dim, + dim_start_offset, + slice_size_on_dim](std::int64_t i) { + auto tensor_slice_offset_adjusted = tensor_slice_offset; + if (fracture_dim == 0) { + tensor_slice_offset_adjusted.y = slice_size_on_dim * i; + } else { + tensor_slice_offset_adjusted.x = slice_size_on_dim * i; + } + TT_ASSERT(tensor_shape.x > 0, "Invalid tensor shape. x = 0 but it must be > 0"); + TT_ASSERT(tensor_shape.y > 0, "Invalid tensor shape. y = 0 but it must be > 0"); + TT_ASSERT(slice_shape.x > 0, "Invalid tensor slice shape. x = 0 but it must be > 0"); + TT_ASSERT(slice_shape.y > 0, "Invalid tensor slice shape. x = 0 but it must be > 0"); + TT_ASSERT( + tensor_slice_offset_adjusted.x < tensor_shape.x, + "Invalid tensor slice offset. x = {} but it must be < tensor shape x={}. slice_offset: (y={},x={}), " + "tensor_shape: (y={},x={}). slice_size_on_dim: {}, i: {}", + tensor_slice_offset_adjusted.x, + tensor_shape.x, + tensor_slice_offset_adjusted.y, + tensor_slice_offset_adjusted.x, + tensor_shape.y, + tensor_shape.x, + slice_size_on_dim, + i); + TT_ASSERT( + tensor_slice_offset_adjusted.y < tensor_shape.y, + "Invalid tensor slice offset. y = {} but it must be < tensor shape y={}. slice_offset: (y={},x={}), " + "tensor_shape: (y={},x={}). slice_size_on_dim: {}, i: {}", + tensor_slice_offset_adjusted.y, + tensor_shape.y, + tensor_slice_offset_adjusted.y, + tensor_slice_offset_adjusted.x, + tensor_shape.y, + tensor_shape.x, + slice_size_on_dim, + i); + TT_ASSERT(worker_slice_shape.x > 0, "Invalid worker slice shape. x = 0 but it must be > 0"); + TT_ASSERT(worker_slice_shape.y > 0, "Invalid worker slice shape. y = 0 but it must be > 0"); + + auto const& tensor_slice = TensorSlice( + tensor_shape, + slice_shape, + tensor_slice_offset_adjusted, + worker_slice_shape, + worker_slice_start_offset, + fracture_dim); + if (forward_direction) { + log_trace( + tt::LogOp, + "generate_slice ({}):\n\ttensor_shape: (y={},x={})\n\ttensor_slice_shape: " + "(y={},x={})\n\ttensor_slice_offset_adjusted: (y={},x={})\n\tslice_start_shape: (y={},x={})\n\tworker " + "relative slice_start_offset: (y={},x={})\n\tfracture_dim: {}\n\tdim_start_offset: " + "{}\n\tslice_size_on_dim: {}\n", + i, + tensor_slice.tensor_shape.y, + tensor_slice.tensor_shape.x, + tensor_slice.tensor_slice_shape.y, + tensor_slice.tensor_slice_shape.x, + tensor_slice.tensor_slice_offset.y, + tensor_slice.tensor_slice_offset.x, + tensor_slice.worker_slice_shape.y, + tensor_slice.worker_slice_shape.x, + tensor_slice.worker_slice_offset.y, + tensor_slice.worker_slice_offset.x, + fracture_dim, + dim_start_offset, + slice_size_on_dim); + } - auto generate_slice = [forward_direction,incr, &slices, &tensor_shape, &slice_shape, &worker_slice_shape, tensor_slice_offset, &worker_slice_start_offset, fracture_dim, dim_start_offset, slice_size_on_dim](std::int64_t i){ + slices.push_back(tensor_slice); + }; + + for (int i = start_slice_index; i != end_slice_index_exclusive; i += incr) { + generate_slice(i); + } + + return slices; +} + +/* + * @brief: Given a tensor shape, evenly break it into pieces along a given dimension and generate the slices + * accordingly. This can be fed into a CCL Send command generator + */ +std::vector generate_slice_sequence_on_dim_v2( + TensorSlice::ords_t tensor_shape, + TensorSlice::ords_t worker_slice_shape, + TensorSlice::ords_t worker_slice_offset, + std::size_t fracture_dim, + std::size_t num_slices, + std::int64_t start_slice_index, + std::int64_t end_slice_index_exclusive, + std::size_t worker_index) { + static_assert( + std::is_same_v, + "generate_slice_sequence_on_dim_v2 not yet implemented for type not of tt_xy_pair"); + // We don't support 4D shapes in the CCL kernels yet, which are needed for proper reduction/concatenation in some + // cases so for now we subtract the outer dims from the fracture_dim since we only support 2D at the moment. + if (fracture_dim == 3) { + fracture_dim -= 2; + } else { + // dims are + fracture_dim = 0; + } + + TT_ASSERT(worker_slice_shape.y == 1); + + std::vector slices; + auto dim_size = fracture_dim == 1 ? tensor_shape.x : tensor_shape.y; + TT_ASSERT(dim_size % num_slices == 0); + auto slice_size_on_dim = dim_size / num_slices; + auto slice_shape = fracture_dim == 0 ? tt_xy_pair{tensor_shape.x, slice_size_on_dim} + : tt_xy_pair{slice_size_on_dim, tensor_shape.y}; + + auto dim_start_offset = start_slice_index * slice_size_on_dim; + TensorSlice::ords_t tensor_slice_offset = + fracture_dim == 0 ? tt_xy_pair{0, dim_start_offset} : tt_xy_pair{dim_start_offset, 0}; + + bool forward_direction = start_slice_index > end_slice_index_exclusive; // only for debug + auto incr = start_slice_index < end_slice_index_exclusive ? 1 : -1; + if (forward_direction) { + log_trace(tt::LogOp, "slice_size_on_dim {}", slice_size_on_dim); + log_trace(tt::LogOp, "worker_index {}", worker_index); + } + + auto worker_slice_start_offset = worker_slice_offset; + + auto generate_slice = [forward_direction, + incr, + &slices, + &tensor_shape, + &slice_shape, + &worker_slice_shape, + tensor_slice_offset, + &worker_slice_start_offset, + fracture_dim, + dim_start_offset, + slice_size_on_dim](std::int64_t i) { auto tensor_slice_offset_adjusted = tensor_slice_offset; if (fracture_dim == 0) { tensor_slice_offset_adjusted.y = slice_size_on_dim * i; @@ -824,30 +1020,61 @@ std::vector generate_slice_sequence_on_dim( TT_ASSERT(tensor_shape.y > 0, "Invalid tensor shape. y = 0 but it must be > 0"); TT_ASSERT(slice_shape.x > 0, "Invalid tensor slice shape. x = 0 but it must be > 0"); TT_ASSERT(slice_shape.y > 0, "Invalid tensor slice shape. x = 0 but it must be > 0"); - TT_ASSERT(tensor_slice_offset_adjusted.x < tensor_shape.x, "Invalid tensor slice offset. x = {} but it must be < tensor shape x={}. slice_offset: (y={},x={}), tensor_shape: (y={},x={}). slice_size_on_dim: {}, i: {}", tensor_slice_offset_adjusted.x, tensor_shape.x, tensor_slice_offset_adjusted.y, tensor_slice_offset_adjusted.x, tensor_shape.y, tensor_shape.x, slice_size_on_dim, i); - TT_ASSERT(tensor_slice_offset_adjusted.y < tensor_shape.y, "Invalid tensor slice offset. y = {} but it must be < tensor shape y={}. slice_offset: (y={},x={}), tensor_shape: (y={},x={}). slice_size_on_dim: {}, i: {}", tensor_slice_offset_adjusted.y, tensor_shape.y, tensor_slice_offset_adjusted.y, tensor_slice_offset_adjusted.x, tensor_shape.y, tensor_shape.x, slice_size_on_dim, i); + TT_ASSERT( + tensor_slice_offset_adjusted.x < tensor_shape.x, + "Invalid tensor slice offset. x = {} but it must be < tensor shape x={}. slice_offset: (y={},x={}), " + "tensor_shape: (y={},x={}). slice_size_on_dim: {}, i: {}", + tensor_slice_offset_adjusted.x, + tensor_shape.x, + tensor_slice_offset_adjusted.y, + tensor_slice_offset_adjusted.x, + tensor_shape.y, + tensor_shape.x, + slice_size_on_dim, + i); + TT_ASSERT( + tensor_slice_offset_adjusted.y < tensor_shape.y, + "Invalid tensor slice offset. y = {} but it must be < tensor shape y={}. slice_offset: (y={},x={}), " + "tensor_shape: (y={},x={}). slice_size_on_dim: {}, i: {}", + tensor_slice_offset_adjusted.y, + tensor_shape.y, + tensor_slice_offset_adjusted.y, + tensor_slice_offset_adjusted.x, + tensor_shape.y, + tensor_shape.x, + slice_size_on_dim, + i); TT_ASSERT(worker_slice_shape.x > 0, "Invalid worker slice shape. x = 0 but it must be > 0"); TT_ASSERT(worker_slice_shape.y > 0, "Invalid worker slice shape. y = 0 but it must be > 0"); - auto const& tensor_slice = TensorSlice(tensor_shape, slice_shape, tensor_slice_offset_adjusted, worker_slice_shape, worker_slice_start_offset, fracture_dim); + auto const& tensor_slice = TensorSlice( + tensor_shape, + slice_shape, + tensor_slice_offset_adjusted, + worker_slice_shape, + worker_slice_start_offset, + fracture_dim); if (forward_direction) { - log_trace( - tt::LogOp, - "generate_slice ({}):\n\ttensor_shape: (y={},x={})\n\ttensor_slice_shape: (y={},x={})\n\ttensor_slice_offset_adjusted: (y={},x={})\n\tslice_start_shape: (y={},x={})\n\tworker relative slice_start_offset: (y={},x={})\n\tfracture_dim: {}\n\tdim_start_offset: {}\n\tslice_size_on_dim: {}\n", - i, - tensor_slice.tensor_shape.y, - tensor_slice.tensor_shape.x, - tensor_slice.tensor_slice_shape.y, - tensor_slice.tensor_slice_shape.x, - tensor_slice.tensor_slice_offset.y, - tensor_slice.tensor_slice_offset.x, - tensor_slice.worker_slice_shape.y, - tensor_slice.worker_slice_shape.x, - tensor_slice.worker_slice_offset.y, - tensor_slice.worker_slice_offset.x, - fracture_dim, - dim_start_offset, - slice_size_on_dim); + log_trace( + tt::LogOp, + "generate_slice ({}):\n\ttensor_shape: (y={},x={})\n\ttensor_slice_shape: " + "(y={},x={})\n\ttensor_slice_offset_adjusted: (y={},x={})\n\tslice_start_shape: (y={},x={})\n\tworker " + "relative slice_start_offset: (y={},x={})\n\tfracture_dim: {}\n\tdim_start_offset: " + "{}\n\tslice_size_on_dim: {}\n", + i, + tensor_slice.tensor_shape.y, + tensor_slice.tensor_shape.x, + tensor_slice.tensor_slice_shape.y, + tensor_slice.tensor_slice_shape.x, + tensor_slice.tensor_slice_offset.y, + tensor_slice.tensor_slice_offset.x, + tensor_slice.worker_slice_shape.y, + tensor_slice.worker_slice_shape.x, + tensor_slice.worker_slice_offset.y, + tensor_slice.worker_slice_offset.x, + fracture_dim, + dim_start_offset, + slice_size_on_dim); } slices.push_back(tensor_slice); @@ -860,5 +1087,302 @@ std::vector generate_slice_sequence_on_dim( return slices; } +GenericWrappedTensorSlicer::GenericWrappedTensorSlicer( + const Tensor& input_tensor, + const Tensor& output_tensor, + int slice_dim, + uint32_t partition_index, + uint32_t partition_size, + uint32_t total_num_workers, + uint32_t max_slice_size_in_bytes, + uint32_t half_cb_n_pages) { + this->initialize( + input_tensor, + output_tensor, + slice_dim, + partition_index, + partition_size, + total_num_workers, + max_slice_size_in_bytes, + half_cb_n_pages); +} + +tt_xy_pair GenericWrappedTensorSlicer::calculate_tensor_slice_shape( + const Tensor& input_tensor, int slice_dim, uint32_t partition_size) { + const uint32_t num_tiles_x = input_tensor.get_legacy_shape()[-1] / tt::constants::TILE_WIDTH; + uint32_t num_tiles_y = (input_tensor.get_legacy_shape()[-2] / tt::constants::TILE_HEIGHT); + for (std::size_t i = 0; + input_tensor.get_legacy_shape().rank() > 2 && i < input_tensor.get_legacy_shape().rank() - 2; + i++) { + num_tiles_y *= input_tensor.get_legacy_shape()[i]; + } + TT_ASSERT(num_tiles_x >= partition_size); + tt_xy_pair tensor_slice_shape; + tensor_slice_shape.x = slice_dim == 3 ? (num_tiles_x / partition_size) : num_tiles_x; + tensor_slice_shape.y = slice_dim != 3 ? num_tiles_y / partition_size : num_tiles_y; + return tensor_slice_shape; +} + +void GenericWrappedTensorSlicer::initialize( + const Tensor& input_tensor, + const Tensor& output_tensor, + int slice_dim, + uint32_t partition_index, + uint32_t partition_size, + uint32_t total_num_workers, + uint32_t max_slice_size_in_bytes, + uint32_t half_cb_n_pages) { + // Configure layout parameters + this->row_major = (input_tensor.get_layout() == Layout::ROW_MAJOR); + this->input_page_size = input_tensor.buffer()->page_size(); + this->partition_index = partition_index; + this->partition_size = partition_size; + + // Assume everything in Tile layout for now, row major not supported yet + TT_FATAL(!this->row_major, "Row major not supported yet"); + + this->tensor_slice_shape = calculate_tensor_slice_shape(input_tensor, slice_dim, partition_size); + + // Calculate worker slice shapes (tile layout) + this->worker_slice_shapes = create_worker_slice_shapes_for_tile_layout( + input_tensor.get_legacy_shape(), + this->tensor_slice_shape, + total_num_workers, + max_slice_size_in_bytes / this->input_page_size, + half_cb_n_pages); + + // Flattened tensor shape (tile layout) + this->flattened_tensor_shape = tt_xy_pair{ + input_tensor.get_legacy_shape()[3] / tt::constants::TILE_WIDTH, + (input_tensor.get_legacy_shape()[0] * input_tensor.get_legacy_shape()[1] * input_tensor.get_legacy_shape()[2]) / + tt::constants::TILE_HEIGHT}; + + this->worker_slice_offsets = compute_worker_slice_offsets(this->worker_slice_shapes, this->tensor_slice_shape); +} + +ccl::InterleavedTensorWorkerSlice GenericWrappedTensorSlicer::get_worker_slice(std::size_t global_worker_index) { + assert(global_worker_index < this->worker_slice_shapes.size()); + assert(global_worker_index < this->worker_slice_offsets.size()); + return ccl::InterleavedTensorWorkerSlice( + this->flattened_tensor_shape, + this->tensor_slice_shape, + this->worker_slice_shapes[global_worker_index], + this->worker_slice_offsets[global_worker_index], + true // wrapped + ); +} + +std::vector GenericWrappedTensorSlicer::compute_worker_slice_offsets( + std::vector const& worker_slice_shapes, tt_xy_pair const& tensor_slice_shape) { + return compute_worker_slice_offsets_for_wrapped_tensor_slicer(worker_slice_shapes, tensor_slice_shape); +} + +std::vector GenericWrappedTensorSlicer::create_worker_slice_shapes_for_tile_layout( + tt::tt_metal::LegacyShape const& tensor_shape, + tt_xy_pair const& tensor_slice_shape_in_tiles, + uint32_t num_workers, + uint32_t max_slice_size_in_pages, + uint32_t half_cb_n_pages) { + log_trace(tt::LogOp, "\tmax_slice_size_in_pages={}", max_slice_size_in_pages); + TT_ASSERT(max_slice_size_in_pages > 0); + std::vector worker_slice_shapes; + worker_slice_shapes.reserve(num_workers); + const uint32_t total_num_tiles = tensor_slice_shape_in_tiles.x * tensor_slice_shape_in_tiles.y; + if (num_workers > total_num_tiles) { + log_warning( + tt::LogOp, + "Reduce Scatter more workers instantiated than is work to be done. Some workers will be idle and do " + "nothing"); + for (uint32_t w = 0; w < total_num_tiles; ++w) { + worker_slice_shapes.emplace_back(1, 1); + } + for (uint32_t w = total_num_tiles; w < num_workers; ++w) { + worker_slice_shapes.emplace_back(0, 0); + } + return worker_slice_shapes; + } + + std::size_t max_slice_size_in_tiles = max_slice_size_in_pages; + + // Assign slices by assuming that the input tensor is flattened into a 1D Shape + std::size_t optim_worker_slice_len_tiles = std::ceil( + static_cast(total_num_tiles) / + num_workers); // Ceil so that the remainder worker will have a smaller slice + + log_trace(tt::LogOp, "---- GenericWrappedTensorSlicer::create_worker_slice_shapes_for_tile_layout ---- "); + log_trace(tt::LogOp, "total_num_tiles: {}", total_num_tiles); + log_trace(tt::LogOp, "num_workers: {}", num_workers); + log_trace(tt::LogOp, "optim_worker_slice_len_tiles: {}", optim_worker_slice_len_tiles); + + if (max_slice_size_in_tiles < optim_worker_slice_len_tiles) { // Each worker will have a full slice + for (uint32_t w = 0; w < num_workers; ++w) { + worker_slice_shapes.emplace_back(max_slice_size_in_tiles, 1); + } + } else { // Each worker will only have one slice + uint32_t remainder_worker_len_tiles = total_num_tiles % optim_worker_slice_len_tiles; + + for (uint32_t w = 0; w < num_workers; ++w) { + worker_slice_shapes.emplace_back(optim_worker_slice_len_tiles, 1); + } + // If there is a remainder worker, we need to adjust the last worker's slice shape to be smaller + if (remainder_worker_len_tiles > 0) { + worker_slice_shapes.back() = tt_xy_pair{remainder_worker_len_tiles, 1}; + } + } + + log_trace(tt::LogOp, "--------------------------------"); + + return worker_slice_shapes; +} + + + + +GenericWrappedTensorSlicerV2::GenericWrappedTensorSlicerV2( + const Tensor& input_tensor, + int slice_dim, + uint32_t partition_index, + uint32_t partition_size, + uint32_t total_num_workers) +{ + this->initialize(input_tensor, slice_dim, partition_index, partition_size, total_num_workers); +} + +Shape4D GenericWrappedTensorSlicerV2::calculate_tensor_slice_shape( + Shape4D const& input_shape, + int slice_dim, + uint32_t partition_size) { + + // Calculate the size of the slice along the given dimension + uint32_t dim_size = input_shape[slice_dim]; + uint32_t slice_size = dim_size / partition_size; + + // Start with full shape + Shape4D slice_shape(input_shape[0], input_shape[1], input_shape[2], input_shape[3]); + + TT_FATAL(slice_dim >= 0 && slice_dim < 4, "Invalid slice dimension. Must be between 0 and 3 but got {}. This should have been normalized to fit within the range", slice_dim); + slice_shape[slice_dim] = slice_size; + + return slice_shape; +} + +Shape4D GenericWrappedTensorSlicerV2::calculate_tensor_slice_offset( + Shape4D const& input_shape, + int slice_dim, + uint32_t partition_index) { + + Shape4D offset(0, 0, 0, 0); + + // Calculate the size of the slice along the given dimension + uint32_t dim_size = input_shape[slice_dim]; + uint32_t slice_size = dim_size / partition_size; + + TT_FATAL(slice_dim >= 0 && slice_dim < 4, "Invalid slice dimension. Must be between 0 and 3 but got {}. This should have been normalized to fit within the range", slice_dim); + offset[slice_dim] = partition_index * slice_size; + + return offset; +} + +void GenericWrappedTensorSlicerV2::initialize( + const Tensor& input_tensor, + int slice_dim, + uint32_t partition_index, + uint32_t partition_size, + uint32_t total_num_workers) +{ + // Configure layout parameters + this->row_major = (input_tensor.get_layout() == Layout::ROW_MAJOR); + this->input_page_size = input_tensor.buffer()->page_size(); + this->partition_index = partition_index; + this->partition_size = partition_size; + + // Assume everything in Tile layout for now, row major not supported yet + TT_FATAL(!this->row_major, "Row major not supported yet"); + + // Record the input tensor shape + auto input_shape = input_tensor.get_legacy_shape(); + this->tensor_shape = Shape4D(input_shape[0], input_shape[1], input_shape[2]/tt::constants::TILE_HEIGHT, input_shape[3]/tt::constants::TILE_WIDTH); + + // Calculate tensor slice shape + this->tensor_slice_shape = calculate_tensor_slice_shape(this->tensor_shape, slice_dim, partition_size); + + // Calculate tensor slice offset + this->tensor_slice_offset = calculate_tensor_slice_offset(this->tensor_shape, slice_dim, partition_index); + + // Calculate worker slice shapes in terms of flattened tiles + this->worker_slice_shapes = create_worker_slice_shapes_for_tile_layout(this->tensor_slice_shape, total_num_workers); + + // Calculate worker slice offsets in terms of flattened tiles + this->worker_slice_offsets = compute_worker_slice_offsets(this->worker_slice_shapes); +} + +ttnn::ccl::v2::TensorSlice GenericWrappedTensorSlicerV2::get_worker_slice_v2(std::size_t global_worker_index) { + assert(global_worker_index < this->worker_slice_shapes.size()); + assert(global_worker_index < this->worker_slice_offsets.size()); + return ttnn::ccl::v2::TensorSlice( + this->tensor_shape, // tensor_shape + this->tensor_slice_shape, // tensor_slice_shape + this->tensor_slice_offset, // tensor_slice_offset + this->worker_slice_shapes[global_worker_index], // worker_slice_shape + this->worker_slice_offsets[global_worker_index] // worker_slice_offset + ); +} + +/* Worker slices and offsets are 4D shapes but flattened to 1D in the last dimension*/ + +std::vector> GenericWrappedTensorSlicerV2::compute_worker_slice_offsets(std::vector> const& worker_slice_shapes) { + Shape4D offset(0, 0, 0, 0); + std::vector> worker_slice_offsets; + worker_slice_offsets.reserve(worker_slice_shapes.size()); + for (const auto& slice_shape : worker_slice_shapes) { + worker_slice_offsets.push_back(offset); + offset.x += slice_shape.x; + } + return worker_slice_offsets; +} + +std::vector> GenericWrappedTensorSlicerV2::create_worker_slice_shapes_for_tile_layout( + Shape4D const& tensor_slice_shape_in_tiles, + uint32_t num_workers) +{ + std::vector> worker_slice_shapes; + worker_slice_shapes.reserve(num_workers); + const uint32_t total_num_tiles = tensor_slice_shape_in_tiles.x * tensor_slice_shape_in_tiles.y * tensor_slice_shape_in_tiles.z * tensor_slice_shape_in_tiles.w; + if (num_workers > total_num_tiles) { + log_warning( + tt::LogOp, + "More workers instantiated than is work to be done. Some workers will be idle and do nothing"); + for (uint32_t w = 0; w < total_num_tiles; ++w) { + worker_slice_shapes.emplace_back(1,1,1,1); + } + for (uint32_t w = total_num_tiles; w < num_workers; ++w) { + worker_slice_shapes.emplace_back(0,0,0,0); + } + return worker_slice_shapes; + } + + // Assign slices by assuming that the input tensor is flattened into a 1D Shape + std::size_t optim_worker_slice_len_tiles = std::ceil(static_cast(total_num_tiles) / num_workers); // Ceil so that the remainder worker will have a smaller slice + + log_trace(tt::LogOp, "---- GenericWrappedTensorSlicer::create_worker_slice_shapes_for_tile_layout ---- "); + log_trace(tt::LogOp, "total_num_tiles: {}", total_num_tiles); + log_trace(tt::LogOp, "num_workers: {}", num_workers); + log_trace(tt::LogOp, "optim_worker_slice_len_tiles: {}", optim_worker_slice_len_tiles); + + uint32_t remainder_worker_len_tiles = total_num_tiles % optim_worker_slice_len_tiles; + + for (uint32_t w = 0; w < num_workers; ++w) { + worker_slice_shapes.emplace_back(Shape4D(1, 1, 1, optim_worker_slice_len_tiles)); + } + // If there is a remainder worker, we need to adjust the last worker's slice shape to be smaller + if (remainder_worker_len_tiles > 0) { + worker_slice_shapes.back() = Shape4D(1,1,1,remainder_worker_len_tiles); + } + + log_trace(tt::LogOp, "--------------------------------"); + + return worker_slice_shapes; +} + } // namespace ccl } // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp b/ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp index f714f3e44ca..b3bd0e6c7a9 100644 --- a/ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp @@ -14,10 +14,21 @@ #include "tt_metal/host_api.hpp" #include "tt_metal/impl/program/program.hpp" #include "ttnn/tensor/types.hpp" +#include "ttnn/operations/ccl/erisc_datamover_builder.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/common/host/ccl_command_stream_builders.hpp" namespace ttnn { namespace ccl { +struct SyncModeSpec { + uint32_t num_signals = 0; + CoreCoord core; + std::vector sem_ids; + std::vector wait_counts; + + void add_signal(uint32_t sem_id, uint32_t wait_count); +}; + class FabricEriscDatamoverBuilder; class EriscDatamoverBuilder; @@ -26,6 +37,32 @@ std::tuple, std::optional> get_dev const std::vector& devices, const ttnn::ccl::Topology& topology); + +class LineTopology { + public: + LineTopology( + size_t line_size, + size_t line_index); + + bool is_first_device_in_line(ttnn::ccl::EdmLineFabricOpInterface::Direction direction) const; + bool is_last_device_in_line(ttnn::ccl::EdmLineFabricOpInterface::Direction direction) const; + + bool is_at_end_of_line() const; + + size_t line_size() const; + + size_t line_index() const; + + size_t get_distance_to_end_of_line(ttnn::ccl::EdmLineFabricOpInterface::Direction direction) const; + + ttnn::ccl::Topology topology() const; + + private: + size_t _line_size; + size_t _line_index; +}; + + // Eventual home: ccl_topology_descriptors struct RingTopology { RingTopology( @@ -51,6 +88,17 @@ struct RingTopology { bool is_linear; }; +struct TensorPartition { + TensorPartition( + uint32_t partition_size, + uint32_t partition_index) + : partition_size(partition_size), + partition_index(partition_index) {} + + uint32_t partition_size; + uint32_t partition_index; +}; + class CclOpTensorConfig { public: static std::unique_ptr build_all_gather_tensor_config(Tensor const& tensor); @@ -206,6 +254,8 @@ struct LegacyCclTensorSlicer { }; + +inline namespace v1 { struct TensorSlice { using ords_t = tt_xy_pair; ords_t tensor_shape; @@ -215,6 +265,7 @@ struct TensorSlice { ords_t worker_slice_offset; std::size_t dim; }; +}; // Workers iterate over tensor slices in a sequence along a // single, specified dimension. Workers iterator over the tensor @@ -264,6 +315,16 @@ struct InterleavedTensorWorkerSlice { return worker_slice_shape.x * worker_slice_shape.y; } + void print() const { + log_trace(tt::LogOp, "----- printing worker slice -----"); + log_trace(tt::LogOp, "tensor_shape: ({},{})", tensor_shape.x, tensor_shape.y); + log_trace(tt::LogOp, "tensor_slice_shape: ({},{})", tensor_slice_shape.x, tensor_slice_shape.y); + log_trace(tt::LogOp, "worker_slice_shape: ({},{})", worker_slice_shape.x, worker_slice_shape.y); + log_trace(tt::LogOp, "worker_slice_offset: ({},{})", worker_slice_offset.x, worker_slice_offset.y); + log_trace(tt::LogOp, "worker_slice_is_wrapped: {}", worker_slice_is_wrapped); + log_trace(tt::LogOp, "worker_slice_num_pages: {}", get_worker_slice_num_pages()); + } + tt_xy_pair tensor_shape; tt_xy_pair tensor_slice_shape; tt_xy_pair worker_slice_shape; @@ -502,5 +563,115 @@ ccl::EriscDatamoverBuilder create_erisc_datamover_builder( ccl::EriscDataMoverBufferSharingMode buffer_sharing_mode, EriscDataMoverTerminationMode termination_mode); + +std::vector generate_slice_sequence_on_dim_v2( + TensorSlice::ords_t tensor_shape, + TensorSlice::ords_t worker_slice_shape, + TensorSlice::ords_t worker_slice_offset, + std::size_t fracture_dim, + std::size_t num_slices, + std::int64_t start_slice_index, + std::int64_t end_slice_index_exclusive, + std::size_t worker_index +); + +class GenericWrappedTensorSlicer { +public: + GenericWrappedTensorSlicer( + const Tensor& input_tensor, + const Tensor& output_tensor, + int slice_dim, + uint32_t partition_index, + uint32_t partition_size, + uint32_t total_num_workers, + uint32_t max_slice_size_in_bytes, + uint32_t half_cb_n_pages); + + ccl::InterleavedTensorWorkerSlice get_worker_slice(std::size_t global_worker_index); + + ttnn::ccl::v2::TensorSlice get_worker_slice_v2(std::size_t global_worker_index); + + // method to compute offsets in a wrapped layout + std::vector compute_worker_slice_offsets( + const std::vector& worker_slice_shapes, + tt_xy_pair const& tensor_slice_shape); + + // method to create worker slice shapes in a tile layout + std::vector create_worker_slice_shapes_for_tile_layout( + const tt::tt_metal::LegacyShape& tensor_shape, + tt_xy_pair const& tensor_slice_shape_in_tiles, + uint32_t num_workers, + uint32_t max_slice_size_in_pages, + uint32_t half_cb_n_pages); + +private: + void initialize( + const Tensor& input_tensor, + const Tensor& output_tensor, + int slice_dim, + uint32_t partition_index, + uint32_t partition_size, + uint32_t total_num_workers, + uint32_t max_slice_size_in_bytes, + uint32_t half_cb_n_pages); + + tt_xy_pair calculate_tensor_slice_shape(const Tensor& input_tensor, int slice_dim, uint32_t partition_size); + Shape4D calculate_tensor_slice_offset(const Tensor& input_tensor, int slice_dim, uint32_t partition_index); + + // Class member variables + tt_xy_pair flattened_tensor_shape; + tt_xy_pair tensor_slice_shape; + Shape4D tensor_slice_offset; + std::vector worker_slice_shapes; + std::vector worker_slice_offsets; + uint32_t input_page_size; + bool row_major; + uint32_t partition_index; + uint32_t partition_size; +}; + + +class GenericWrappedTensorSlicerV2 { +public: + GenericWrappedTensorSlicerV2( + const Tensor& input_tensor, + int slice_dim, + uint32_t partition_index, + uint32_t partition_size, + uint32_t total_num_workers); + + ttnn::ccl::v2::TensorSlice get_worker_slice_v2(std::size_t global_worker_index); + + // method to compute offsets in a wrapped layout + std::vector> compute_worker_slice_offsets(std::vector> const& worker_slice_shapes); + + // method to create worker slice shapes in a tile layout + std::vector> create_worker_slice_shapes_for_tile_layout( + Shape4D const& tensor_slice_shape_in_tiles, + uint32_t num_workers); + +private: + void initialize( + const Tensor& input_tensor, + int slice_dim, + uint32_t partition_index, + uint32_t partition_size, + uint32_t total_num_workers); + + Shape4D calculate_tensor_slice_shape(Shape4D const& tensor_shape, int slice_dim, uint32_t partition_size); + Shape4D calculate_tensor_slice_offset(Shape4D const& tensor_shape, int slice_dim, uint32_t partition_index); + + // Class member variables + Shape4D tensor_shape; + Shape4D tensor_slice_shape; + Shape4D tensor_slice_offset; + std::vector> worker_slice_shapes; + std::vector> worker_slice_offsets; + uint32_t input_page_size; + bool row_major; + uint32_t partition_index; + uint32_t partition_size; +}; + } // namespace ccl } // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/ccl/ccl_pybind.cpp b/ttnn/cpp/ttnn/operations/ccl/ccl_pybind.cpp new file mode 100644 index 00000000000..cf8c6d8f800 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/ccl_pybind.cpp @@ -0,0 +1,33 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "ttnn/cpp/ttnn/operations/ccl/ccl_pybind.hpp" + +#include "ttnn/operations/ccl/all_gather/all_gather_pybind.hpp" +#include "ttnn/operations/ccl/reduce_scatter/reduce_scatter_pybind.hpp" +#include "ttnn/operations/ccl/barrier/barrier_pybind.hpp" + +#include "ttnn/operations/ccl/ccl_host_datastructures.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/erisc_datamover_builder.hpp" + +namespace ttnn::operations::ccl { + +void py_bind_common(pybind11::module& module) { + py::enum_(module, "Topology") + .value("Ring", ttnn::ccl::Topology::Ring) + .value("Linear", ttnn::ccl::Topology::Linear); + + module.def("initialize_edm_fabric", &ttnn::ccl::initialize_edm_fabric, py::arg("mesh_device"), py::kw_only()); + + module.def("teardown_edm_fabric", &ttnn::ccl::teardown_edm_fabric, py::arg("mesh_device"), py::kw_only()); +} + +void py_module(py::module& module) { + ccl::py_bind_common(module); + ccl::py_bind_all_gather(module); + ccl::py_bind_reduce_scatter(module); + ccl::py_bind_barrier(module); +} + +} // namespace ttnn::operations::ccl diff --git a/ttnn/cpp/ttnn/operations/ccl/ccl_pybind.hpp b/ttnn/cpp/ttnn/operations/ccl/ccl_pybind.hpp new file mode 100644 index 00000000000..37b14b60bf1 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/ccl_pybind.hpp @@ -0,0 +1,16 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include + +namespace py = pybind11; + +namespace ttnn::operations::ccl { + +void py_module(py::module& module); + +} // namespace ttnn::operations::ccl diff --git a/ttnn/cpp/ttnn/operations/ccl/common/host/ccl_command_stream_builders.cpp b/ttnn/cpp/ttnn/operations/ccl/common/host/ccl_command_stream_builders.cpp new file mode 100644 index 00000000000..1882a7383df --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/common/host/ccl_command_stream_builders.cpp @@ -0,0 +1,181 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC. +// +// SPDX-License-Identifier: Apache-2.0 +/// + +#include "ttnn/cpp/ttnn/operations/ccl/common/host/ccl_command_stream_builders.hpp" + +#include "tt_metal/common/assert.hpp" + +#include "ttnn/cpp/ttnn/tensor/tensor.hpp" + +#include +#include +#include + +namespace ttnn::ccl::cmd::builder { + +std::vector generate_tensor_slices( + const size_t num_slices, const Tensor& tensor, size_t split_dim) { + return compute_page_aligned_slices(num_slices, tensor, split_dim); +} + +ttnn::ccl::v2::TensorSlice convert_to_whole_tensor_slice(const Tensor& tensor) { + return compute_page_aligned_slices(1, tensor, 0).at(0); +} + +// TENSOR MANIP +// Pairs of slice size and slice offset +std::vector> compute_evenly_split_sizes(size_t size, size_t num_slices) { + const int64_t num_larger_slices_total = size % num_slices; + const bool evenly_divisible = num_larger_slices_total == 0; + const int64_t smaller_slice_size = size / num_slices; + const int64_t larger_slice_size = smaller_slice_size + !evenly_divisible; + + auto compute_slice_dim_size = + [larger_slice_size, smaller_slice_size, num_larger_slices_total](int64_t slice_index) { + bool is_larger_slice = slice_index < num_larger_slices_total; + return is_larger_slice ? larger_slice_size : smaller_slice_size; + }; + + auto compute_slice_offset = [num_larger_slices_total, larger_slice_size, smaller_slice_size](int64_t slice_index) { + int64_t num_larger_slices = std::min(slice_index, num_larger_slices_total); + int64_t num_smaller_slices = std::min(slice_index - num_larger_slices, 0L); + return num_larger_slices * larger_slice_size + (slice_index - num_larger_slices) * smaller_slice_size; + }; + + auto compute_slice_size_and_offset = [compute_slice_dim_size, + compute_slice_offset](size_t slice_index) -> std::pair { + return {compute_slice_dim_size(slice_index), compute_slice_offset(slice_index)}; + }; + auto result = std::vector>{}; + result.reserve(num_slices); + for (size_t i = 0; i < num_slices; i++) { + result.push_back(compute_slice_size_and_offset(i)); + } + return std::vector>(result.begin(), result.end()); +} + +// // Outer vector = per worker command stream, inner vector = commands +std::vector> split_tensor_slices_across_workers_page_aligned( + size_t num_workers, std::vector const& tensor_slices) { + TT_FATAL(tensor_slices.size() > 0, "Number of slices must be greater than 0"); + // not split up across workers yet + + auto worker_slices_streams = std::vector>(num_workers); + std::ranges::for_each( + worker_slices_streams, [&tensor_slices](auto& worker_slices) { worker_slices.reserve(tensor_slices.size()); }); + + for (auto const& tensor_slice : tensor_slices) { + auto const worker_slices = split_tensor_slice_across_workers_wrapped_page_aligned(tensor_slice, num_workers); + TT_FATAL( + worker_slices.size() == num_workers, + "Expected {} worker slices for tensor slice but got {}", + num_workers, + worker_slices.size()); + for (size_t i = 0; i < num_workers; i++) { + worker_slices_streams[i].push_back(worker_slices[i]); + } + } + for (size_t i = 0; i < num_workers; i++) { + TT_FATAL( + worker_slices_streams[i].size() == tensor_slices.size(), + "Mismatch in tensor slices. Expected {} but got {}", + tensor_slices.size(), + worker_slices_streams[i].size()); + } + + return worker_slices_streams; +}; + +Shape4D from_tensor_shape(ttnn::Shape const& shape) { + constexpr size_t max_rank = 4; + TT_FATAL( + shape.size() <= max_rank, + "Reduce scatter device code only supports tensors up to rank 4. Current tensor rank is {}. The host code " + "calling the program factory must reduce the dimensionality", + shape.size()); + + Shape4D shape4d = {1, 1, 1, 1}; + size_t output_index = max_rank - 1; + for (int i = shape.size() - 1; i >= 0; --i) { + shape4d[output_index] = shape[i]; + output_index--; + } + return shape4d; +} + +static ttnn::ccl::Shape4D shape_to_shape_in_tiles(ttnn::Shape const& shape) { + auto logical_shape = shape.logical_shape(); + logical_shape[-2] /= tt::constants::TILE_HEIGHT; + logical_shape[-1] /= tt::constants::TILE_WIDTH; + TT_FATAL(logical_shape.size() == 4, "Expected 4D shape but got {}", logical_shape.size()); + ttnn::ccl::Shape4D shape_in_tiles = { + logical_shape[0], logical_shape[1], logical_shape[2], logical_shape[3]}; + return shape_in_tiles; +} + +std::vector split_tensor_slice_across_workers_wrapped_page_aligned( + ttnn::ccl::v2::TensorSlice const& tensor_slice, size_t num_workers) { + const size_t num_pages = tensor_slice.tensor_slice_shape.volume(); + + auto to_cmd_tensor = [&tensor_slice](std::pair size_offset) { + auto worker_slice = tensor_slice; + worker_slice.worker_slice_shape = {1, 1, 1, size_offset.first}; + worker_slice.worker_slice_offset = {0, 0, 0, size_offset.second}; + return worker_slice; + }; + + const auto worker_slices_view = + compute_evenly_split_sizes(num_pages, num_workers) | + std::views::transform([to_cmd_tensor](auto size_offset) { return to_cmd_tensor(size_offset); }); + + std::vector worker_slices; + worker_slices.reserve(num_workers); + std::ranges::copy(worker_slices_view, std::back_inserter(worker_slices)); + TT_FATAL( + worker_slices.size() == num_workers, "Expected {} worker slices but got {}", num_workers, worker_slices.size()); + return worker_slices; +} + +// Assumed that the tensor_slice shape is in terms of pages, not elements +std::vector compute_page_aligned_slices( + size_t const num_slices, const Tensor& input_tensor, size_t split_dim) { + TT_FATAL(num_slices > 0, "Number of slices must be greater than 0"); + std::vector tensor_slices; + + auto const input_tensor_shape_in_tiles = shape_to_shape_in_tiles(input_tensor.get_shape()); + tensor_slices.reserve(num_slices); + + // split the input tensor, by shape, into pieces + ttnn::ccl::v2::TensorSlice reference_tensor = { + input_tensor_shape_in_tiles, + input_tensor_shape_in_tiles, + {0, 0, 0, 0}, + input_tensor_shape_in_tiles, + {0, 0, 0, 0}}; + auto to_cmd_tensor = [&reference_tensor, split_dim](std::pair size_offset) { + auto cmd_tensor = reference_tensor; + cmd_tensor.tensor_slice_shape[split_dim] = size_offset.first; + cmd_tensor.tensor_slice_offset[split_dim] = size_offset.second; + return cmd_tensor; + }; + + const auto tensor_slices_view = + compute_evenly_split_sizes(input_tensor_shape_in_tiles[split_dim], num_slices) | + std::views::transform([to_cmd_tensor](auto size_offset) { return to_cmd_tensor(size_offset); }); + + std::ranges::copy(tensor_slices_view, std::back_inserter(tensor_slices)); + TT_FATAL( + tensor_slices.size() == num_slices, "Expected {} tensor slices but got {}", num_slices, tensor_slices.size()); + + return tensor_slices; +} + +std::vector> generate_worker_tensor_slices( + const size_t num_slices, const Tensor& tensor, const size_t num_workers, size_t split_dim) { + auto tensor_slices = compute_page_aligned_slices(num_slices, tensor, split_dim); + return split_tensor_slices_across_workers_page_aligned(num_workers, tensor_slices); +} + +}; // namespace ttnn::ccl::cmd::builder diff --git a/ttnn/cpp/ttnn/operations/ccl/common/host/ccl_command_stream_builders.hpp b/ttnn/cpp/ttnn/operations/ccl/common/host/ccl_command_stream_builders.hpp new file mode 100644 index 00000000000..1df0611252c --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/common/host/ccl_command_stream_builders.hpp @@ -0,0 +1,40 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC. +// +// SPDX-License-Identifier: Apache-2.0 +/// + +#pragma once + +#include "ttnn/cpp/ttnn/operations/ccl/common/uops/ccl_command.hpp" + +#include +// #include + +namespace tt::tt_metal { +class Tensor; +}; + +namespace ttnn::ccl::cmd::builder { + +std::vector generate_tensor_slices( + const size_t num_slices, const tt::tt_metal::Tensor& tensor, size_t split_dim); + +ttnn::ccl::v2::TensorSlice convert_to_whole_tensor_slice(const tt::tt_metal::Tensor& tensor); + +std::vector compute_page_aligned_slices( + size_t const num_slices, const tt::tt_metal::Tensor& input_tensor, size_t split_dim); + +// Pairs of slice size and slice offset +std::vector> compute_evenly_split_sizes(size_t size, size_t num_slices); + +// // Outer vector = per worker command stream, inner vector = commands +std::vector> split_tensor_slices_across_workers_page_aligned( + size_t num_workers, std::vector const& tensor_slices); + +std::vector split_tensor_slice_across_workers_wrapped_page_aligned( + ttnn::ccl::v2::TensorSlice const& tensor_slice, size_t num_workers); + +std::vector> generate_worker_tensor_slices( + const size_t num_slices, const tt::tt_metal::Tensor& tensor, const size_t num_workers, size_t split_dim); + +}; // namespace ttnn::ccl::cmd::builder diff --git a/ttnn/cpp/ttnn/operations/ccl/common/host/ccl_worker_builder.cpp b/ttnn/cpp/ttnn/operations/ccl/common/host/ccl_worker_builder.cpp new file mode 100644 index 00000000000..e09ddf81d93 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/common/host/ccl_worker_builder.cpp @@ -0,0 +1,1570 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include + +#include "hostdevcommon/kernel_structs.h" +#include "ttnn/cpp/ttnn/operations/ccl/common/types/ccl_types_args_emitters.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/common/uops/ccl_command.hpp" +#include "ttnn/operations/ccl/ccl_common.hpp" +#include "ttnn/operations/ccl/erisc_datamover_builder.hpp" + +#include "ttnn/cpp/ttnn/operations/ccl/common/host/ccl_worker_builder.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp" +#include "tt_metal/host_api.hpp" + +#include "ttnn/cpp/ttnn/operations/ccl/common/uops/ccl_host_commands.hpp" +#include "tt_metal/tt_stl/overloaded.hpp" + +#include +#include + +namespace ttnn::ccl::worker_detail { + +CCLWorkerArgBuilder::CCLWorkerArgBuilder( + Device const* device, + ttnn::ccl::CCLOpConfig const& op_config, + ttnn::ccl::TensorPartition const& input_tensor_partition, + ttnn::ccl::TensorPartition const& output_tensor_partition, + std::size_t operating_dim) : + device(device), + op_config(op_config), + input_tensor_partition(input_tensor_partition), + output_tensor_partition(output_tensor_partition), + operating_dim(operating_dim) {} + +Shape4D to_4d_shape(Shape4D const& shape) { return shape; } +Shape4D to_4d_offset(Shape4D const& offset) { return offset; } +size_t get_volume(Shape4D const& shape) { return shape.volume(); } + +Shape4D to_4d_shape(tt_xy_pair const& shape) { return Shape4D(1, 1, shape.y, shape.x); } +Shape4D to_4d_offset(tt_xy_pair const& offset) { return Shape4D(0, 0, offset.y, offset.x); } +size_t get_volume(tt_xy_pair const& shape) { return shape.x * shape.y; } + +template +struct tensor_slice_command_arg_field { + using type = std::nullptr_t; +}; +template <> +struct tensor_slice_command_arg_field { + static auto get_value(v2::TensorSlice const& s) { return s.tensor_shape; }; +}; +template <> +struct tensor_slice_command_arg_field { + static auto get_value(v2::TensorSlice const& s) { return s.tensor_slice_shape; }; +}; +template <> +struct tensor_slice_command_arg_field { + static auto get_value(v2::TensorSlice const& s) { return s.tensor_slice_offset; }; +}; +template <> +struct tensor_slice_command_arg_field { + static auto get_value(v2::TensorSlice const& s) { return s.worker_slice_offset; }; +}; +template <> +struct tensor_slice_command_arg_field { + static auto get_value(v2::TensorSlice const& s) { return get_volume(s.worker_slice_shape); }; +}; +template <> +struct tensor_slice_command_arg_field { + static auto get_value(v2::TensorSlice const& s) { + return ttnn::ccl::cmd::CclCommandTensor{ + to_4d_shape(s.tensor_shape), + to_4d_shape(s.tensor_slice_shape), + to_4d_offset(s.tensor_slice_offset), + to_4d_offset(s.worker_slice_offset), + get_volume(s.worker_slice_shape)}; + }; +}; + +template +void add_ccl_command_arg_to_runtime_args(v2::TensorSlice const& tensor_slice, std::vector& rt_args_out) { + rt_args_out.push_back(static_cast(arg_code)); + auto num_words_for_args = ttnn::ccl::cmd::CclCommandArg::size_in_words(); + log_trace(tt::LogOp, "Emitting {} args for tensor_shape field", num_words_for_args); + rt_args_out.resize(rt_args_out.size() + num_words_for_args); + + ttnn::ccl::cmd::CclCommandArg::pack_to( + &rt_args_out[rt_args_out.size() - num_words_for_args], + tensor_slice_command_arg_field::get_value(tensor_slice)); + + for (std::size_t j = rt_args_out.size() - num_words_for_args; j < rt_args_out.size(); j++) { + log_trace(tt::LogOp, "\t{}", rt_args_out[j]); + } +} +template <> +void add_ccl_command_arg_to_runtime_args( + v2::TensorSlice const& tensor_slice, std::vector& rt_args_out) { + rt_args_out.push_back(static_cast(ttnn::ccl::cmd::CclCommandArgCode::SET_WORKER_PAGES_PER_SLICE)); + auto num_words_for_args = 1; + log_trace(tt::LogOp, "Emitting {} args for tensor_shape field", num_words_for_args); + rt_args_out.resize(rt_args_out.size() + num_words_for_args); + + ttnn::ccl::cmd::CclCommandArg::pack_to( + &rt_args_out[rt_args_out.size() - num_words_for_args], + tensor_slice_command_arg_field::get_value( + tensor_slice)); + + for (std::size_t j = rt_args_out.size() - num_words_for_args; j < rt_args_out.size(); j++) { + log_trace(tt::LogOp, "\t{}", rt_args_out[j]); + } +} + +template +void generate_ccl_slice_sequence_commands_impl( + std::vector const& slices, + ttnn::ccl::cmd::CclCommandCode command_type, + std::vector& args_out, + ttnn::ccl::cmd::CclCommandDestArgs const& dest_args) { + for (std::size_t i = 0; i < slices.size(); i++) { + auto const& slice = slices[i]; + // Copy the header + if (i == 0) { + const std::size_t args_index_old = args_out.size(); + // push back Command Header + args_out.push_back(static_cast(ttnn::ccl::cmd::CclCommandHeader::to_uint32( + ttnn::ccl::cmd::CclCommandHeader{command_type, dest_args, 1}))); + + // push back arg 0 header + args_out.push_back( + static_cast(ttnn::ccl::cmd::CclCommandArgCode::SET_FULL_TENSOR_SLICE_SPEC_IN_PAGES)); + auto const& ccl_command_tensor = ttnn::ccl::cmd::CclCommandTensor{ + to_4d_shape(slice.tensor_shape), + to_4d_shape(slice.tensor_slice_shape), + to_4d_offset(slice.tensor_slice_offset), + to_4d_offset(slice.worker_slice_offset), + get_volume(slice.worker_slice_shape)}; + const auto num_words_for_args = ttnn::ccl::cmd::CclCommandArg< + ttnn::ccl::cmd::CclCommandArgCode::SET_FULL_TENSOR_SLICE_SPEC_IN_PAGES>::size_in_words(); + log_trace(tt::LogOp, "Emitting {} args for full tensor slice command", num_words_for_args); + args_out.resize(args_out.size() + num_words_for_args); + // push_back arg 0 payload + ttnn::ccl::cmd::CclCommandArg:: + pack_to(&args_out[args_out.size() - num_words_for_args], ccl_command_tensor); + const std::size_t args_index_new = args_out.size(); + + TT_ASSERT(i < slices.size(), "Internal Error"); + std::stringstream ss; + ss << "ccl_send command " << std::to_string(i) << " has " << args_index_new - args_index_old << " args:\n"; + for (std::size_t j = args_index_old; j < args_index_new; j++) { + ss << "\targ " << j << ":" << args_out[j] << "\n"; + } + log_trace(tt::LogOp, "{}", ss.str()); + // We can reused cached values for the first slice + } else { + auto const& last_slice = slices[i - 1]; + const std::size_t args_index_old = args_out.size(); + auto header_index = args_out.size(); + args_out.push_back(static_cast(ttnn::ccl::cmd::CclCommandHeader::to_uint32( + ttnn::ccl::cmd::CclCommandHeader{command_type, dest_args, 1}))); + std::size_t num_args = 0; + + // tensor shape + if (last_slice.tensor_shape != slice.tensor_shape) { + args_out.push_back(static_cast(ttnn::ccl::cmd::CclCommandArgCode::SET_TENSOR_SHAPE_IN_PAGES)); + auto num_words_for_args = ttnn::ccl::cmd::CclCommandArg< + ttnn::ccl::cmd::CclCommandArgCode::SET_TENSOR_SHAPE_IN_PAGES>::size_in_words(); + log_trace(tt::LogOp, "Emitting {} args for tensor_shape field", num_words_for_args); + args_out.resize(args_out.size() + num_words_for_args); + ttnn::ccl::cmd::CclCommandArg::pack_to( + &args_out[args_out.size() - num_words_for_args], to_4d_shape(slice.tensor_shape)); + for (std::size_t j = args_out.size() - num_words_for_args; j < args_out.size(); j++) { + log_trace(tt::LogOp, "\t{}", args_out[j]); + } + + num_args++; + } + + // tensor slice shape + if (last_slice.tensor_slice_shape != slice.tensor_slice_shape) { + args_out.push_back( + static_cast(ttnn::ccl::cmd::CclCommandArgCode::SET_TENSOR_SLICE_SHAPE_IN_PAGES)); + auto num_words_for_args = ttnn::ccl::cmd::CclCommandArg< + ttnn::ccl::cmd::CclCommandArgCode::SET_TENSOR_SLICE_SHAPE_IN_PAGES>::size_in_words(); + log_trace(tt::LogOp, "Emitting {} args for tensor_slice_shape field", num_words_for_args); + args_out.resize(args_out.size() + num_words_for_args); + ttnn::ccl::cmd::CclCommandArg:: + pack_to(&args_out[args_out.size() - num_words_for_args], to_4d_shape(slice.tensor_slice_shape)); + for (std::size_t i = args_out.size() - num_words_for_args; i < args_out.size(); i++) { + log_trace(tt::LogOp, "\t{}", args_out[i]); + } + + num_args++; + } + + // tensor slice offset + if (last_slice.tensor_slice_offset != slice.tensor_slice_offset) { + args_out.push_back( + static_cast(ttnn::ccl::cmd::CclCommandArgCode::SET_TENSOR_SLICE_OFFSET_IN_PAGES)); + auto num_words_for_args = ttnn::ccl::cmd::CclCommandArg< + ttnn::ccl::cmd::CclCommandArgCode::SET_TENSOR_SLICE_OFFSET_IN_PAGES>::size_in_words(); + log_trace(tt::LogOp, "Emitting {} args for tensor_slice_offset field", num_words_for_args); + args_out.resize(args_out.size() + num_words_for_args); + ttnn::ccl::cmd::CclCommandArg:: + pack_to(&args_out[args_out.size() - num_words_for_args], to_4d_offset(slice.tensor_slice_offset)); + for (std::size_t j = args_out.size() - num_words_for_args; j < args_out.size(); j++) { + log_trace(tt::LogOp, "\t{}", args_out[j]); + } + + num_args++; + } + + // worker slice offset + if (last_slice.worker_slice_offset != slice.worker_slice_offset) { + args_out.push_back(static_cast( + ttnn::ccl::cmd::CclCommandArgCode::SET_WORKER_START_OFFSET_IN_SLICE_IN_PAGES)); + auto num_words_for_args = ttnn::ccl::cmd::CclCommandArg< + ttnn::ccl::cmd::CclCommandArgCode::SET_WORKER_START_OFFSET_IN_SLICE_IN_PAGES>::size_in_words(); + log_trace(tt::LogOp, "Emitting {} args for worker_slice_offset field", num_words_for_args); + args_out.resize(args_out.size() + num_words_for_args); + ttnn::ccl::cmd::CclCommandArg< + ttnn::ccl::cmd::CclCommandArgCode::SET_WORKER_START_OFFSET_IN_SLICE_IN_PAGES>:: + pack_to(&args_out[args_out.size() - num_words_for_args], to_4d_offset(slice.worker_slice_offset)); + + for (std::size_t j = args_out.size() - num_words_for_args; j < args_out.size(); j++) { + log_trace(tt::LogOp, "\t{}", args_out[j]); + } + num_args++; + } + + // worker_pages_per_slice + if (last_slice.worker_slice_shape != slice.worker_slice_shape) { + args_out.push_back( + static_cast(ttnn::ccl::cmd::CclCommandArgCode::SET_WORKER_PAGES_PER_SLICE)); + auto num_words_for_args = ttnn::ccl::cmd::CclCommandArg< + ttnn::ccl::cmd::CclCommandArgCode::SET_WORKER_PAGES_PER_SLICE>::size_in_words(); + log_trace(tt::LogOp, "Emitting {} args for worker_pages_per_slice field", num_words_for_args); + args_out.resize(args_out.size() + num_words_for_args); + ttnn::ccl::cmd::CclCommandArg::pack_to( + &args_out[args_out.size() - num_words_for_args], get_volume(slice.worker_slice_shape)); + for (std::size_t j = args_out.size() - num_words_for_args; j < args_out.size(); j++) { + log_trace(tt::LogOp, "\t{}", args_out[j]); + } + + num_args++; + } + + args_out[header_index] = static_cast(ttnn::ccl::cmd::CclCommandHeader::to_uint32( + ttnn::ccl::cmd::CclCommandHeader{command_type, dest_args, 1})); + + std::size_t args_index_new = args_out.size(); + std::stringstream ss; + ss << "ccl_send command " << i << " has " << args_index_new - args_index_old << " args:\n"; + for (std::size_t j = args_index_old; j < args_index_new; j++) { + ss << "\targ " << j << ":" << args_out[j] << "\n"; + } + log_trace(tt::LogOp, "{}", ss.str()); + } + } +} + +/* + * Number of CCL command arguments generated - note that this does not necessarily match + * the number of runtime args generated. + */ +size_t generate_ccl_tensor_slice_command_args( + std::optional const& last_tensor_slice, + v2::TensorSlice const& current_tensor_slice, + std::vector& args_out) { + // Copy the header + std::size_t num_command_args_added = 0; + auto const args_index_old = args_out.size(); + if (!last_tensor_slice.has_value()) { + const std::size_t args_index_old = args_out.size(); + // push back Command Header + // push back arg 0 header + log_trace(tt::LogOp, "Generating full tensor spec command args"); + add_ccl_command_arg_to_runtime_args( + current_tensor_slice, args_out); + const size_t args_index_new = args_out.size(); + // We can reused cached values for the first slice + num_command_args_added++; + } else { + auto const& last_slice = last_tensor_slice.value(); + const std::size_t args_index_old = args_out.size(); + auto header_index = args_out.size(); + + // tensor shape + if (last_slice.tensor_shape != current_tensor_slice.tensor_shape) { + add_ccl_command_arg_to_runtime_args( + current_tensor_slice, args_out); + num_command_args_added++; + } + + // tensor slice shape + if (last_slice.tensor_slice_shape != current_tensor_slice.tensor_slice_shape) { + add_ccl_command_arg_to_runtime_args( + current_tensor_slice, args_out); + num_command_args_added++; + } + + // tensor slice offset + if (last_slice.tensor_slice_offset != current_tensor_slice.tensor_slice_offset) { + add_ccl_command_arg_to_runtime_args( + current_tensor_slice, args_out); + num_command_args_added++; + } + + // worker slice offset + if (last_slice.worker_slice_offset != current_tensor_slice.worker_slice_offset) { + add_ccl_command_arg_to_runtime_args< + ttnn::ccl::cmd::CclCommandArgCode::SET_WORKER_START_OFFSET_IN_SLICE_IN_PAGES>( + current_tensor_slice, args_out); + num_command_args_added++; + } + + // worker_pages_per_slice + if (last_slice.worker_slice_shape != current_tensor_slice.worker_slice_shape) { + add_ccl_command_arg_to_runtime_args( + current_tensor_slice, args_out); + num_command_args_added++; + } + } + + log_trace( + tt::LogOp, "\t{} rt_args added, {} cmd args added", args_out.size() - args_index_old, num_command_args_added); + + return num_command_args_added; +} + +// TODO: commonize with all uncached arg types (e.g. this can be commonized with atomic inc arg generation) +size_t generate_ccl_wait_value_command_args( + ttnn::ccl::cmd::CclCommandWaitValue const& wait_value_args, std::vector& args_out) { + auto const arg_code = ttnn::ccl::cmd::CclCommandArgCode::SET_TARGET_VALUE; + ttnn::ccl::cmd::CclCommandArgHeader hdr; + hdr.code = arg_code; + hdr.inline_value0 = static_cast(true); + hdr.inline_value1 = wait_value_args.target_value; + args_out.push_back(hdr.to_uint32()); + log_trace( + tt::LogOp, + "Emitting header only for for wait_value field. header.code={}, .inline_val0={}, .inline_val1={}", + static_cast(hdr.code), + hdr.inline_value0, + hdr.inline_value1); + + return 1; +} + +size_t generate_ccl_raw_inline_write_command_args( + ttnn::ccl::cmd::CclCommandInlineReadWrite const& inline_rw_args, std::vector& args_out) { + auto const arg_code = ttnn::ccl::cmd::CclCommandArgCode::SET_TARGET_VALUE; + ttnn::ccl::cmd::CclCommandArgHeader hdr; + hdr.code = arg_code; + hdr.inline_value0 = static_cast(true); + hdr.inline_value1 = inline_rw_args.value; + args_out.push_back(hdr.to_uint32()); + log_trace( + tt::LogOp, + "Emitting header only for for inline write field. header.code={}, .inline_val0={}, .inline_val1={}", + static_cast(hdr.code), + hdr.inline_value0, + hdr.inline_value1); + return 1; +} + +static size_t generate_ccl_atomic_inc_command_args( + ttnn::ccl::cmd::CclCommandAtomicInc const& atomic_inc_args, std::vector& args_out) { + auto const arg_code = ttnn::ccl::cmd::CclCommandArgCode::SET_ATOMIC_INC_VALUE; + ttnn::ccl::cmd::CclCommandArgHeader hdr; + hdr.code = arg_code; + hdr.inline_value0 = static_cast(true); + hdr.inline_value1 = atomic_inc_args.value; + TT_FATAL( + atomic_inc_args.value < std::numeric_limits::max(), + "Atomic increment value is too large: {}", + atomic_inc_args.value); + args_out.push_back(hdr.to_uint32()); + + log_trace( + tt::LogOp, + "Emitting header only for for atomic_inc field. header.code={}, .inline_val0={}, .inline_val1={}", + static_cast(hdr.code), + hdr.inline_value0, + hdr.inline_value1); + + return 1; +} + +/* + * Returns the number of ccl command args added + */ +static size_t generate_ccl_address_info_command_args( + std::optional> const& + last_addr_type, + std::pair const& current_addr_type_args, + ttnn::ccl::cmd::SRC_DEST_TYPE src_dest_type, + std::vector& args_out) { + auto requires_args_to_be_generated = [](auto const& last_addr_type, auto const& current_addr_type_args) { + bool different_type_or_args = !last_addr_type.has_value(); + different_type_or_args = + different_type_or_args || (last_addr_type.value().first != current_addr_type_args.first); + different_type_or_args = + different_type_or_args || (last_addr_type.value().second.index() != current_addr_type_args.second.index()); + if (different_type_or_args) { + return true; + } + if (std::holds_alternative(current_addr_type_args.second)) { + auto const& last_semaphore_id = + std::get(last_addr_type.value().second); + auto const& current_semaphore_id = + std::get(current_addr_type_args.second); + return last_semaphore_id.semaphore_id != current_semaphore_id.semaphore_id; + } + if (std::holds_alternative(current_addr_type_args.second)) { + auto const& last_circular_buffer_id = + std::get(last_addr_type.value().second); + auto const& current_circular_buffer_id = + std::get(current_addr_type_args.second); + return last_circular_buffer_id.circular_buffer_id != current_circular_buffer_id.circular_buffer_id; + } + if (std::holds_alternative(current_addr_type_args.second)) { + auto const& last_absolute_address = + std::get(last_addr_type.value().second); + auto const& current_absolute_address = + std::get(current_addr_type_args.second); + return last_absolute_address.absolute_address != current_absolute_address.absolute_address; + } + if (std::holds_alternative(current_addr_type_args.second)) { + auto const& last_relative_address = + std::get(last_addr_type.value().second); + auto const& current_relative_address = + std::get(current_addr_type_args.second); + return last_relative_address.relative_address != current_relative_address.relative_address; + } + if (std::holds_alternative(current_addr_type_args.second)) { + return false; + } + return true; + }; + + size_t num_ccl_command_args_added = 0; + if (requires_args_to_be_generated(last_addr_type, current_addr_type_args)) { + const size_t header_index = args_out.size(); + args_out.push_back(0); + num_ccl_command_args_added++; + ttnn::ccl::cmd::CclCommandArgHeader header; + header.code = ttnn::ccl::cmd::CclCommandArgCode::SET_ADDRESS_INFO; + if (std::holds_alternative(current_addr_type_args.second)) { + log_trace(tt::LogOp, "Emitting {} args for absolute_address field", 2); + header.inline_value0 = src_dest_type; + header.inline_value1 = static_cast(ttnn::ccl::cmd::CclCommandAddrType::ABSOLUTE_ADDRESS); + + auto const& absolute_address = + std::get(current_addr_type_args.second); + args_out.push_back(absolute_address.absolute_address); + } else if (std::holds_alternative( + current_addr_type_args.second)) { + log_trace(tt::LogOp, "Emitting {} args for relative_address field at index {}", 2, header_index); + header.inline_value0 = src_dest_type; + header.inline_value1 = static_cast(ttnn::ccl::cmd::CclCommandAddrType::RELATIVE_ADDRESS); + + auto const& relative_address = + std::get(current_addr_type_args.second); + args_out.push_back(relative_address.relative_address); + } else if (std::holds_alternative(current_addr_type_args.second)) { + log_trace(tt::LogOp, "Emitting {} args for semaphore_id field at index {}", 1, header_index); + header.inline_value0 = src_dest_type; + header.inline_value1 = static_cast(ttnn::ccl::cmd::CclCommandAddrType::SEMAPHORE_ID); + + auto const& semaphore_id = + std::get(current_addr_type_args.second); + header.inline_value2 = semaphore_id.semaphore_id; + } else if (std::holds_alternative( + current_addr_type_args.second)) { + log_trace(tt::LogOp, "Emitting {} args for circular_buffer_id field at index {}", 1, header_index); + header.inline_value0 = src_dest_type; + header.inline_value1 = static_cast(ttnn::ccl::cmd::CclCommandAddrType::CIRCULAR_BUFFER_ID); + + auto const& circular_buffer_id = + std::get(current_addr_type_args.second); + header.inline_value2 = circular_buffer_id.circular_buffer_id; + } else if (std::holds_alternative(current_addr_type_args.second)) { + log_trace(tt::LogOp, "Emitting {} args for NONE addr field at index {}", 1, header_index); + header.inline_value0 = src_dest_type; + header.inline_value1 = static_cast(ttnn::ccl::cmd::CclCommandAddrType::NONE); + // do nothing + } else { + TT_THROW("Unsupported address type: {}", static_cast(current_addr_type_args.first)); + } + log_trace( + tt::LogOp, + "\theader.code={}, .inline_val0={}, .inline_val1={}, .inline_val2={}", + static_cast(header.code), + header.inline_value0, + header.inline_value1, + header.inline_value2); + args_out[header_index] = header.to_uint32(); + } + + return num_ccl_command_args_added; +} + +size_t generate_ccl_core_descriptor_info_command_args( + std::optional< + std::pair> const& + last_core_descriptor, + std::pair const& + current_core_descriptor, + std::vector& args_out) { + size_t num_ccl_command_args_added = 0; + bool requires_update_to_args = + !last_core_descriptor.has_value() || (last_core_descriptor.value().first != current_core_descriptor.first); + requires_update_to_args = requires_update_to_args || + (last_core_descriptor.value().second.index() != current_core_descriptor.second.index()); + if (!requires_update_to_args) { + if (std::holds_alternative( + current_core_descriptor.second)) { + requires_update_to_args = false; + } else if (std::holds_alternative( + current_core_descriptor.second)) { + requires_update_to_args = true; + } else if (std::holds_alternative( + current_core_descriptor.second)) { + auto const& last_noc_xy = + std::get(last_core_descriptor.value().second); + auto const& current_noc_xy = + std::get(current_core_descriptor.second); + requires_update_to_args = (last_noc_xy.x != current_noc_xy.x) || (last_noc_xy.y != current_noc_xy.y); + } else if (std::holds_alternative( + current_core_descriptor.second)) { + auto const& last_rectangle = + std::get(last_core_descriptor.value().second); + auto const& current_rectangle = + std::get(current_core_descriptor.second); + requires_update_to_args = (last_rectangle.noc0_start_x != current_rectangle.noc0_start_x) || + (last_rectangle.noc0_start_y != current_rectangle.noc0_start_y) || + (last_rectangle.noc0_end_x != current_rectangle.noc0_end_x) || + (last_rectangle.noc0_end_y != current_rectangle.noc0_end_y); + } + } + if (requires_update_to_args) { + const size_t header_index = args_out.size(); + log_trace(tt::LogOp, "Emitting {} args for core_descriptor field at index {}", 1, header_index); + args_out.push_back(0); + ttnn::ccl::cmd::CclCommandArgHeader hdr; + hdr.code = ttnn::ccl::cmd::CclCommandArgCode::SET_CORE_DESCRIPTOR_INFO; + hdr.inline_value0 = static_cast(current_core_descriptor.first); + if (std::holds_alternative(current_core_descriptor.second)) { + auto const& noc_xy = + std::get(current_core_descriptor.second); + hdr.inline_value1 = noc_xy.x; + hdr.inline_value2 = noc_xy.y; + } else if (std::holds_alternative( + current_core_descriptor.second)) { + auto const& rectangle = + std::get(current_core_descriptor.second); + args_out.push_back(rectangle.to_uint32()); + } + log_trace( + tt::LogOp, + "\theader.code={}, .inline_val0={}, .inline_val1={}, .inline_val2={}", + static_cast(hdr.code), + hdr.inline_value0, + hdr.inline_value1, + hdr.inline_value2); + args_out[header_index] = hdr.to_uint32(); + num_ccl_command_args_added++; + } + return num_ccl_command_args_added; +} + +void validate_ccl_command_dest_args(ttnn::ccl::cmd::CclCommandDestArgs const& dest_args) { + bool valid = std::holds_alternative(dest_args) || + std::holds_alternative(dest_args) || + std::holds_alternative(dest_args); + if (!valid) { + TT_THROW( + "Unsupported CCL command dest args. Expected one of UnicastCommandDestArgs, MulticastCommandDestArgs, or " + "LocalOnlyCommandDestArgs"); + } +} +void validate_ccl_command_dest_type(ttnn::ccl::cmd::CclCommandDestType dest_type) { + bool valid = dest_type == ttnn::ccl::cmd::CclCommandDestType::CHIP_UNICAST || + dest_type == ttnn::ccl::cmd::CclCommandDestType::CHIP_MULTICAST || + dest_type == ttnn::ccl::cmd::CclCommandDestType::CHIP_LOCAL_ONLY; + if (!valid) { + TT_THROW("Unsupported CCL command dest type: {}", static_cast(dest_type)); + } +} + +void validate_command(ttnn::ccl::cmd::CclHostLowLevelWorkerCommand const& command) { + validate_ccl_command_dest_type(command.fabric_transfer_type); + validate_ccl_command_dest_args(command.fabric_transfer_args); +} + +void generate_ccl_command_stream_to_kernel_args( + std::vector const& ccl_command_stream, + std::vector& rt_args_out) { + std::optional last_tensor_slice = std::nullopt; + std::optional> + last_src_addr_type = std::nullopt; + std::optional> + last_dest_addr_type = std::nullopt; + std::optional> + last_core_descriptor = std::nullopt; + + log_trace(tt::LogOp, "Generating CCL command stream to kernel args, starting at index {}", rt_args_out.size()); + + for (size_t i = 0; i < ccl_command_stream.size(); i++) { + log_trace(tt::LogOp, "New command starting at arg idx: {}", rt_args_out.size()); + auto const& command = ccl_command_stream[i]; + validate_command(command); + + // Set aside the placeholder rt arg for the command header + const size_t command_header_rt_arg_index = rt_args_out.size(); + static_assert(sizeof(ttnn::ccl::cmd::CclCommandHeader) == sizeof(uint32_t)); + const size_t old_rt_args_start_index = rt_args_out.size(); + rt_args_out.push_back(0); + + // populate the body (ccl command args)of the command + size_t num_ccl_command_args_added = 0; + switch (command.command_code) { + case ttnn::ccl::cmd::CclCommandCode::STREAM_CB_TO_TENSOR: + case ttnn::ccl::cmd::CclCommandCode::STREAM_TENSOR_TO_CB: { + auto const& current_tensor_slice = + std::get(command.command_args); + num_ccl_command_args_added += + generate_ccl_tensor_slice_command_args(last_tensor_slice, current_tensor_slice, rt_args_out); + last_tensor_slice = current_tensor_slice; + } break; + + case ttnn::ccl::cmd::CclCommandCode::RAW_INLINE_WRITE_BYTES: + num_ccl_command_args_added += generate_ccl_raw_inline_write_command_args( + std::get(command.command_args), rt_args_out); + break; + + case ttnn::ccl::cmd::CclCommandCode::ATOMIC_INC: + num_ccl_command_args_added += generate_ccl_atomic_inc_command_args( + std::get(command.command_args), rt_args_out); + break; + case ttnn::ccl::cmd::CclCommandCode::WAIT_VALUE: + num_ccl_command_args_added += generate_ccl_wait_value_command_args( + std::get(command.command_args), rt_args_out); + break; + + case ttnn::ccl::cmd::CclCommandCode::STREAM_EDM_TO_TENSOR: + TT_THROW( + "CCL command STREAM_EDM_TO_TENSOR is not useable, supported, or intended to be supported in CCL " + "v2. This command is deprecated."); + break; + TT_THROW( + "CCL command STREAM_TENSOR_TO_EDM is not useable, supported, or intended to be supported in CCL " + "v2. This command is deprecated."); + break; + + default: + TT_THROW("Unsupported CCL command code: {}. Support missing", static_cast(command.command_code)); + break; + } + + // populate the src_addr_type + num_ccl_command_args_added += generate_ccl_address_info_command_args( + last_src_addr_type, + {command.source_addr_type, command.source_addr_args}, + ttnn::ccl::cmd::SRC_DEST_TYPE::SRC, + rt_args_out); + last_src_addr_type = {command.source_addr_type, command.source_addr_args}; + + // populate the dest_addr_type + num_ccl_command_args_added += generate_ccl_address_info_command_args( + last_dest_addr_type, + {command.dest_addr_type, command.dest_addr_args}, + ttnn::ccl::cmd::SRC_DEST_TYPE::DEST, + rt_args_out); + last_dest_addr_type = {command.dest_addr_type, command.dest_addr_args}; + // populate the core_desc_type + num_ccl_command_args_added += generate_ccl_core_descriptor_info_command_args( + last_core_descriptor, {command.core_desc_type, command.core_desc_args}, rt_args_out); + last_core_descriptor = {command.core_desc_type, command.core_desc_args}; + + // populate the fabric_transfer_type + // Handled by header + log_trace( + tt::LogOp, + "Emitting command_header at index {}. code={}. fabric_transfer_type={}", + command_header_rt_arg_index, + command.command_code, + command.fabric_transfer_type); + TT_FATAL(command.command_code != ttnn::ccl::cmd::CclCommandCode::INVALID, "Invalid command code"); + rt_args_out[command_header_rt_arg_index] = + static_cast(ttnn::ccl::cmd::CclCommandHeader::to_uint32(ttnn::ccl::cmd::CclCommandHeader{ + command.command_code, + command.fabric_transfer_args, + num_ccl_command_args_added, + })); + TT_FATAL( + ttnn::ccl::cmd::CclCommandHeader::from_uint32(rt_args_out[command_header_rt_arg_index]).code != + ttnn::ccl::cmd::CclCommandCode::INVALID, + "Invalid command code"); + + const size_t new_rt_args_start_index = rt_args_out.size(); + std::stringstream ss; + ss << "ccl_send command " << i << " has " << new_rt_args_start_index - old_rt_args_start_index + << " args starting at arg index: " << old_rt_args_start_index << "\n"; + for (std::size_t j = old_rt_args_start_index; j < new_rt_args_start_index; j++) { + ss << "\targ " << j << ":" << rt_args_out[j] << "\n"; + } + log_trace(tt::LogOp, "{}", ss.str()); + } +} + +void generate_ccl_slice_sequence_commands( + std::vector const& slices, + ttnn::ccl::cmd::CclCommandCode command_type, + std::vector& args_out) { + generate_ccl_slice_sequence_commands_impl( + slices, command_type, args_out, ttnn::ccl::cmd::LocalOnlyCommandDestArgs{}); +} +void generate_ccl_slice_sequence_commands( + std::vector const& slices, + ttnn::ccl::cmd::CclCommandCode command_type, + std::vector& args_out, + ttnn::ccl::cmd::CclCommandDestArgs const& dest_args) { + generate_ccl_slice_sequence_commands_impl(slices, command_type, args_out, dest_args); +} + +void emit_ccl_send_slice_sequence_commands(std::vector const& slices, std::vector& args_out) { + generate_ccl_slice_sequence_commands(slices, ttnn::ccl::cmd::CclCommandCode::STREAM_TENSOR_TO_EDM, args_out); +} +void generate_ccl_read_to_cb_slice_sequence_commands( + std::vector const& slices, + std::vector& args_out, + ttnn::ccl::cmd::CclCommandDestArgs const& dest_args) { + generate_ccl_slice_sequence_commands( + slices, ttnn::ccl::cmd::CclCommandCode::STREAM_TENSOR_TO_CB, args_out, dest_args); +} +void generate_ccl_cb_to_tensor_slice_sequence_commands( + std::vector const& slices, + std::vector& args_out, + ttnn::ccl::cmd::CclCommandDestArgs const& dest_args) { + generate_ccl_slice_sequence_commands( + slices, ttnn::ccl::cmd::CclCommandCode::STREAM_CB_TO_TENSOR, args_out, dest_args); +} + +KernelHandle generate_multi_command_stream_kernel_ct_args( + Program& program, + std::vector const& cb_indices, // TODO: move to RT arg + std::vector const& tensors, + CoreRangeSet const& worker_core_range, + DataMovementConfig datamovement_kernel_config, + const size_t num_command_streams, + std::optional my_chip_id) { + TT_FATAL(cb_indices.size() == tensors.size(), "Number of CB indices must match number of tensors"); + TT_FATAL( + num_command_streams > 0 && num_command_streams <= 2, + "Invalid number of command streams: {}. Must be 1 or 2", + num_command_streams); + + log_trace(tt::LogOp, "Generating multi command stream kernel CT args"); + + std::ranges::for_each(tensors, [](auto const& t) { + TT_FATAL(t != nullptr, "Null tensor passed to generate_multi_command_stream_kernel_ct_args"); + }); + if (tensors.size() > 0 && tensors[0]->is_sharded()) { + datamovement_kernel_config.defines["TENSOR0_SHARDED_MEM_LAYOUT"] = "1"; + } + if (tensors.size() > 1 && tensors[1]->is_sharded()) { + datamovement_kernel_config.defines["TENSOR1_SHARDED_MEM_LAYOUT"] = "1"; + } + if (num_command_streams == 1) { + // single input so we need to disable the second one + datamovement_kernel_config.defines["SINGLE_INPUT_MODE"] = "1"; + } + if (tensors.size() == 2) { + datamovement_kernel_config.defines["TWO_TENSOR"] = "1"; + } else if (tensors.size() == 1) { + datamovement_kernel_config.defines["SINGLE_TENSOR"] = "1"; + } else { + datamovement_kernel_config.defines["NO_TENSOR_MODE"] = "1"; + } + if (datamovement_kernel_config.defines.size() > 0) { + log_trace(tt::LogOp, "Command Kernel Defines:"); + for (auto const& [k, v] : datamovement_kernel_config.defines) { + log_trace(tt::LogOp, "\t{}: {}", k, v); + } + } + + for (auto i : cb_indices) { + TT_FATAL( + i != tt::CB::c_in7 && i != tt::CB::c_in6, + "Command processor kernel reserves cb in7 for use but user specified CBs included it. Please choose " + "another CB besides c_in6 and c_in7."); + } + + // Set aside a buffer we can use for storing packet headers in (particularly for atomic incs) + const auto reserved_packet_header_CB_index = + datamovement_kernel_config.processor == DataMovementProcessor::RISCV_0 ? tt::CB::c_in6 : tt::CB::c_in7; + static constexpr auto num_packet_headers_storable = 8; + static constexpr auto packet_header_size_bytes = sizeof(tt::fabric::PacketHeader); + tt::tt_metal::CircularBufferConfig cb_config = + tt::tt_metal::CircularBufferConfig( + num_packet_headers_storable * packet_header_size_bytes * 2, + {{reserved_packet_header_CB_index, tt::DataFormat::RawUInt32}}) + .set_page_size(reserved_packet_header_CB_index, packet_header_size_bytes); + log_trace( + tt::LogOp, + "Setting up reserved packet header CB for {} processor at CB index {} of size {} and page size {}. Core range: " + "{}", + datamovement_kernel_config.processor, + reserved_packet_header_CB_index, + num_packet_headers_storable * packet_header_size_bytes, + packet_header_size_bytes, + worker_core_range); + auto reserved_packet_header_CB_handle = CreateCircularBuffer(program, worker_core_range, cb_config); + + { // CT ARGS + std::vector ct_args = {my_chip_id.value_or(0xFFFF), reserved_packet_header_CB_index}; + for (size_t i = 0; i < tensors.size(); i++) { + std::ranges::copy( + std::array{ + static_cast( + tensors[i]->buffer()->buffer_layout()), // TODO: refactor out to generate_tensor_ct_args + static_cast(tensors[i]->buffer()->buffer_type()), + static_cast(tensors[i]->layout()), + static_cast(cb_indices[i])}, + std::back_inserter(ct_args)); + } + for (size_t i = 0; i < tensors.size(); i++) { + std::ranges::copy( + ttnn::ccl::emit_address_generator_compile_time_args(*tensors[i]), std::back_inserter(ct_args)); + } + + datamovement_kernel_config.compile_args = ct_args; + log_trace(tt::LogOp, "\tSenderReader Kernel Defines"); + for (auto const& [k, v] : datamovement_kernel_config.defines) { + log_trace(tt::LogOp, "\t\t{}: {}", k, v); + } + log_trace(tt::LogOp, "\tSenderReader CT Args"); + for (size_t i = 0; i < ct_args.size(); i++) { + auto const& arg = ct_args[i]; + log_trace(tt::LogOp, "\t\t{}: {}", i, arg); + } + } + auto sender_worker_reader_kernel = tt::tt_metal::CreateKernel( + program, + "ttnn/cpp/ttnn/operations/ccl/common/kernels/ccl_send_reader_two_input.cpp", + worker_core_range, + datamovement_kernel_config); + + return sender_worker_reader_kernel; +} + +static void log_command_stream(ttnn::ccl::cmd::CclHostLowLevelCommandSequence const& commands, size_t tab_level = 0) { + using namespace ttnn::ccl; + using namespace ttnn::ccl::cmd; + size_t index = 0; + for (auto const& c : commands) { + index++; + std::stringstream tabs_ss; + for (size_t i = 0; i < tab_level; i++) { + tabs_ss << "\t"; + } + + auto get_addr_args_str = [](std::stringstream& ss, CclCommandAddrArgs const& args) { + std::visit( + tt::stl::overloaded{ + [&ss](CclCommandAddrRelativeAddress const& a) { + ss << fmt::format("(relative_address:{})", a.relative_address); + }, + [&ss](CclCommandAddrAbsoluteAddress const& a) { + ss << fmt::format("(absolute_address:{})", a.absolute_address); + }, + [&ss](CclCommandAddrSemaphoreId const& a) { + ss << fmt::format("(semaphore_id:{})", a.semaphore_id); + }, + [&ss](CclCommandAddrCircularBufferId const& a) { + ss << fmt::format("(circular_buffer_id:{})", a.circular_buffer_id); + }, + [&ss](CclCommandAddrNone const& a) { ss << "none"; }}, + args); + }; + auto get_cmd_args_str = [](std::stringstream& ss, CclCommandArgs const& args) { + std::visit( + tt::stl::overloaded{ + [&ss](CclCommandStreamTensorSlice const& a) { + ss << fmt::format( + "(shape: (w:{},z:{},y:{},x:{}), slice_shape: (w:{},z:{},y:{},x:{}), slice_offset: " + "(w:{},z:{},y:{},x:{}), worker_slice_shape: (w:{},z:{},y:{},x:{}), worker_slice_offset: " + "(w:{},z:{},y:{},x:{}))", + a.tensor_shape.w, + a.tensor_shape.z, + a.tensor_shape.y, + a.tensor_shape.x, + a.tensor_slice_shape.w, + a.tensor_slice_shape.z, + a.tensor_slice_shape.y, + a.tensor_slice_shape.x, + a.tensor_slice_offset.w, + a.tensor_slice_offset.z, + a.tensor_slice_offset.y, + a.tensor_slice_offset.x, + a.worker_slice_shape.w, + a.worker_slice_shape.z, + a.worker_slice_shape.y, + a.worker_slice_shape.x, + a.worker_slice_offset.w, + a.worker_slice_offset.z, + a.worker_slice_offset.y, + a.worker_slice_offset.x); + }, + [&ss](CclCommandAtomicInc const& a) { + ss << fmt::format("(val:{}, wrap: {})", a.value, a.wrap_value); + }, + [&ss](CclCommandWaitValue const& a) { ss << fmt::format("(wait_value: {})", a.target_value); }, + [&ss](CclCommandInlineReadWrite const& a) { ss << fmt::format("(value: {})", a.value); }, + [&ss](CclCommandReadWrite const& a) { ss << fmt::format("(size_bytes: {})", a.size_bytes); }, + [&ss](auto const&&) { ss << "ERROR"; }}, + args); + }; + + auto get_core_desc_args_str = [](std::stringstream& ss, CclCommandCoreDescriptorArgs const& args) { + std::visit( + tt::stl::overloaded{ + [&ss](CclCommandCoreDescriptorTypeAddrgen const& a) { ss << fmt::format("(addrgen)"); }, + [&ss](CclCommandCoreDescriptorTypeLocal const& a) { ss << fmt::format("(local_core)"); }, + [&ss](CclCommandCoreDescriptorTypeNocXY const& a) { ss << fmt::format("(x:{}, y:{})", a.x, a.y); }, + [&ss](CclCommandCoreDescriptorTypeMcast const& a) { + ss << fmt::format( + "(noc0_start_x:{}, noc0_start_y:{}, noc0_end_x:{}, noc0_end_y:{})", + a.noc0_start_x, + a.noc0_start_y, + a.noc0_end_x, + a.noc0_end_y); + }, + }, + args); + }; + + auto get_fabric_transfer_args_str = [](std::stringstream& ss, CclCommandDestArgs const& args) { + std::visit( + tt::stl::overloaded{ + [&ss](UnicastCommandDestArgs const& a) { + ss << fmt::format( + "(distance_in_hops:{}, is_forward_direction:{})", + a.distance_in_hops, + a.is_forward_direction); + }, + [&ss](MulticastCommandDestArgs const& a) { + ss << fmt::format( + "(num_targets_forward_direction:{}, num_targets_backward_direction:{})", + a.num_targets_forward_direction, + a.num_targets_backward_direction); + }, + [&ss](LocalOnlyCommandDestArgs const& a) { ss << fmt::format("(None)"); }, + }, + args); + }; + + std::stringstream cmd_attrs_ss; + std::stringstream src_attrs_ss; + std::stringstream dest_attrs_ss; + std::stringstream core_attrs_ss; + std::stringstream fabric_attrs_ss; + get_addr_args_str(src_attrs_ss, c.source_addr_args); + get_addr_args_str(dest_attrs_ss, c.dest_addr_args); + get_core_desc_args_str(core_attrs_ss, c.core_desc_args); + get_fabric_transfer_args_str(fabric_attrs_ss, c.fabric_transfer_args); + get_cmd_args_str(cmd_attrs_ss, c.command_args); + + log_trace( + tt::LogOp, + "{}{}. SRC({})[{}] -> CMD({})[{}] -> DST({})[{}]; CORE({})[{}]; FABRIC({})[{}]", + tabs_ss.str(), + index, + c.source_addr_type, + src_attrs_ss.str(), + c.command_code, + cmd_attrs_ss.str(), + c.dest_addr_type, + dest_attrs_ss.str(), + c.core_desc_type, + core_attrs_ss.str(), + c.fabric_transfer_type, + fabric_attrs_ss.str()); + } +} + +void generate_multi_input_command_stream_kernel_rt_args( + Program& program, + KernelHandle kernel_id, + std::vector const& tensors, + std::vector const& page_sizes, + Device* device, + uint32_t num_pages_per_edm_buffer, // TODO: get from fabric + CoreRangeSet const& worker_core_range, + ttnn::ccl::cmd::CclHostLowLevelCommandSequence const& ccl_command_stream0, + std::optional const& ccl_command_stream1, + std::optional const& forward_fabric_connections, + std::optional const& backward_fabric_connections, + std::optional> const& tensor_device_override) { + // TODO: see if we can pull the kernel defines to understand if we built the kernel in single command stream mode + log_trace( + tt::LogOp, + "Generating multi command stream kernel RT args for kernel {} on core(s): {}", + kernel_id, + worker_core_range); + log_trace(tt::LogOp, "Command stream 0:"); + log_command_stream(ccl_command_stream0, 1); + if (ccl_command_stream1) { + log_trace(tt::LogOp, "Command stream 1:"); + log_command_stream(ccl_command_stream1.value(), 1); + } + + std::vector*> command_streams = { + &ccl_command_stream0}; + if (ccl_command_stream1.has_value()) { + command_streams.push_back(&ccl_command_stream1.value()); + } + + // RT ARGS + const size_t num_command_streams = command_streams.size(); + TT_FATAL( + tensors.size() <= num_command_streams, + "Current CCL Command Processor kernel only supports a 1-to-1 mapping between command streams and tensors. " + "Switching between tensors within a command stream is future work"); + TT_FATAL(page_sizes.size() == tensors.size(), "Number of page sizes must match with the number of tensors"); + auto command_stream_start_arg_indices = std::vector(num_command_streams, 0); + std::vector rt_args; + rt_args.reserve(100); + for (size_t i = 0; i < tensors.size(); i++) { + if (tensors[i]) { + rt_args.push_back(tensors[i]->buffer()->address()); + } else { + // take up the rt arg with filler value in case user built a kernel across a core range + // set with multiple command streams/tensors, but this particular core doesn't actualy need/use + // both tensors/command streams + rt_args.push_back(0xdeaddead); + } + } + for (size_t i = 0; i < num_command_streams; i++) { + rt_args.push_back(command_streams[i]->size()); // in0_read_command_slices + command_stream_start_arg_indices[i] = rt_args.size(); + rt_args.push_back(0); // in0_command_start_offset + } + rt_args.push_back(num_pages_per_edm_buffer); + TT_FATAL(tensors.size() == page_sizes.size(), "Number of pages must match with the number of tensors"); + for (size_t i = 0; i < tensors.size(); i++) { + if (tensors[i]) { + rt_args.push_back(page_sizes[i]); // in0 + } else { + rt_args.push_back(0xdeaddead); + } + } + + for (Tensor const* t : tensors) { + if (t) { + if (tensor_device_override.has_value() and + tensor_device_override.value().find(t) != tensor_device_override.value().end()) { + std::ranges::copy( + ttnn::ccl::emit_address_generator_runtime_args(tensor_device_override->at(t), *t), + std::back_inserter(rt_args)); + } else { + std::ranges::copy( + ttnn::ccl::emit_address_generator_runtime_args(t->buffer()->device(), *t), + std::back_inserter(rt_args)); + } + } + // else: Interleaved addrgen passes no additional args - we specify interleaved addrgen as the default + } + + rt_args.push_back(forward_fabric_connections.has_value()); + if (forward_fabric_connections.has_value()) { + auto sender_worker_flow_control_semaphore_id = CreateSemaphore(program, worker_core_range, 0); + auto sender_worker_buffer_index_semaphore_id = CreateSemaphore(program, worker_core_range, 0); + append_worker_to_fabric_edm_sender_rt_args( + forward_fabric_connections.value(), + sender_worker_flow_control_semaphore_id, + sender_worker_buffer_index_semaphore_id, + rt_args); + } + rt_args.push_back(backward_fabric_connections.has_value()); + if (backward_fabric_connections.has_value()) { + auto sender_worker_flow_control_semaphore_id = CreateSemaphore(program, worker_core_range, 0); + auto sender_worker_buffer_index_semaphore_id = CreateSemaphore(program, worker_core_range, 0); + append_worker_to_fabric_edm_sender_rt_args( + backward_fabric_connections.value(), + sender_worker_flow_control_semaphore_id, + sender_worker_buffer_index_semaphore_id, + rt_args); + } + + for (size_t i = 0; i < num_command_streams; i++) { + // Update the command stream start arg index argument to point to here (i.e. where + // this command stream's commands will start) + rt_args[command_stream_start_arg_indices[i]] = rt_args.size(); + generate_ccl_command_stream_to_kernel_args((*command_streams[i]), rt_args); + } + + log_trace(tt::LogOp, "\tMulti-input command processor RT Args"); + for (size_t i = 0; i < rt_args.size(); i++) { + auto const& arg = rt_args[i]; + log_trace(tt::LogOp, "\t\t{}: {}", i, arg); + } + tt::tt_metal::SetRuntimeArgs(program, kernel_id, worker_core_range, rt_args); +} + +void generate_multi_command_stream_kernel_rt_args( + Program& program, + KernelHandle kernel_id, + std::vector const& cb_ids, + std::vector const& tensors, + Device* device, + uint32_t page_size, // TODO: get from tensors + CoreRangeSet const& worker_core_range, + uint32_t num_pages_per_edm_buffer, // TODO: get from fabric + std::vector> const& command_tensor_slices, + ttnn::ccl::cmd::CclCommandCode command_type, // TODAY REQURED TO BE SAME - FUTURE - wrapped with above + std::optional const& forward_fabric_connections, + std::optional const& backward_fabric_connections, + std::optional> const& edm_termination_infos, + std::vector const& dest_args) { + for (size_t i = 0; i < tensors.size(); i++) { + TT_FATAL(tensors[i] != nullptr, "Tensor at index {} is nullptr", i); + } + + TT_FATAL(tensors.size() > 0 && tensors.size() <= 2, "Size mismatch between tensors and cb_ids"); + TT_FATAL(cb_ids.size() == tensors.size(), "Size mismatch between cb_ids and tensors"); + TT_FATAL(command_tensor_slices.size() == tensors.size(), "Size mismatch between command_tensor_slices and tensors"); + TT_FATAL(dest_args.size() == tensors.size(), "Size mismatch between dest_args and tensors"); + // RT ARGS + const size_t num_command_streams = tensors.size(); + auto command_stream_start_arg_indices = std::vector(num_command_streams, 0); + std::vector rt_args; + rt_args.reserve(100); + for (size_t i = 0; i < num_command_streams; i++) { + rt_args.push_back(tensors[i]->buffer()->address()); + } + for (size_t i = 0; i < num_command_streams; i++) { + rt_args.push_back(command_tensor_slices[i].size()); // input_tensor_0_read_command_slices + command_stream_start_arg_indices[i] = rt_args.size(); + rt_args.push_back(0); // in0_command_start_offset + } + rt_args.push_back(num_pages_per_edm_buffer); + for (size_t i = 0; i < num_command_streams; i++) { + rt_args.push_back(page_size); // in0 + } + + for (size_t i = 0; i < num_command_streams; i++) { + std::ranges::copy( + ttnn::ccl::emit_address_generator_runtime_args(device, *tensors[i]), std::back_inserter(rt_args)); + } + + // TODO: Handle teardown signalling + rt_args.push_back(forward_fabric_connections.has_value()); + if (forward_fabric_connections.has_value()) { + auto sender_worker_flow_control_semaphore_id = CreateSemaphore(program, worker_core_range, 0); + auto sender_worker_buffer_index_semaphore_id = CreateSemaphore(program, worker_core_range, 0); + append_worker_to_fabric_edm_sender_rt_args( + forward_fabric_connections.value(), + sender_worker_flow_control_semaphore_id, + sender_worker_buffer_index_semaphore_id, + rt_args); + } + rt_args.push_back(backward_fabric_connections.has_value()); + if (backward_fabric_connections.has_value()) { + auto sender_worker_flow_control_semaphore_id = CreateSemaphore(program, worker_core_range, 0); + auto sender_worker_buffer_index_semaphore_id = CreateSemaphore(program, worker_core_range, 0); + append_worker_to_fabric_edm_sender_rt_args( + backward_fabric_connections.value(), + sender_worker_flow_control_semaphore_id, + sender_worker_buffer_index_semaphore_id, + rt_args); + } + size_t fabric_teardown_arg_idx = 0; + if (edm_termination_infos.has_value()) { + fabric_teardown_arg_idx = rt_args.size(); + rt_args.push_back(0); + } + + switch (command_type) { + case ttnn::ccl::cmd::CclCommandCode::STREAM_TENSOR_TO_CB: + for (size_t i = 0; i < num_command_streams; i++) { + rt_args[command_stream_start_arg_indices[i]] = rt_args.size(); + ttnn::ccl::worker_detail::generate_ccl_read_to_cb_slice_sequence_commands( + command_tensor_slices[i], rt_args, dest_args[i]); + } + break; + case ttnn::ccl::cmd::CclCommandCode::STREAM_CB_TO_TENSOR: + for (size_t i = 0; i < num_command_streams; i++) { + rt_args[command_stream_start_arg_indices[i]] = rt_args.size(); + ttnn::ccl::worker_detail::generate_ccl_cb_to_tensor_slice_sequence_commands( + command_tensor_slices[i], rt_args, dest_args[i]); + } + break; + + case ttnn::ccl::cmd::CclCommandCode::STREAM_EDM_TO_TENSOR: + case ttnn::ccl::cmd::CclCommandCode::INVALID: + default: TT_ASSERT(false); + }; + + if (edm_termination_infos.has_value()) { + rt_args[fabric_teardown_arg_idx] = rt_args.size(); + get_runtime_args_for_edm_termination_infos(edm_termination_infos.value(), rt_args); + } + + log_trace(tt::LogOp, "\tMulti-input command processor RT Args"); + for (size_t i = 0; i < rt_args.size(); i++) { + auto const& arg = rt_args[i]; + log_trace(tt::LogOp, "\t\t{}: {}", i, arg); + } + tt::tt_metal::SetRuntimeArgs(program, kernel_id, worker_core_range, rt_args); +} + +ttnn::ccl::cmd::CclHostLowLevelCommandSequence build_ccl_cmd_proc_teardown_commands( + tt::tt_metal::Program& program, + Device* device, + Device* forward_device, + size_t line_size, + size_t line_index, + std::vector const& edm_termination_infos, + ccl::SyncModeSpec const& sync_details, + ccl::EdmLineFabricOpInterface& fabric_interface) { + TT_FATAL(sync_details.num_signals == 1, "Only one signal is supported for CCL command processor teardown"); + TT_FATAL(sync_details.sem_ids.size() == 1, "Only one signal is supported for CCL command processor teardown"); + TT_FATAL(sync_details.wait_counts.size() == 1, "Only one signal is supported for CCL command processor teardown"); + + auto local_wait_sem_id = sync_details.sem_ids.at(0); + auto remote_sem_id = sync_details.sem_ids.at(0); + + ttnn::ccl::cmd::CclHostLowLevelCommandSequence teardown_cmd_stream = { + // + 1 because we need to wait for our left/backward neighbour to tell us it's safe to teardown (because they + // are + // done tearing down - we teardown from first to last) + cmd::uops::local_semaphore_wait(local_wait_sem_id, sync_details.wait_counts.at(0) + (line_index != 0)), + }; + + // If there is a forward connection, notify that neighbour that they can teardown + if (forward_device != nullptr) { + auto remote_worker_noc0_core = forward_device->worker_core_from_logical_core(sync_details.core); + teardown_cmd_stream.push_back(cmd::uops::fabric_unicast_semaphore_inc( + remote_sem_id, + ttnn::ccl::cmd::CclCommandAtomicInc{1}, + remote_worker_noc0_core.x, + remote_worker_noc0_core.y, + ttnn::ccl::cmd::UnicastCommandDestArgs{1, true})); + } + + // Finally teardown our local chip's fabric endpoint(s) + if (edm_termination_infos.size() > 0) { + log_trace(tt::LogOp, "{} termination infos", edm_termination_infos.size()); + } + for (auto& info : edm_termination_infos) { + if (info.distance == 0) { + log_trace( + tt::LogOp, + "Adding local chip fabric teardown command for termination address {},", + info.termination_addr); + teardown_cmd_stream.push_back(cmd::uops::local_chip_noc_absolute_address_semaphore_inc( + info.edm_noc_x, info.edm_noc_y, info.termination_addr, 1)); + } else { + log_trace( + tt::LogOp, + "Adding remote chip fabric teardown command for termination address {} of distance {}", + info.termination_addr, + info.distance); + teardown_cmd_stream.push_back(ttnn::ccl::cmd::uops::fabric_unicast_absolute_address_semaphore_inc( + ttnn::ccl::cmd::CclCommandAddrAbsoluteAddress{info.termination_addr}, + ttnn::ccl::cmd::CclCommandAtomicInc{1}, + info.edm_noc_x, + info.edm_noc_y, + ttnn::ccl::cmd::UnicastCommandDestArgs{info.distance, true})); + } + } + + return teardown_cmd_stream; +} + +void build_sync_kernels( + Device* device, + tt::tt_metal::Program& program, + ccl::SyncModeSpec const& sync_details, + bool terminate_fabric, + ccl::EdmLineFabricOpInterface& fabric_interface) { + auto const sync_kernel_id = CreateKernel( + program, + "ttnn/cpp/ttnn/operations/ccl/common/kernels/ccl_wait_completion.cpp", + sync_details.core, + tt::tt_metal::ReaderDataMovementConfig({sync_details.num_signals, terminate_fabric})); + + std::vector rt_args; + rt_args.reserve(sync_details.num_signals * 2); + for (size_t i = 0; i < sync_details.num_signals; ++i) { + rt_args.push_back(sync_details.sem_ids[i]); + rt_args.push_back(sync_details.wait_counts[i]); + } + + if (terminate_fabric) { + auto termination_infos = fabric_interface.generate_local_chip_fabric_termination_infos(device); + rt_args.push_back(termination_infos.size()); + for (auto& info : termination_infos) { + if (info.distance != 0) { + continue; + } + rt_args.push_back(info.termination_addr); + rt_args.push_back(info.edm_noc_x); + rt_args.push_back(info.edm_noc_y); + } + } + + tt::tt_metal::SetRuntimeArgs(program, sync_kernel_id, sync_details.core, rt_args); +} + +std::vector CCLWorkerArgBuilder::generate_sender_reader_kernel_rt_args( + ttnn::ccl::InterleavedTensorWorkerSlice worker_slice, + std::size_t operating_dim, + uint32_t num_pages_per_packet, + uint32_t worker_slice_index) const { + const std::size_t num_commands_expected = this->input_tensor_partition.partition_size; + + auto const& tensor_shape = worker_slice.tensor_shape; + auto const& tensor_slice_shape = worker_slice.tensor_slice_shape; + + auto num_slices = input_tensor_partition.partition_size; + auto start_slice_index = input_tensor_partition.partition_index; + std::int64_t end_slice_index_exclusive = input_tensor_partition.partition_index + 1; + + log_trace(tt::LogOp, "ccl_send_writer start_slice_index = {}", start_slice_index); + log_trace(tt::LogOp, "ccl_send_writer end_slice_index_exclusive = {}", end_slice_index_exclusive); + + // Add the command args + auto const& slices = generate_slice_sequence_on_dim_v2( + tensor_shape, + worker_slice.worker_slice_shape, + worker_slice.worker_slice_offset, + operating_dim, + num_slices, + start_slice_index, + end_slice_index_exclusive, + worker_slice_index); + TT_ASSERT(num_commands_expected == slices.size()); + + // If we are on device zero, we send n-1 chunks in ascending order + auto& input_tensor = this->op_config.get_input_tensor(0); + TT_ASSERT(input_tensor.get_legacy_shape().size() == 4, "Only 4D tensors are supported for ccl"); + ttnn::ccl::Shape4D input_tensor_shape = { + input_tensor.get_legacy_shape()[0], + input_tensor.get_legacy_shape()[1], + input_tensor.get_legacy_shape()[2], + input_tensor.get_legacy_shape()[3]}; + + std::vector args = { + static_cast(input_tensor.buffer()->address()), + static_cast(slices.size()), + num_pages_per_packet, + this->op_config.get_page_size()}; + std::size_t logged_arg_idx = 0; + log_trace(tt::LogOp, "ccl_send_reader arg[{}]: buffer_address = {}", logged_arg_idx, args[logged_arg_idx]); + logged_arg_idx++; + log_trace(tt::LogOp, "ccl_send_reader arg[{}]: num_commands = {}", logged_arg_idx, args[logged_arg_idx]); + logged_arg_idx++; + log_trace(tt::LogOp, "ccl_send_reader arg[{}]: pages_per_packet {}", logged_arg_idx, args[logged_arg_idx]); + logged_arg_idx++; + log_trace(tt::LogOp, "ccl_send_reader arg[{}]: page_size {}", logged_arg_idx, args[logged_arg_idx]); + logged_arg_idx++; + + auto const& addr_gen_rt_args = ttnn::ccl::emit_address_generator_runtime_args(this->device, input_tensor); + std::ranges::copy(addr_gen_rt_args, std::back_inserter(args)); + for (auto const& arg : addr_gen_rt_args) { + log_trace(tt::LogOp, "ccl_send_reader arg[{}]: addr_gen_rt_args[] {}", logged_arg_idx, args[logged_arg_idx]); + logged_arg_idx++; + } + + log_trace(tt::LogOp, "ccl_send_reader Generating {} ccl send commands", slices.size()); + emit_ccl_send_slice_sequence_commands(slices, args); + + log_trace(tt::LogOp, "ccl_send_reader Sender Worker has {} RT Args: {}", args.size(), args); + + return args; +} + +std::vector CCLWorkerArgBuilder::generate_sender_writer_kernel_rt_args( + std::optional const& forward_fabric_connection, + const size_t sender_worker_forward_flow_control_semaphore_id, + const size_t sender_worker_forward_buffer_index_semaphore_id, + std::optional const& backward_fabric_connection, + const size_t sender_worker_backward_flow_control_semaphore_id, + const size_t sender_worker_backward_buffer_index_semaphore_id, + const size_t forward_direction_distance_to_end_of_line, + const size_t backward_direction_distance_to_end_of_line, + ttnn::ccl::InterleavedTensorWorkerSlice worker_slice, + std::size_t operating_dim, + uint32_t num_pages_per_packet, + uint32_t worker_slice_index, + std::optional sync_details) const { + const std::size_t num_commands_expected = this->output_tensor_partition.partition_size - 1; + + auto const& tensor_shape = worker_slice.tensor_shape; + auto const& tensor_slice_shape = worker_slice.tensor_slice_shape; + + auto num_slices = output_tensor_partition.partition_size; + auto start_slice_index = output_tensor_partition.partition_index; + std::int64_t end_slice_index_exclusive = output_tensor_partition.partition_index + 1; + + log_trace(tt::LogOp, "ccl_send_writer start_slice_index = {}", start_slice_index); + log_trace(tt::LogOp, "ccl_send_writer end_slice_index_exclusive = {}", end_slice_index_exclusive); + + // Add the command args + auto const& slices = generate_slice_sequence_on_dim_v2( + tensor_shape, + worker_slice.worker_slice_shape, + worker_slice.worker_slice_offset, + operating_dim, + num_slices, + start_slice_index, + end_slice_index_exclusive, + worker_slice_index); + TT_ASSERT(num_commands_expected == slices.size()); + + // If we are on device zero, we send n-1 chunks in ascending order + auto& output_tensor = this->op_config.get_output_tensor(0); + TT_ASSERT(output_tensor.get_legacy_shape().size() == 4, "Only 4D tensors are supported for ccl"); + ttnn::ccl::Shape4D output_tensor_shape = { + output_tensor.get_legacy_shape()[0], + output_tensor.get_legacy_shape()[1], + output_tensor.get_legacy_shape()[2], + output_tensor.get_legacy_shape()[3]}; + + std::vector args = { + static_cast(output_tensor.buffer()->address()), + static_cast(slices.size()), + num_pages_per_packet, + this->op_config.get_page_size(), + forward_direction_distance_to_end_of_line, + backward_direction_distance_to_end_of_line}; + std::size_t logged_arg_idx = 0; + log_trace(tt::LogOp, "ccl_send_writer arg[{}]: buffer_address = {}", logged_arg_idx, args[logged_arg_idx]); + logged_arg_idx++; + log_trace(tt::LogOp, "ccl_send_writer arg[{}]: num_commands = {}", logged_arg_idx, args[logged_arg_idx]); + logged_arg_idx++; + log_trace(tt::LogOp, "ccl_send_writer arg[{}]: pages_per_packet {}", logged_arg_idx, args[logged_arg_idx]); + logged_arg_idx++; + log_trace(tt::LogOp, "ccl_send_writer arg[{}]: page_size {}", logged_arg_idx, args[logged_arg_idx]); + logged_arg_idx++; + args.push_back(forward_fabric_connection.has_value() ? 1 : 0); + if (forward_fabric_connection.has_value()) { + TT_FATAL( + forward_direction_distance_to_end_of_line > 0, + "Forward direction distance to end of line must be greater than 0"); + log_trace(tt::LogOp, "ccl_send_writer has forward fabric connection"); + log_trace(tt::LogOp, "\tedm_noc_x: {}", forward_fabric_connection.value().edm_noc_x); + log_trace(tt::LogOp, "\tedm_noc_y: {}", forward_fabric_connection.value().edm_noc_y); + log_trace(tt::LogOp, "\tedm_buffer_base_addr: {}", forward_fabric_connection.value().edm_buffer_base_addr); + log_trace(tt::LogOp, "\tnum_buffers_per_channel: {}", forward_fabric_connection.value().num_buffers_per_channel); + log_trace(tt::LogOp, "\tedm_l1_sem_addr: {}", forward_fabric_connection.value().edm_l1_sem_addr); + log_trace( + tt::LogOp, + "\tedm_connection_handshake_addr: {}", + forward_fabric_connection.value().edm_connection_handshake_addr); + log_trace( + tt::LogOp, + "\tedm_worker_location_info_addr: {}", + forward_fabric_connection.value().edm_worker_location_info_addr); + log_trace(tt::LogOp, "\tbuffer_size_bytes: {}", forward_fabric_connection.value().buffer_size_bytes); + log_trace( + tt::LogOp, "\tbuffer_index_semaphore_id: {}", forward_fabric_connection.value().buffer_index_semaphore_id); + ttnn::ccl::append_worker_to_fabric_edm_sender_rt_args( + forward_fabric_connection.value(), + sender_worker_forward_flow_control_semaphore_id, + sender_worker_forward_buffer_index_semaphore_id, + args); + logged_arg_idx = ttnn::ccl::log_worker_to_fabric_edm_sender_rt_args(args, logged_arg_idx); + } + args.push_back(backward_fabric_connection.has_value() ? 1 : 0); + if (backward_fabric_connection.has_value()) { + TT_FATAL( + backward_direction_distance_to_end_of_line > 0, + "Backward direction distance to end of line must be greater than 0"); + log_trace(tt::LogOp, "ccl_send_writer has backward fabric connection"); + log_trace(tt::LogOp, "\tedm_noc_x: {}", backward_fabric_connection.value().edm_noc_x); + log_trace(tt::LogOp, "\tedm_noc_y: {}", backward_fabric_connection.value().edm_noc_y); + log_trace(tt::LogOp, "\tedm_buffer_base_addr: {}", backward_fabric_connection.value().edm_buffer_base_addr); + log_trace( + tt::LogOp, "\tnum_buffers_per_channel: {}", backward_fabric_connection.value().num_buffers_per_channel); + log_trace(tt::LogOp, "\tedm_l1_sem_addr: {}", backward_fabric_connection.value().edm_l1_sem_addr); + log_trace( + tt::LogOp, + "\tedm_connection_handshake_addr: {}", + backward_fabric_connection.value().edm_connection_handshake_addr); + log_trace( + tt::LogOp, + "\tedm_worker_location_info_addr: {}", + backward_fabric_connection.value().edm_worker_location_info_addr); + log_trace(tt::LogOp, "\tbuffer_size_bytes: {}", backward_fabric_connection.value().buffer_size_bytes); + log_trace( + tt::LogOp, "\tbuffer_index_semaphore_id: {}", backward_fabric_connection.value().buffer_index_semaphore_id); + ttnn::ccl::append_worker_to_fabric_edm_sender_rt_args( + backward_fabric_connection.value(), + sender_worker_backward_flow_control_semaphore_id, + sender_worker_backward_buffer_index_semaphore_id, + args); + logged_arg_idx = ttnn::ccl::log_worker_to_fabric_edm_sender_rt_args(args, logged_arg_idx); + } + + args.push_back(sync_details.has_value() ? 1 : 0); + if (sync_details.has_value()) { + args.push_back(sync_details.value().num_signals); + for (size_t i = 0; i < sync_details.value().num_signals; ++i) { + auto const noc_coord = + this->device->virtual_core_from_logical_core(sync_details.value().core, CoreType::WORKER); + log_trace( + tt::LogOp, + "ccl_send_writer on device {} adding sync signal dest to (y={},x={},id={})", + this->device->id(), + noc_coord.y, + noc_coord.x, + sync_details.value().sem_ids[i]); + args.push_back(sync_details.value().sem_ids[i]); + args.push_back(noc_coord.x); + args.push_back(noc_coord.y); + } + } + + auto const& addr_gen_rt_args = ttnn::ccl::emit_address_generator_runtime_args(this->device, output_tensor); + std::ranges::copy(addr_gen_rt_args, std::back_inserter(args)); + for (auto const& arg : addr_gen_rt_args) { + log_trace(tt::LogOp, "ccl_send_writer arg[{}]: addr_gen_rt_args[] {}", logged_arg_idx, args[logged_arg_idx]); + logged_arg_idx++; + } + + log_trace(tt::LogOp, "ccl_send_writer Generating {} ccl send commands", slices.size()); + emit_ccl_send_slice_sequence_commands(slices, args); + + log_trace(tt::LogOp, "ccl_send_writer Sender Worker has {} RT Args: {}", args.size(), args); + + return args; +} + +std::vector CCLWorkerArgBuilder::generate_sender_reader_kernel_ct_args() const { + auto const& input_tensor = this->op_config.get_input_tensor(0); + std::vector args = { + static_cast(input_tensor.memory_config().memory_layout), // tensor memory layout + static_cast(input_tensor.buffer()->buffer_type()), // buffer type + static_cast(input_tensor.layout()), // page layout + static_cast(tt::CB::c_in0) // cb_id + }; + + auto const& addr_gen_rt_args = ttnn::ccl::emit_address_generator_compile_time_args(input_tensor); + std::ranges::copy(addr_gen_rt_args, std::back_inserter(args)); + + return args; +} + +std::vector CCLWorkerArgBuilder::generate_sender_writer_kernel_ct_args() const { + auto const& output_tensor = this->op_config.get_output_tensor(0); + std::vector args = { + static_cast(output_tensor.memory_config().memory_layout), // tensor memory layout + static_cast(output_tensor.buffer()->buffer_type()), // buffer type + static_cast(output_tensor.layout()), // page layout + static_cast(tt::CB::c_in0) // cb_id + }; + + auto const& addr_gen_rt_args = ttnn::ccl::emit_address_generator_compile_time_args(output_tensor); + std::ranges::copy(addr_gen_rt_args, std::back_inserter(args)); + + return args; +} + +} // namespace ttnn::ccl::worker_detail diff --git a/ttnn/cpp/ttnn/operations/ccl/common/host/ccl_worker_builder.hpp b/ttnn/cpp/ttnn/operations/ccl/common/host/ccl_worker_builder.hpp new file mode 100644 index 00000000000..79699816337 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/common/host/ccl_worker_builder.hpp @@ -0,0 +1,172 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "ttnn/cpp/ttnn/operations/ccl/ccl_host_datastructures.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp" +#include "ttnn/operations/ccl/common/uops/ccl_command.hpp" +#include "ttnn/operations/ccl/common/uops/ccl_host_commands.hpp" + +#include +#include +#include + +namespace tt::tt_metal { +inline namespace v0 { + +// Forward declarations +class Device; + +} // namespace v0 +} // namespace tt::tt_metal + +namespace ttnn::ccl { +class WorkerEdmInterfaceArgs; +class SenderWorkerAdapterSpec; + +namespace worker_detail { + +Shape4D to_4d_shape(Shape4D const& shape); +Shape4D to_4d_offset(Shape4D const& offset); +size_t get_volume(Shape4D const& shape); + +Shape4D to_4d_shape(tt_xy_pair const& shape); +Shape4D to_4d_offset(tt_xy_pair const& offset); +size_t get_volume(tt_xy_pair const& shape); + +void generate_ccl_slice_sequence_commands( + std::vector const& slices, + ttnn::ccl::cmd::CclCommandCode command_type, + std::vector& args_out); +void generate_ccl_slice_sequence_commands( + std::vector const& slices, + ttnn::ccl::cmd::CclCommandCode command_type, + std::vector& args_out, + ttnn::ccl::cmd::CclCommandDestArgs const& dest_args); +void emit_ccl_send_slice_sequence_commands(std::vector const& slices, std::vector& args_out); +void generate_ccl_read_to_cb_slice_sequence_commands( + std::vector const& slices, + std::vector& args_out, + ttnn::ccl::cmd::CclCommandDestArgs const& dest_args); +void generate_ccl_cb_to_tensor_slice_sequence_commands( + std::vector const& slices, + std::vector& args_out, + ttnn::ccl::cmd::CclCommandDestArgs const& dest_args); +void generate_ccl_command_stream_to_kernel_args( + std::vector const& ccl_command_stream, + std::vector& args_out); + +// TODO: eventually take a fabric handle +void generate_multi_input_command_stream_kernel_rt_args( + Program& program, + KernelHandle kernel_id, + std::vector const& tensors, + std::vector const& page_sizes, + Device* device, + uint32_t num_pages_per_edm_buffer, // TODO: get from fabric + CoreRangeSet const& worker_core_range, + std::vector const& ccl_command_stream0, + std::optional> const& ccl_command_stream1, + std::optional const& forward_fabric_connections, + std::optional const& backward_fabric_connections, + std::optional> const& tensor_device_override = std::nullopt); +// Helper functions for building command processing datamovement kernels +// TODO: Bundle into command bundle per command stream to cut down +// on args and improve usability +void generate_multi_command_stream_kernel_rt_args( + Program& program, + KernelHandle kernel_id, + std::vector const& cb_ids, + std::vector const& tensors, + Device* device, + uint32_t page_size, // TODO: get from tensors + CoreRangeSet const& worker_core_range, + uint32_t num_pages_per_edm_buffer, // TODO: get from fabric + std::vector> const& command_tensor_slices, + ttnn::ccl::cmd::CclCommandCode command_type, // TODAY REQURED TO BE SAME - FUTURE - wrapped with above + std::optional const& forward_fabric_connections, + std::optional const& backward_fabric_connections, + std::optional> const& edm_termination_infos, + std::vector const& dest_args); +KernelHandle generate_multi_command_stream_kernel_ct_args( + Program& program, + std::vector const& cb_indices, + std::vector const& tensors, + CoreRangeSet const& worker_core_range, + DataMovementConfig datamovement_kernel_config, + const size_t num_command_streams = 2, + std::optional my_chip_id = std::nullopt); + +// Maybe not the right place for this - re-evaluate +// Generates the kernel that allows async-tensor-mode CCLs to run in synchronous mode such that +// they will wait for all outstanding writes to complete before completing the CCL on any given chip +// to avoid races because, generally speaking, async mode for CCLs requires the consumer ops to support +// async tensors. +// + +// Async tensor mode doesn't require that the producer of a tensor wait for the tensor to be fully populated +// before terminating; instead that responsibility is left to the consumer. This can be advantageous because it +// a) Allows dispatch overheads to be partly or fully hidden +// b) Allows producer and consumer ops to more natively overlap execution +void build_sync_kernels( + Device* device, + tt::tt_metal::Program& program, + ccl::SyncModeSpec const& sync_details, + bool terminate_fabric, + ccl::EdmLineFabricOpInterface& fabric_interface); +ttnn::ccl::cmd::CclHostLowLevelCommandSequence build_ccl_cmd_proc_teardown_commands( + tt::tt_metal::Program& program, + Device* device, + Device* forward_device, + size_t line_size, + size_t line_index, + std::vector const& edm_termination_infos, + ccl::SyncModeSpec const& sync_details, + ccl::EdmLineFabricOpInterface& fabric_interface); + +struct CCLWorkerArgBuilder { + CCLWorkerArgBuilder( + tt::tt_metal::Device const* device, + ttnn::ccl::CCLOpConfig const& op_config, + ttnn::ccl::TensorPartition const& input_tensor_partition, + ttnn::ccl::TensorPartition const& output_tensor_partition, + std::size_t operating_dim); + + std::vector generate_sender_reader_kernel_rt_args( + ttnn::ccl::InterleavedTensorWorkerSlice worker_slice, + std::size_t operating_dim, + uint32_t num_pages_per_packet, + uint32_t worker_slice_index) const; + + std::vector generate_sender_writer_kernel_rt_args( + std::optional const& forward_fabric_connection, + const size_t sender_worker_forward_flow_control_semaphore_id, + const size_t sender_worker_forward_buffer_index_semaphore_id, + std::optional const& backward_fabric_connection, + const size_t sender_worker_backward_flow_control_semaphore_id, + const size_t sender_worker_backward_buffer_index_semaphore_id, + const size_t forward_direction_distance_to_end_of_line, + const size_t backward_direction_distance_to_end_of_line, + ttnn::ccl::InterleavedTensorWorkerSlice worker_slice, + std::size_t operating_dim, + uint32_t num_pages_per_packet, + uint32_t worker_slice_index, + std::optional sync_details) const; + + std::vector generate_sender_reader_kernel_ct_args() const; + + std::vector generate_sender_writer_kernel_ct_args() const; + + tt::tt_metal::Device const* device; + ttnn::ccl::TensorPartition const input_tensor_partition; + ttnn::ccl::TensorPartition const output_tensor_partition; + ttnn::ccl::CCLOpConfig const op_config; + std::size_t operating_dim; + bool src_is_dram; + bool dst_is_dram; +}; + +} // namespace worker_detail +} // namespace ttnn::ccl diff --git a/ttnn/cpp/ttnn/operations/ccl/common/kernels/ccl_send_reader.cpp b/ttnn/cpp/ttnn/operations/ccl/common/kernels/ccl_send_reader.cpp new file mode 100644 index 00000000000..80d165687fc --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/common/kernels/ccl_send_reader.cpp @@ -0,0 +1,203 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "dataflow_api.h" +#include "impl/buffers/buffer_constants.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/common/types/ccl_types.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/sharded_tensor_addr_gen.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/common/uops/ccl_command_device.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/common/types/ccl_types_device.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_adapters.hpp" +#include "debug/dprint.h" +#include "ttnn/cpp/ttnn/tensor/enum_types.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/common/kernels/command_processor.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/common/kernels/ccl_send_utils.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_transmission.hpp" + +#include + +/////////////////////////////////////////////////// +// COMPILE TIME ARGS +/////////////////////////////////////////////////// + +constexpr TensorMemoryLayout tensor_layout = static_cast(get_compile_time_arg_val(0)); +constexpr BufferType buffer_type = static_cast(get_compile_time_arg_val(1)); +constexpr Layout page_layout = static_cast(get_compile_time_arg_val(2)); +constexpr uint32_t cb_id = get_compile_time_arg_val(3); + +#ifdef SHARDED_MEM_LAYOUT +static constexpr bool is_sharded_mode = true; +static constexpr uint32_t input_tensor_shard_grid_height = get_compile_time_arg_val(4); +static constexpr uint32_t input_tensor_shard_grid_width = get_compile_time_arg_val(5); +static constexpr uint32_t input_tensor_shard_grid_start_y_logical = get_compile_time_arg_val(6); +static constexpr uint32_t input_tensor_shard_grid_start_x_logical = get_compile_time_arg_val(7); +static constexpr uint32_t input_tensor_shard_pages_per_shard_y = get_compile_time_arg_val(8); +static constexpr uint32_t input_tensor_shard_pages_per_shard_x = get_compile_time_arg_val(9); +static constexpr bool input_tensor_shard_grid_transposed = get_compile_time_arg_val(10) != 0; +#else +static constexpr bool is_sharded_mode = false; +static constexpr uint32_t input_tensor_shard_grid_height = 0; +static constexpr uint32_t input_tensor_shard_grid_width = 0; +static constexpr uint32_t input_tensor_shard_grid_start_y_logical = 0; +static constexpr uint32_t input_tensor_shard_grid_start_x_logical = 0; +static constexpr uint32_t input_tensor_shard_pages_per_shard_y = 0; +static constexpr uint32_t input_tensor_shard_pages_per_shard_x = 0; +static constexpr bool input_tensor_shard_grid_transposed = false; +#endif + +template < + tt::tt_metal::TensorMemoryLayout tensor_layout, + tt::tt_metal::BufferType buffer_type, + tt::tt_metal::Layout page_layout> +auto build_source_address_generator( + std::size_t& arg_idx, address_t tensor_address, std::size_t page_size, uint32_t cb_id_in0) -> + typename source_tensor_addrgen::type { + constexpr bool is_sharded = is_sharded_tensor_layout(tensor_layout); + constexpr bool is_interleaved = tensor_layout == tt::tt_metal::TensorMemoryLayout::INTERLEAVED; + constexpr bool is_tile_page_layout = page_layout == tt::tt_metal::Layout::TILE; + constexpr bool is_row_major_layout = page_layout == tt::tt_metal::Layout::ROW_MAJOR; + static_assert( + is_sharded || is_interleaved, + "Only sharded and interleaved tensor layouts are supported but the unified address generator. A tensor layout " + "not matching TensorMemoryLayout::WIDTH_SHARDED, TensorMemoryLayout::HEIGHT_SHARDED, " + "TensorMemoryLayout::BLOCK_SHARDED, or TensorMemoryLayout::INTERLEAVED was specified."); + + using addrgen_type = typename source_tensor_addrgen::type; + + if constexpr (tensor_layout == tt::tt_metal::TensorMemoryLayout::INTERLEAVED) { + if constexpr (is_row_major_layout) { + return addrgen_type{.bank_base_address = tensor_address, .page_size = page_size}; + } else { + return addrgen_type{ + .bank_base_address = tensor_address, .page_size = page_size, .data_format = get_dataformat(cb_id_in0)}; + } + } else if constexpr ( + tensor_layout == tt::tt_metal::TensorMemoryLayout::BLOCK_SHARDED || + tensor_layout == tt::tt_metal::TensorMemoryLayout::HEIGHT_SHARDED || + tensor_layout == tt::tt_metal::TensorMemoryLayout::WIDTH_SHARDED) { + size_t input_shard_grid_nrows = get_arg_val(arg_idx++); + const auto* const input_shard_grid_row_map = reinterpret_cast(get_arg_addr(arg_idx)); + arg_idx += input_shard_grid_nrows; + size_t input_shard_grid_ncols = get_arg_val(arg_idx++); + const auto* const input_shard_grid_col_map = reinterpret_cast(get_arg_addr(arg_idx)); + arg_idx += input_shard_grid_ncols; + + return tt::tt_metal::address_generators::build_sharded_addr_gen( + tt::tt_metal::address_generators::HarvestedWormholeWorkerToNocLookup( + input_shard_grid_nrows, input_shard_grid_row_map, input_shard_grid_ncols, input_shard_grid_col_map), + typename tt::tt_metal::address_generators::DeviceShardSpecTypeGetter::type( + input_tensor_shard_pages_per_shard_y, + input_tensor_shard_pages_per_shard_x, + input_tensor_shard_grid_height, + input_tensor_shard_grid_width, + input_tensor_shard_grid_start_y_logical, + input_tensor_shard_grid_start_x_logical, + input_tensor_shard_grid_transposed), + page_size, + tensor_address); + } else { + ASSERT(false); + } +} + +/* + * CCL Send will present various operating modes. Although there is only a single send kernel, it may (compile time) + * dispatch implementations depending on those invocation parameters. + */ +void kernel_main() { + std::size_t arg_idx = 0; + + /////////////////////////////////////////////////// + // ARGS + /////////////////////////////////////////////////// + + // Load the input tensor spec + address_t tensor_address = get_arg_val(arg_idx++); + address_t num_commands = get_arg_val(arg_idx++); + + // Assuming whole page transmissions (which is the only mode we support at the moment) + // -> however, wanted to call it out here to make it clear that we need to pull this + // out when we start enabling other modes + const uint32_t packet_size_in_pages = get_arg_val(arg_idx++); + const uint32_t payload_page_size = get_arg_val(arg_idx++); + auto tensor_addrgen = build_source_address_generator( + arg_idx, tensor_address, payload_page_size, tt::CB::c_in0); + + ttnn::ccl::cmd::CclCommandTensor command_tensor; + + // Don't use CBs because there appears to be a bug if we have the same producer/consumer core to a given CB + // Instead, open up the CB and use it as a raw scratch space6 + +#ifdef DEBUG_PRINT_ENABLED + DPRINT << "ccl_send_reader has " << (uint32_t)num_commands << " commands" << ENDL(); +#endif + + for (std::size_t i = 0; i < num_commands; ++i) { + // Generalized would be to get the command header info and then dispatch accordingly - if the command type is + // singular + // + std::size_t old_arg_idx = arg_idx; + ttnn::ccl::cmd::update_command_tensor(arg_idx, command_tensor); + std::size_t new_arg_idx = arg_idx; + + { + print_tensor_command(i, command_tensor); + ASSERT(command_tensor.worker_pages_per_slice > 0); + + // CURRENTLY ONLY SUPPORTS WRAPPED TENSOR ITERATION COMMANDS + // Implemented really inefficiently for now - in the future we can do more efficient packing and also change + // the tensor read API to require the information in a more efficient way (less intermediate calculations) + shape_t valid_worker_slice_shape = + build_wrapped_row_tensor_slice(command_tensor.worker_pages_per_slice); // Parametrizable by ct arg + + shape_t const& global_offset = + command_tensor.tensor_slice_offset + command_tensor.worker_start_offset_in_slice; + + uint32_t curr_tile_id = get_flat_index_from_shape(command_tensor.tensor_shape, global_offset); + + uint32_t offset_into_worker_slice = 0; + bool last_page_of_worker = false; + for (uint32_t p = 0; p < command_tensor.worker_pages_per_slice; p += packet_size_in_pages) { + cb_reserve_back(cb_id, packet_size_in_pages); + const uint32_t local_l1_scratch_buffer_address = + get_write_ptr(cb_id) + sizeof(tt::fabric::PacketHeader); + + uint32_t n_pages = std::min(packet_size_in_pages, command_tensor.worker_pages_per_slice - p); + ASSERT(command_tensor.worker_start_offset_in_slice.w == 0); + ASSERT(command_tensor.worker_start_offset_in_slice.z == 0); + ASSERT(valid_worker_slice_shape.w == 1); + ASSERT(valid_worker_slice_shape.z == 1); + ASSERT(command_tensor.tensor_shape.w == 1); + ASSERT(command_tensor.tensor_shape.z == 1); + ASSERT(command_tensor.tensor_slice_shape.w == 1); + ASSERT(command_tensor.tensor_slice_shape.z == 1); + + DPRINT << "iter " << p << " curr_tile_id: " << curr_tile_id << ENDL(); + DPRINT << "local_l1_scratch_buffer_address: " << local_l1_scratch_buffer_address << ENDL(); + + read_wrapped_chunk_from_output_tensor_to_address( + curr_tile_id, + offset_into_worker_slice, + ttnn::ccl::coord_t( + command_tensor.worker_start_offset_in_slice.x, + command_tensor.worker_start_offset_in_slice.y), // Offset into tensor slice + ttnn::ccl::coord_t(valid_worker_slice_shape.x, valid_worker_slice_shape.y), + // In tiles for tile layout + ttnn::ccl::coord_t(command_tensor.tensor_shape.x, command_tensor.tensor_shape.y), + ttnn::ccl::coord_t(command_tensor.tensor_slice_shape.x, command_tensor.tensor_slice_shape.y), + local_l1_scratch_buffer_address, + tensor_addrgen, + n_pages, + payload_page_size, + last_page_of_worker); + + cb_push_back(cb_id, packet_size_in_pages); + } + } + } + //////////////////////////////////////////////////////////////////////////////////// + + DPRINT << "ccl_send_reader done\n"; +} diff --git a/ttnn/cpp/ttnn/operations/ccl/common/kernels/ccl_send_reader_two_input.cpp b/ttnn/cpp/ttnn/operations/ccl/common/kernels/ccl_send_reader_two_input.cpp new file mode 100644 index 00000000000..ca6e26a33e0 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/common/kernels/ccl_send_reader_two_input.cpp @@ -0,0 +1,1012 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +// NOTE: This should ideally be merged with `ccl_send_reader` when we are able to support compile time args +// that don't require macros to function + +#include "dataflow_api.h" +#include "impl/buffers/buffer_constants.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" +#include "tt_metal/impl/buffers/buffer_constants.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/sharded_tensor_addr_gen.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_ring_gather_utils.hpp" + +#include "ttnn/cpp/ttnn/operations/ccl/common/kernels/command_processor.hpp" + +#include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/edm_fabric_worker_adapters.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header.hpp" + +#include "ttnn/cpp/ttnn/tensor/enum_types.hpp" +#include +#include + +using arg_idx_t = uint16_t; + +/////////////////////////////////////////////////// +// COMPILE TIME ARGS +/////////////////////////////////////////////////// + +struct no_addrgen {}; +static constexpr size_t num_packet_headers_storable = 8; +constexpr uint16_t my_chip_id = get_compile_time_arg_val(0); +constexpr uint32_t reserved_packet_header_cb_id = get_compile_time_arg_val(1); +#ifdef NO_TENSOR_MODE +constexpr TensorMemoryLayout tensor0_layout = TensorMemoryLayout::INTERLEAVED; +constexpr BufferType buffer0_type = BufferType::DRAM; +constexpr Layout tensor0_page_layout = Layout::TILE; +constexpr uint32_t cb0_id = tt::CB::c_in0; +constexpr TensorMemoryLayout tensor1_layout = TensorMemoryLayout::INTERLEAVED; +constexpr BufferType buffer1_type = BufferType::DRAM; +constexpr Layout tensor1_page_layout = Layout::TILE; +constexpr uint32_t cb1_id = tt::CB::c_in1; +#else +constexpr TensorMemoryLayout tensor0_layout = static_cast(get_compile_time_arg_val(2)); +constexpr BufferType buffer0_type = static_cast(get_compile_time_arg_val(3)); +constexpr Layout tensor0_page_layout = static_cast(get_compile_time_arg_val(4)); +constexpr uint32_t cb0_id = get_compile_time_arg_val(5); +#ifndef SINGLE_TENSOR +constexpr TensorMemoryLayout tensor1_layout = static_cast(get_compile_time_arg_val(6)); +constexpr BufferType buffer1_type = static_cast(get_compile_time_arg_val(7)); +constexpr Layout tensor1_page_layout = static_cast(get_compile_time_arg_val(8)); +constexpr uint32_t cb1_id = get_compile_time_arg_val(9); +#endif +#endif + +// TODO: COMMONIZE WITH THE ONE IN `ccl_send_writer.cpp` +FORCE_INLINE std::pair get_noc_address_components(uint64_t noc_addr) { + const size_t bank_addr = noc_addr & 0xFFFFFFFF; + const size_t noc_x = (noc_addr >> NOC_ADDR_LOCAL_BITS) & ((1 << NOC_ADDR_NODE_ID_BITS) - 1); + const size_t noc_y = + (noc_addr >> (NOC_ADDR_LOCAL_BITS + NOC_ADDR_NODE_ID_BITS)) & ((1 << NOC_ADDR_NODE_ID_BITS) - 1); + return {WorkerXY(noc_x, noc_y), bank_addr}; +} + +struct sharded_addrgen_fields { + bool is_sharded = false; + uint8_t tensor_shard_grid_height = 0; + uint8_t tensor_shard_grid_width = 0; + uint8_t tensor_shard_grid_start_y_logical = 0; + uint8_t tensor_shard_grid_start_x_logical = 0; + uint32_t tensor_shard_pages_per_shard_y = 0; + uint32_t tensor_shard_pages_per_shard_x = 0; + bool tensor_shard_grid_transposed = 0; +}; + +#ifdef TENSOR0_SHARDED_MEM_LAYOUT +#ifdef SINGLE_TENSOR +// SINGLE INPUT MODE - SHARDED +constexpr sharded_addrgen_fields in0_sharded_addrgen_fields = { + true, + get_compile_time_arg_val(6), + get_compile_time_arg_val(7), + get_compile_time_arg_val(8), + get_compile_time_arg_val(9), + get_compile_time_arg_val(10), + get_compile_time_arg_val(11), + get_compile_time_arg_val(12) != 0}; +#else +// TWO INPUT MODE +constexpr sharded_addrgen_fields in0_sharded_addrgen_fields = { + true, + get_compile_time_arg_val(10), + get_compile_time_arg_val(11), + get_compile_time_arg_val(12), + get_compile_time_arg_val(13), + get_compile_time_arg_val(14), + get_compile_time_arg_val(15), + get_compile_time_arg_val(16) != 0}; + +#endif +static_assert(in0_sharded_addrgen_fields.tensor_shard_grid_height > 0, "Misconfigured sharded addrgen fields for tensor0. Field \"tensor_shard_grid_height\" was resolved to 0 but it must not be 0."); +static_assert(in0_sharded_addrgen_fields.tensor_shard_grid_width > 0, "Misconfigured sharded addrgen fields for tensor0. Field \"tensor_shard_grid_width\" was resolved to 0 but it must not be 0."); +static_assert(in0_sharded_addrgen_fields.tensor_shard_pages_per_shard_y > 0, "Misconfigured sharded addrgen fields for tensor0. Field \"tensor_shard_pages_per_shard_y\" was resolved to 0 but it must not be 0."); +static_assert(in0_sharded_addrgen_fields.tensor_shard_pages_per_shard_x > 0, "Misconfigured sharded addrgen fields for tensor0. Field \"tensor_shard_pages_per_shard_x\" was resolved to 0 but it must not be 0."); +#else +constexpr sharded_addrgen_fields in0_sharded_addrgen_fields = {false, 0, 0, 0, 0, 0, 0, 0}; +#endif + +#ifndef SINGLE_TENSOR +#if defined(TENSOR1_SHARDED_MEM_LAYOUT) +#if defined(TENSOR0_SHARDED_MEM_LAYOUT) +constexpr sharded_addrgen_fields in1_sharded_addrgen_fields = { + true, + get_compile_time_arg_val(17), + get_compile_time_arg_val(18), + get_compile_time_arg_val(19), + get_compile_time_arg_val(20), + get_compile_time_arg_val(21), + get_compile_time_arg_val(22), + get_compile_time_arg_val(23) != 0}; +#else +// Then we are only consuming ct args for second operand and we resume from operation 8 +constexpr sharded_addrgen_fields in1_sharded_addrgen_fields = { + true, + get_compile_time_arg_val(10), + get_compile_time_arg_val(11), + get_compile_time_arg_val(12), + get_compile_time_arg_val(13), + get_compile_time_arg_val(14), + get_compile_time_arg_val(15), + get_compile_time_arg_val(16) != 0}; +#endif + +static_assert(in1_sharded_addrgen_fields.tensor_shard_grid_height > 0, "Misconfigured sharded addrgen fields for tensor1. Field \"tensor_shard_grid_height\" was resolved to 0 but it must not be 0."); +static_assert(in1_sharded_addrgen_fields.tensor_shard_grid_width > 0, "Misconfigured sharded addrgen fields for tensor1. Field \"tensor_shard_grid_width\" was resolved to 0 but it must not be 0."); +static_assert(in1_sharded_addrgen_fields.tensor_shard_pages_per_shard_y > 0, "Misconfigured sharded addrgen fields for tensor1. Field \"tensor_shard_pages_per_shard_y\" was resolved to 0 but it must not be 0."); +static_assert(in1_sharded_addrgen_fields.tensor_shard_pages_per_shard_x > 0, "Misconfigured sharded addrgen fields for tensor1. Field \"tensor_shard_pages_per_shard_x\" was resolved to 0 but it must not be 0."); +#else +constexpr sharded_addrgen_fields in1_sharded_addrgen_fields = {0, 0, 0, 0, 0, 0, 0, 0}; +#endif +#endif + + +// NOTE: This will eventually be updated with an official API +static constexpr size_t VIRTUAL_COORDS_START_X = 16; +static constexpr size_t VIRTUAL_COORDS_START_Y = 16; +FORCE_INLINE bool is_using_noc_coords(uint16_t noc_x, uint16_t noc_y) { + return noc_x < VIRTUAL_COORDS_START_X && noc_y < VIRTUAL_COORDS_START_Y; +} + +FORCE_INLINE uint64_t safe_get_noc_addr(uint8_t dest_noc_x, uint8_t dest_noc_y, uint32_t dest_bank_addr) { + bool using_noc_coords = is_using_noc_coords(dest_noc_x, dest_noc_y); + uint8_t noc_x = dest_noc_x; + uint8_t noc_y = dest_noc_y; + if (using_noc_coords) { + noc_x = NOC_X_PHYS_COORD(dest_noc_x); + noc_y = NOC_Y_PHYS_COORD(dest_noc_y); + } + return get_noc_addr(noc_x, noc_y, dest_bank_addr); +} + + +template < + tt::tt_metal::TensorMemoryLayout tensor_layout, + tt::tt_metal::BufferType buffer_type, + tt::tt_metal::Layout page_layout> +FORCE_INLINE auto build_source_address_generator( + std::size_t& arg_idx, + address_t tensor_address, + std::size_t page_size, + sharded_addrgen_fields const& tensor_sharded_addrgen_fields, + uint32_t cb_id_in) -> typename source_tensor_addrgen::type { + constexpr bool is_sharded = is_sharded_tensor_layout(tensor_layout); + constexpr bool is_interleaved = tensor_layout == tt::tt_metal::TensorMemoryLayout::INTERLEAVED; + constexpr bool is_tile_page_layout = page_layout == tt::tt_metal::Layout::TILE; + constexpr bool is_row_major_layout = page_layout == tt::tt_metal::Layout::ROW_MAJOR; + static_assert( + is_sharded || is_interleaved, + "Only sharded and interleaved tensor layouts are supported but the unified address generator. A tensor layout " + "not matching TensorMemoryLayout::WIDTH_SHARDED, TensorMemoryLayout::HEIGHT_SHARDED, " + "TensorMemoryLayout::BLOCK_SHARDED, or TensorMemoryLayout::INTERLEAVED was specified."); + + using addrgen_type = typename source_tensor_addrgen::type; + + if constexpr (tensor_layout == tt::tt_metal::TensorMemoryLayout::INTERLEAVED) { + if constexpr (is_row_major_layout) { + return addrgen_type{.bank_base_address = tensor_address, .page_size = page_size}; + } else { + return addrgen_type{ + .bank_base_address = tensor_address, .page_size = page_size, .data_format = get_dataformat(cb_id_in)}; + } + } else if constexpr ( + tensor_layout == tt::tt_metal::TensorMemoryLayout::BLOCK_SHARDED || + tensor_layout == tt::tt_metal::TensorMemoryLayout::HEIGHT_SHARDED || + tensor_layout == tt::tt_metal::TensorMemoryLayout::WIDTH_SHARDED) { + // We don't use these args at the moment but we keep them here for now to avoid a rewrite in the very + // near future where we'll want to support custom shard grid. + uint8_t input_shard_grid_nrows = get_arg_val(arg_idx++); + const auto* const input_shard_grid_row_map = reinterpret_cast(get_arg_addr(arg_idx)); + arg_idx += input_shard_grid_nrows; + uint8_t input_shard_grid_ncols = get_arg_val(arg_idx++); + const auto* const input_shard_grid_col_map = reinterpret_cast(get_arg_addr(arg_idx)); + arg_idx += input_shard_grid_ncols; + + return tt::tt_metal::address_generators::build_sharded_addr_gen( + tt::tt_metal::address_generators::VirtualCoordWormholeWorkerToNocLookup(), + typename tt::tt_metal::address_generators::DeviceShardSpecTypeGetter::type( + tensor_sharded_addrgen_fields.tensor_shard_pages_per_shard_y, + tensor_sharded_addrgen_fields.tensor_shard_pages_per_shard_x, + tensor_sharded_addrgen_fields.tensor_shard_grid_height, + tensor_sharded_addrgen_fields.tensor_shard_grid_width, + tensor_sharded_addrgen_fields.tensor_shard_grid_start_y_logical, + tensor_sharded_addrgen_fields.tensor_shard_grid_start_x_logical, + tensor_sharded_addrgen_fields.tensor_shard_grid_transposed), + page_size, + tensor_address); + } else { + ASSERT(false); + } +} + +// TODO: rename to tensor IO command context +struct wrapped_worker_slice_read_context { + uint32_t curr_tile_id = 0; + uint32_t offset_into_worker_slice = 0; +}; +struct inline_value_context { + uint32_t value = 0; + uint32_t wrap = 0; +}; +struct remote_sem_change_context { + uint32_t value = 0; +}; +using remote_sem_wait_context = remote_sem_change_context; +using remote_atomic_inc_context = remote_sem_change_context; + +union cmd_specific_context { + wrapped_worker_slice_read_context wrapped_worker_slice_read_ctx; + + // sem wait and atomic inc + inline_value_context inline_value_ctx; + cmd_specific_context() {} +}; + +class FabricConnectionManager final { +public: + // return if there is/should be a connection - doesn't return whether or not the connection + // is actually live + bool is_logically_connected() const { return has_forward_connection() || has_backward_connection(); } + + // make the connection live + void open() { + if (has_forward_connection()) { + forward_fabric_sender.open(); + } + if (has_backward_connection()) { + backward_fabric_sender.open(); + } + } + bool has_forward_connection() const { return connection_flags & FORWARD_CONNECTION_FLAG_MASK; } + bool has_backward_connection() const { return connection_flags & BACKWARD_CONNECTION_FLAG_MASK; } + void close() { + if (has_forward_connection()) { + forward_fabric_sender.close(); + } + if (has_backward_connection()) { + backward_fabric_sender.close(); + } + } + + static FabricConnectionManager build_from_args(std::size_t& arg_idx) { + FabricConnectionManager connection_manager; + connection_manager.connection_flags = static_cast(get_arg_val(arg_idx++) != 0) + << FORWARD_CONNECTION_FLAG_OFFSET; + if (connection_manager.has_forward_connection()) { + connection_manager.forward_fabric_sender = + tt::fabric::WorkerToFabricEdmSender::build_from_args(arg_idx); + } + connection_manager.connection_flags |= static_cast(get_arg_val(arg_idx++) != 0) + << BACKWARD_CONNECTION_FLAG_OFFSET; + if (connection_manager.has_backward_connection()) { + connection_manager.backward_fabric_sender = + tt::fabric::WorkerToFabricEdmSender::build_from_args(arg_idx); + } + return connection_manager; + } + + tt::fabric::WorkerToFabricEdmSender& get_forward_connection() { + ASSERT(has_forward_connection()); + return forward_fabric_sender; + } + tt::fabric::WorkerToFabricEdmSender& get_backward_connection() { + ASSERT(has_backward_connection()); + return backward_fabric_sender; + } + +private: + static constexpr uint8_t FORWARD_CONNECTION_FLAG_MASK = 0x01; + static constexpr uint8_t BACKWARD_CONNECTION_FLAG_MASK = 0x02; + static constexpr uint8_t FORWARD_CONNECTION_FLAG_OFFSET = 0x0; + static constexpr uint8_t BACKWARD_CONNECTION_FLAG_OFFSET = 0x1; + tt::fabric::WorkerToFabricEdmSender forward_fabric_sender; + tt::fabric::WorkerToFabricEdmSender backward_fabric_sender; + uint8_t connection_flags; +}; + +struct address_info_t { + size_t address = 0; +}; + +struct core_descriptor_info_t { + union { + ttnn::ccl::cmd::CclCommandCoreDescriptorTypeNocXY noc_unicast; + ttnn::ccl::cmd::CclCommandCoreDescriptorTypeMcast noc_multicast; + } core_desc_args; +}; +template +struct command_context_t; + +template +void update_ccl_command( + arg_idx_t& arg_idx, command_context_t& cmd_ctx, ttnn::ccl::cmd::CclCommandHeader const& cmd_header); + +template +struct command_context_t final { + FORCE_INLINE command_context_t( + FabricConnectionManager& fabric_connection, + Addrgen& addrgen, + uint16_t num_commands, + arg_idx_t start_arg_idx, + uint8_t cb_id, + uint16_t page_size, + uint8_t packet_size_in_pages, + size_t packet_header_buffer_addr, + uint8_t stream_id) : + fabric_connection(fabric_connection), + tensor_addrgen(addrgen), + cmd_specific_ctx(), + packet_header_buffer_addr(packet_header_buffer_addr), + num_commands(num_commands), + arg_idx(start_arg_idx), + command_idx(0), + page_size(page_size), + cb_id(cb_id), + packet_size_in_pages(packet_size_in_pages), + stream_id(stream_id) { + ASSERT(num_commands == 0 || arg_idx > 4); + } + FabricConnectionManager& fabric_connection; + ttnn::ccl::cmd::CclCommandTensor command_tensor; + ttnn::ccl::cmd::CclCommandHeader current_cmd_header; + // TODO: optimize packing + address_info_t src_addr_info; + address_info_t dest_addr_info; + core_descriptor_info_t core_desc_info; + Addrgen& tensor_addrgen; + cmd_specific_context cmd_specific_ctx; + size_t packet_header_buffer_addr = 0; + + uint16_t num_commands = 0; + arg_idx_t arg_idx = 0; + uint16_t command_idx = 0; + + uint16_t page_size = 0; + ttnn::ccl::cmd::CclCommandAddrType src_addr_type = ttnn::ccl::cmd::CclCommandAddrType::NONE; + ttnn::ccl::cmd::CclCommandAddrType dest_addr_type = ttnn::ccl::cmd::CclCommandAddrType::NONE; + ttnn::ccl::cmd::CclCommandCoreDescriptorType core_desc_type = ttnn::ccl::cmd::CclCommandCoreDescriptorType::ADDRGEN; + uint8_t cb_id = 0; + uint8_t packet_size_in_pages = 0; + uint8_t stream_id; + + bool populated = false; + + bool command_requires_fabric() const { + return current_cmd_header.dest_type != ttnn::ccl::cmd::CclCommandDestType::CHIP_LOCAL_ONLY; + } + + FORCE_INLINE bool is_complete() const { return command_idx >= num_commands; } + + FORCE_INLINE void complete_current_command() { + command_idx++; + populated = false; + } + + FORCE_INLINE bool current_command_active() const { return populated; } + + FORCE_INLINE void fetch_next_command() { + populated = true; + + this->current_cmd_header = ttnn::ccl::cmd::CclCommandHeader::from_uint32(get_arg_val(arg_idx++)); +#ifdef DEBUG_PRINT_ENABLED + DPRINT << "CMD (code=" << (uint32_t)current_cmd_header.code + << ", args=" << (uint32_t)current_cmd_header.arg_count << ", idx=" << (uint32_t)(arg_idx - 1) << "\n"; +#endif + update_ccl_command(arg_idx, *this, current_cmd_header); + switch (current_cmd_header.code) { + case ttnn::ccl::cmd::CclCommandCode::STREAM_CB_TO_TENSOR: + case ttnn::ccl::cmd::CclCommandCode::STREAM_TENSOR_TO_EDM: { +#ifndef NO_TENSOR_MODE + shape_t const worker_start_offset_global = v2::worker_wrapped_offset_to_coord( + command_tensor.tensor_slice_shape, command_tensor.worker_start_offset_in_slice); + shape_t const global_offset = command_tensor.tensor_slice_offset + worker_start_offset_global; + + size_t const curr_tile_id = get_flat_index_from_shape(command_tensor.tensor_shape, global_offset); + cmd_specific_ctx.wrapped_worker_slice_read_ctx = wrapped_worker_slice_read_context{curr_tile_id}; +#endif + } break; + case ttnn::ccl::cmd::CclCommandCode::WAIT_VALUE: + case ttnn::ccl::cmd::CclCommandCode::ATOMIC_INC: + case ttnn::ccl::cmd::CclCommandCode::RAW_INLINE_WRITE_BYTES: break; + default: ASSERT(false); + } + } +}; + +template +void update_ccl_command( + arg_idx_t& arg_idx, command_context_t& cmd_ctx, ttnn::ccl::cmd::CclCommandHeader const& cmd_header) { + using namespace ttnn::ccl::cmd; + + arg_idx_t arg_idx_old = arg_idx; + for (arg_idx_t i = 0; i < cmd_header.arg_count; i++) { + // Note that we choose to reinterpret our pointers as volatile so that in the future we can add streaming + // of additional commands from some backing memory (e.g. dram or L1), potentially by another core, without + // having to track down this code and add volatile casts later (which would be a potentially tricky bug to + // root cause). + const CclCommandArgHeader command_arg_header = + CclCommandArgHeader::from_uint32(get_arg_val(arg_idx++)); + const CclCommandArgCode command_arg_code = command_arg_header.code; + auto& cmd_tensor = cmd_ctx.command_tensor; + switch (command_arg_code) { + case CclCommandArgCode::SET_TENSOR_SHAPE_IN_PAGES: + CclCommandArg::unpack( + reinterpret_cast(get_arg_addr(arg_idx)), cmd_tensor.tensor_shape); + arg_idx += CclCommandArg::size_in_words(); + break; + case CclCommandArgCode::SET_TENSOR_SLICE_SHAPE_IN_PAGES: + CclCommandArg::unpack( + reinterpret_cast(get_arg_addr(arg_idx)), cmd_tensor.tensor_slice_shape); + arg_idx += CclCommandArg::size_in_words(); + break; + case CclCommandArgCode::SET_TENSOR_SLICE_OFFSET_IN_PAGES: + CclCommandArg::unpack( + reinterpret_cast(get_arg_addr(arg_idx)), cmd_tensor.tensor_slice_offset); + arg_idx += CclCommandArg::size_in_words(); + break; + case CclCommandArgCode::SET_WORKER_START_OFFSET_IN_SLICE_IN_PAGES: + CclCommandArg::unpack( + reinterpret_cast(get_arg_addr(arg_idx)), + cmd_tensor.worker_start_offset_in_slice); + arg_idx += CclCommandArg::size_in_words(); + break; + case CclCommandArgCode::SET_WORKER_PAGES_PER_SLICE: + CclCommandArg::unpack( + reinterpret_cast(get_arg_addr(arg_idx)), cmd_tensor.worker_pages_per_slice); + arg_idx += CclCommandArg::size_in_words(); + break; + case CclCommandArgCode::SET_FULL_TENSOR_SLICE_SPEC_IN_PAGES: + CclCommandArg::unpack( + reinterpret_cast(get_arg_addr(arg_idx)), cmd_tensor); + arg_idx += CclCommandArg::size_in_words(); + break; + + case CclCommandArgCode::SET_TARGET_VALUE: + case CclCommandArgCode::SET_ATOMIC_INC_VALUE: { + bool val_inline = static_cast(command_arg_header.inline_value0); + ASSERT(val_inline); + cmd_ctx.cmd_specific_ctx.inline_value_ctx = inline_value_context{}; + cmd_ctx.cmd_specific_ctx.inline_value_ctx.value = command_arg_header.inline_value1; + } break; + + case CclCommandArgCode::SET_ADDRESS_INFO: { + const auto src_dest_type = static_cast(command_arg_header.inline_value0); + const auto addr_type = + static_cast(command_arg_header.inline_value1); + auto& addr_info = src_dest_type == SRC_DEST_TYPE::SRC ? cmd_ctx.src_addr_info : cmd_ctx.dest_addr_info; + auto& cmd_ctx_addr_type = + src_dest_type == SRC_DEST_TYPE::SRC ? cmd_ctx.src_addr_type : cmd_ctx.dest_addr_type; + cmd_ctx_addr_type = addr_type; + switch (addr_type) { + case ttnn::ccl::cmd::CclCommandAddrType::CIRCULAR_BUFFER_ID: + cmd_ctx.cb_id = command_arg_header.inline_value2; + break; + case ttnn::ccl::cmd::CclCommandAddrType::ABSOLUTE_ADDRESS: + case ttnn::ccl::cmd::CclCommandAddrType::RELATIVE_ADDRESS: + addr_info.address = get_arg_val(arg_idx++); + break; + case ttnn::ccl::cmd::CclCommandAddrType::SEMAPHORE_ID: + addr_info.address = get_semaphore(command_arg_header.inline_value2); + break; + case ttnn::ccl::cmd::CclCommandAddrType::NONE: break; + default: ASSERT(false); break; + }; + } break; + + case CclCommandArgCode::SET_CORE_DESCRIPTOR_INFO: { + cmd_ctx.core_desc_type = + static_cast(command_arg_header.inline_value0); + switch (cmd_ctx.core_desc_type) { + case ttnn::ccl::cmd::CclCommandCoreDescriptorType::ADDRGEN: + case ttnn::ccl::cmd::CclCommandCoreDescriptorType::LOCAL: break; + case ttnn::ccl::cmd::CclCommandCoreDescriptorType::NOC_XY: + cmd_ctx.core_desc_info.core_desc_args.noc_unicast = + ttnn::ccl::cmd::CclCommandCoreDescriptorTypeNocXY{ + command_arg_header.inline_value1, command_arg_header.inline_value2}; + break; + case ttnn::ccl::cmd::CclCommandCoreDescriptorType::RECTANGLE: + cmd_ctx.core_desc_info.core_desc_args.noc_multicast = + ttnn::ccl::cmd::CclCommandCoreDescriptorTypeMcast::from_uint32( + get_arg_val(arg_idx++)); + break; + default: ASSERT(false); + }; + + } break; + + default: { + ASSERT(false); + } + }; + } +} + +template +FORCE_INLINE void try_advance_inline_write_or_atomic_inc(command_context_t& cmd_ctx) { + const size_t value = cmd_ctx.cmd_specific_ctx.inline_value_ctx.value; + const size_t dest_bank_addr = cmd_ctx.dest_addr_info.address; + bool is_remote_atomic_inc_over_fabric = cmd_ctx.command_requires_fabric(); + + // noc mcast atomic inc not supported yet + ASSERT( + cmd_ctx.core_desc_type == ttnn::ccl::cmd::CclCommandCoreDescriptorType::NOC_XY || + cmd_ctx.core_desc_type == ttnn::ccl::cmd::CclCommandCoreDescriptorType::LOCAL); + const uint8_t dest_noc0_x = cmd_ctx.core_desc_type == ttnn::ccl::cmd::CclCommandCoreDescriptorType::LOCAL + ? my_x[0] + : cmd_ctx.core_desc_info.core_desc_args.noc_unicast.x; + const uint8_t dest_noc0_y = cmd_ctx.core_desc_type == ttnn::ccl::cmd::CclCommandCoreDescriptorType::LOCAL + ? my_y[0] + : cmd_ctx.core_desc_info.core_desc_args.noc_unicast.y; + + if (is_remote_atomic_inc_over_fabric) { + ASSERT(cmd_ctx.core_desc_type == ttnn::ccl::cmd::CclCommandCoreDescriptorType::NOC_XY); + // For now, we won't skip if we are waiting for space from fabric + // since we assume the other command stream will need to wait anyways + bool can_write_into_fabric = true; + if (!can_write_into_fabric) { + return; + } + + ASSERT(cmd_ctx.packet_header_buffer_addr != 0); + auto* pkt_hdr = reinterpret_cast(cmd_ctx.packet_header_buffer_addr); + if (cmd_ctx.current_cmd_header.code == ttnn::ccl::cmd::CclCommandCode::ATOMIC_INC) { + pkt_hdr->to_atomic_inc(); + } else { + pkt_hdr->to_write(); + } + #ifdef DEBUG_PRINT_ENABLED + pkt_hdr->reserved2 = my_chip_id; + #endif + pkt_hdr->to_noc_unicast_atomic_inc(tt::fabric::NocUnicastAtomicIncCommandHeader{ + dest_bank_addr, + static_cast(value), + 32, + static_cast(dest_noc0_x), + static_cast(dest_noc0_y)}); + + switch (cmd_ctx.current_cmd_header.dest_type) { + case ttnn::ccl::cmd::CclCommandDestType::CHIP_UNICAST: { + pkt_hdr->to_chip_unicast(tt::fabric::UnicastRoutingCommandHeader{ + cmd_ctx.current_cmd_header.get_unicast_dest_args().distance_in_hops}); + + auto& fabric_connection = cmd_ctx.current_cmd_header.get_unicast_dest_args().is_forward_direction + ? cmd_ctx.fabric_connection.get_forward_connection() + : cmd_ctx.fabric_connection.get_backward_connection(); + fabric_connection.wait_for_empty_write_slot(); + fabric_connection.send_payload_flush_blocking_from_address( + cmd_ctx.packet_header_buffer_addr, sizeof(tt::fabric::PacketHeader)); + } break; + case ttnn::ccl::cmd::CclCommandDestType::CHIP_MULTICAST: { + const auto& mcast_args = cmd_ctx.current_cmd_header.get_multicast_dest_args(); + if (cmd_ctx.fabric_connection.has_forward_connection()) { + cmd_ctx.fabric_connection.get_forward_connection().wait_for_empty_write_slot(); + pkt_hdr->to_chip_multicast(tt::fabric::MulticastRoutingCommandHeader{ + 1, static_cast(mcast_args.num_targets_forward_direction)}); + cmd_ctx.fabric_connection.get_forward_connection().send_payload_flush_blocking_from_address( + cmd_ctx.packet_header_buffer_addr, sizeof(tt::fabric::PacketHeader)); + } + + // Write the mcast packet (backward) + if (cmd_ctx.fabric_connection.has_backward_connection()) { + pkt_hdr->to_chip_multicast(tt::fabric::MulticastRoutingCommandHeader{ + 1, static_cast(mcast_args.num_targets_backward_direction)}); + cmd_ctx.fabric_connection.get_backward_connection().wait_for_empty_write_slot(); + cmd_ctx.fabric_connection.get_backward_connection().send_payload_non_blocking_from_address( + cmd_ctx.packet_header_buffer_addr, sizeof(tt::fabric::PacketHeader)); + } + + uint64_t dest_noc_addr = safe_get_noc_addr(dest_noc0_x, dest_noc0_y, dest_bank_addr); + if (cmd_ctx.current_cmd_header.code == ttnn::ccl::cmd::CclCommandCode::ATOMIC_INC) { + noc_semaphore_inc(dest_noc_addr, value); + } else if (cmd_ctx.current_cmd_header.code == ttnn::ccl::cmd::CclCommandCode::RAW_INLINE_WRITE_BYTES) { + noc_inline_dw_write(dest_noc_addr, value); + } + + } break; + + default: { + ASSERT(false); + } break; + }; + + } else { + const uint64_t dest_noc_addr = get_noc_addr(dest_noc0_x, dest_noc0_y, dest_bank_addr); + if (cmd_ctx.current_cmd_header.code == ttnn::ccl::cmd::CclCommandCode::ATOMIC_INC) { + noc_semaphore_inc(dest_noc_addr, value); + } else if (cmd_ctx.current_cmd_header.code == ttnn::ccl::cmd::CclCommandCode::RAW_INLINE_WRITE_BYTES) { + noc_inline_dw_write(dest_noc_addr, value); + } + } +} + +#ifndef NO_TENSOR_MODE +template +FORCE_INLINE void try_advance_read_tensor_to_cb(command_context_t& cmd_ctx) { + if (!cb_pages_reservable_at_back(cmd_ctx.cb_id, cmd_ctx.packet_size_in_pages)) { + return; + } + + DPRINT << "tensor -> CB: " << (uint32_t)cmd_ctx.cb_id << "\n"; + + wrapped_worker_slice_read_context& cmd_specific_ctx = cmd_ctx.cmd_specific_ctx.wrapped_worker_slice_read_ctx; + const uint16_t max_pages_readable = std::min( + cmd_ctx.packet_size_in_pages, + cmd_ctx.command_tensor.worker_pages_per_slice - cmd_specific_ctx.offset_into_worker_slice); + + uint16_t contig_pages_advanced = 1; + cb_reserve_back(cmd_ctx.cb_id, cmd_ctx.packet_size_in_pages); + const uint32_t l1_write_addr_base = get_write_ptr(cmd_ctx.cb_id); + uint32_t l1_write_addr = l1_write_addr_base; + + for (uint16_t i = 0; i < max_pages_readable; i += contig_pages_advanced) { + DPRINT << "t_id: " << (uint32_t)cmd_specific_ctx.curr_tile_id << "\n"; + auto const [noc_addr, contig_pages_] = get_noc_addr_and_contiguous_pages( + cmd_specific_ctx.curr_tile_id, + cmd_specific_ctx.offset_into_worker_slice, + cmd_ctx.command_tensor.worker_start_offset_in_slice, + cmd_ctx.tensor_addrgen, + cmd_ctx.command_tensor.tensor_slice_shape); + + { + contig_pages_advanced = std::min(max_pages_readable, contig_pages_); + contig_pages_advanced = std::min(cmd_ctx.packet_size_in_pages - i, contig_pages_); + ASSERT(contig_pages_advanced > 0); + ASSERT(contig_pages_advanced <= cmd_ctx.packet_size_in_pages); + noc_async_read(noc_addr, l1_write_addr, cmd_ctx.page_size * contig_pages_advanced); + } + l1_write_addr += cmd_ctx.page_size * contig_pages_advanced; + + + bool done_worker_slice = ttnn::ccl::v2::advance_worker_global_page( + cmd_specific_ctx.curr_tile_id, // Updated internally + cmd_specific_ctx.offset_into_worker_slice, + cmd_ctx.command_tensor.worker_start_offset_in_slice, + cmd_ctx.command_tensor.worker_pages_per_slice, + cmd_ctx.command_tensor.tensor_slice_shape, + cmd_ctx.command_tensor.tensor_slice_offset, + cmd_ctx.command_tensor.tensor_shape, + contig_pages_advanced); + } + + noc_async_read_barrier(); + + cb_push_back(cmd_ctx.cb_id, cmd_ctx.packet_size_in_pages); +} +#endif + +template +FORCE_INLINE void write_and_advance_local_read_address_for_fabric_write( + uint64_t noc0_dest_noc_addr, + command_context_t& cmd_ctx, + wrapped_worker_slice_read_context& cmd_specific_ctx, + size_t& l1_read_addr, + uint16_t contig_pages_advanced) { + // All fabric writes have noc0 coordinates specified in the header. Therefore, we need to regenerate the noc + // address noc_index coordinates + const size_t payload_size_bytes = static_cast(contig_pages_advanced) * cmd_ctx.page_size; + const auto [dest_noc_xy, dest_addr] = get_noc_address_components(noc0_dest_noc_addr); + const size_t payload_l1_address = l1_read_addr; + auto pkt_hdr = reinterpret_cast(cmd_ctx.packet_header_buffer_addr); + #ifdef DEBUG_PRINT_ENABLED + pkt_hdr->reserved2 = my_chip_id; + #endif + size_t packet_send_size_bytes = payload_size_bytes + sizeof(tt::fabric::PacketHeader); + pkt_hdr->to_write()->to_noc_unicast(tt::fabric::NocUnicastCommandHeader{ + dest_addr, packet_send_size_bytes, static_cast(dest_noc_xy.x), static_cast(dest_noc_xy.y)}); + switch (cmd_ctx.current_cmd_header.dest_type) { + case ttnn::ccl::cmd::CclCommandDestType::CHIP_UNICAST: { + const auto& unicast_args = cmd_ctx.current_cmd_header.get_unicast_dest_args(); + auto& fabric_connection = unicast_args.is_forward_direction + ? cmd_ctx.fabric_connection.get_forward_connection() + : cmd_ctx.fabric_connection.get_backward_connection(); + + pkt_hdr->to_chip_unicast(tt::fabric::UnicastRoutingCommandHeader{unicast_args.distance_in_hops}); + fabric_connection.wait_for_empty_write_slot(); + fabric_connection.send_payload_without_header_non_blocking_from_address(l1_read_addr, payload_size_bytes); + fabric_connection.send_payload_flush_blocking_from_address( + (uint32_t)pkt_hdr, sizeof(tt::fabric::PacketHeader)); + } break; + case ttnn::ccl::cmd::CclCommandDestType::CHIP_MULTICAST: { + noc_async_write(payload_l1_address, safe_get_noc_addr(dest_noc_xy.x, dest_noc_xy.y, dest_addr), payload_size_bytes); + const auto& mcast_args = cmd_ctx.current_cmd_header.get_multicast_dest_args(); + if (cmd_ctx.fabric_connection.has_forward_connection()) { + pkt_hdr->to_chip_multicast(tt::fabric::MulticastRoutingCommandHeader{ + 1, static_cast(mcast_args.num_targets_forward_direction)}); + cmd_ctx.fabric_connection.get_forward_connection().wait_for_empty_write_slot(); + cmd_ctx.fabric_connection.get_forward_connection().send_payload_without_header_non_blocking_from_address(l1_read_addr, payload_size_bytes); + cmd_ctx.fabric_connection.get_forward_connection().send_payload_flush_blocking_from_address( + (uint32_t)pkt_hdr, sizeof(tt::fabric::PacketHeader)); + } + + // Write the mcast packet (backward) + if (cmd_ctx.fabric_connection.has_backward_connection()) { + pkt_hdr->to_chip_multicast(tt::fabric::MulticastRoutingCommandHeader{ + 1, static_cast(mcast_args.num_targets_backward_direction)}); + cmd_ctx.fabric_connection.get_backward_connection().wait_for_empty_write_slot(); + cmd_ctx.fabric_connection.get_backward_connection().send_payload_without_header_non_blocking_from_address(l1_read_addr, payload_size_bytes); + cmd_ctx.fabric_connection.get_backward_connection().send_payload_flush_blocking_from_address( + (uint32_t)pkt_hdr, sizeof(tt::fabric::PacketHeader)); + } + } break; + default: { + ASSERT(false); + } break; + } + + // Don't advance (payload + header) because we want to make sure we keep sizeof(tt::fabric::PacketHeader) space + // that's safe to use, preceeding the next hypothetical packet in L1. + l1_read_addr += payload_size_bytes; +} + +template +FORCE_INLINE void write_payload_then_advance_read_address( + uint64_t noc0_dest_noc_addr, + command_context_t& cmd_ctx, + wrapped_worker_slice_read_context& cmd_specific_ctx, + size_t& l1_read_addr, + uint16_t contig_pages_advanced) { + static_assert( + ((sizeof(tt::fabric::PacketHeader) - 1) & sizeof(tt::fabric::PacketHeader)) == 0, + "sizeof(sizeof(tt::fabric::PacketHeader)) is not a power of two which violates the below assertion"); + switch (cmd_ctx.current_cmd_header.dest_type) { + case ttnn::ccl::cmd::CclCommandDestType::CHIP_UNICAST: [[fallthrough]]; + case ttnn::ccl::cmd::CclCommandDestType::CHIP_MULTICAST: + write_and_advance_local_read_address_for_fabric_write( + noc0_dest_noc_addr, cmd_ctx, cmd_specific_ctx, l1_read_addr, contig_pages_advanced); + break; + case ttnn::ccl::cmd::CclCommandDestType::CHIP_LOCAL_ONLY: { + auto const [dest_noc_xy, dest_addr] = get_noc_address_components(noc0_dest_noc_addr); + // Conver to our local noc_index based address + noc_async_write(l1_read_addr, safe_get_noc_addr(dest_noc_xy.x, dest_noc_xy.y, dest_addr), cmd_ctx.page_size * contig_pages_advanced); + l1_read_addr += cmd_ctx.page_size * contig_pages_advanced; + } break; + } +} + +#ifndef NO_TENSOR_MODE +// Both this and try_advance_read_tensor_to_cb are very similar - particularly with the main +// packetization loop. We should look into refactoring this to reduce code size (and duplication) +// even if it results in a mild perf hit, it's probably worth while until generality is achieved. +// At the very least, we can have templatization of the main function that can specialize +// based on command type so we can avoid the perf overhead of the branching that would otherwise +// be required. +template +FORCE_INLINE void try_advance_write_tensor_from_cb(command_context_t& cmd_ctx) { + if (!cb_pages_available_at_front(cmd_ctx.cb_id, cmd_ctx.packet_size_in_pages)) { + return; + } + DPRINT << "CB -> tensor: " << (uint32_t)cmd_ctx.stream_id << "\n"; + + wrapped_worker_slice_read_context& cmd_specific_ctx = cmd_ctx.cmd_specific_ctx.wrapped_worker_slice_read_ctx; + const uint16_t max_pages_writable = std::min( + cmd_ctx.packet_size_in_pages, + cmd_ctx.command_tensor.worker_pages_per_slice - cmd_specific_ctx.offset_into_worker_slice); + ASSERT(cmd_ctx.command_tensor.worker_pages_per_slice >= cmd_specific_ctx.offset_into_worker_slice); + + cb_wait_front(cmd_ctx.cb_id, cmd_ctx.packet_size_in_pages); + size_t l1_read_addr = get_read_ptr(cmd_ctx.cb_id); + + uint16_t contig_pages_advanced = 1; + for (uint16_t i = 0; i < max_pages_writable; i += contig_pages_advanced) { + // This needs to be cleaned up a little bit. + // There's a basic usability issue here in that when/if the write is sent over the fabric, + // then the fabric expects noc x/y coordinates to be provided as noc0 coordinates. + // However, if we're writing locally, then we need to actually write using `noc_index` based coordinates. + // This can lead to a discrepancy, so to stay consistent, we always generate noc0 based addresses here + // so we can reliably translate to `noc_index` based addresses writing locally, inside the write function + DPRINT << "t_id: " << (uint32_t)cmd_specific_ctx.curr_tile_id << "\n"; + auto const [noc0_dest_noc_addr, contig_pages_] = + get_noc_addr_and_contiguous_pages_for_fabric_write( + cmd_specific_ctx.curr_tile_id, + cmd_specific_ctx.offset_into_worker_slice, + cmd_ctx.command_tensor.worker_start_offset_in_slice, + cmd_ctx.tensor_addrgen, + cmd_ctx.command_tensor.tensor_slice_shape); + contig_pages_advanced = std::min(contig_pages_, max_pages_writable); + contig_pages_advanced = std::min(cmd_ctx.packet_size_in_pages - i, contig_pages_); + + write_payload_then_advance_read_address( + noc0_dest_noc_addr, cmd_ctx, cmd_specific_ctx, l1_read_addr, contig_pages_advanced); + + auto done_worker_slice = ttnn::ccl::v2::advance_worker_global_page( + cmd_specific_ctx.curr_tile_id, // Updated internally + cmd_specific_ctx.offset_into_worker_slice, + cmd_ctx.command_tensor.worker_start_offset_in_slice, + cmd_ctx.command_tensor.worker_pages_per_slice, + cmd_ctx.command_tensor.tensor_slice_shape, + cmd_ctx.command_tensor.tensor_slice_offset, + cmd_ctx.command_tensor.tensor_shape, + contig_pages_advanced); + } + noc_async_writes_flushed(); + + cb_pop_front(cmd_ctx.cb_id, cmd_ctx.packet_size_in_pages); +} +#endif + +template +FORCE_INLINE void try_advance(command_context_t& cmd_ctx) { + switch (cmd_ctx.current_cmd_header.code) { + case ttnn::ccl::cmd::CclCommandCode::STREAM_TENSOR_TO_EDM: // STREAM TENSOR TO CB +#ifndef NO_TENSOR_MODE + try_advance_read_tensor_to_cb(cmd_ctx); +#endif + break; + case ttnn::ccl::cmd::CclCommandCode::STREAM_CB_TO_TENSOR: +#ifndef NO_TENSOR_MODE + try_advance_write_tensor_from_cb(cmd_ctx); +#endif + break; + + case ttnn::ccl::cmd::CclCommandCode::ATOMIC_INC: [[fallthrough]]; + case ttnn::ccl::cmd::CclCommandCode::RAW_INLINE_WRITE_BYTES: + try_advance_inline_write_or_atomic_inc(cmd_ctx); + break; + case ttnn::ccl::cmd::CclCommandCode::WAIT_VALUE: + // Nothing to actively do to advance - just needs to wait for completion + break; + default: ASSERT(false); break; + }; + + // Advance to next command index + switch (cmd_ctx.current_cmd_header.code) { + case ttnn::ccl::cmd::CclCommandCode::STREAM_TENSOR_TO_EDM: // STREAM TENSOR TO CB + case ttnn::ccl::cmd::CclCommandCode::STREAM_CB_TO_TENSOR: + if (cmd_ctx.cmd_specific_ctx.wrapped_worker_slice_read_ctx.offset_into_worker_slice >= + cmd_ctx.command_tensor.worker_pages_per_slice) { + DPRINT << "t_stream cmd cmpl\n"; + cmd_ctx.complete_current_command(); + } + break; + + case ttnn::ccl::cmd::CclCommandCode::ATOMIC_INC: [[fallthrough]]; + case ttnn::ccl::cmd::CclCommandCode::RAW_INLINE_WRITE_BYTES: + DPRINT << "at_inc cmd cmpl\n"; + cmd_ctx.complete_current_command(); + break; + case ttnn::ccl::cmd::CclCommandCode::WAIT_VALUE: + // Technically we are implementating semaphore wait as WAIT_MIN. FUTURE work to make separate commands + if (*reinterpret_cast(cmd_ctx.src_addr_info.address) >= + cmd_ctx.cmd_specific_ctx.inline_value_ctx.value) { + DPRINT << "Completing waitval command\n"; + cmd_ctx.complete_current_command(); + } + break; + default: ASSERT(false); break; + }; +} + +/* + * CCL Send will present various operating modes. Although there is only a single send kernel, it may (compile time) + * dispatch implementations depending on those invocation parameters. + */ +void kernel_main() { + /////////////////////////////////////////////////// + // ARGS + /////////////////////////////////////////////////// + + size_t arg_idx = 0; +#ifndef NO_TENSOR_MODE + // Load the input tensor spec + address_t tensor_address0 = get_arg_val(arg_idx++); +#ifndef SINGLE_TENSOR + address_t tensor_address1 = get_arg_val(arg_idx++); +#endif +#endif + uint8_t num_commands0 = get_arg_val(arg_idx++); + arg_idx_t command0_start_offset = get_arg_val(arg_idx++); + +#ifndef SINGLE_INPUT_MODE + uint8_t num_commands1 = get_arg_val(arg_idx++); + arg_idx_t command1_start_offset = get_arg_val(arg_idx++); +#endif + + // Assuming whole page transmissions (which is the only mode we support at the moment) + // -> however, wanted to call it out here to make it clear that we need to pull this + // out when we start enabling other modes + const uint16_t packet_size_in_pages = get_arg_val(arg_idx++); + uint16_t tensor0_page_size = +#ifndef NO_TENSOR_MODE + get_arg_val(arg_idx++); +#else + 0; +#endif + uint16_t tensor1_page_size = +#if !defined(NO_TENSOR_MODE) and !defined(SINGLE_TENSOR) + get_arg_val(arg_idx++); +#else + 0; +#endif + + auto tensor0_addrgen = +#ifndef NO_TENSOR_MODE + build_source_address_generator( + arg_idx, tensor_address0, tensor0_page_size, in0_sharded_addrgen_fields, cb0_id); +#else + no_addrgen{}; +#endif + +#if !defined(SINGLE_INPUT_MODE) + auto tensor1_addrgen = +#if !defined(NO_TENSOR_MODE) && !defined(SINGLE_TENSOR) + build_source_address_generator( + arg_idx, tensor_address1, tensor1_page_size, in1_sharded_addrgen_fields, cb1_id); +#else + no_addrgen{}; +#endif +#endif + + // TODO: move to common + auto fabric_connection = FabricConnectionManager::build_from_args(arg_idx); + + cb_reserve_back(reserved_packet_header_cb_id, num_packet_headers_storable); + auto packet_header_buffer_addr0 = get_write_ptr(reserved_packet_header_cb_id); + auto packet_header_buffer_addr1 = + packet_header_buffer_addr0 + (num_packet_headers_storable >> 2) * sizeof(tt::fabric::PacketHeader); + + auto operand_0_cmd_ctx = command_context_t( + fabric_connection, + tensor0_addrgen, + num_commands0, + command0_start_offset, + cb0_id, + tensor0_page_size, + packet_size_in_pages, + packet_header_buffer_addr0, + 0); + + // enabling either of the writes will cause the issue + static_assert(sizeof(command_context_t) <= 120, "command_context_t is too big"); + uint8_t stream_done_mask = +#ifndef SINGLE_INPUT_MODE + (static_cast(num_commands1 == 0) << 1) | +#endif + static_cast(num_commands0 == 0); +#ifndef SINGLE_INPUT_MODE + const uint8_t finish_value = 0x3; + static_assert(sizeof(command_context_t) <= 120, "command_context_t is too big"); + auto operand_1_cmd_ctx = command_context_t( + fabric_connection, + tensor1_addrgen, + num_commands1, + command1_start_offset, + cb1_id, + tensor1_page_size, + packet_size_in_pages, + packet_header_buffer_addr1, + 1); +#else + const uint8_t finish_value = 0x1; +#endif + + if (fabric_connection.is_logically_connected()) { + fabric_connection.open(); + } + while (stream_done_mask != finish_value) { + if ((stream_done_mask & 0x1) == 0) { + if (!operand_0_cmd_ctx.current_command_active()) { + DPRINT << "get_cmd0\n"; + operand_0_cmd_ctx.fetch_next_command(); + }; + try_advance(operand_0_cmd_ctx); + } + stream_done_mask |= static_cast(operand_0_cmd_ctx.is_complete()); +#ifndef SINGLE_INPUT_MODE + if ((stream_done_mask & 0x2) == 0) { + if (!operand_1_cmd_ctx.current_command_active()) { + DPRINT << "get_cmd1\n"; + operand_1_cmd_ctx.fetch_next_command(); + } + try_advance(operand_1_cmd_ctx); + } + stream_done_mask |= (static_cast(operand_1_cmd_ctx.is_complete()) << 1); +#endif + } + + if (fabric_connection.is_logically_connected()) { + fabric_connection.close(); + } + + noc_async_write_barrier(); + DPRINT << "DONE \n"; +} diff --git a/ttnn/cpp/ttnn/operations/ccl/common/kernels/ccl_send_utils.hpp b/ttnn/cpp/ttnn/operations/ccl/common/kernels/ccl_send_utils.hpp new file mode 100644 index 00000000000..833ad9396f8 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/common/kernels/ccl_send_utils.hpp @@ -0,0 +1,330 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "dataflow_api.h" +#include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/sharded_tensor_addr_gen.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_transmission.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/edm_fabric_worker_adapters.hpp" + +#include "debug/dprint.h" +#include + +//------------------------------------------------------------------------------ +// Section 1: Generic Utility Functions +//------------------------------------------------------------------------------ + +template +std::pair get_noc_addr_and_contiguous_pages( + uint32_t curr_page_idx, + const uint32_t offset_into_worker_slice, + const ttnn::ccl::Shape4D& offset_worker_slice, + const AddrGen& address_generator, + const ttnn::ccl::Shape4D& tensor_slice_shape, + uint8_t noc_id = noc_index) { + if constexpr (TENSOR_LAYOUT == tt::tt_metal::TensorMemoryLayout::INTERLEAVED) { + static constexpr uint32_t offset = 0; + uint64_t dst_noc_addr = get_noc_addr(curr_page_idx, address_generator, offset, noc_id); + return {dst_noc_addr, 1}; + } else { + static_assert( + TENSOR_LAYOUT == tt::tt_metal::TensorMemoryLayout::WIDTH_SHARDED || + TENSOR_LAYOUT == tt::tt_metal::TensorMemoryLayout::HEIGHT_SHARDED || + TENSOR_LAYOUT == tt::tt_metal::TensorMemoryLayout::BLOCK_SHARDED); + if constexpr (MEM_LAYOUT == tt::tt_metal::Layout::ROW_MAJOR) { + ASSERT(false); // unimplemented + return {0, 0}; + } else { + static_assert(MEM_LAYOUT == tt::tt_metal::Layout::TILE); + // TODO: Make d.get_noc_addr work on host + device + auto const& [noc_yx, page_offset, contig_pages_] = + address_generator.get_page_location_with_contiguous_pages_in_row_in_bank(curr_page_idx); + /* + * Shared with `read_wrapped_chunk_from_output_tensor` + */ + uint32_t flattened_offset_worker_slice = + ttnn::ccl::v2::flattened_index(tensor_slice_shape, offset_worker_slice); + uint32_t contig_until_edge_of_tensor_slice = + tensor_slice_shape.x - + ((flattened_offset_worker_slice + offset_into_worker_slice) % tensor_slice_shape.x); + + size_t contig_pages = std::min(contig_pages_, contig_until_edge_of_tensor_slice); + uint64_t dst_noc_addr = get_noc_addr( + static_cast(noc_yx.noc_x), + noc_yx.noc_y, + address_generator.bank_base_address + (page_offset * address_generator.page_size) + 0, + noc_id); + return {dst_noc_addr, contig_pages}; + } + } +} + +template +FORCE_INLINE std::pair get_noc_addr_and_contiguous_pages_for_fabric_write( + uint32_t curr_page_idx, + const uint32_t offset_into_worker_slice, + const ttnn::ccl::Shape4D& offset_worker_slice, + const AddrGen& address_generator, + const ttnn::ccl::Shape4D& tensor_slice_shape) { + return get_noc_addr_and_contiguous_pages( + curr_page_idx, offset_into_worker_slice, offset_worker_slice, address_generator, tensor_slice_shape, 0); +} + +std::pair get_noc_address_components(uint64_t noc_addr) { + const size_t bank_addr = noc_addr & 0xFFFFFFFF; + const size_t noc_x = (noc_addr >> NOC_ADDR_LOCAL_BITS) & ((1 << NOC_ADDR_NODE_ID_BITS) - 1); + const size_t noc_y = + (noc_addr >> (NOC_ADDR_LOCAL_BITS + NOC_ADDR_NODE_ID_BITS)) & ((1 << NOC_ADDR_NODE_ID_BITS) - 1); + return {WorkerXY(noc_x, noc_y), bank_addr}; +} + +//------------------------------------------------------------------------------ +// Section 2: Multicast Write with Fabric Functions +//------------------------------------------------------------------------------ + +void mcast_contig_pages_to_noc_address( + uint64_t noc_addr, + size_t l1_read_addr, + size_t contig_pages_advanced, + size_t payload_page_size, + bool has_forward_fabric_connection, + bool has_backward_fabric_connection, + tt::fabric::WorkerToFabricEdmSender& forward_fabric_sender, + tt::fabric::WorkerToFabricEdmSender& backward_fabric_sender, + size_t forward_direction_num_hops, + size_t backward_direction_num_hops) { + const size_t payload_size_bytes = contig_pages_advanced * payload_page_size; + const auto [dest_noc_xy, dest_addr] = get_noc_address_components(noc_addr); + const size_t payload_l1_address = l1_read_addr + sizeof(tt::fabric::PacketHeader); + + // Local chip write + noc_async_write( + payload_l1_address, get_noc_addr(dest_noc_xy.x, dest_noc_xy.y, dest_addr, noc_index), payload_size_bytes); + size_t packet_send_size_bytes = payload_size_bytes + sizeof(tt::fabric::PacketHeader); + + // Forward fabric connection + if (has_forward_fabric_connection) { + static_assert( + ((sizeof(tt::fabric::PacketHeader) - 1) & sizeof(tt::fabric::PacketHeader)) == 0, + "sizeof(sizeof(tt::fabric::PacketHeader)) is not a power of two which violates the below assertion"); + + auto& pkt_hdr = *reinterpret_cast(l1_read_addr); + pkt_hdr.to_write() + .to_chip_multicast( + tt::fabric::MulticastRoutingCommandHeader{1, static_cast(forward_direction_num_hops)}) + .to_noc_unicast(tt::fabric::NocUnicastCommandHeader{ + dest_addr, + packet_send_size_bytes, + static_cast(dest_noc_xy.x), + static_cast(dest_noc_xy.y)}); + forward_fabric_sender.wait_for_empty_write_slot(); + forward_fabric_sender.send_payload_flush_blocking_from_address(l1_read_addr, packet_send_size_bytes); + } + + // Backward fabric connection + if (has_backward_fabric_connection) { + auto& pkt_hdr = *reinterpret_cast(l1_read_addr); + pkt_hdr.to_write() + .to_chip_multicast( + tt::fabric::MulticastRoutingCommandHeader{1, static_cast(backward_direction_num_hops)}) + .to_noc_unicast(tt::fabric::NocUnicastCommandHeader{ + dest_addr, + packet_send_size_bytes, + static_cast(dest_noc_xy.x), + static_cast(dest_noc_xy.y)}); + backward_fabric_sender.wait_for_empty_write_slot(); + backward_fabric_sender.send_payload_non_blocking_from_address(l1_read_addr, packet_send_size_bytes); + } +} + +template +void mcast_payload_chunk_to_output_tensor_address( + uint32_t& curr_page_idx, + uint32_t& offset_into_worker_slice, + const shape_t& worker_slice_offset, + const shape_t& worker_slice_shape, + const ttnn::ccl::cmd::CclCommandTensor& command_tensor, + size_t l1_read_addr, + size_t n_pages, + size_t payload_page_size, + size_t l1_scratch_page_size, + bool has_forward_fabric_connection, + bool has_backward_fabric_connection, + tt::fabric::WorkerToFabricEdmSender& forward_fabric_sender, + tt::fabric::WorkerToFabricEdmSender& backward_fabric_sender, + size_t forward_direction_num_hops, + size_t backward_direction_num_hops, + const AddrGen& tensor_addrgen) { + size_t contig_pages_advanced = 1; + + for (size_t i = 0; i < n_pages; i += contig_pages_advanced) { + auto const [noc_addr, contig_pages] = + get_noc_addr_and_contiguous_pages_for_fabric_write( + curr_page_idx, + offset_into_worker_slice, + worker_slice_offset, + tensor_addrgen, + command_tensor.tensor_slice_shape); + + contig_pages_advanced = std::min(contig_pages, n_pages); + + mcast_contig_pages_to_noc_address( + noc_addr, + l1_read_addr, + contig_pages_advanced, + payload_page_size, + has_forward_fabric_connection, + has_backward_fabric_connection, + forward_fabric_sender, + backward_fabric_sender, + forward_direction_num_hops, + backward_direction_num_hops); + + bool last_page_of_worker = ttnn::ccl::v2::advance_worker_global_page( + curr_page_idx, + offset_into_worker_slice, + worker_slice_offset, + worker_slice_shape.volume(), + command_tensor.tensor_slice_shape, + command_tensor.tensor_shape, + contig_pages_advanced); + + noc_async_write_barrier(); + l1_read_addr += contig_pages_advanced * l1_scratch_page_size; + } +} + +//------------------------------------------------------------------------------ +// Section 3: Local Read into L1 Scratchpad Functions +//------------------------------------------------------------------------------ + +template +FORCE_INLINE void read_wrapped_chunk_from_output_tensor_to_address( + uint32_t& curr_page_idx, + uint32_t& offset_into_worker_slice, + const ttnn::ccl::coord_t& offset_worker_slice, + const ttnn::ccl::coord_t& worker_slice_shape, + + // In tiles for tile layout + const ttnn::ccl::coord_t& tensor_shape, + const ttnn::ccl::coord_t& tensor_slice_shape, + const uint32_t local_l1_scratch_buffer_address, + const AddrGen& s, + const uint32_t num_pages, + const uint32_t page_size, + bool& last_page_of_worker) { + // we expected caller to reset this and the last curr_page_idx when we set it true + uint32_t local_l1_read_addr = local_l1_scratch_buffer_address; + + int32_t contig_pages = 1; + for (uint32_t i = 0; i < num_pages; i += contig_pages) { + contig_pages = 1; +#ifdef ROW_MAJOR_LAYOUT +#ifdef INTERLEAVED_MEM_LAYOUT + uint64_t src_noc_addr = get_noc_addr(curr_page_idx, s); + noc_async_read(src_noc_addr, local_l1_read_addr, page_size); +#elif defined SHARDED_MEM_LAYOUT + ASSERT(false); // unimplemented +#endif +#elif defined TILED_LAYOUT +#ifdef INTERLEAVED_MEM_LAYOUT + noc_async_read_tile(curr_page_idx, s, local_l1_read_addr); + // common with `write_chunk_v2` +#elif defined SHARDED_MEM_LAYOUT + // TODO: Make d.get_noc_addr work on host + device + auto const& [noc_yx, page_offset, contig_pages_] = + s.get_page_location_with_contiguous_pages_in_row_in_bank(curr_page_idx); + /* + * num_pages - i: check if we are outside the number of pages remaining + * contig_pages_: check if we are outside the max number of contig pages we can read in a row in a bank + * contig_edge_of_tensor_slice: check if we are outside the edge of the tensor slice (in which case, we wrap + * around if we aren't at the end) + */ + uint32_t flattened_offset_worker_slice = offset_worker_slice.x + (offset_worker_slice.y * tensor_slice_shape.x); + uint32_t contig_edge_of_tensor_slice = + tensor_slice_shape.x - ((flattened_offset_worker_slice + offset_into_worker_slice) % tensor_slice_shape.x); + + contig_pages = std::min(num_pages - i, std::min(contig_pages_, contig_edge_of_tensor_slice)); + uint64_t src_noc_addr = get_noc_addr( + static_cast(noc_yx.noc_x), noc_yx.noc_y, s.bank_base_address + (page_offset * s.page_size) + 0); + noc_async_read(src_noc_addr, local_l1_read_addr, page_size * contig_pages); +#endif + + // Update the curr_page_idx based on how the worker chunks + tensor slice is laid out in global tensor + advance_worker_global_page_interleaved( + curr_page_idx, // Updated internally + offset_into_worker_slice, + offset_worker_slice, + worker_slice_shape, + tensor_slice_shape, + tensor_shape, + contig_pages, + last_page_of_worker); + +#endif + local_l1_read_addr += page_size * contig_pages; + } + noc_async_read_barrier(); +} + +//------------------------------------------------------------------------------ +// Section 4: Sync Signal Write Functions +//------------------------------------------------------------------------------ + +void mcast_sync_signal_to_addr( + size_t some_buffering_addr, + size_t& sync_details_arg_idx, + bool has_forward_fabric_connection, + bool has_backward_fabric_connection, + tt::fabric::WorkerToFabricEdmSender& forward_fabric_sender, + tt::fabric::WorkerToFabricEdmSender& backward_fabric_sender, + size_t forward_direction_num_hops, + size_t backward_direction_num_hops, + size_t num_sync_signals) { + auto send_sync_signal = [](size_t pkt_addr, + tt::fabric::WorkerToFabricEdmSender& fabric_connection, + size_t remote_sem_noc_x, + size_t remote_sem_noc_y, + size_t remote_sem_l1_addr, + size_t directional_num_hops) { + static_assert( + ((sizeof(tt::fabric::PacketHeader) - 1) & sizeof(tt::fabric::PacketHeader)) == 0, + "sizeof(sizeof(tt::fabric::PacketHeader)) is not a power of two which violates the below assertion"); + ASSERT((pkt_addr & (sizeof(tt::fabric::PacketHeader) - 1)) == 0); + + auto& pkt_hdr = *reinterpret_cast(pkt_addr); + pkt_hdr.to_atomic_inc() + .to_chip_multicast(tt::fabric::MulticastRoutingCommandHeader{1, static_cast(directional_num_hops)}) + .to_noc_unicast_atomic_inc(tt::fabric::NocUnicastAtomicIncCommandHeader{ + remote_sem_l1_addr, + 1, + 32, + static_cast(remote_sem_noc_x), + static_cast(remote_sem_noc_y)}); + fabric_connection.wait_for_empty_write_slot(); + fabric_connection.send_payload_flush_blocking_from_address( + pkt_addr, pkt_hdr.get_payload_size_including_header()); + }; + + for (size_t i = 0; i < num_sync_signals; ++i) { + auto dest_sem_addr = + get_arg_val(sync_details_arg_idx++); // hack, we pass in the address instead of the semaphore id + auto dest_noc_x = get_arg_val(sync_details_arg_idx++); + auto dest_noc_y = get_arg_val(sync_details_arg_idx++); + + if (has_forward_fabric_connection) { + const size_t pkt_addr = some_buffering_addr; + send_sync_signal( + pkt_addr, forward_fabric_sender, dest_noc_x, dest_noc_y, dest_sem_addr, forward_direction_num_hops); + } + if (has_backward_fabric_connection) { + const size_t pkt_addr = some_buffering_addr; + send_sync_signal( + pkt_addr, backward_fabric_sender, dest_noc_x, dest_noc_y, dest_sem_addr, backward_direction_num_hops); + } + + auto sem_inc_noc_addr = get_noc_addr(dest_noc_x, dest_noc_y, dest_sem_addr); + noc_semaphore_inc(sem_inc_noc_addr, 1); + } +} diff --git a/ttnn/cpp/ttnn/operations/ccl/common/kernels/ccl_send_writer.cpp b/ttnn/cpp/ttnn/operations/ccl/common/kernels/ccl_send_writer.cpp new file mode 100644 index 00000000000..8993147ac35 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/common/kernels/ccl_send_writer.cpp @@ -0,0 +1,274 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "dataflow_api.h" +#include "impl/buffers/buffer_constants.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/common/types/ccl_types.hpp" +#include "tt_metal/impl/buffers/buffer_constants.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/sharded_tensor_addr_gen.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/common/uops/ccl_command_device.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/common/types/ccl_types_device.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_adapters.hpp" +#include "ttnn/cpp/ttnn/tensor/enum_types.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/common/kernels/command_processor.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/common/kernels/ccl_send_utils.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_transmission.hpp" + +#include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/edm_fabric_worker_adapters.hpp" + +#include "debug/dprint.h" +#include + +/////////////////////////////////////////////////// +// COMPILE TIME ARGS +/////////////////////////////////////////////////// + +constexpr TensorMemoryLayout tensor_layout = static_cast(get_compile_time_arg_val(0)); +constexpr BufferType buffer_type = static_cast(get_compile_time_arg_val(1)); +constexpr Layout page_layout = static_cast(get_compile_time_arg_val(2)); +constexpr uint32_t cb_id = get_compile_time_arg_val(3); + +#ifdef SHARDED_MEM_LAYOUT +static constexpr bool is_sharded_mode = true; +static constexpr uint32_t input_tensor_shard_grid_height = get_compile_time_arg_val(4); +static constexpr uint32_t input_tensor_shard_grid_width = get_compile_time_arg_val(5); +static constexpr uint32_t input_tensor_shard_grid_start_y_logical = get_compile_time_arg_val(6); +static constexpr uint32_t input_tensor_shard_grid_start_x_logical = get_compile_time_arg_val(7); +static constexpr uint32_t input_tensor_shard_pages_per_shard_y = get_compile_time_arg_val(8); +static constexpr uint32_t input_tensor_shard_pages_per_shard_x = get_compile_time_arg_val(9); +static constexpr bool input_tensor_shard_grid_transposed = get_compile_time_arg_val(10) != 0; +#else +static constexpr bool is_sharded_mode = false; +static constexpr uint32_t input_tensor_shard_grid_height = 0; +static constexpr uint32_t input_tensor_shard_grid_width = 0; +static constexpr uint32_t input_tensor_shard_grid_start_y_logical = 0; +static constexpr uint32_t input_tensor_shard_grid_start_x_logical = 0; +static constexpr uint32_t input_tensor_shard_pages_per_shard_y = 0; +static constexpr uint32_t input_tensor_shard_pages_per_shard_x = 0; +static constexpr bool input_tensor_shard_grid_transposed = false; +#endif + +template < + tt::tt_metal::TensorMemoryLayout tensor_layout, + tt::tt_metal::BufferType buffer_type, + tt::tt_metal::Layout page_layout> +auto build_source_address_generator( + std::size_t& arg_idx, address_t tensor_address, std::size_t page_size, uint32_t cb_id_in0) -> + typename source_tensor_addrgen::type { + constexpr bool is_sharded = is_sharded_tensor_layout(tensor_layout); + constexpr bool is_interleaved = tensor_layout == tt::tt_metal::TensorMemoryLayout::INTERLEAVED; + constexpr bool is_tile_page_layout = page_layout == tt::tt_metal::Layout::TILE; + constexpr bool is_row_major_layout = page_layout == tt::tt_metal::Layout::ROW_MAJOR; + static_assert( + is_sharded || is_interleaved, + "Only sharded and interleaved tensor layouts are supported but the unified address generator. A tensor layout " + "not matching TensorMemoryLayout::WIDTH_SHARDED, TensorMemoryLayout::HEIGHT_SHARDED, " + "TensorMemoryLayout::BLOCK_SHARDED, or TensorMemoryLayout::INTERLEAVED was specified."); + + using addrgen_type = typename source_tensor_addrgen::type; + + if constexpr (tensor_layout == tt::tt_metal::TensorMemoryLayout::INTERLEAVED) { + if constexpr (is_row_major_layout) { + return addrgen_type{.bank_base_address = tensor_address, .page_size = page_size}; + } else { + return addrgen_type{ + .bank_base_address = tensor_address, .page_size = page_size, .data_format = get_dataformat(cb_id_in0)}; + } + } else if constexpr ( + tensor_layout == tt::tt_metal::TensorMemoryLayout::BLOCK_SHARDED || + tensor_layout == tt::tt_metal::TensorMemoryLayout::HEIGHT_SHARDED || + tensor_layout == tt::tt_metal::TensorMemoryLayout::WIDTH_SHARDED) { + size_t input_shard_grid_nrows = get_arg_val(arg_idx++); + const auto* const input_shard_grid_row_map = reinterpret_cast(get_arg_addr(arg_idx)); + arg_idx += input_shard_grid_nrows; + size_t input_shard_grid_ncols = get_arg_val(arg_idx++); + const auto* const input_shard_grid_col_map = reinterpret_cast(get_arg_addr(arg_idx)); + arg_idx += input_shard_grid_ncols; + + return tt::tt_metal::address_generators::build_sharded_addr_gen( + tt::tt_metal::address_generators::HarvestedWormholeWorkerToNocLookup( + input_shard_grid_nrows, input_shard_grid_row_map, input_shard_grid_ncols, input_shard_grid_col_map), + typename tt::tt_metal::address_generators::DeviceShardSpecTypeGetter::type( + input_tensor_shard_pages_per_shard_y, + input_tensor_shard_pages_per_shard_x, + input_tensor_shard_grid_height, + input_tensor_shard_grid_width, + input_tensor_shard_grid_start_y_logical, + input_tensor_shard_grid_start_x_logical, + input_tensor_shard_grid_transposed), + page_size, + tensor_address); + } else { + ASSERT(false); + } +} + +/* + * CCL Send will present various operating modes. Although there is only a single send kernel, it may (compile time) + * dispatch implementations depending on those invocation parameters. + */ +void kernel_main() { + std::size_t arg_idx = 0; + + /////////////////////////////////////////////////// + // ARGS + /////////////////////////////////////////////////// + + // Load the input tensor spec + address_t dest_address = get_arg_val(arg_idx++); + address_t num_commands = get_arg_val(arg_idx++); + + // Assuming whole page transmissions (which is the only mode we support at the moment) + // -> however, wanted to call it out here to make it clear that we need to pull this + // out when we start enabling other modes + const size_t packet_size_in_pages = get_arg_val(arg_idx++); + const size_t payload_page_size = get_arg_val(arg_idx++); + const size_t l1_scratch_page_size = payload_page_size + sizeof(tt::fabric::PacketHeader); + const size_t forward_direction_num_hops = get_arg_val(arg_idx++); + const size_t backward_direction_num_hops = get_arg_val(arg_idx++); + const bool has_forward_fabric_connection = get_arg_val(arg_idx++) != 0; + auto forward_fabric_sender = + has_forward_fabric_connection + ? tt::fabric::WorkerToFabricEdmSender::build_from_args(arg_idx) + : tt::fabric::WorkerToFabricEdmSender(); + const bool has_backward_fabric_connection = get_arg_val(arg_idx++) != 0; + auto backward_fabric_sender = + has_backward_fabric_connection + ? tt::fabric::WorkerToFabricEdmSender::build_from_args(arg_idx) + : tt::fabric::WorkerToFabricEdmSender(); + + constexpr size_t num_args_per_sync_signal_sender = 3; + const bool must_send_sync_signals = get_arg_val(arg_idx++) != 0; + auto num_sync_signals = must_send_sync_signals ? get_arg_val(arg_idx++) : 0; + auto sync_details_arg_idx = arg_idx; + arg_idx += num_sync_signals * num_args_per_sync_signal_sender; + + auto tensor_addrgen = build_source_address_generator( + arg_idx, dest_address, payload_page_size, tt::CB::c_in0); + + if (has_forward_fabric_connection) { + DPRINT << "Opening forward fabric connection\n"; + forward_fabric_sender.open(); + DPRINT << "Forward fabric connection opened\n"; + } + if (has_backward_fabric_connection) { + DPRINT << "Opening backward fabric connection\n"; + backward_fabric_sender.open(); + DPRINT << "Backward fabric connection opened\n"; + } + + ttnn::ccl::cmd::CclCommandTensor command_tensor; + +#ifdef DEBUG_PRINT_ENABLED + DPRINT << "ccl_send_writer has " << (uint32_t)num_commands << " commands" << ENDL(); +#endif + + size_t some_buffering_addr = 0; + + for (std::size_t i = 0; i < num_commands; ++i) { + // Generalized would be to get the command header info and then dispatch accordingly - if the command type is + // singular + // + std::size_t old_arg_idx = arg_idx; + ttnn::ccl::cmd::update_command_tensor(arg_idx, command_tensor); + std::size_t new_arg_idx = arg_idx; + + { + // print_tensor_command(i, command_tensor); + ASSERT(command_tensor.worker_pages_per_slice > 0); + + // CURRENTLY ONLY SUPPORTS WRAPPED TENSOR ITERATION COMMANDS + // Implemented really inefficiently for now - in the future we can do more efficient packing and also change + // the tensor read API to require the information in a more efficient way (less intermediate calculations) + shape_t valid_worker_slice_shape = + build_wrapped_row_tensor_slice(command_tensor.worker_pages_per_slice); // Parametrizable by ct arg + + shape_t const& global_offset = + command_tensor.tensor_slice_offset + command_tensor.worker_start_offset_in_slice; + + uint32_t curr_page_idx = get_flat_index_from_shape(command_tensor.tensor_shape, global_offset); + + uint32_t offset_into_worker_slice = 0; + DPRINT << "Outside loop\n"; + DPRINT << "worker_pages_per_slice: " << command_tensor.worker_pages_per_slice << ENDL(); + DPRINT << "payload_page_size: " << (uint32_t)payload_page_size << ENDL(); + // DPRINT << "packet_size_in_pages: " << packet_size_in_pages << ENDL(); + for (uint32_t p = 0; p < command_tensor.worker_pages_per_slice; p += packet_size_in_pages) { + DPRINT << "Packet loop\n"; + uint32_t n_pages = std::min(packet_size_in_pages, command_tensor.worker_pages_per_slice - p); + + ASSERT(command_tensor.worker_start_offset_in_slice.w == 0); + ASSERT(command_tensor.worker_start_offset_in_slice.z == 0); + ASSERT(valid_worker_slice_shape.w == 1); + ASSERT(valid_worker_slice_shape.z == 1); + ASSERT(command_tensor.tensor_shape.w == 1); + ASSERT(command_tensor.tensor_shape.z == 1); + ASSERT(command_tensor.tensor_slice_shape.w == 1); + ASSERT(command_tensor.tensor_slice_shape.z == 1); + + DPRINT << "iter " << p << " curr_tile_id: " << curr_page_idx << ENDL(); + + DPRINT << "cb_wait_front\n"; + cb_wait_front(cb_id, n_pages); + DPRINT << "cb_wait_front done\n"; + size_t l1_read_addr = get_read_ptr(cb_id); + some_buffering_addr = l1_read_addr; + + mcast_payload_chunk_to_output_tensor_address( + curr_page_idx, + offset_into_worker_slice, + command_tensor.worker_start_offset_in_slice, // worker_slice_offset + valid_worker_slice_shape, // worker_slice_shape + command_tensor, + l1_read_addr, + n_pages, + payload_page_size, + l1_scratch_page_size, + has_forward_fabric_connection, + has_backward_fabric_connection, + forward_fabric_sender, + backward_fabric_sender, + forward_direction_num_hops, + backward_direction_num_hops, + tensor_addrgen); + + DPRINT << "cb_pop_front\n"; + cb_pop_front(cb_id, n_pages); + DPRINT << "cb_pop_front done\n"; + } + DPRINT << "Packet loop done\n"; + } + DPRINT << "Outside loop done\n"; + } + DPRINT << "ccl_send_writer done main loop - enterring teardown\n"; + + if (must_send_sync_signals) { + DPRINT << "ccl_send_writer Sending payload completion sync signals\n"; + ASSERT(some_buffering_addr != 0); + some_buffering_addr = + (some_buffering_addr + (sizeof(tt::fabric::PacketHeader))) & ~(sizeof(tt::fabric::PacketHeader) - 1); + + mcast_sync_signal_to_addr( + some_buffering_addr, + sync_details_arg_idx, + has_forward_fabric_connection, + has_backward_fabric_connection, + forward_fabric_sender, + backward_fabric_sender, + forward_direction_num_hops, + backward_direction_num_hops, + num_sync_signals); + } + + DPRINT << "ccl_send_writer closing connections\n"; + if (has_forward_fabric_connection) { + forward_fabric_sender.close(); + } + if (has_backward_fabric_connection) { + backward_fabric_sender.close(); + } + //////////////////////////////////////////////////////////////////////////////////// + DPRINT << "ccl_send_writer done\n"; +} diff --git a/ttnn/cpp/ttnn/operations/ccl/common/kernels/ccl_wait_completion.cpp b/ttnn/cpp/ttnn/operations/ccl/common/kernels/ccl_wait_completion.cpp new file mode 100644 index 00000000000..aabcec3eca8 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/common/kernels/ccl_wait_completion.cpp @@ -0,0 +1,67 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + + +#include "dataflow_api.h" + +#include +#include +#include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/edm_fabric_worker_adapters.hpp" + +#include "debug/dprint.h" + +void kernel_main() { + constexpr size_t num_signals_to_wait_for = get_compile_time_arg_val(0); + constexpr size_t send_termination_signals = get_compile_time_arg_val(1); + std::array sem_addrs; + std::array expected_sem_counts; + std::array current_sem_counts; + + + size_t arg_idx = 0; + for (size_t i = 0; i < num_signals_to_wait_for; ++i) { + sem_addrs[i] = reinterpret_cast(get_arg_val(arg_idx++)); // hack, we pass in the address instead of the semaphore id + DPRINT << "DRAIN WAITING ON SEMAPHORE ADDR " << (uint32_t)sem_addrs[i] << " on core (" << (uint32_t)my_y[0] << ", " << (uint32_t)my_x[0] << ")\n"; + expected_sem_counts[i] = get_arg_val(arg_idx++); + current_sem_counts[i] = 0; + } + + while (true) { + for (size_t i = 0; i < num_signals_to_wait_for; ++i) { + if (current_sem_counts[i] >= expected_sem_counts[i]) { + continue; + } + + if (current_sem_counts[i] != *sem_addrs[i]) { + DPRINT << "DRAIN GOT SEMINC @ " << (uint32_t)sem_addrs[i] << ". NOW= " << (uint32_t)*sem_addrs[i] << "\n"; + current_sem_counts[i] = *sem_addrs[i]; + } + } + + bool all_done = true; + for (size_t i = 0; i < num_signals_to_wait_for; ++i) { + if (current_sem_counts[i] < expected_sem_counts[i]) { + all_done = false; + break; + } + } + if (all_done) { + break; + } + } + + DPRINT << "DONE RECEIVING SEMINCS. SHUTTING DOWN FABRIC\n"; + + if (send_termination_signals) { + size_t num_termination_signals = get_arg_val(arg_idx++); + for (size_t i = 0; i < num_termination_signals; ++i) { + uint32_t termination_addr = get_arg_val(arg_idx++); + uint32_t noc_x = get_arg_val(arg_idx++); + uint32_t noc_y = get_arg_val(arg_idx++); + DPRINT << "SENDING TERMINATION SIGNAL TO " << (uint32_t)noc_x << " " << (uint32_t)noc_y << " " << (uint32_t)termination_addr << "\n"; + noc_semaphore_inc(get_noc_addr(noc_x, noc_y, termination_addr), 1); + } + } + DPRINT << "DRAIN DONE\n"; +} diff --git a/ttnn/cpp/ttnn/operations/ccl/common/kernels/command_processor.hpp b/ttnn/cpp/ttnn/operations/ccl/common/kernels/command_processor.hpp new file mode 100644 index 00000000000..63596906d20 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/common/kernels/command_processor.hpp @@ -0,0 +1,136 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once +#include "tt_metal/impl/buffers/buffer_constants.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/common/types/ccl_types.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/common/types/ccl_types_device.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/common/uops/ccl_command_device.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/common/types/ccl_types.hpp" + +#include "ttnn/cpp/ttnn/tensor/enum_types.hpp" + +#include "dataflow_api.h" // for interleaved addrgen +#include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/sharded_tensor_addr_gen.hpp" + +using shape_t = ttnn::ccl::Shape4D; +using ttnn::ccl::coord_t; +using address_t = uint32_t; + +using tt::tt_metal::BufferType; +using tt::tt_metal::Layout; +using tt::tt_metal::TensorMemoryLayout; +using ttnn::ccl::Shape4D; + +#ifdef DEBUG_PRINT_ENABLED +#include "debug/dprint.h" + +void dprint(ttnn::ccl::cmd::CclCommandTensor const& command_tensor) { + DPRINT << "\ttensor_shape.w: " << (uint32_t)command_tensor.tensor_shape.w << "\n"; + DPRINT << "\ttensor_shape.z: " << (uint32_t)command_tensor.tensor_shape.z << "\n"; + DPRINT << "\ttensor_shape.y: " << (uint32_t)command_tensor.tensor_shape.y << "\n"; + DPRINT << "\ttensor_shape.x: " << (uint32_t)command_tensor.tensor_shape.x << "\n"; + DPRINT << "\ttensor_slice_shape.w: " << (uint32_t)command_tensor.tensor_slice_shape.w << "\n"; + DPRINT << "\ttensor_slice_shape.z: " << (uint32_t)command_tensor.tensor_slice_shape.z << "\n"; + DPRINT << "\ttensor_slice_shape.y: " << (uint32_t)command_tensor.tensor_slice_shape.y << "\n"; + DPRINT << "\ttensor_slice_shape.x: " << (uint32_t)command_tensor.tensor_slice_shape.x << "\n"; + DPRINT << "\ttensor_slice_offset.w: " << (uint32_t)command_tensor.tensor_slice_offset.w << "\n"; + DPRINT << "\ttensor_slice_offset.z: " << (uint32_t)command_tensor.tensor_slice_offset.z << "\n"; + DPRINT << "\ttensor_slice_offset.y: " << (uint32_t)command_tensor.tensor_slice_offset.y << "\n"; + DPRINT << "\ttensor_slice_offset.x: " << (uint32_t)command_tensor.tensor_slice_offset.x << "\n"; + DPRINT << "\tworker_start_offset_in_slice.w: " << (uint32_t)command_tensor.worker_start_offset_in_slice.w << "\n"; + DPRINT << "\tworker_start_offset_in_slice.z: " << (uint32_t)command_tensor.worker_start_offset_in_slice.z << "\n"; + DPRINT << "\tworker_start_offset_in_slice.y: " << (uint32_t)command_tensor.worker_start_offset_in_slice.y << "\n"; + DPRINT << "\tworker_start_offset_in_slice.x: " << (uint32_t)command_tensor.worker_start_offset_in_slice.x << "\n"; + DPRINT << "\tworker_pages_per_slice: " << (uint32_t)command_tensor.worker_pages_per_slice << "\n"; +} +#endif + +void print_tensor_command(uint32_t command_index, ttnn::ccl::cmd::CclCommandTensor const& command_tensor) { +#ifdef DEBUG_PRINT_ENABLED + DPRINT << "cmd[" << (uint32_t)command_index << "]:\n"; + dprint(command_tensor); +#endif +} + +/* + * Convert a flattened worker offset coord value (assumed 0,0,0, worker offset in pages into tensor slice) + * into a 4D coordinate value + */ +FORCE_INLINE shape_t worker_wrapped_offset_to_coord(shape_t const& slice_shape, shape_t const& worker_slice_offset) { + static_assert( + sizeof(coord_t) == 2 * sizeof(uint32_t), "worker_wrapped_offset_to_coord not updated to work with 4d shape"); + auto const y = worker_slice_offset.x / slice_shape.x; + return shape_t(0, 0, y, worker_slice_offset.x - (y * slice_shape.x)); +} + +FORCE_INLINE std::size_t get_flat_index_from_shape(const Shape4D& shape, const Shape4D& index) { + std::size_t offset = index.x; + std::size_t inner_volume = shape.x; + offset += index.y * inner_volume; + inner_volume *= shape.y; + offset += index.z * inner_volume; + inner_volume *= shape.z; + offset += index.w * inner_volume; + return offset; +} + +namespace v2 { +/* + * Convert a flattened worker offset coord value (assumed 0,0,0, worker offset in pages into tensor slice) + * into a 4D coordinate value + */ +FORCE_INLINE shape_t worker_wrapped_offset_to_coord(shape_t const& slice_shape, shape_t const& worker_slice_offset) { + static_assert( + sizeof(coord_t) == 2 * sizeof(uint32_t), "worker_wrapped_offset_to_coord not updated to work with 4d shape"); + auto const y = worker_slice_offset.x / slice_shape.x; + return shape_t(0, 0, y, worker_slice_offset.x - (y * slice_shape.x)); +} + +} // namespace v2 + +template +struct source_tensor_addrgen { + static constexpr char name[] = "Uninitialized"; +}; +template +struct source_tensor_addrgen { + static constexpr bool is_dram = buffer_type == tt::tt_metal::BufferType::DRAM; + static constexpr char name[] = "InterleavedAddrGen(default)"; + using type = InterleavedAddrGen; +}; +template +struct source_tensor_addrgen { + static constexpr bool is_dram = buffer_type == tt::tt_metal::BufferType::DRAM; + static constexpr char name[] = "InterleavedAddrGen(Tile)"; + using type = InterleavedAddrGenFast; +}; +template +struct source_tensor_addrgen { + static constexpr char name[] = "WidthSharded"; + using type = tt::tt_metal::address_generators::DefaultVirtualCoordWidthShardedAddressGenerator; +}; +template +struct source_tensor_addrgen { + static constexpr char name[] = "HeightSharded"; + using type = tt::tt_metal::address_generators::DefaultVirtualCoordHeightShardedAddressGenerator; +}; +template +struct source_tensor_addrgen { + static constexpr char name[] = "BlockSharded"; + using type = tt::tt_metal::address_generators::DefaultVirtualCoordBlockShardedAddressGenerator; +}; + +constexpr bool is_sharded_tensor_layout(tt::tt_metal::TensorMemoryLayout tensor_layout) { + return tensor_layout == tt::tt_metal::TensorMemoryLayout::WIDTH_SHARDED || + tensor_layout == tt::tt_metal::TensorMemoryLayout::HEIGHT_SHARDED || + tensor_layout == tt::tt_metal::TensorMemoryLayout::BLOCK_SHARDED; +} + +// reader code +template +FORCE_INLINE constexpr Shape4D build_wrapped_row_tensor_slice(T n_pages) { + return Shape4D{1, 1, 1, n_pages}; +} diff --git a/ttnn/cpp/ttnn/operations/ccl/common/types/ccl_types.hpp b/ttnn/cpp/ttnn/operations/ccl/common/types/ccl_types.hpp index 59c6c6f38e0..584e2d0dd0d 100644 --- a/ttnn/cpp/ttnn/operations/ccl/common/types/ccl_types.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/common/types/ccl_types.hpp @@ -41,6 +41,13 @@ struct Shape4D { return w == rhs.w && z == rhs.z && y == rhs.y && x == rhs.x; } + T& operator[](size_t index) { + return *(&w + index); + } + const T& operator[](size_t index) const { + return *(&w + index); + } + constexpr std::size_t volume() const { return w * z * y * x; } diff --git a/ttnn/cpp/ttnn/operations/ccl/common/types/ccl_types_args_emitters.cpp b/ttnn/cpp/ttnn/operations/ccl/common/types/ccl_types_args_emitters.cpp index ed5ba80cb77..aaf889d66be 100644 --- a/ttnn/cpp/ttnn/operations/ccl/common/types/ccl_types_args_emitters.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/common/types/ccl_types_args_emitters.cpp @@ -136,16 +136,20 @@ std::vector ShardedAddrGenArgBuilder::emit_ct_args(Tensor const& t) { t.memory_config().memory_layout); // shard_grid_height (cores) args.push_back(shard_grid_end.y - shard_grid_start.y + 1); + TT_FATAL(args.back() > 0, "Passed shard_grid height == 0 to sharded addrgen, which is invalid"); // shard_grid_width (cores) args.push_back(shard_grid_end.x - shard_grid_start.x + 1); + TT_FATAL(args.back() > 0, "Passed shard_grid width == 0 to sharded addrgen, which is invalid"); // shard_grid_start_y args.push_back(shard_grid_start.y); // shard_grid_start_x args.push_back(shard_grid_start.x); // pages_per_shard_y args.push_back(pages_per_shard_y); + TT_FATAL(args.back() > 0, "Passed pages per shard y == 0 to sharded addrgen, which is invalid"); // pages_per_shard_x args.push_back(pages_per_shard_x); + TT_FATAL(args.back() > 0, "Passed pages per shard x == 0 to sharded addrgen, which is invalid"); // transposed grid args.push_back(static_cast(shard_grid_transposed)); diff --git a/ttnn/cpp/ttnn/operations/ccl/common/types/ccl_types_device.hpp b/ttnn/cpp/ttnn/operations/ccl/common/types/ccl_types_device.hpp index bf4eddca1e7..83187f3baeb 100644 --- a/ttnn/cpp/ttnn/operations/ccl/common/types/ccl_types_device.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/common/types/ccl_types_device.hpp @@ -5,6 +5,7 @@ #pragma once #include "ttnn/cpp/ttnn/operations/ccl/common/types/ccl_types.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/common/uops/ccl_command_device.hpp" namespace ttnn { namespace ccl { diff --git a/ttnn/cpp/ttnn/operations/ccl/common/uops/ccl_command.hpp b/ttnn/cpp/ttnn/operations/ccl/common/uops/ccl_command.hpp index 1b1051916bb..525bb7e5e77 100644 --- a/ttnn/cpp/ttnn/operations/ccl/common/uops/ccl_command.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/common/uops/ccl_command.hpp @@ -6,26 +6,122 @@ #include #include +#include +#include #include "ttnn/cpp/ttnn/operations/ccl/common/types/ccl_types.hpp" +// For command dest type +#include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header.hpp" + namespace ttnn { namespace ccl { +namespace v2 { +struct TensorSlice { + using ords_t = Shape4D; + ords_t tensor_shape; + ords_t tensor_slice_shape; + ords_t tensor_slice_offset; + ords_t worker_slice_shape; + ords_t worker_slice_offset; + + TensorSlice(TensorSlice const& rhs) = default; + TensorSlice(TensorSlice&& rhs) = default; + TensorSlice& operator=(TensorSlice const& rhs) = default; + TensorSlice& operator=(TensorSlice&& rhs) = default; + + TensorSlice() = default; + TensorSlice( + ords_t tensor_shape, + ords_t tensor_slice_shape, + ords_t tensor_slice_offset, + ords_t worker_slice_shape, + ords_t worker_slice_offset) : + tensor_shape(tensor_shape), + tensor_slice_shape(tensor_slice_shape), + tensor_slice_offset(tensor_slice_offset), + worker_slice_shape(worker_slice_shape), + worker_slice_offset(worker_slice_offset) {} +}; +} // namespace v2 + namespace cmd { constexpr std::size_t round_up(std::size_t a, std::size_t multiple) { return ((a + multiple - 1) / multiple) * multiple; } +// for CclCommandStreamTensorToCB and CclCommandStreamCBToTensor +using CclCommandStreamTensorSlice = v2::TensorSlice; +struct CclCommandWaitValue { + uint32_t target_value = 0; +}; +struct CclCommandAtomicInc { + uint32_t value = 1; + uint32_t wrap_value = std::numeric_limits::max(); +}; +struct CclCommandInlineReadWrite { + uint32_t value = 0; +}; +struct CclCommandReadWrite { + uint32_t size_bytes = 0; +}; +using CclCommandArgs = std::variant< + CclCommandStreamTensorSlice, + CclCommandWaitValue, + CclCommandAtomicInc, + CclCommandInlineReadWrite, + CclCommandReadWrite>; + +enum SRC_DEST_TYPE : uint8_t { SRC = 0, DEST = 1 }; + +// Explicitly assigned integer values for easier debug enum class CclCommandArgCode : uint8_t { // If operating on a per page granularity SET_TENSOR_SHAPE_IN_PAGES = 0, - SET_TENSOR_SLICE_SHAPE_IN_PAGES, - SET_TENSOR_SLICE_OFFSET_IN_PAGES, - SET_WORKER_START_OFFSET_IN_SLICE_IN_PAGES, - SET_WORKER_PAGES_PER_SLICE, - SET_FULL_TENSOR_SLICE_SPEC_IN_PAGES + SET_TENSOR_SLICE_SHAPE_IN_PAGES = 1, + SET_TENSOR_SLICE_OFFSET_IN_PAGES = 2, + SET_WORKER_START_OFFSET_IN_SLICE_IN_PAGES = 3, + SET_WORKER_PAGES_PER_SLICE = 4, + SET_FULL_TENSOR_SLICE_SPEC_IN_PAGES = 5, + + // wait_value, inline read/write + SET_TARGET_VALUE = 6, + SET_ATOMIC_INC_VALUE = 7, + + // addr type commands + SET_ADDRESS_INFO = 8, + + // core descriptor commands + SET_CORE_DESCRIPTOR_INFO = 9, + + INVALID = std::numeric_limits::max(), +}; + +struct CclCommandArgHeader { + CclCommandArgCode code = CclCommandArgCode::INVALID; + uint8_t inline_value0 = 0; + uint8_t inline_value1 = 0; + uint8_t inline_value2 = 0; + + static CclCommandArgHeader from_uint32(uint32_t val) { + CclCommandArgHeader header; + header.code = static_cast(val & 0xFF); + header.inline_value0 = (val >> 8) & 0xFF; + header.inline_value1 = (val >> 16) & 0xFF; + header.inline_value2 = (val >> 24) & 0xFF; + return header; + } + uint32_t to_uint32() const { + uint32_t val = 0; + val |= static_cast(this->code); + val |= static_cast(this->inline_value0) << 8; + val |= static_cast(this->inline_value1) << 16; + val |= static_cast(this->inline_value2) << 24; + return val; + } }; +static_assert(sizeof(CclCommandArgHeader) == sizeof(uint32_t)); struct CclCommandTensor { Shape4D tensor_shape; @@ -35,20 +131,49 @@ struct CclCommandTensor { uint32_t worker_pages_per_slice; }; -template struct command_arg_field { using type = std::nullptr_t; }; -template <> struct command_arg_field { using type = Shape4D; }; -template <> struct command_arg_field { using type = Shape4D; }; -template <> struct command_arg_field { using type = Shape4D; }; -template <> struct command_arg_field { using type = Shape4D; }; -template <> struct command_arg_field { using type = uint32_t; }; -template <> struct command_arg_field { using type = CclCommandTensor; }; - - -template -struct CclCommandArg { - +template +struct command_arg_field { + using type = std::nullptr_t; +}; +template <> +struct command_arg_field { + using type = Shape4D; +}; +template <> +struct command_arg_field { + using type = Shape4D; +}; +template <> +struct command_arg_field { + using type = Shape4D; +}; +template <> +struct command_arg_field { + using type = Shape4D; +}; +template <> +struct command_arg_field { + using type = uint32_t; +}; +template <> +struct command_arg_field { + using type = uint32_t; +}; +template <> +struct command_arg_field { + using type = CclCommandAtomicInc; +}; +template <> +struct command_arg_field { + using type = CclCommandTensor; +}; +template <> +struct command_arg_field { + using type = uint32_t; }; +template +struct CclCommandArg {}; using args_elem_t = uint32_t; template @@ -82,8 +207,10 @@ inline void unpack_field_without_header(volatile args_elem_t const* args, Shape4 void pack_field_without_header(args_elem_t* args, Shape4D const& out); template <> -struct CclCommandArg : public CclCommandArgBase, CclCommandArgCode::SET_TENSOR_SHAPE_IN_PAGES> { - +struct CclCommandArg + : public CclCommandArgBase< + CclCommandArg, + CclCommandArgCode::SET_TENSOR_SHAPE_IN_PAGES> { static void pack_to(args_elem_t* args, CclCommandTensor const& out) { pack_field_without_header(&args[0], out.tensor_shape); } @@ -93,13 +220,17 @@ struct CclCommandArg : public CclC static void unpack(volatile args_elem_t const* args, CclCommandTensor& out) { unpack_field_without_header(&args[0], out.tensor_shape); } - static void unpack(volatile args_elem_t const* args, field_type& out) { unpack_field_without_header(&args[0], out); } + static void unpack(volatile args_elem_t const* args, field_type& out) { + unpack_field_without_header(&args[0], out); + } void unpack(volatile args_elem_t const* args) { unpack_field_without_header(&args[0], this->value); } }; template <> -struct CclCommandArg : public CclCommandArgBase, CclCommandArgCode::SET_TENSOR_SLICE_SHAPE_IN_PAGES> { - +struct CclCommandArg + : public CclCommandArgBase< + CclCommandArg, + CclCommandArgCode::SET_TENSOR_SLICE_SHAPE_IN_PAGES> { static void pack_to(args_elem_t* args, CclCommandTensor const& out) { pack_field_without_header(&args[0], out.tensor_slice_shape); } @@ -114,7 +245,10 @@ struct CclCommandArg : publi }; template <> -struct CclCommandArg : public CclCommandArgBase, CclCommandArgCode::SET_TENSOR_SLICE_OFFSET_IN_PAGES> { +struct CclCommandArg + : public CclCommandArgBase< + CclCommandArg, + CclCommandArgCode::SET_TENSOR_SLICE_OFFSET_IN_PAGES> { using type = Shape4D; static void pack_to(args_elem_t* args, CclCommandTensor const& out) { @@ -131,7 +265,10 @@ struct CclCommandArg : publ }; template <> -struct CclCommandArg : public CclCommandArgBase, CclCommandArgCode::SET_WORKER_START_OFFSET_IN_SLICE_IN_PAGES> { +struct CclCommandArg + : public CclCommandArgBase< + CclCommandArg, + CclCommandArgCode::SET_WORKER_START_OFFSET_IN_SLICE_IN_PAGES> { using type = Shape4D; static void pack_to(args_elem_t* args, CclCommandTensor const& out) { @@ -148,21 +285,28 @@ struct CclCommandArg -struct CclCommandArg : public CclCommandArgBase, CclCommandArgCode::SET_WORKER_PAGES_PER_SLICE> { +struct CclCommandArg + : public CclCommandArgBase< + CclCommandArg, + CclCommandArgCode::SET_WORKER_PAGES_PER_SLICE> { using type = uint32_t; static void pack_to(args_elem_t* args, CclCommandTensor const& out) { args[0] = out.worker_pages_per_slice; } static void pack_to(args_elem_t* args, field_type const& out) { args[0] = out; } void pack_to(args_elem_t* args) const { args[0] = this->value; } - static void unpack(volatile args_elem_t const* args, CclCommandTensor& out) { out.worker_pages_per_slice = args[0]; } + static void unpack(volatile args_elem_t const* args, CclCommandTensor& out) { + out.worker_pages_per_slice = args[0]; + } static void unpack(volatile args_elem_t const* args, field_type& out) { out = args[0]; } void unpack(volatile args_elem_t const* args) { this->value = args[0]; } }; template <> struct CclCommandArg - : public CclCommandArgBase, CclCommandArgCode::SET_FULL_TENSOR_SLICE_SPEC_IN_PAGES> { + : public CclCommandArgBase< + CclCommandArg, + CclCommandArgCode::SET_FULL_TENSOR_SLICE_SPEC_IN_PAGES> { using type = CclCommandTensor; // considering making this some generator type that implements operator[] @@ -174,16 +318,20 @@ struct CclCommandArg CclCommandArg::pack_to(&args[i], command_tensor.tensor_shape); i += CclCommandArg::size_in_words(); - CclCommandArg::pack_to(&args[i], command_tensor.tensor_slice_shape); + CclCommandArg::pack_to( + &args[i], command_tensor.tensor_slice_shape); i += CclCommandArg::size_in_words(); - CclCommandArg::pack_to(&args[i], command_tensor.tensor_slice_offset); + CclCommandArg::pack_to( + &args[i], command_tensor.tensor_slice_offset); i += CclCommandArg::size_in_words(); - CclCommandArg::pack_to(&args[i], command_tensor.worker_start_offset_in_slice); + CclCommandArg::pack_to( + &args[i], command_tensor.worker_start_offset_in_slice); i += CclCommandArg::size_in_words(); - CclCommandArg::pack_to(&args[i], command_tensor.worker_pages_per_slice); + CclCommandArg::pack_to( + &args[i], command_tensor.worker_pages_per_slice); i += CclCommandArg::size_in_words(); } @@ -191,7 +339,6 @@ struct CclCommandArg CclCommandArg::pack_to(args, this->value); } - // TODO: when kernels get c++20, use std::span static void unpack(volatile args_elem_t const* args, CclCommandTensor& out) { std::size_t i = 0; CclCommandArg::unpack(&args[i], out.tensor_shape); @@ -203,7 +350,8 @@ struct CclCommandArg CclCommandArg::unpack(&args[i], out.tensor_slice_offset); i += CclCommandArg::size_in_words(); - CclCommandArg::unpack(&args[i], out.worker_start_offset_in_slice); + CclCommandArg::unpack( + &args[i], out.worker_start_offset_in_slice); i += CclCommandArg::size_in_words(); CclCommandArg::unpack(&args[i], out.worker_pages_per_slice); @@ -215,6 +363,54 @@ struct CclCommandArg } }; +template <> +struct CclCommandArg : public CclCommandArgBase< + CclCommandArg, + CclCommandArgCode::SET_TARGET_VALUE> { + static void pack_to(args_elem_t* args, uint32_t value) { args[0] = value; } + void pack_to(args_elem_t* args) { pack_to(&args[0], this->value); } + + static void unpack(volatile args_elem_t const* args, CclCommandTensor& out) { + unpack_field_without_header(&args[0], out.tensor_shape); + } + static void unpack(volatile args_elem_t const* args, field_type& out) { out = args[0]; } + void unpack(volatile args_elem_t const* args) { this->value = args[0]; } +}; + +template <> +struct CclCommandArg + : public CclCommandArgBase< + CclCommandArg, + CclCommandArgCode::SET_ATOMIC_INC_VALUE> { + static void pack_to(args_elem_t* args, CclCommandAtomicInc const& atomic_inc_args) { + args[0] = atomic_inc_args.value; + args[1] = atomic_inc_args.wrap_value; + } + void pack_to(args_elem_t* args) { pack_to(&args[0], this->value); } + + static void unpack(volatile args_elem_t const* args, CclCommandAtomicInc& out) { + out.value = args[0]; + out.wrap_value = args[1]; + } + void unpack(volatile args_elem_t const* args) { + this->value.value = args[0]; + this->value.wrap_value = args[1]; + } +}; + +template <> +struct CclCommandArg : public CclCommandArgBase< + CclCommandArg, + CclCommandArgCode::SET_ADDRESS_INFO> { + static void pack_to(args_elem_t* args, uint32_t value) { args[0] = value; } + void pack_to(args_elem_t* args) { pack_to(&args[0], this->value); } + + static void unpack(volatile args_elem_t const* args, CclCommandTensor& out) { + unpack_field_without_header(&args[0], out.tensor_shape); + } + static void unpack(volatile args_elem_t const* args, field_type& out) { out = args[0]; } + void unpack(volatile args_elem_t const* args) { this->value = args[0]; } +}; // Convenience type aliases using tensor_shape_command_arg_t = CclCommandArg; @@ -224,45 +420,214 @@ using worker_start_offset_command_arg_t = CclCommandArg; using full_tensor_command_arg_t = CclCommandArg; +enum class CclCommandAddrType : uint8_t { + SEMAPHORE_ID, + CIRCULAR_BUFFER_ID, + ABSOLUTE_ADDRESS, + RELATIVE_ADDRESS, + + // Useful for inline commands (read/write, atomic inc) + NONE +}; +struct CclCommandAddrSemaphoreId { + uint32_t semaphore_id; +}; +struct CclCommandAddrCircularBufferId { + uint32_t circular_buffer_id; +}; +struct CclCommandAddrAbsoluteAddress { + uint32_t absolute_address; +}; +struct CclCommandAddrRelativeAddress { + uint32_t relative_address; +}; +struct CclCommandAddrNone {}; + +using CclCommandAddrArgs = std::variant< + CclCommandAddrSemaphoreId, + CclCommandAddrCircularBufferId, + CclCommandAddrAbsoluteAddress, + CclCommandAddrRelativeAddress, + CclCommandAddrNone>; + +enum class CclCommandCoreDescriptorType : uint8_t { + // Temporary since at the moment, tensor commands have their source/dest type implied + // by the command stream index - the info is all off the addrgen + ADDRGEN = 0, + LOCAL = 1, + NOC_XY = 2, + RECTANGLE = 3 + // Future types may include: list, rectangle_list, etc. +}; +struct CclCommandCoreDescriptorTypeAddrgen {}; +struct CclCommandCoreDescriptorTypeLocal {}; +struct CclCommandCoreDescriptorTypeNocXY { + uint8_t x; + uint8_t y; +}; +// unused atm +struct CclCommandCoreDescriptorTypeMcast { + uint32_t to_uint32() const { + uint32_t value = 0; + value |= (noc0_start_x << 0); + value |= (noc0_start_y << 8); + value |= (noc0_end_x << 16); + value |= (noc0_end_y << 24); + return value; + } + static CclCommandCoreDescriptorTypeMcast from_uint32(uint32_t value) { + CclCommandCoreDescriptorTypeMcast mcast; + mcast.noc0_start_x = (value >> 0) & 0xFF; + mcast.noc0_start_y = (value >> 8) & 0xFF; + mcast.noc0_end_x = (value >> 16) & 0xFF; + mcast.noc0_end_y = (value >> 24) & 0xFF; + return mcast; + } + + uint8_t noc0_start_x; + uint8_t noc0_start_y; + uint8_t noc0_end_x; + uint8_t noc0_end_y; +}; +using CclCommandCoreDescriptorArgs = std::variant< + CclCommandCoreDescriptorTypeAddrgen, + CclCommandCoreDescriptorTypeLocal, + CclCommandCoreDescriptorTypeNocXY, + CclCommandCoreDescriptorTypeMcast>; + // A command is composed of one or more arguments // This enum specifies the high level command // Future commands are to be added and will enable // functionalilty such as synchronizing enum class CclCommandCode : uint8_t { - STREAM_TENSOR_TO_EDM = 0, - STREAM_EDM_TO_TENSOR + STREAM_TENSOR_TO_EDM = 0, // TODO: rename uses of to the below + STREAM_TENSOR_TO_CB = 0, + STREAM_CB_TO_TENSOR = 1, + STREAM_EDM_TO_TENSOR = 2, // TODO: rename uses of to the above + + WAIT_VALUE = 3, + + // value, wrap, dest_type, dest_addr_info + ATOMIC_INC = 4, + + RAW_INLINE_WRITE_BYTES = 5, + + // Behaviour reading/writing to CBs still a little unclear + // This mode isn't actually supported yet + RAW_READ_BYTES = 6, + RAW_WRITE_BYTES = 7, + + INVALID = 8 }; +enum CclCommandDestType : uint8_t { + CHIP_UNICAST = tt::fabric::CHIP_UNICAST, + CHIP_MULTICAST = tt::fabric::CHIP_MULTICAST, + CHIP_LOCAL_ONLY = 2 +}; +static_assert(tt::fabric::CHIP_UNICAST < 2); +static_assert(tt::fabric::CHIP_MULTICAST < 2); +struct DestTypeArgsNull {}; +static_assert(sizeof(DestTypeArgsNull) <= 2); +struct UnicastCommandDestArgs { + uint8_t distance_in_hops; + bool is_forward_direction; +}; +struct MulticastCommandDestArgs { + uint8_t num_targets_forward_direction; + uint8_t num_targets_backward_direction; +}; +using LocalOnlyCommandDestArgs = DestTypeArgsNull; + +// Used only for host code paths +using CclCommandDestArgs = std::variant; + +namespace v2 {}; + struct CclCommandHeader { - CclCommandCode code; + CclCommandCode code : 6; + CclCommandDestType dest_type : 2; // For the time being we have a dedicated arg_count because we assume // we may save args/tensor info from previous command. Up to command sequence // generator to make sure any fields/args not explicitly listed are correct from prior command uint8_t arg_count : 4; - uint8_t reserved1; - uint8_t reserved2; + union { + DestTypeArgsNull null; + UnicastCommandDestArgs unicast; + MulticastCommandDestArgs multicast; + LocalOnlyCommandDestArgs local_only; + } command_dest_args; + + CclCommandHeader() : code(CclCommandCode::INVALID), dest_type(CclCommandDestType::CHIP_LOCAL_ONLY), arg_count(0) {} + CclCommandHeader(CclCommandCode code, CclCommandDestArgs const& args, uint8_t arg_count) : + code(code), arg_count(arg_count) { + if (std::holds_alternative(args)) { + command_dest_args.unicast = std::get(args); + this->dest_type = CclCommandDestType::CHIP_UNICAST; + } else if (std::holds_alternative(args)) { + command_dest_args.multicast = std::get(args); + this->dest_type = CclCommandDestType::CHIP_MULTICAST; + } else if (std::holds_alternative(args)) { + command_dest_args.local_only = std::get(args); + this->dest_type = CclCommandDestType::CHIP_LOCAL_ONLY; + } + } + CclCommandHeader(CclCommandCode code, MulticastCommandDestArgs const& multicast_args, uint8_t arg_count) : + code(code), dest_type(CclCommandDestType::CHIP_MULTICAST), arg_count(arg_count) { + this->command_dest_args.multicast = multicast_args; + } + CclCommandHeader(CclCommandCode code, LocalOnlyCommandDestArgs const& local_only_args, uint8_t arg_count) : + code(code), dest_type(CclCommandDestType::CHIP_LOCAL_ONLY), arg_count(arg_count) { + this->command_dest_args.local_only = local_only_args; + } - static CclCommandHeader from_uint32(uint32_t const& cmd_header) { + static CclCommandHeader from_uint32(uint32_t cmd_header) { CclCommandHeader decoded; - decoded.code = static_cast(cmd_header & 0xFF); - decoded.arg_count = (cmd_header >> 8) & 0xF; + reinterpret_cast(&decoded)[0] = cmd_header; return decoded; + // decoded.code = static_cast(cmd_header & 0xFF); + // decoded.dest_type = static_cast((cmd_header >> 6) & 0x3); + // switch (decoded.dest_type) { + // case CclCommandDestType::CHIP_UNICAST: + // decoded.command_dest_args.unicast = UnicastCommandDestArgs{static_cast((cmd_header >> 16) & + // 0xFF), static_cast((cmd_header >> 24) & 0x1)}; break; + // case CclCommandDestType::CHIP_MULTICAST: + // decoded.command_dest_args.multicast = MulticastCommandDestArgs{static_cast((cmd_header >> + // 16) & 0xFF), static_cast((cmd_header >> 24) & 0xFF)}; break; + // default: + // break; + // } + // decoded.arg_count = (cmd_header >> 8) & 0xF; + // return decoded; } static uint32_t to_uint32(CclCommandHeader const& cmd_header) { uint32_t encoded = 0; - encoded = (uint8_t)(cmd_header.code); - encoded = encoded | (cmd_header.arg_count << 8); + encoded = (uint32_t)(cmd_header.code); + encoded |= (cmd_header.dest_type << 6); + encoded |= (cmd_header.arg_count << 8); + switch (cmd_header.dest_type) { + case CclCommandDestType::CHIP_UNICAST: + encoded |= (cmd_header.command_dest_args.unicast.distance_in_hops << 16); + encoded |= (cmd_header.command_dest_args.unicast.is_forward_direction << 24); + break; + case CclCommandDestType::CHIP_MULTICAST: + encoded |= (cmd_header.command_dest_args.multicast.num_targets_forward_direction << 16); + encoded |= (cmd_header.command_dest_args.multicast.num_targets_backward_direction << 24); + break; + default: break; + }; return encoded; } - uint32_t to_uint32() const { - return *reinterpret_cast(this); - } + uint32_t to_uint32() const { return to_uint32(*this); } + + UnicastCommandDestArgs const& get_unicast_dest_args() const { return command_dest_args.unicast; } + MulticastCommandDestArgs const& get_multicast_dest_args() const { return command_dest_args.multicast; } + LocalOnlyCommandDestArgs const& get_local_only_dest_args() const { return command_dest_args.local_only; } }; static_assert(sizeof(CclCommandHeader) == sizeof(uint32_t)); - } // namespace cmd } // namespace ccl } // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/ccl/common/uops/ccl_command_device.hpp b/ttnn/cpp/ttnn/operations/ccl/common/uops/ccl_command_device.hpp index 3d7d8c91bd9..acdbc515d85 100644 --- a/ttnn/cpp/ttnn/operations/ccl/common/uops/ccl_command_device.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/common/uops/ccl_command_device.hpp @@ -44,19 +44,21 @@ auto build_from_args>(std::size_t &rt_arg_idx) -> Shape4D(arg_idx++)); #ifdef DEBUG_PRINT_ENABLED - DPRINT << "CMD (code=" << (uint32_t)cmd.code << ", arg_count=" << (uint32_t)cmd.arg_count << ")\n"; + DPRINT << "CMD (code=" << (uint32_t)cmd.code << ", dst_t=" << (uint32_t)cmd.dest_type << ", arg_count=" << (uint32_t)cmd.arg_count << ")\n"; #endif - for (std::size_t i = 0; i < cmd.arg_count; i++) { + for (size_t i = 0; i < cmd.arg_count; i++) { // Note that we choose to reinterpret our pointers as volatile so that in the future we can add streaming // of additional commands from some backing memory (e.g. dram or L1), potentially by another core, without // having to track down this code and add volatile casts later (which would be a potentially tricky bug to // root cause). - switch (static_cast(get_arg_val(arg_idx++))) { + const CclCommandArgHeader command_arg_header = CclCommandArgHeader::from_uint32(get_arg_val(arg_idx++)); + const CclCommandArgCode command_arg_code = command_arg_header.code; + switch (command_arg_code) { case CclCommandArgCode::SET_TENSOR_SHAPE_IN_PAGES: CclCommandArg::unpack(reinterpret_cast(get_arg_addr(arg_idx)), cmd_tensor.tensor_shape); #ifdef DEBUG_PRINT_ENABLED @@ -105,8 +107,12 @@ void update_command_tensor(std::size_t &arg_idx, CclCommandTensor &cmd_tensor) { }; } + return cmd; } + + + } // namespace cmd } // namespace ccl diff --git a/ttnn/cpp/ttnn/operations/ccl/common/uops/ccl_host_commands.cpp b/ttnn/cpp/ttnn/operations/ccl/common/uops/ccl_host_commands.cpp new file mode 100644 index 00000000000..af3de7d0e46 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/common/uops/ccl_host_commands.cpp @@ -0,0 +1,381 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "ttnn/cpp/ttnn/operations/ccl/common/uops/ccl_host_commands.hpp" + +#include "ttnn/operations/ccl/common/uops/ccl_command.hpp" +#include "tt_metal/impl/buffers/global_semaphore.hpp" +#include "tt_metal/tt_stl/overloaded.hpp" + +#include +namespace ttnn::ccl::cmd { + +// This file defines commands that are resolved on a per worker level. This is the lowest level of +// command description (Intermediate Representation if you will) before being lowered directly to +// Ccl Command Process KernelCommands + +namespace uops { + +CclHostLowLevelWorkerCommand read_tensor_slice_to_cb_for_eventual_fabric_write( + ttnn::ccl::v2::TensorSlice const& slice, size_t cb_id) { + return CclHostLowLevelWorkerCommand{ + CclCommandCode::STREAM_TENSOR_TO_CB, + slice, + // At the moment, we don't support switching tensors from within a command stream + // so we set none because we assume the command stream is fixed/assigned to a given tensor + // based on order: + // - Command stream 0: tensor 0 + // - Command stream 1: tensor 1 + ttnn::ccl::cmd::CclCommandAddrType::NONE, + ttnn::ccl::cmd::CclCommandAddrNone(), + ttnn::ccl::cmd::CclCommandAddrType::CIRCULAR_BUFFER_ID, + ttnn::ccl::cmd::CclCommandAddrCircularBufferId{cb_id}, + ttnn::ccl::cmd::CclCommandCoreDescriptorType::ADDRGEN, + ttnn::ccl::cmd::CclCommandCoreDescriptorTypeAddrgen(), + ttnn::ccl::cmd::CclCommandDestType::CHIP_UNICAST, + // Hack to add packet header padding when doing reads + ttnn::ccl::cmd::UnicastCommandDestArgs(0, true)}; +}; +CclHostLowLevelWorkerCommand read_tensor_slice_to_cb(ttnn::ccl::v2::TensorSlice const& slice, size_t cb_id) { + return CclHostLowLevelWorkerCommand{ + CclCommandCode::STREAM_TENSOR_TO_CB, + slice, + // At the moment, we don't support switching tensors from within a command stream + // so we set none because we assume the command stream is fixed/assigned to a given tensor + // based on order: + // - Command stream 0: tensor 0 + // - Command stream 1: tensor 1 + ttnn::ccl::cmd::CclCommandAddrType::NONE, + ttnn::ccl::cmd::CclCommandAddrNone(), + ttnn::ccl::cmd::CclCommandAddrType::CIRCULAR_BUFFER_ID, + ttnn::ccl::cmd::CclCommandAddrCircularBufferId{cb_id}, + ttnn::ccl::cmd::CclCommandCoreDescriptorType::ADDRGEN, + ttnn::ccl::cmd::CclCommandCoreDescriptorTypeAddrgen(), + ttnn::ccl::cmd::CclCommandDestType::CHIP_LOCAL_ONLY, + ttnn::ccl::cmd::LocalOnlyCommandDestArgs()}; +}; + +CclHostLowLevelWorkerCommand local_write_cb_to_tensor_slice(ttnn::ccl::v2::TensorSlice const& slice, size_t cb_id) { + return CclHostLowLevelWorkerCommand( + CclCommandCode::STREAM_CB_TO_TENSOR, + ttnn::ccl::cmd::CclCommandArgs(slice), + ttnn::ccl::cmd::CclCommandAddrType::CIRCULAR_BUFFER_ID, + ttnn::ccl::cmd::CclCommandAddrCircularBufferId{cb_id}, + ttnn::ccl::cmd::CclCommandAddrType::NONE, + ttnn::ccl::cmd::CclCommandAddrNone(), + ttnn::ccl::cmd::CclCommandCoreDescriptorType::ADDRGEN, + ttnn::ccl::cmd::CclCommandCoreDescriptorTypeAddrgen(), + ttnn::ccl::cmd::CclCommandDestType::CHIP_LOCAL_ONLY, + ttnn::ccl::cmd::LocalOnlyCommandDestArgs()); +} + +CclHostLowLevelWorkerCommand fabric_write_cb_to_tensor_slice( + ttnn::ccl::v2::TensorSlice const& slice, + size_t cb_id, + std::variant const& dest_args) { + auto const dest_type = std::visit( + tt::stl::overloaded{ + [](ttnn::ccl::cmd::UnicastCommandDestArgs const&) { return CclCommandDestType::CHIP_UNICAST; }, + [](ttnn::ccl::cmd::MulticastCommandDestArgs const&) { return CclCommandDestType::CHIP_MULTICAST; }, + [](auto&&) -> void { + TT_THROW( + "ttnn::ccl::cmd::uops::fabric_write_cb_to_tensor_slice called with unsupported fabric dest_args " + "types. " + "Currently supported types are UnicastCommandDestArgs and MulticastCommandDestArgs"); + }}, + dest_args); + auto dest_args_variant = std::visit( + tt::stl::overloaded{ + [](ttnn::ccl::cmd::UnicastCommandDestArgs const& arg) -> ttnn::ccl::cmd::CclCommandDestArgs { + return ttnn::ccl::cmd::UnicastCommandDestArgs(arg); + }, + [](ttnn::ccl::cmd::MulticastCommandDestArgs const& arg) -> ttnn::ccl::cmd::CclCommandDestArgs { + return ttnn::ccl::cmd::MulticastCommandDestArgs(arg); + }, + [](auto&&) -> void { + TT_THROW( + "ttnn::ccl::cmd::uops::fabric_write_cb_to_tensor_slice called with unsupported fabric dest_args " + "types. " + "Currently supported types are UnicastCommandDestArgs and MulticastCommandDestArgs"); + }}, + dest_args); + + return CclHostLowLevelWorkerCommand( + CclCommandCode::STREAM_CB_TO_TENSOR, + ttnn::ccl::cmd::CclCommandStreamTensorSlice(slice), + // src + ttnn::ccl::cmd::CclCommandAddrType::CIRCULAR_BUFFER_ID, + ttnn::ccl::cmd::CclCommandAddrCircularBufferId{cb_id}, + + // dest + ttnn::ccl::cmd::CclCommandAddrType::NONE, + ttnn::ccl::cmd::CclCommandAddrNone(), + + ttnn::ccl::cmd::CclCommandCoreDescriptorType::ADDRGEN, + ttnn::ccl::cmd::CclCommandCoreDescriptorTypeAddrgen(), + + dest_type, + dest_args_variant); +} + +static ttnn::ccl::cmd::CclCommandAddrType get_semaphore_addr_type(semaphore_id_t const& semaphore_id) { + return std::visit( + tt::stl::overloaded{ + [](uint32_t) { return ttnn::ccl::cmd::CclCommandAddrType::SEMAPHORE_ID; }, + [](tt::tt_metal::GlobalSemaphore const*) { return ttnn::ccl::cmd::CclCommandAddrType::ABSOLUTE_ADDRESS; }, + [](auto&&) -> void { + TT_THROW( + "ttnn::ccl::cmd::uops::get_semaphore_addr_type called with unsupported semaphore types. " + "Currently supported types are uint32_t (semaphore ID) and GlobalSemaphore"); + }}, + semaphore_id); +} +static ttnn::ccl::cmd::CclCommandAddrArgs get_semaphore_addr_val(semaphore_id_t const& semaphore_id) { + using ttnn::ccl::cmd::CclCommandAddrArgs; + return std::visit( + tt::stl::overloaded{ + [](uint32_t id) -> CclCommandAddrArgs { return ttnn::ccl::cmd::CclCommandAddrSemaphoreId{id}; }, + [](tt::tt_metal::GlobalSemaphore const* semaphore) -> CclCommandAddrArgs { + TT_FATAL(semaphore != nullptr, "Internal error: GlobalSemaphore pointer is null in call to get_semaphore_addr_val"); + return ttnn::ccl::cmd::CclCommandAddrAbsoluteAddress{semaphore->address()}; + }, + [](auto&&) -> void { + TT_THROW( + "ttnn::ccl::cmd::uops::get_semaphore_addr_val called with unsupported semaphore types. " + "Currently supported types are uint32_t (semaphore ID) and GlobalSemaphore"); + } + + }, + semaphore_id); +} + +CclHostLowLevelWorkerCommand local_semaphore_wait(semaphore_id_t const& semaphore_id, size_t value) { + return CclHostLowLevelWorkerCommand( + CclCommandCode::WAIT_VALUE, + ttnn::ccl::cmd::CclCommandArgs(ttnn::ccl::cmd::CclCommandWaitValue{value}), + get_semaphore_addr_type(semaphore_id), + get_semaphore_addr_val(semaphore_id), + ttnn::ccl::cmd::CclCommandAddrType::NONE, + ttnn::ccl::cmd::CclCommandAddrNone(), + ttnn::ccl::cmd::CclCommandCoreDescriptorType::LOCAL, + ttnn::ccl::cmd::CclCommandCoreDescriptorTypeAddrgen(), + ttnn::ccl::cmd::CclCommandDestType::CHIP_LOCAL_ONLY, + ttnn::ccl::cmd::LocalOnlyCommandDestArgs()); +} + +CclHostLowLevelWorkerCommand local_core_semaphore_set(semaphore_id_t const& semaphore_id, size_t value) { + TT_FATAL( + value < std::numeric_limits::max(), + "When invoking: local_core_inline_write. Raw inline writes currently are limited to values no larger than {} " + "due to a command encoding limitation. Support for larger values is not yet added", + std::numeric_limits::max()); + return CclHostLowLevelWorkerCommand( + CclCommandCode::RAW_INLINE_WRITE_BYTES, + ttnn::ccl::cmd::CclCommandArgs(ttnn::ccl::cmd::CclCommandInlineReadWrite{value}), + ttnn::ccl::cmd::CclCommandAddrType::NONE, + ttnn::ccl::cmd::CclCommandAddrNone{}, + get_semaphore_addr_type(semaphore_id), + get_semaphore_addr_val(semaphore_id), + ttnn::ccl::cmd::CclCommandCoreDescriptorType::LOCAL, + ttnn::ccl::cmd::CclCommandCoreDescriptorTypeLocal(), + ttnn::ccl::cmd::CclCommandDestType::CHIP_LOCAL_ONLY, + ttnn::ccl::cmd::LocalOnlyCommandDestArgs()); +} + +CclHostLowLevelWorkerCommand local_core_semaphore_inc(semaphore_id_t const& semaphore_id, size_t value) { + return CclHostLowLevelWorkerCommand( + CclCommandCode::ATOMIC_INC, + ttnn::ccl::cmd::CclCommandArgs(ttnn::ccl::cmd::CclCommandAtomicInc{value}), + // src + ttnn::ccl::cmd::CclCommandAddrType::NONE, + ttnn::ccl::cmd::CclCommandAddrNone(), + // dest + get_semaphore_addr_type(semaphore_id), + get_semaphore_addr_val(semaphore_id), + ttnn::ccl::cmd::CclCommandCoreDescriptorType::LOCAL, + ttnn::ccl::cmd::CclCommandCoreDescriptorTypeLocal(), + ttnn::ccl::cmd::CclCommandDestType::CHIP_LOCAL_ONLY, + ttnn::ccl::cmd::LocalOnlyCommandDestArgs()); +} + +CclHostLowLevelWorkerCommand local_chip_noc_semaphore_inc( + size_t dest_noc0_x, + size_t dest_noc0_y, + semaphore_id_t const& semaphore_id, + // size_t semaphore_id, + size_t value) { + return CclHostLowLevelWorkerCommand( + CclCommandCode::ATOMIC_INC, + ttnn::ccl::cmd::CclCommandArgs(ttnn::ccl::cmd::CclCommandAtomicInc{value}), + // src + ttnn::ccl::cmd::CclCommandAddrType::NONE, + ttnn::ccl::cmd::CclCommandAddrNone(), + // dest + get_semaphore_addr_type(semaphore_id), + get_semaphore_addr_val(semaphore_id), + ttnn::ccl::cmd::CclCommandCoreDescriptorType::NOC_XY, + ttnn::ccl::cmd::CclCommandCoreDescriptorTypeNocXY{dest_noc0_x, dest_noc0_y}, + ttnn::ccl::cmd::CclCommandDestType::CHIP_LOCAL_ONLY, + ttnn::ccl::cmd::LocalOnlyCommandDestArgs()); +} + +static std::pair optimize_mcast_core_desc_args( + CclCommandCoreDescriptorTypeMcast const& noc_mcast_args) { + bool is_really_a_unicast = noc_mcast_args.noc0_end_x == noc_mcast_args.noc0_start_x && + noc_mcast_args.noc0_end_y == noc_mcast_args.noc0_start_y; + auto core_desc_type = + is_really_a_unicast ? CclCommandCoreDescriptorType::NOC_XY : CclCommandCoreDescriptorType::RECTANGLE; + CclCommandCoreDescriptorArgs core_desc_args = is_really_a_unicast + ? CclCommandCoreDescriptorArgs{CclCommandCoreDescriptorTypeNocXY{ + noc_mcast_args.noc0_start_x, noc_mcast_args.noc0_start_y}} + : CclCommandCoreDescriptorArgs{noc_mcast_args}; + return {core_desc_type, core_desc_args}; +} + +[[nodiscard]] CclHostLowLevelWorkerCommand fabric_unicast_semaphore_inc_mcast( + semaphore_id_t const& semaphore_dest_args, + CclCommandAtomicInc const& increment_args, + CclCommandCoreDescriptorTypeMcast const& mcast_spec, + UnicastCommandDestArgs const& unicast_args) { + auto const [core_desc_type, core_desc_args] = optimize_mcast_core_desc_args(mcast_spec); + TT_FATAL( + core_desc_type != CclCommandCoreDescriptorType::RECTANGLE, + "semaphore inc commands don't support noc multicast yet"); + return CclHostLowLevelWorkerCommand( + CclCommandCode::ATOMIC_INC, + increment_args, + // src + CclCommandAddrType::NONE, + CclCommandAddrNone(), + // dest + get_semaphore_addr_type(semaphore_dest_args), + get_semaphore_addr_val(semaphore_dest_args), + core_desc_type, + core_desc_args, + CclCommandDestType::CHIP_UNICAST, + UnicastCommandDestArgs(unicast_args)); +} + +[[nodiscard]] CclHostLowLevelWorkerCommand local_chip_semaphore_inc_mcast( + // CclCommandAddrSemaphoreId const& semaphore_dest_args, + semaphore_id_t const& semaphore_dest_args, + CclCommandAtomicInc const& increment_args, + CclCommandCoreDescriptorTypeMcast const& mcast_spec) { + auto const [core_desc_type, core_desc_args] = optimize_mcast_core_desc_args(mcast_spec); + TT_FATAL( + core_desc_type != CclCommandCoreDescriptorType::RECTANGLE, + "semaphore inc commands don't support noc multicast yet"); + return CclHostLowLevelWorkerCommand( + CclCommandCode::ATOMIC_INC, + increment_args, + // src + ttnn::ccl::cmd::CclCommandAddrType::NONE, + ttnn::ccl::cmd::CclCommandAddrNone(), + // dest + get_semaphore_addr_type(semaphore_dest_args), + get_semaphore_addr_val(semaphore_dest_args), + core_desc_type, + core_desc_args, + ttnn::ccl::cmd::CclCommandDestType::CHIP_LOCAL_ONLY, + ttnn::ccl::cmd::LocalOnlyCommandDestArgs()); +} + +CclHostLowLevelWorkerCommand local_chip_noc_absolute_address_semaphore_inc( + size_t dest_noc0_x, size_t dest_noc0_y, size_t bank_address, size_t value) { + return CclHostLowLevelWorkerCommand( + CclCommandCode::ATOMIC_INC, + ttnn::ccl::cmd::CclCommandArgs(ttnn::ccl::cmd::CclCommandAtomicInc{value}), + + ttnn::ccl::cmd::CclCommandAddrType::NONE, + ttnn::ccl::cmd::CclCommandAddrNone(), + + ttnn::ccl::cmd::CclCommandAddrType::ABSOLUTE_ADDRESS, + ttnn::ccl::cmd::CclCommandAddrAbsoluteAddress{bank_address}, + + ttnn::ccl::cmd::CclCommandCoreDescriptorType::NOC_XY, + ttnn::ccl::cmd::CclCommandCoreDescriptorTypeNocXY{dest_noc0_x, dest_noc0_y}, + + ttnn::ccl::cmd::CclCommandDestType::CHIP_LOCAL_ONLY, + ttnn::ccl::cmd::LocalOnlyCommandDestArgs()); +} + +CclHostLowLevelWorkerCommand fabric_multicast_semaphore_inc( + semaphore_id_t const& semaphore_dest_args, + CclCommandAtomicInc const& increment_args, + size_t dest_noc0_x, + size_t dest_noc0_y, + MulticastCommandDestArgs const& multicast_args) { + return CclHostLowLevelWorkerCommand( + CclCommandCode::ATOMIC_INC, + ttnn::ccl::cmd::CclCommandArgs(ttnn::ccl::cmd::CclCommandAtomicInc{increment_args}), + + // src + ttnn::ccl::cmd::CclCommandAddrType::NONE, + ttnn::ccl::cmd::CclCommandAddrNone(), + + // dest + get_semaphore_addr_type(semaphore_dest_args), + get_semaphore_addr_val(semaphore_dest_args), + + ttnn::ccl::cmd::CclCommandCoreDescriptorType::NOC_XY, + ttnn::ccl::cmd::CclCommandCoreDescriptorTypeNocXY{dest_noc0_x, dest_noc0_y}, + + ttnn::ccl::cmd::CclCommandDestType::CHIP_MULTICAST, + ttnn::ccl::cmd::MulticastCommandDestArgs(multicast_args)); +} + +CclHostLowLevelWorkerCommand fabric_unicast_semaphore_inc( + // CclCommandAddrSemaphoreId const& semaphore_dest_args, + semaphore_id_t const& semaphore_dest_args, + CclCommandAtomicInc const& increment_args, + size_t dest_noc0_x, + size_t dest_noc0_y, + UnicastCommandDestArgs const& unicast_args) { + return CclHostLowLevelWorkerCommand( + CclCommandCode::ATOMIC_INC, + ttnn::ccl::cmd::CclCommandArgs(ttnn::ccl::cmd::CclCommandAtomicInc{increment_args}), + + // src + ttnn::ccl::cmd::CclCommandAddrType::NONE, + ttnn::ccl::cmd::CclCommandAddrNone(), + + // dest + get_semaphore_addr_type(semaphore_dest_args), + get_semaphore_addr_val(semaphore_dest_args), + + ttnn::ccl::cmd::CclCommandCoreDescriptorType::NOC_XY, + ttnn::ccl::cmd::CclCommandCoreDescriptorTypeNocXY{dest_noc0_x, dest_noc0_y}, + + ttnn::ccl::cmd::CclCommandDestType::CHIP_UNICAST, + ttnn::ccl::cmd::UnicastCommandDestArgs(unicast_args)); +} + +CclHostLowLevelWorkerCommand fabric_unicast_absolute_address_semaphore_inc( + CclCommandAddrAbsoluteAddress const& address_dest_args, + CclCommandAtomicInc const& increment_args, + size_t dest_noc0_x, + size_t dest_noc0_y, + UnicastCommandDestArgs const& unicast_args) { + return CclHostLowLevelWorkerCommand( + CclCommandCode::ATOMIC_INC, + ttnn::ccl::cmd::CclCommandArgs(ttnn::ccl::cmd::CclCommandAtomicInc{increment_args}), + + // src + ttnn::ccl::cmd::CclCommandAddrType::NONE, + ttnn::ccl::cmd::CclCommandAddrNone(), + + // dest + ttnn::ccl::cmd::CclCommandAddrType::ABSOLUTE_ADDRESS, + address_dest_args, + + ttnn::ccl::cmd::CclCommandCoreDescriptorType::NOC_XY, + ttnn::ccl::cmd::CclCommandCoreDescriptorTypeNocXY{dest_noc0_x, dest_noc0_y}, + + ttnn::ccl::cmd::CclCommandDestType::CHIP_UNICAST, + ttnn::ccl::cmd::UnicastCommandDestArgs(unicast_args)); +} + +} // namespace uops + +} // namespace ttnn::ccl::cmd diff --git a/ttnn/cpp/ttnn/operations/ccl/common/uops/ccl_host_commands.hpp b/ttnn/cpp/ttnn/operations/ccl/common/uops/ccl_host_commands.hpp new file mode 100644 index 00000000000..8e8c22ea8b5 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/common/uops/ccl_host_commands.hpp @@ -0,0 +1,88 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "ttnn/cpp/ttnn/operations/ccl/common/uops/ccl_command.hpp" +#include "tt_metal/impl/buffers/global_semaphore.hpp" + +namespace ttnn::ccl::cmd { + +// This file defines commands that are resolved on a per worker level. This is the lowest level of +// command description (Intermediate Representation if you will) before being lowered directly to +// Ccl Command Process KernelCommands + +struct CclHostLowLevelWorkerCommand { + ttnn::ccl::cmd::CclCommandCode command_code; + ttnn::ccl::cmd::CclCommandArgs command_args; + + // semaphore ID derived address, absolute address, relative address + ttnn::ccl::cmd::CclCommandAddrType source_addr_type; + ttnn::ccl::cmd::CclCommandAddrArgs source_addr_args; + + ttnn::ccl::cmd::CclCommandAddrType dest_addr_type; + ttnn::ccl::cmd::CclCommandAddrArgs dest_addr_args; + + // resolved core-xy, rectangle (for mcast) + ttnn::ccl::cmd::CclCommandCoreDescriptorType core_desc_type; + ttnn::ccl::cmd::CclCommandCoreDescriptorArgs core_desc_args; + + // unicast, mcast, local_only + ttnn::ccl::cmd::CclCommandDestType fabric_transfer_type; + ttnn::ccl::cmd::CclCommandDestArgs fabric_transfer_args; +}; + +using CclHostLowLevelCommandSequence = std::vector; + +namespace uops { + +using semaphore_id_t = std::variant; + +[[nodiscard]] CclHostLowLevelWorkerCommand read_tensor_slice_to_cb_for_eventual_fabric_write( + ttnn::ccl::v2::TensorSlice const& slice, size_t cb_id); +[[nodiscard]] CclHostLowLevelWorkerCommand read_tensor_slice_to_cb( + ttnn::ccl::v2::TensorSlice const& slice, size_t cb_id); +[[nodiscard]] CclHostLowLevelWorkerCommand local_write_cb_to_tensor_slice( + ttnn::ccl::v2::TensorSlice const& slice, size_t cb_id); +[[nodiscard]] CclHostLowLevelWorkerCommand fabric_write_cb_to_tensor_slice( + ttnn::ccl::v2::TensorSlice const& slice, + size_t cb_id, + std::variant const& dest_args_variant); +[[nodiscard]] CclHostLowLevelWorkerCommand local_semaphore_wait(semaphore_id_t const& semaphore_id, size_t value); +[[nodiscard]] CclHostLowLevelWorkerCommand local_chip_noc_semaphore_inc( + size_t dest_noc0_x, size_t dest_noc0_y, semaphore_id_t const& semaphore_id, size_t value); +[[nodiscard]] CclHostLowLevelWorkerCommand local_core_semaphore_inc(semaphore_id_t const& semaphore_id, size_t value); +[[nodiscard]] CclHostLowLevelWorkerCommand local_core_semaphore_set(semaphore_id_t const& semaphore_id, size_t value); +[[nodiscard]] [[deprecated]] CclHostLowLevelWorkerCommand local_chip_noc_absolute_address_semaphore_inc( + size_t dest_noc0_x, size_t dest_noc0_y, size_t bank_address, size_t value); +[[nodiscard]] CclHostLowLevelWorkerCommand fabric_multicast_semaphore_inc( + semaphore_id_t const& semaphore_dest_args, + CclCommandAtomicInc const& increment_args, + size_t dest_noc0_x, + size_t dest_noc0_y, + MulticastCommandDestArgs const& multicast_args); +[[nodiscard]] CclHostLowLevelWorkerCommand fabric_unicast_semaphore_inc( + semaphore_id_t const& semaphore_dest_args, + CclCommandAtomicInc const& increment_args, + size_t dest_noc0_x, + size_t dest_noc0_y, + UnicastCommandDestArgs const& unicast_args); +[[nodiscard]] CclHostLowLevelWorkerCommand fabric_unicast_semaphore_inc_mcast( + semaphore_id_t const& semaphore_dest_args, + CclCommandAtomicInc const& increment_args, + CclCommandCoreDescriptorTypeMcast const& dest_mcast_spec, + UnicastCommandDestArgs const& unicast_args); +[[nodiscard]] CclHostLowLevelWorkerCommand local_chip_semaphore_inc_mcast( + semaphore_id_t const& semaphore_dest_args, + CclCommandAtomicInc const& increment_args, + CclCommandCoreDescriptorTypeMcast const& dest_mcast_spec); +[[nodiscard]] CclHostLowLevelWorkerCommand fabric_unicast_absolute_address_semaphore_inc( + CclCommandAddrAbsoluteAddress const& address_dest_args, + CclCommandAtomicInc const& increment_args, + size_t dest_noc0_x, + size_t dest_noc0_y, + UnicastCommandDestArgs const& unicast_args); + +}; // namespace uops +}; // namespace ttnn::ccl::cmd diff --git a/ttnn/cpp/ttnn/operations/ccl/erisc_datamover_builder.cpp b/ttnn/cpp/ttnn/operations/ccl/erisc_datamover_builder.cpp index 3f6c480ef48..d954dacb906 100644 --- a/ttnn/cpp/ttnn/operations/ccl/erisc_datamover_builder.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/erisc_datamover_builder.cpp @@ -7,6 +7,7 @@ #include "common/math.hpp" #include "erisc_datamover_builder.hpp" #include "eth_l1_address_map.h" +#include "sub_device/sub_device_types.hpp" #include "tt_metal/common/assert.hpp" #include "ttnn/operations/ccl/ccl_common.hpp" #include "ttnn/operations/math.hpp" @@ -16,7 +17,13 @@ #include "tt_metal/impl/device/device.hpp" #include "tt_metal/impl/program/program.hpp" +#include "tt_metal/detail/tt_metal.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header.hpp" + +#include #include +#include +#include namespace ttnn::ccl { @@ -77,6 +84,62 @@ FabricEriscDatamoverConfig::FabricEriscDatamoverConfig( eth_l1_mem::address_map::MAX_L1_LOADING_SIZE, "Internal error - channel buffers spilled past the end of usable L1 region."); } +void get_runtime_args_for_edm_termination_infos(std::vector const& edm_termination_infos, std::vector& args_out) { + args_out.reserve(args_out.size() + edm_termination_infos.size() * 4 + 1); + args_out.push_back(edm_termination_infos.size()); + for (auto const& info : edm_termination_infos) { + args_out.push_back(info.edm_noc_x); + args_out.push_back(info.edm_noc_y); + args_out.push_back(info.distance); + args_out.push_back(info.termination_addr); + log_trace( + tt::LogTest, + "EDM termination info: x={}, y={}, distance={}, termination_addr={}", + info.edm_noc_x, + info.edm_noc_y, + info.distance, + info.termination_addr); + } +} + +void append_worker_to_fabric_edm_sender_rt_args( + SenderWorkerAdapterSpec const& connection, + size_t sender_worker_flow_control_semaphore_id, + size_t sender_worker_buffer_index_semaphore_id, + std::vector& args_out) { + auto edm_noc_xy = WorkerXY(connection.edm_noc_x, connection.edm_noc_y); + std::vector const values = { + connection.persistent_fabric, + edm_noc_xy.to_uint32(), + connection.edm_buffer_base_addr, + connection.num_buffers_per_channel, + connection.edm_l1_sem_addr, + connection.edm_connection_handshake_addr, + connection.edm_worker_location_info_addr, + connection.buffer_size_bytes, + connection.buffer_index_semaphore_id, + sender_worker_flow_control_semaphore_id, + sender_worker_buffer_index_semaphore_id + }; + args_out.reserve(args_out.size() + (values.size() / sizeof(size_t))); + std::ranges::copy(values, std::back_inserter(args_out)); +} + +size_t log_worker_to_fabric_edm_sender_rt_args(std::vector const& args, size_t starting_arg_idx) { + log_trace(tt::LogOp, "Worker to fabric EDM Sender has {} RT Args: {}", args.size(), args); + log_trace(tt::LogOp, "arg[{}]: edm_noc_xy {}", starting_arg_idx, args[starting_arg_idx++]); + log_trace(tt::LogOp, "arg[{}]: edm_buffer_base_addr {}", starting_arg_idx, args[starting_arg_idx++]); + log_trace(tt::LogOp, "arg[{}]: num_buffers_per_channel {}", starting_arg_idx, args[starting_arg_idx++]); + log_trace(tt::LogOp, "arg[{}]: edm_l1_sem_addr {}", starting_arg_idx, args[starting_arg_idx++]); + log_trace(tt::LogOp, "arg[{}]: edm_connection_handshake_addr {}", starting_arg_idx, args[starting_arg_idx++]); + log_trace(tt::LogOp, "arg[{}]: edm_worker_location_info_addr {}", starting_arg_idx, args[starting_arg_idx++]); + log_trace(tt::LogOp, "arg[{}]: buffer_size_bytes {}", starting_arg_idx, args[starting_arg_idx++]); + log_trace(tt::LogOp, "arg[{}]: buffer_index_semaphore_id {}", starting_arg_idx, args[starting_arg_idx++]); + log_trace(tt::LogOp, "arg[{}]: sender_worker_flow_control_semaphore_id {}", starting_arg_idx, args[starting_arg_idx++]); + log_trace(tt::LogOp, "arg[{}]: sender_worker_buffer_index_semaphore_id {}", starting_arg_idx, args[starting_arg_idx++]); + return starting_arg_idx + 10; +} + FabricEriscDatamoverBuilder::FabricEriscDatamoverBuilder( CoreCoord const& my_eth_core_logical, size_t my_noc_x, @@ -92,7 +155,9 @@ FabricEriscDatamoverBuilder::FabricEriscDatamoverBuilder( size_t sender_channel_0_buffer_index_semaphore_id, size_t sender_channel_1_buffer_index_semaphore_id, - FabricEriscDatamoverConfig const& config) : + FabricEriscDatamoverConfig const& config, + bool enable_persistent_mode, + bool build_in_worker_connection_mode) : my_eth_core_logical(my_eth_core_logical), my_noc_x(my_noc_x), my_noc_y(my_noc_y), @@ -114,7 +179,7 @@ FabricEriscDatamoverBuilder::FabricEriscDatamoverBuilder( sender_channel_0_buffer_index_semaphore_id(sender_channel_0_buffer_index_semaphore_id), sender_channel_1_buffer_index_semaphore_id(sender_channel_1_buffer_index_semaphore_id), - receiver_channel_local_buffer_index_addr(FabricEriscDatamoverConfig::receiver_channel_local_buffer_index_addr), + receiver_channel_local_buffer_index_address(FabricEriscDatamoverConfig::receiver_channel_local_buffer_index_address), local_sender_channel_0_buffer_address(config.sender_0_channel_base_address), local_sender_channel_0_connection_info_addr( @@ -124,7 +189,9 @@ FabricEriscDatamoverBuilder::FabricEriscDatamoverBuilder( FabricEriscDatamoverConfig::sender_channel_1_worker_connection_info_address), local_receiver_channel_buffer_address(config.receiver_channel_base_address), - termination_signal_ptr(FabricEriscDatamoverConfig::termination_signal_address) {} + termination_signal_ptr(FabricEriscDatamoverConfig::termination_signal_address), + enable_persistent_mode(enable_persistent_mode), + build_in_worker_connection_mode(build_in_worker_connection_mode) {} std::vector FabricEriscDatamoverBuilder::get_compile_time_args() const { const bool is_handshake_master = this->my_chip_id < this->peer_chip_id; @@ -156,7 +223,8 @@ std::vector FabricEriscDatamoverBuilder::get_compile_time_args() const config.sender_0_channel_base_address, config.sender_1_channel_base_address, - this->termination_signal_ptr}; + this->termination_signal_ptr, + this->enable_persistent_mode}; } std::vector FabricEriscDatamoverBuilder::get_runtime_args() const { @@ -172,9 +240,9 @@ std::vector FabricEriscDatamoverBuilder::get_runtime_args() const { this->downstream_edm_semaphore_address.value_or(-1), this->downstream_edm_worker_registration_address.value_or(0), this->downstream_edm_worker_location_info_address.value_or(0), - this->receiver_channel_local_buffer_index_addr, + this->receiver_channel_local_buffer_index_address, // this is the receiver channel's local sem for flow controlling with downstream fabric sender - this->receiver_channel_downstream_flow_control_semaphore_id.value_or(0), + this->receiver_channel_downstream_flow_control_semaphore_id.value_or(-1), this->sender_channel_0_flow_control_semaphore_id, this->sender_channel_1_flow_control_semaphore_id }; @@ -186,40 +254,87 @@ FabricEriscDatamoverBuilder FabricEriscDatamoverBuilder::build( CoreCoord const& ethernet_core, chip_id_t local_chip_id, chip_id_t peer_chip_id, - FabricEriscDatamoverConfig const& config) { - std::optional receiver_channel_downstream_flow_control_semaphore_id = std::nullopt; - auto sender_channel_0_flow_control_semaphore_id = - tt::tt_metal::CreateSemaphore(program, ethernet_core, 0, CoreType::ETH); - auto sender_channel_1_flow_control_semaphore_id = - tt::tt_metal::CreateSemaphore(program, ethernet_core, 0, CoreType::ETH); - auto sender_channel_0_connection_semaphore_id = - tt::tt_metal::CreateSemaphore(program, ethernet_core, 0, CoreType::ETH); - auto sender_channel_1_connection_semaphore_id = - tt::tt_metal::CreateSemaphore(program, ethernet_core, 0, CoreType::ETH); - auto sender_channel_0_buffer_index_semaphore_id = - tt::tt_metal::CreateSemaphore(program, ethernet_core, 0, CoreType::ETH); - auto sender_channel_1_buffer_index_semaphore_id = - tt::tt_metal::CreateSemaphore(program, ethernet_core, 0, CoreType::ETH); - - return FabricEriscDatamoverBuilder( - ethernet_core, - device->ethernet_core_from_logical_core(ethernet_core).x, - device->ethernet_core_from_logical_core(ethernet_core).y, - local_chip_id, - peer_chip_id, - - receiver_channel_downstream_flow_control_semaphore_id, - sender_channel_0_flow_control_semaphore_id, - sender_channel_1_flow_control_semaphore_id, - sender_channel_0_connection_semaphore_id, - sender_channel_1_connection_semaphore_id, - sender_channel_0_buffer_index_semaphore_id, - sender_channel_1_buffer_index_semaphore_id, - - config); + FabricEriscDatamoverConfig const& config, + bool enable_persistent_mode, + bool build_in_worker_connection_mode) { + if (enable_persistent_mode) { + auto sender_channel_0_buffer_index_semaphore_address = + FabricEriscDatamoverConfig::sender_channel_0_buffer_index_semaphore_address; + auto sender_channel_0_flow_control_semaphore_address = + FabricEriscDatamoverConfig::sender_channel_0_local_flow_control_semaphore_address; + auto sender_channel_0_connection_semaphore_address = + FabricEriscDatamoverConfig::sender_channel_0_connection_semaphore_address; + + std::optional receiver_channel_downstream_flow_control_semaphore_address = + build_in_worker_connection_mode ? 0: tt::tt_metal::CreateSemaphore(program, ethernet_core, 0, CoreType::ETH); + auto sender_channel_1_flow_control_semaphore_id = + build_in_worker_connection_mode ? 0: tt::tt_metal::CreateSemaphore(program, ethernet_core, 0, CoreType::ETH); + auto sender_channel_1_connection_semaphore_id = + build_in_worker_connection_mode ? 0: tt::tt_metal::CreateSemaphore(program, ethernet_core, 0, CoreType::ETH); + auto sender_channel_1_buffer_index_semaphore_id = + build_in_worker_connection_mode ? 0: tt::tt_metal::CreateSemaphore(program, ethernet_core, 0, CoreType::ETH); + + return FabricEriscDatamoverBuilder( + ethernet_core, + device->ethernet_core_from_logical_core(ethernet_core).x, + device->ethernet_core_from_logical_core(ethernet_core).y, + local_chip_id, + peer_chip_id, + + receiver_channel_downstream_flow_control_semaphore_address, + sender_channel_0_flow_control_semaphore_address, + sender_channel_1_flow_control_semaphore_id, + sender_channel_0_connection_semaphore_address, + sender_channel_1_connection_semaphore_id, + sender_channel_0_buffer_index_semaphore_address, + sender_channel_1_buffer_index_semaphore_id, + + config, + enable_persistent_mode, + build_in_worker_connection_mode); + + } else { + std::optional receiver_channel_downstream_flow_control_semaphore_id = tt::tt_metal::CreateSemaphore(program, ethernet_core, 0, CoreType::ETH); + auto sender_channel_0_flow_control_semaphore_id = + tt::tt_metal::CreateSemaphore(program, ethernet_core, 0, CoreType::ETH); + auto sender_channel_1_flow_control_semaphore_id = + tt::tt_metal::CreateSemaphore(program, ethernet_core, 0, CoreType::ETH); + auto sender_channel_0_connection_semaphore_id = + tt::tt_metal::CreateSemaphore(program, ethernet_core, 0, CoreType::ETH); + auto sender_channel_1_connection_semaphore_id = + tt::tt_metal::CreateSemaphore(program, ethernet_core, 0, CoreType::ETH); + auto sender_channel_0_buffer_index_semaphore_id = + tt::tt_metal::CreateSemaphore(program, ethernet_core, 0, CoreType::ETH); + auto sender_channel_1_buffer_index_semaphore_id = + tt::tt_metal::CreateSemaphore(program, ethernet_core, 0, CoreType::ETH); + + return FabricEriscDatamoverBuilder( + ethernet_core, + device->ethernet_core_from_logical_core(ethernet_core).x, + device->ethernet_core_from_logical_core(ethernet_core).y, + local_chip_id, + peer_chip_id, + + receiver_channel_downstream_flow_control_semaphore_id, + sender_channel_0_flow_control_semaphore_id, + sender_channel_1_flow_control_semaphore_id, + sender_channel_0_connection_semaphore_id, + sender_channel_1_connection_semaphore_id, + sender_channel_0_buffer_index_semaphore_id, + sender_channel_1_buffer_index_semaphore_id, + + config, + enable_persistent_mode); + } } SenderWorkerAdapterSpec FabricEriscDatamoverBuilder::build_connection_to_worker_channel() const { + if (this->enable_persistent_mode) { + log_trace(tt::LogOp, "Building connection to persistent fabric"); + } else { + log_trace(tt::LogOp, "Building connection to non-persistent fabric"); + } + TT_FATAL(sender_channel_0_buffer_index_semaphore_id != sender_channel_0_flow_control_semaphore_id, "Internal error - sender_channel_0_buffer_index_semaphore_id and sender_channel_0_flow_control_semaphore_id aliased eachother"); return SenderWorkerAdapterSpec { this->my_noc_x, this->my_noc_y, @@ -229,7 +344,8 @@ SenderWorkerAdapterSpec FabricEriscDatamoverBuilder::build_connection_to_worker_ this->sender_channel_0_connection_semaphore_id, FabricEriscDatamoverConfig::sender_channel_0_worker_connection_info_address, this->config.channel_buffer_size_bytes, - this->sender_channel_0_buffer_index_semaphore_id + this->sender_channel_0_buffer_index_semaphore_id, + this->enable_persistent_mode }; } @@ -244,11 +360,13 @@ SenderWorkerAdapterSpec FabricEriscDatamoverBuilder::build_connection_to_fabric_ this->sender_channel_1_connection_semaphore_id, FabricEriscDatamoverConfig::sender_channel_1_worker_connection_info_address, this->config.channel_buffer_size_bytes, - this->sender_channel_1_buffer_index_semaphore_id + this->sender_channel_1_buffer_index_semaphore_id, + false }; } void FabricEriscDatamoverBuilder::connect_to_downstream_edm(FabricEriscDatamoverBuilder const& downstream_edm) { + TT_FATAL(!this->build_in_worker_connection_mode, "Tried to connect two EDMs to each other in worker connection mode"); auto const adapter_spec = downstream_edm.build_connection_to_fabric_channel(); log_trace(tt::LogTest, "Connecting to downstream EDM at x={}, y={}", adapter_spec.edm_noc_x, adapter_spec.edm_noc_y); @@ -262,11 +380,13 @@ void FabricEriscDatamoverBuilder::connect_to_downstream_edm(FabricEriscDatamover this->downstream_sender_channel_buffer_index_semaphore_id = adapter_spec.buffer_index_semaphore_id; } - - -EdmLineFabricOpInterface::EdmLineFabricOpInterface (std::vector const& device_sequence, std::vector const& program_sequence, std::optional desired_num_links) : - device_sequence(device_sequence), - programs(program_sequence) { +EdmLineFabricOpInterface::EdmLineFabricOpInterface( + std::vector const& device_sequence, + std::vector const& program_sequence, + bool enable_persistent_mode, + std::optional desired_num_links, + bool build_in_worker_connection_mode) : + device_sequence(device_sequence), programs(program_sequence) { static constexpr std::size_t edm_buffer_size = 4096 + sizeof(tt::fabric::PacketHeader); auto const config = FabricEriscDatamoverConfig(edm_buffer_size, 1, 2); TT_ASSERT(device_sequence.size() == program_sequence.size()); @@ -274,8 +394,21 @@ EdmLineFabricOpInterface::EdmLineFabricOpInterface (std::vector const& for (size_t i = 0; i < device_sequence.size(); i++) { log_trace(tt::LogOp, "device[{}] id={}", i, device_sequence[i]->id()); } + size_t min_link_count = desired_num_links.value_or(std::numeric_limits::max()); + for (size_t hop = 0; hop < device_sequence.size() - 1; hop++) { + auto src_device = device_sequence[hop]; + auto dest_device = device_sequence[hop + 1]; + auto const& src_device_sockets = src_device->get_ethernet_sockets(dest_device->id());; + auto const& dest_device_sockets = dest_device->get_ethernet_sockets(src_device->id());; + if (src_device_sockets.size() > 0) { + min_link_count = std::min(min_link_count, src_device_sockets.size()); + } + if (src_device_sockets.size() > 0) { + min_link_count = std::min(min_link_count, dest_device_sockets.size()); + } + } - + FabricEriscDatamoverBuilder *a_builder = nullptr; // Construct the builders for (size_t hop = 0; hop < device_sequence.size() - 1; hop++) { auto src_device = device_sequence[hop]; @@ -288,7 +421,7 @@ EdmLineFabricOpInterface::EdmLineFabricOpInterface (std::vector const& std::copy_if(src_device_sockets.begin(), src_device_sockets.end(), std::back_inserter(local_link_cores), [src_device](CoreCoord const& core) { return src_device->is_active_ethernet_core(core, true); }); std::copy_if(dest_device_sockets.begin(), dest_device_sockets.end(), std::back_inserter(remote_link_cores), [dest_device](CoreCoord const& core) { return dest_device->is_active_ethernet_core(core, true); }); - this->num_links = std::min(desired_num_links.value_or(std::numeric_limits::max()), local_link_cores.size()); + this->num_links = min_link_count; TT_ASSERT(local_link_cores.size() == remote_link_cores.size()); @@ -302,7 +435,9 @@ EdmLineFabricOpInterface::EdmLineFabricOpInterface (std::vector const& local_link_cores[l], src_device->id(), dest_device->id(), - config)); + config, + enable_persistent_mode, + build_in_worker_connection_mode)); log_trace(tt::LogOp, "Building backward direction EDM on chip {} on link {}", dest_device->id(), edm_builders_backward_direction[dest_device->id()].size()); edm_builders_backward_direction[dest_device->id()].push_back(FabricEriscDatamoverBuilder::build( @@ -311,44 +446,170 @@ EdmLineFabricOpInterface::EdmLineFabricOpInterface (std::vector const& remote_link_cores[l], dest_device->id(), src_device->id(), - config)); + config, + enable_persistent_mode, + build_in_worker_connection_mode)); + + a_builder = &edm_builders_backward_direction[dest_device->id()].front(); } + + this->buffer_size_bytes = a_builder->channel_buffer_size; } - // Establish local connections between EDMs on the same chips to establish the lin fabric - for (size_t i = 1; i < device_sequence.size() - 1; i++) { - const size_t num_links = edm_builders_forward_direction.at(device_sequence[i]->id()).size(); - auto& forward_direction_edm = edm_builders_forward_direction.at(device_sequence[i]->id()); - auto& backward_direction_edm = edm_builders_backward_direction.at(device_sequence[i]->id()); + if (!build_in_worker_connection_mode) { + // Establish local connections between EDMs on the same chips to establish the lin fabric + for (size_t i = 1; i < device_sequence.size() - 1; i++) { + const size_t num_links = edm_builders_forward_direction.at(device_sequence[i]->id()).size(); + auto& forward_direction_edm = edm_builders_forward_direction.at(device_sequence[i]->id()); + auto& backward_direction_edm = edm_builders_backward_direction.at(device_sequence[i]->id()); - for (size_t l = 0; l < num_links; l++) { - forward_direction_edm.at(l).connect_to_downstream_edm(backward_direction_edm.at(l)); - backward_direction_edm.at(l).connect_to_downstream_edm(forward_direction_edm.at(l)); + for (size_t l = 0; l < num_links; l++) { + forward_direction_edm.at(l).connect_to_downstream_edm(backward_direction_edm.at(l)); + backward_direction_edm.at(l).connect_to_downstream_edm(forward_direction_edm.at(l)); + } } } - } +// Invocable per chip if we want to collectively build the fabric by building this separately per chip +// (and implicitly building the fabric that way) +EdmLineFabricOpInterface::EdmLineFabricOpInterface( + Device* local_device, + std::optional forward_device, + std::optional backward_device, + Program* program, + bool enable_persistent_mode, + std::optional desired_num_links, + bool build_in_worker_connection_mode) : + device_sequence({local_device}), programs({program}) { + static constexpr std::size_t edm_buffer_size = 4096 + sizeof(tt::fabric::PacketHeader); + auto const config = FabricEriscDatamoverConfig(edm_buffer_size, 1, 2); + + log_trace(tt::LogOp, "device id={}", local_device->id()); + log_trace(tt::LogOp, "EDM Fabric Factory ctor on device: {}", local_device->id()); + if (forward_device.has_value()) { + log_trace(tt::LogOp, "\tConnect[FORWARD]: {} -> {}", local_device->id(), forward_device.value()->id()); + } + if (backward_device.has_value()) { + log_trace(tt::LogOp, "\tConnect[BACKWARD]: {} -> {}", local_device->id(), backward_device.value()->id()); + } + + // Construct the builders + std::array>, 2> device_pairs = { + std::pair>{local_device, forward_device}, + std::pair>{local_device, backward_device} + }; + + static_assert(EdmLineFabricOpInterface::Direction::FORWARD < 2); + static_assert(EdmLineFabricOpInterface::Direction::BACKWARD < 2); + std::array>*, 2> edm_builders_maps; + edm_builders_maps[EdmLineFabricOpInterface::Direction::FORWARD] = &this->edm_builders_forward_direction; + edm_builders_maps[EdmLineFabricOpInterface::Direction::BACKWARD] = &this->edm_builders_backward_direction; + + std::optional counted_num_links = std::nullopt; + std::optional obtained_channel_buffer_size = std::nullopt; + const size_t max_num_links = desired_num_links.value_or(std::numeric_limits::max()); + for (size_t i = 0; i < device_pairs.size(); i++) { + if (!device_pairs[i].second.has_value()) { + continue; + } + log_trace(tt::LogOp, "Device {} is connected to {} at index {}", local_device->id(), device_pairs[i].second.value()->id(), i); + auto &edm_builders = *edm_builders_maps[i]; + + Device *remote_device = device_pairs[i].second.value(); + auto const connected_sockets = local_device->get_ethernet_sockets(remote_device->id()); + + TT_FATAL(edm_builders.size() == 0, "EDM builders already exist for this device"); + edm_builders.clear(); + for (const auto& core : local_device->get_ethernet_sockets(remote_device->id())) { + if (!local_device->is_active_ethernet_core(core, true)) { + continue; + } + if (edm_builders[local_device->id()].size() >= max_num_links) { + break; + } + log_trace(tt::LogOp, "DEBUG: build EDM: device: {}, &program: {}: core-logi(x={},y={})", local_device->id(), (void*)program, core.x, core.y); + edm_builders[local_device->id()].push_back( + FabricEriscDatamoverBuilder::build( + local_device, *program, core, + device_pairs[i].first->id(), + device_pairs[i].second.value()->id(), + config, + enable_persistent_mode, + build_in_worker_connection_mode)); + } + if (!counted_num_links.has_value()) { + TT_FATAL(!obtained_channel_buffer_size.has_value(), "No channel buffer size was counted"); + counted_num_links = edm_builders[local_device->id()].size(); + obtained_channel_buffer_size = edm_builders[local_device->id()].front().channel_buffer_size; + } + } + TT_FATAL(counted_num_links.has_value(), "No links were counted"); + this->num_links = counted_num_links.value(); + + TT_FATAL(obtained_channel_buffer_size.has_value(), "No channel buffer size was counted"); + this->buffer_size_bytes = obtained_channel_buffer_size.value(); + + if (!build_in_worker_connection_mode) { + // Establish local connections between EDMs on the same chips to establish the line fabric + if (forward_device.has_value() && backward_device.has_value()) { + auto& forward_direction_edm = edm_builders_forward_direction.at(local_device->id()); + auto& backward_direction_edm = edm_builders_backward_direction.at(local_device->id()); + + for (size_t l = 0; l < this->num_links; l++) { + forward_direction_edm.at(l).connect_to_downstream_edm(backward_direction_edm.at(l)); + backward_direction_edm.at(l).connect_to_downstream_edm(forward_direction_edm.at(l)); + } + } + } +} SenderWorkerAdapterSpec EdmLineFabricOpInterface::uniquely_connect_worker(Device* device, Direction direction) { - TT_ASSERT((direction == FORWARD) ? edm_builders_forward_direction.find(device->id()) != edm_builders_forward_direction.end() - : edm_builders_backward_direction.find(device->id()) != edm_builders_backward_direction.end()); + TT_FATAL((direction == FORWARD) ? edm_builders_forward_direction.find(device->id()) != edm_builders_forward_direction.end() + : edm_builders_backward_direction.find(device->id()) != edm_builders_backward_direction.end(), "Device {} not found in edm builders", device->id()); auto& edm_builders = (direction == FORWARD) ? edm_builders_forward_direction.at(device->id()) : edm_builders_backward_direction.at(device->id()); auto &link_count_map = (direction == FORWARD) ? next_forward_direction_edm_available : next_backward_direction_edm_available; + log_trace(tt::LogOp, "EDM conecting in {} direction", direction == FORWARD ? "FORWARD" : "BACKWARD"); const auto next_link = link_count_map[device->id()]; - link_count_map[device->id()] = next_link + 1; + link_count_map[device->id()] = (next_link + 1) % edm_builders.size(); - TT_ASSERT(edm_builders.size() > 0); - TT_ASSERT(next_link < edm_builders.size()); + TT_FATAL(edm_builders.size() > 0, "No EDM builders found for device {}", device->id()); + TT_FATAL(next_link < edm_builders.size(), "Next link index {} is out of bounds for device {}", next_link, device->id()); return edm_builders.at(next_link).build_connection_to_worker_channel(); } +EdmLineFabricOpInterface EdmLineFabricOpInterface::build_program_builder_worker_connection_fabric( + std::vector const& device_sequence, + std::vector const& program_sequence, + bool enable_persistent_mode, + std::optional desired_num_links) { + return EdmLineFabricOpInterface(device_sequence, program_sequence, enable_persistent_mode, desired_num_links, true); +} + +EdmLineFabricOpInterface EdmLineFabricOpInterface::build_program_builder_worker_connection_fabric( + Device* local_device, + std::optional forward_device, + std::optional backward_device, + Program* program, + bool enable_persistent_mode, + std::optional desired_num_links) { + return EdmLineFabricOpInterface(local_device, forward_device, backward_device, program, enable_persistent_mode, desired_num_links, true); +} + void EdmLineFabricOpInterface::build_kernels() const { auto generate_kernels_in_direction = [this](Device *device, Program *program, Direction direction) { auto &edm_builders = direction == FORWARD ? edm_builders_forward_direction : edm_builders_backward_direction; if (edm_builders.find(device->id()) != edm_builders.end()) { for (auto& edm_builder : edm_builders.at(device->id())) { + log_trace( + tt::LogOp, + "Building EDM kernel on device {}, logical-core (y={},x={}), noc_core (y={},x={})", + device->id(), + edm_builder.my_eth_core_logical.y, + edm_builder.my_eth_core_logical.x, + device->ethernet_core_from_logical_core(edm_builder.my_eth_core_logical).y, + device->ethernet_core_from_logical_core(edm_builder.my_eth_core_logical).x); auto local_edm_kernel = ttnn::ccl::generate_edm_kernel( *program, device, @@ -368,7 +629,30 @@ void EdmLineFabricOpInterface::build_kernels() const { } } - +std::vector EdmLineFabricOpInterface::generate_local_chip_fabric_termination_infos(Device *device) const { + auto generate_termination_info = [](FabricEriscDatamoverBuilder const& edm_builder) -> edm_termination_info_t { + return edm_termination_info_t{ + 0, + edm_builder.my_noc_x, + edm_builder.my_noc_y, + ttnn::ccl::FabricEriscDatamoverConfig::termination_signal_address}; + }; + std::vector edm_termination_infos; + edm_termination_infos.reserve(this->num_links * 2); + if (edm_builders_backward_direction.find(device->id()) != edm_builders_backward_direction.end()) { + std::ranges::transform( + edm_builders_backward_direction.at(device->id()), + std::back_inserter(edm_termination_infos), + generate_termination_info); + } + if (edm_builders_forward_direction.find(device->id()) != edm_builders_forward_direction.end()) { + std::ranges::transform( + edm_builders_forward_direction.at(device->id()), + std::back_inserter(edm_termination_infos), + generate_termination_info); + } + return edm_termination_infos; +} std::vector EdmLineFabricOpInterface::generate_ordered_termination_info_farthest_to_nearest() const { TT_ASSERT(device_sequence.size() > 0); @@ -411,7 +695,92 @@ std::vector EdmLineFabricOpInterface::generate_ordered_t } +void FabricEriscDatamoverBuilder::teardown_from_host(Device *d, tt::fabric::TerminationSignal termination_signal) const { + std::vector val(1, termination_signal); + d->push_work([&](){tt::tt_metal::detail::WriteToDeviceL1( + d, + d->logical_core_from_ethernet_core(CoreCoord(this->my_noc_x, this->my_noc_y)), + ttnn::ccl::FabricEriscDatamoverConfig::termination_signal_address, + val, + CoreType::ETH);}, true); +} + +void EdmLineFabricOpInterface::teardown_from_host(tt::fabric::TerminationSignal termination_signal) const { + for (Device *d : this->device_sequence) { + if (edm_builders_forward_direction.find(d->id()) != edm_builders_forward_direction.end()) { + for (auto& edm_builder : edm_builders_forward_direction.at(d->id())) { + edm_builder.teardown_from_host(d, termination_signal); + } + } + if (edm_builders_backward_direction.find(d->id()) != edm_builders_backward_direction.end()) { + for (auto& edm_builder : edm_builders_backward_direction.at(d->id())) { + edm_builder.teardown_from_host(d, termination_signal); + } + } + } +} + +void initialize_edm_fabric(distributed::MeshDevice* mesh_device) { + + std::vector row_fabric_lines; + row_fabric_lines.reserve(mesh_device->get_view().get_row_views().size()); + std::vector col_fabric_lines; + col_fabric_lines.reserve(mesh_device->get_view().get_column_views().size()); + + size_t num_rows = mesh_device->get_view().get_row_views().size(); + size_t num_cols = mesh_device->get_view().get_column_views().size(); + std::vector> programs(num_rows); + for (size_t r = 0; r < num_rows; r++) { + programs[r].resize(num_cols); + } + + for (size_t i = 0; i < num_rows; i++) { + std::vector program_ptrs; + program_ptrs.reserve(num_cols); + std::transform(programs[i].begin(), programs[i].end(), std::back_inserter(program_ptrs), [](Program& p) { return &p; }); + row_fabric_lines.push_back(EdmLineFabricOpInterface(mesh_device->get_view().get_row_views()[i], program_ptrs, true)); + } + + for (size_t i = 0; i < num_cols; i++) { + std::vector program_ptrs; + program_ptrs.reserve(num_rows); + for (size_t r = 0; r < num_rows; r++) { + program_ptrs.push_back(&programs[r][i]); + } + col_fabric_lines.push_back(EdmLineFabricOpInterface(mesh_device->get_view().get_column_views()[i], program_ptrs, true)); + } + + std::for_each(row_fabric_lines.begin(), row_fabric_lines.end(), [](auto& line) { line.build_kernels(); }); + std::for_each(col_fabric_lines.begin(), col_fabric_lines.end(), [](auto& line) { line.build_kernels(); }); + + for (size_t r = 0; r < num_rows; r++) { + for (size_t c = 0; c < num_cols; c++) { + log_info(tt::LogAlways, "Compile EDM program"); + Device *device = mesh_device->get_device(r, c); + auto& program = programs.at(r).at(c); + device->push_work([&](){tt::tt_metal::detail::CompileProgram(device, program);}, false); + device->push_work([&](){tt::tt_metal::EnqueueProgram(device->command_queue(), program, false);}, true); + } + } +} +void teardown_edm_fabric(distributed::MeshDevice* mesh_device) { + auto teardown = [](std::vector const& line_view) { + std::vector programs(line_view.size()); + std::vector program_ptrs; + program_ptrs.reserve(programs.size()); + std::transform(programs.begin(), programs.end(), std::back_inserter(program_ptrs), [](Program& p) { return &p; }); + EdmLineFabricOpInterface edm_fabric(line_view, program_ptrs, true); + edm_fabric.teardown_from_host(tt::fabric::TerminationSignal::IMMEDIATELY_TERMINATE); + }; + + for (auto const &row_view : mesh_device->get_view().get_row_views()) { + teardown(row_view); + } + for (auto const &col_view : mesh_device->get_view().get_column_views()) { + teardown(col_view); + } +} } // namespace ttnn::ccl diff --git a/ttnn/cpp/ttnn/operations/ccl/erisc_datamover_builder.hpp b/ttnn/cpp/ttnn/operations/ccl/erisc_datamover_builder.hpp index bf330aca910..14ab290f2b9 100644 --- a/ttnn/cpp/ttnn/operations/ccl/erisc_datamover_builder.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/erisc_datamover_builder.hpp @@ -9,16 +9,19 @@ #include #include "eth_l1_address_map.h" +#include "ttnn/distributed/types.hpp" #include "umd/device/types/cluster_descriptor_types.h" #include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_types.hpp" #include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" - +#include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header.hpp" #include "tt_metal/impl/device/device.hpp" #include "tt_metal/impl/program/program.hpp" #include #include +#include + namespace ttnn { namespace ccl { @@ -30,30 +33,50 @@ struct FabricEriscDatamoverConfig { // Global static constexpr std::size_t eth_channel_sync_size = 16; - static constexpr std::size_t handshake_addr = eth_l1_mem::address_map::ERISC_L1_UNRESERVED_BASE; + static constexpr std::size_t handshake_addr = eth_l1_mem::address_map::ERISC_L1_UNRESERVED_BASE/* + 1024*/; static constexpr std::size_t edm_channel_ack_addr = handshake_addr + eth_channel_sync_size; static constexpr std::size_t termination_signal_address = edm_channel_ack_addr + (2 * eth_channel_sync_size); // pad extra bytes to match old EDM so handshake logic will still work - // Sender Channel 0 + // ----------- Sender Channel 0 static constexpr std::size_t sender_channel_0_buffer_index_address = termination_signal_address + field_size; static constexpr std::size_t sender_channel_0_worker_connection_info_address = sender_channel_0_buffer_index_address + field_size; + static constexpr std::size_t sender_channel_0_local_flow_control_semaphore_address = + sender_channel_0_worker_connection_info_address + field_size; + // persistent mode field + static constexpr std::size_t sender_channel_0_connection_semaphore_address = + sender_channel_0_local_flow_control_semaphore_address + field_size; + // persistent mode field + static constexpr std::size_t sender_channel_0_buffer_index_semaphore_address = + sender_channel_0_connection_semaphore_address + field_size; + static_assert(field_size >= sizeof(tt::fabric::EDMChannelWorkerLocationInfo)); - // Sender Channel 1 + // ----------- Sender Channel 1 static constexpr std::size_t sender_channel_1_buffer_index_address = - sender_channel_0_worker_connection_info_address + field_size; + sender_channel_0_buffer_index_semaphore_address + field_size; static constexpr std::size_t sender_channel_1_worker_connection_info_address = sender_channel_1_buffer_index_address + field_size; - - // Receiver Channel - static constexpr std::size_t receiver_channel_local_buffer_index_addr = + static constexpr std::size_t sender_channel_1_local_flow_control_semaphore_address = sender_channel_1_worker_connection_info_address + field_size; + // persistent mode field + static constexpr std::size_t sender_channel_1_connection_semaphore_address = + sender_channel_1_local_flow_control_semaphore_address + field_size; + // persistent mode field + static constexpr std::size_t sender_channel_1_buffer_index_semaphore_address = + sender_channel_1_connection_semaphore_address + field_size; + + // ----------- Receiver Channel + static constexpr std::size_t receiver_channel_local_buffer_index_address = + sender_channel_1_buffer_index_semaphore_address + field_size; + // persistent mode field + static constexpr std::size_t receiver_channel_downstream_flow_control_semaphore_address = + receiver_channel_local_buffer_index_address + field_size; // Channel Allocations static constexpr std::size_t buffer_region_start = - (receiver_channel_local_buffer_index_addr + field_size + buffer_alignment) & ~(buffer_alignment - 1); // Align + (receiver_channel_downstream_flow_control_semaphore_address + field_size + buffer_alignment) & ~(buffer_alignment - 1); // Align static constexpr std::size_t available_channel_buffering_space = eth_l1_mem::address_map::MAX_L1_LOADING_SIZE - buffer_region_start; @@ -86,7 +109,21 @@ struct SenderWorkerAdapterSpec { size_t edm_worker_location_info_addr = 0; // The EDM's location for `EDMChannelWorkerLocationInfo` size_t buffer_size_bytes = 0; size_t buffer_index_semaphore_id = 0; // the semaphore ID on the EDM, not the worker + bool persistent_fabric = false; +}; + + +struct edm_termination_info_t { + uint32_t distance = 0; + uint32_t edm_noc_x = 0; + uint32_t edm_noc_y = 0; + uint32_t termination_addr = 0; }; + +void get_runtime_args_for_edm_termination_infos(std::vector const& edm_termination_infos, std::vector& args_out); +void append_worker_to_fabric_edm_sender_rt_args(SenderWorkerAdapterSpec const& connection, size_t sender_worker_flow_control_semaphore_id, size_t sender_worker_buffer_index_semaphore_id, std::vector& args_out); +size_t log_worker_to_fabric_edm_sender_rt_args(std::vector const& args, size_t starting_arg_idx = 0); + class FabricEriscDatamoverBuilder { public: FabricEriscDatamoverBuilder( @@ -104,7 +141,9 @@ class FabricEriscDatamoverBuilder { size_t sender_channel_0_buffer_index_semaphore_id, size_t sender_channel_1_buffer_index_semaphore_id, - FabricEriscDatamoverConfig const& config); + FabricEriscDatamoverConfig const& config, + bool enable_persistent_mode, + bool build_in_worker_connection_mode=false); static FabricEriscDatamoverBuilder build( tt::tt_metal::Device* device, @@ -112,7 +151,9 @@ class FabricEriscDatamoverBuilder { CoreCoord const& ethernet_core, chip_id_t local_chip_id, chip_id_t peer_chip_id, - FabricEriscDatamoverConfig const& config); + FabricEriscDatamoverConfig const& config, + bool enable_persistent_mode, + bool build_in_worker_connection_mode=false); [[nodiscard]] SenderWorkerAdapterSpec build_connection_to_worker_channel() const; [[nodiscard]] SenderWorkerAdapterSpec build_connection_to_fabric_channel() const; @@ -127,7 +168,9 @@ class FabricEriscDatamoverBuilder { // TODO } - private: + void teardown_from_host(Device *d, tt::fabric::TerminationSignal termination_signal = tt::fabric::TerminationSignal::GRACEFULLY_TERMINATE) const; + +// protected: friend class EdmLineFabricOpInterface; CoreCoord my_eth_core_logical; size_t my_noc_x = 0; @@ -161,7 +204,7 @@ class FabricEriscDatamoverBuilder { size_t sender_channel_1_connection_semaphore_id = 0; size_t sender_channel_0_buffer_index_semaphore_id = 0; size_t sender_channel_1_buffer_index_semaphore_id = 0; - size_t receiver_channel_local_buffer_index_addr = 0; + size_t receiver_channel_local_buffer_index_address = 0; std::optional downstream_edm_noc_x; std::optional downstream_edm_noc_y; @@ -170,17 +213,14 @@ class FabricEriscDatamoverBuilder { std::optional downstream_edm_worker_registration_address; std::optional downstream_edm_worker_location_info_address; std::optional downstream_sender_channel_buffer_index_semaphore_id; + bool enable_persistent_mode = false; + bool build_in_worker_connection_mode = false; }; -struct edm_termination_info_t { - uint32_t distance = 0; - uint32_t edm_noc_x = 0; - uint32_t edm_noc_y = 0; - uint32_t termination_addr = 0; -}; -struct EdmLineFabricOpInterface { +class EdmLineFabricOpInterface { + public: enum Direction { // Ascending chips in the sequence FORWARD, @@ -189,22 +229,16 @@ struct EdmLineFabricOpInterface { BACKWARD, }; - // Device ID -> EDM Builders - std::unordered_map> edm_builders_forward_direction; - std::unordered_map> edm_builders_backward_direction; - - // Device ID -> link index - std::unordered_map next_forward_direction_edm_available; - std::unordered_map next_backward_direction_edm_available; - - std::vector device_sequence; - std::vector programs; - - size_t num_links = 0; // The constructor will assemble/connect the line across the specified device sequence, for all available links. - EdmLineFabricOpInterface (std::vector const& device_sequence, std::vector const& program_sequence, std::optional desired_num_links = std::nullopt); + EdmLineFabricOpInterface (std::vector const& device_sequence, std::vector const& program_sequence, bool enable_persistent_mode, std::optional desired_num_links = std::nullopt, bool build_in_worker_connection_mode = false); + + // Invocable per chip if we want to collectively build the fabric by building this separately per chip + // (and implicitly building the fabric that way) + EdmLineFabricOpInterface (Device* local_device, std::optional forward_device, std::optional backward_device, Program* program, bool enable_persistent_mode, std::optional desired_num_links, bool build_in_worker_connection_mode = false); + static EdmLineFabricOpInterface build_program_builder_worker_connection_fabric(std::vector const& device_sequence, std::vector const& program_sequence, bool enable_persistent_mode, std::optional desired_num_links = std::nullopt); + static EdmLineFabricOpInterface build_program_builder_worker_connection_fabric(Device* local_device, std::optional forward_device, std::optional backward_device, Program* program, bool enable_persistent_mode, std::optional desired_num_links = std::nullopt); // Will create a connection adapter for a worker which can be used to pass args to the worker kernel talking to the // corresponding fabric endpoint. This interface will guarantee unique connections only so requesting more unique connections @@ -222,7 +256,50 @@ struct EdmLineFabricOpInterface { // and so a termination signal may be sent to our link first before the other eth core links // on the chip so multi-link isn't officially supported yet std::vector generate_ordered_termination_info_farthest_to_nearest() const; + + // Generates a list of termination infos for the local chip's EDMs + std::vector generate_local_chip_fabric_termination_infos(Device *device) const; + + // Accessors + size_t get_num_links() const { return num_links; } + + size_t get_device_count() const { return device_sequence.size(); } + + size_t get_index_of_device(Device *device) const { + for (size_t i = 0; i < device_sequence.size(); i++) { + if (device_sequence[i] == device) { + return i; + } + } + TT_THROW("Device {} not found in device sequence of line fabric", device->id()); + return -1; + } + + size_t get_edm_buffer_size_bytes() const { return buffer_size_bytes; } + + void teardown_from_host(tt::fabric::TerminationSignal termination_signal = tt::fabric::TerminationSignal::GRACEFULLY_TERMINATE) const; + + static void launch_mesh_fabric(MeshDevice *mesh_device); + static void teardown_edm_fabric(MeshDevice *mesh_device); + + // Device ID -> EDM Builders + std::unordered_map> edm_builders_forward_direction; + std::unordered_map> edm_builders_backward_direction; + private: + + // Device ID -> link index + std::unordered_map next_forward_direction_edm_available; + std::unordered_map next_backward_direction_edm_available; + + std::vector device_sequence; + std::vector programs; + + size_t num_links; + size_t buffer_size_bytes; }; +void initialize_edm_fabric(distributed::MeshDevice* mesh_device); +void teardown_edm_fabric(distributed::MeshDevice* mesh_device); + }; // namespace ccl }; // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_utils.hpp b/ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_utils.hpp index 1252ac5f9d1..f7db1ac813f 100644 --- a/ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_utils.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_utils.hpp @@ -21,6 +21,7 @@ static FORCE_INLINE coord_t coord_from_args(std::size_t& arg_idx) { } enum EDM_IO_BLOCKING_MODE { + FLUSH_BLOCKING, BLOCKING, NON_BLOCKING }; @@ -64,7 +65,10 @@ FORCE_INLINE void send_chunk( cb_wait_front(cb_id, num_pages); uint32_t l1_read_addr = get_read_ptr(cb_id); noc_async_write(l1_read_addr, remote_l1_write_addr, page_size * num_pages); - if constexpr (blocking_mode == ttnn::ccl::EDM_IO_BLOCKING_MODE::BLOCKING) { + if constexpr (blocking_mode == ttnn::ccl::EDM_IO_BLOCKING_MODE::FLUSH_BLOCKING) { + noc_async_writes_flushed(); + cb_pop_front(cb_id, num_pages); + } else if constexpr (blocking_mode == ttnn::ccl::EDM_IO_BLOCKING_MODE::BLOCKING) { noc_async_write_barrier(); cb_pop_front(cb_id, num_pages); } diff --git a/ttnn/cpp/ttnn/operations/ccl/kernels/edm/edm_handshake.hpp b/ttnn/cpp/ttnn/operations/ccl/kernels/edm/edm_handshake.hpp index 971b3f37228..e2dad353ecc 100644 --- a/ttnn/cpp/ttnn/operations/ccl/kernels/edm/edm_handshake.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/kernels/edm/edm_handshake.hpp @@ -65,9 +65,10 @@ FORCE_INLINE void initialize_edm_common_datastructures(std::uint32_t handshake_r * As the designated master EDM core, initiate a handshake by sending a packet to reserved * memory region. */ -FORCE_INLINE void sender_side_start(std::uint32_t handshake_register_address) { +FORCE_INLINE void sender_side_start( + std::uint32_t handshake_register_address, size_t HS_CONTEXT_SWITCH_TIMEOUT = A_LONG_TIMEOUT_BEFORE_CONTEXT_SWITCH) { initialize_edm_common_datastructures(handshake_register_address); - eth_wait_receiver_done(A_LONG_TIMEOUT_BEFORE_CONTEXT_SWITCH); + eth_wait_receiver_done(HS_CONTEXT_SWITCH_TIMEOUT); while (eth_txq_reg_read(0, ETH_TXQ_CMD) != 0) { asm volatile("nop"); } @@ -77,8 +78,9 @@ FORCE_INLINE void sender_side_start(std::uint32_t handshake_register_address) { /* * As the designated master EDM core, wait for the acknowledgement from the slave EDM core */ -FORCE_INLINE void sender_side_finish(std::uint32_t handshake_register_address) { - eth_wait_for_receiver_done(A_LONG_TIMEOUT_BEFORE_CONTEXT_SWITCH); +FORCE_INLINE void sender_side_finish( + std::uint32_t handshake_register_address, size_t HS_CONTEXT_SWITCH_TIMEOUT = A_LONG_TIMEOUT_BEFORE_CONTEXT_SWITCH) { + eth_wait_for_receiver_done(HS_CONTEXT_SWITCH_TIMEOUT); } FORCE_INLINE void receiver_side_start(std::uint32_t handshake_register_address) { @@ -96,8 +98,9 @@ FORCE_INLINE bool receiver_side_can_finish() { return eth_bytes_are_available_on * The slave EDM core shall only acknowledge after receiving the initial handshake packet * from the master EDM core. */ -FORCE_INLINE void receiver_side_finish(std::uint32_t handshake_register_address) { - eth_wait_for_bytes(16, A_LONG_TIMEOUT_BEFORE_CONTEXT_SWITCH); +FORCE_INLINE void receiver_side_finish( + std::uint32_t handshake_register_address, size_t HS_CONTEXT_SWITCH_TIMEOUT = A_LONG_TIMEOUT_BEFORE_CONTEXT_SWITCH) { + eth_wait_for_bytes(16, HS_CONTEXT_SWITCH_TIMEOUT); while (eth_txq_reg_read(0, ETH_TXQ_CMD) != 0) { asm volatile("nop"); } diff --git a/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/edm_fabric_worker_adapters.hpp b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/edm_fabric_worker_adapters.hpp index ef0f73d302b..a28263af09c 100644 --- a/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/edm_fabric_worker_adapters.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/edm_fabric_worker_adapters.hpp @@ -10,55 +10,94 @@ #include "ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_utils.hpp" #include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header_validate.hpp" #include "debug/assert.h" +#include "debug/dprint.h" #include - namespace tt::fabric { -struct WorkerToFabricEdmSender{ - +struct WorkerToFabricEdmSender { static constexpr uint32_t open_connection_value = 1; static constexpr uint32_t close_connection_value = 0; - WorkerToFabricEdmSender () : worker_sem_addr(nullptr) {} + WorkerToFabricEdmSender() : worker_sem_addr(nullptr) {} + + template + static WorkerToFabricEdmSender build_from_args(std::size_t& arg_idx) { + bool is_persistent_fabric = get_arg_val(arg_idx++); + WorkerXY const edm_worker_xy = WorkerXY::from_uint32(get_arg_val(arg_idx++)); + auto const edm_buffer_base_addr = get_arg_val(arg_idx++); + uint8_t const num_buffers_per_channel = get_arg_val(arg_idx++); + size_t const edm_l1_sem_id = get_arg_val(arg_idx++); + auto const edm_connection_handshake_l1_addr = get_arg_val(arg_idx++); + auto const edm_worker_location_info_addr = get_arg_val(arg_idx++); + uint16_t const buffer_size_bytes = get_arg_val(arg_idx++); + auto const edm_buffer_index_addr = get_arg_val(arg_idx++); + auto writer_send_sem_addr = + reinterpret_cast(get_semaphore(get_arg_val(arg_idx++))); + auto const worker_buffer_index_semaphore_addr = get_semaphore(get_arg_val(arg_idx++)); + ASSERT( + (my_core_type == ProgrammableCoreType::TENSIX && worker_buffer_index_semaphore_addr < 1499136) || + (my_core_type == ProgrammableCoreType::ACTIVE_ETH && worker_buffer_index_semaphore_addr < 262144)); + ASSERT( + (my_core_type == ProgrammableCoreType::TENSIX && (uint32_t)writer_send_sem_addr < 1499136) || + (my_core_type == ProgrammableCoreType::ACTIVE_ETH && (uint32_t)writer_send_sem_addr < 262144)); + ASSERT(edm_buffer_index_addr < 262144); + return WorkerToFabricEdmSender( + is_persistent_fabric, + edm_worker_xy.x, + edm_worker_xy.y, + edm_buffer_base_addr, + num_buffers_per_channel, + edm_l1_sem_id, + edm_connection_handshake_l1_addr, + edm_worker_location_info_addr, // The EDM's location for `EDMChannelWorkerLocationInfo` + buffer_size_bytes, + edm_buffer_index_addr, + writer_send_sem_addr, + worker_buffer_index_semaphore_addr); + } - WorkerToFabricEdmSender ( - size_t edm_worker_x, - size_t edm_worker_y, + WorkerToFabricEdmSender( + bool connected_to_persistent_fabric, + uint8_t edm_worker_x, + uint8_t edm_worker_y, std::size_t edm_buffer_base_addr, - std::size_t num_buffers_per_channel, - std::size_t edm_l1_sem_id, - std::size_t edm_connection_handshake_l1_addr, - std::size_t edm_worker_location_info_addr, // The EDM's location for `EDMChannelWorkerLocationInfo` - std::size_t buffer_size_bytes, - std::size_t edm_buffer_index_addr, - volatile uint32_t * const worker_sem_addr, - uint32_t local_buffer_index_addr - ) : - edm_buffer_addr(get_noc_addr(edm_worker_x, edm_worker_y, edm_buffer_base_addr)), - edm_semaphore_addr(get_noc_addr(edm_worker_x, edm_worker_y, get_semaphore(edm_l1_sem_id))), - edm_connection_handshake_l1_addr(edm_connection_handshake_l1_addr), + uint8_t num_buffers_per_channel, + size_t edm_l1_sem_id, // may also be an address + std::size_t edm_connection_handshake_l1_id, + std::size_t edm_worker_location_info_addr, // The EDM's location for `EDMChannelWorkerLocationInfo` + uint16_t buffer_size_bytes, + size_t edm_buffer_index_id, + volatile uint32_t* const worker_sem_addr, + uint32_t local_buffer_index_addr) : + edm_buffer_addr(edm_buffer_base_addr), + edm_semaphore_addr( + connected_to_persistent_fabric ? edm_l1_sem_id + : get_semaphore(edm_l1_sem_id)), + edm_connection_handshake_l1_addr( + connected_to_persistent_fabric + ? edm_connection_handshake_l1_id + : get_semaphore(edm_connection_handshake_l1_id)), edm_worker_location_info_addr(edm_worker_location_info_addr), - edm_buffer_index_addr(edm_buffer_index_addr), + edm_buffer_index_addr( + connected_to_persistent_fabric ? edm_buffer_index_id + : get_semaphore(edm_buffer_index_id)), worker_sem_addr(worker_sem_addr), edm_buffer_base_addr(edm_buffer_base_addr), + buffer_index_ptr(reinterpret_cast(local_buffer_index_addr)), + buffer_size_bytes(buffer_size_bytes), num_buffers_per_channel(num_buffers_per_channel), last_buffer_index(num_buffers_per_channel - 1), - edm_l1_sem_addr(get_semaphore(edm_l1_sem_id)), - buffer_size_bytes(buffer_size_bytes), - buffer_index_ptr(reinterpret_cast(local_buffer_index_addr)) - { + edm_noc_x(edm_worker_x), + edm_noc_y(edm_worker_y) { ASSERT(buffer_size_bytes > 0); } - [[nodiscard]] FORCE_INLINE bool consumer_has_space() const { - return *this->worker_sem_addr == 1; - } - FORCE_INLINE void clear_flow_control_semaphore() const { - noc_semaphore_set(this->worker_sem_addr, 0); - } + [[nodiscard]] FORCE_INLINE bool consumer_has_space() const { return *this->worker_sem_addr == 1; } + FORCE_INLINE void clear_flow_control_semaphore() const { noc_semaphore_set(this->worker_sem_addr, 0); } FORCE_INLINE void wait_for_empty_write_slot() const { + DPRINT << "Wait for write slot @ " << (uint32_t)this->worker_sem_addr << "\n"; noc_semaphore_wait(this->worker_sem_addr, 1); } @@ -74,6 +113,15 @@ struct WorkerToFabricEdmSender{ /* * No CB */ + FORCE_INLINE void send_packet_header_and_notify_fabric_flush_blocking(uint32_t source_address) { + send_packet_header_and_notify_fabric(source_address); + } + FORCE_INLINE void send_payload_without_header_non_blocking_from_address(uint32_t source_address, size_t size_bytes) { + send_payload_without_header_from_address_impl(source_address, size_bytes); + } + FORCE_INLINE void send_payload_flush_blocking_from_address(uint32_t source_address, size_t size_bytes) { + send_payload_from_address_impl(source_address, size_bytes); + } FORCE_INLINE void send_payload_blocking_from_address(uint32_t source_address, size_t size_bytes) { send_payload_from_address_impl(source_address, size_bytes); } @@ -86,24 +134,11 @@ struct WorkerToFabricEdmSender{ send_payload_from_address_impl(source_address, size_bytes); } - // Layout - // |-----------------------| - // | EDM Handshake | 16B - // |-----------------------| - // | EDM Ack Channel Sync | 16B - // |-----------------------| - - // | Connection Semaphore | 16B | - // |-----------------------| | - // | Buffer Index | 16B >- Per Sender Channel (On EDM) - // |-----------------------| | - // | Worker Connection Info| 16B |worker - // |-----------------------| -/ - // |-----------------------| - // static constexpr size_t edm_sender_channel_field_stride_bytes = 16; FORCE_INLINE void open() { - const auto dest_noc_addr_coord_only = this->edm_semaphore_addr & ~(uint64_t)NOC_COORDINATE_MASK; + const auto dest_noc_addr_coord_only = + get_noc_addr(this->edm_noc_x, this->edm_noc_y, this->edm_semaphore_addr) & ~(uint64_t)NOC_COORDINATE_MASK; const uint64_t remote_buffer_index_addr = dest_noc_addr_coord_only | edm_buffer_index_addr; ASSERT(remote_buffer_index_addr > 0); @@ -112,15 +147,18 @@ struct WorkerToFabricEdmSender{ const uint64_t dest_edm_location_info_addr = dest_noc_addr_coord_only | edm_worker_location_info_addr; // TODO: Need to change byte enable to be word enable noc_inline_dw_write(dest_edm_location_info_addr, reinterpret_cast(worker_sem_addr)); - noc_inline_dw_write(dest_edm_location_info_addr + sizeof(uint32_t), ttnn::ccl::WorkerXY(my_x[0], my_y[0]).to_uint32()); + noc_inline_dw_write( + dest_edm_location_info_addr + sizeof(uint32_t), ttnn::ccl::WorkerXY(my_x[0], my_y[0]).to_uint32()); const uint64_t edm_connection_handshake_noc_addr = dest_noc_addr_coord_only | edm_connection_handshake_l1_addr; noc_inline_dw_write(edm_connection_handshake_noc_addr, open_connection_value); noc_async_read_barrier(); + ASSERT(*this->buffer_index_ptr < 20); } FORCE_INLINE void close() { - const auto dest_noc_addr_coord_only = this->edm_semaphore_addr & ~(uint64_t)NOC_COORDINATE_MASK; + const auto dest_noc_addr_coord_only = + get_noc_addr(this->edm_noc_x, this->edm_noc_y, this->edm_semaphore_addr) & ~(uint64_t)NOC_COORDINATE_MASK; const uint64_t dest_edm_connection_state_addr = dest_noc_addr_coord_only | edm_connection_handshake_l1_addr; noc_inline_dw_write(dest_edm_connection_state_addr, close_connection_value); @@ -129,65 +167,80 @@ struct WorkerToFabricEdmSender{ const uint64_t remote_buffer_index_addr = dest_noc_addr_coord_only | edm_buffer_index_addr; noc_inline_dw_write(remote_buffer_index_addr, *this->buffer_index_ptr); + // Need to wait for the ack from edm + wait_for_empty_write_slot(); + noc_async_write_barrier(); } - uint64_t edm_buffer_addr; - uint64_t edm_semaphore_addr; + uint32_t edm_buffer_addr; + uint32_t edm_semaphore_addr; size_t edm_connection_handshake_l1_addr; size_t edm_worker_location_info_addr; size_t edm_buffer_index_addr; - volatile uint32_t * const worker_sem_addr; - std::size_t edm_buffer_base_addr; - std::size_t num_buffers_per_channel; - std::size_t last_buffer_index; - std::size_t edm_l1_sem_addr; - std::size_t buffer_size_bytes; - std::size_t *buffer_index_ptr; - - private: - template + volatile uint32_t* worker_sem_addr; + size_t edm_buffer_base_addr; + size_t* buffer_index_ptr; + uint16_t buffer_size_bytes; + uint8_t num_buffers_per_channel; + uint8_t last_buffer_index; + uint8_t edm_noc_x; + uint8_t edm_noc_y; + +private: + + template + FORCE_INLINE void send_packet_header_and_notify_fabric(uint32_t source_address) { + this->clear_flow_control_semaphore(); + uint64_t buffer_address = get_noc_addr(this->edm_noc_x, this->edm_noc_y, this->edm_buffer_addr) + + (*this->buffer_index_ptr * (this->buffer_size_bytes + sizeof(eth_channel_sync_t))); + + send_chunk_from_address(source_address, 1, sizeof(tt::fabric::PacketHeader), buffer_address); + auto const noc_sem_addr = get_noc_addr(this->edm_noc_x, this->edm_noc_y, this->edm_semaphore_addr); + noc_semaphore_inc(noc_sem_addr, 1); + *this->buffer_index_ptr = + (*this->buffer_index_ptr == this->last_buffer_index) ? 0 : *this->buffer_index_ptr + 1; + } + template + FORCE_INLINE void send_payload_without_header_from_address_impl(uint32_t source_address, size_t size_bytes) { + uint64_t buffer_address = get_noc_addr(this->edm_noc_x, this->edm_noc_y, this->edm_buffer_addr) + + (*this->buffer_index_ptr * (this->buffer_size_bytes + sizeof(eth_channel_sync_t))); + + // skip past the first part of the buffer which will be occupied by the packet header + send_chunk_from_address(source_address, 1, size_bytes, buffer_address + sizeof(tt::fabric::PacketHeader)); + } + + template FORCE_INLINE void send_payload_from_address_impl(uint32_t source_address, size_t size_bytes) { this->clear_flow_control_semaphore(); - uint64_t buffer_address = this->edm_buffer_addr + (*this->buffer_index_ptr * (this->buffer_size_bytes + sizeof(eth_channel_sync_t))); + uint64_t buffer_address = get_noc_addr(this->edm_noc_x, this->edm_noc_y, this->edm_buffer_addr) + + (*this->buffer_index_ptr * (this->buffer_size_bytes + sizeof(eth_channel_sync_t))); ASSERT(size_bytes <= this->buffer_size_bytes); - /*{ // For debug purposes only. Useful to permanently backup the packet somewhere we can inspect with ttx-status - uint32_t dram_noc_x = my_y[0] == 1 ? 0 : 0; - uint32_t dram_noc_y = my_y[0] == 1 ? 0 : 5; - // noc_inline_dw_write(get_noc_addr(dram_noc_x, dram_noc_y, storage_offset), 0x0F); - // noc_async_writes_flushed(); - // noc_inline_dw_write(get_noc_addr(dram_noc_x, dram_noc_y, storage_offset + 4), 0); - // auto pkthdr_size_words = sizeof(tt::fabric::PacketHeader) >> 2; - // for (size_t i = 0; i < pkthdr_size_words; i++) { - // reinterpret_cast(source_address)[pkthdr_size_words - i] = - // reinterpret_cast(source_address)[pkthdr_size_words - 1 - i]; - // } - // reinterpret_cast(source_address)[0] = 0xc0ffee; - // DPRINT << "NEXT STORAGE OFF: " << (uint32_t)storage_offset << "\n"; - noc_async_write(source_address, get_noc_addr(dram_noc_x, dram_noc_y, storage_offset), size_bytes); - storage_offset += size_bytes; - storage_offset += 64; - storage_offset = storage_offset & (~0x1F); - }*/ - ASSERT(tt::fabric::is_valid(*const_cast(reinterpret_cast(source_address)))); + DPRINT << "SND PKT TO @ " << (uint64_t)buffer_address << "\n"; + ASSERT(tt::fabric::is_valid(*const_cast( + reinterpret_cast(source_address)))); send_chunk_from_address(source_address, 1, size_bytes, buffer_address); - noc_semaphore_inc(edm_semaphore_addr, 1); + auto const noc_sem_addr = get_noc_addr(this->edm_noc_x, this->edm_noc_y, this->edm_semaphore_addr); + DPRINT << "\tSEMINC TO @ " << (uint64_t)noc_sem_addr << "\n"; + noc_semaphore_inc(noc_sem_addr, 1); - *this->buffer_index_ptr = (*this->buffer_index_ptr == this->last_buffer_index) ? 0 : *this->buffer_index_ptr + 1; + *this->buffer_index_ptr = + (*this->buffer_index_ptr == this->last_buffer_index) ? 0 : *this->buffer_index_ptr + 1; } - template + template FORCE_INLINE void send_payload_impl(uint32_t cb_id, uint32_t num_pages, uint32_t page_size) { this->clear_flow_control_semaphore(); - uint64_t buffer_address = this->edm_buffer_addr + (*this->buffer_index_ptr * (this->buffer_size_bytes + sizeof(eth_channel_sync_t))); + uint64_t buffer_address = get_noc_addr(this->edm_noc_x, this->edm_noc_y, this->edm_buffer_addr) + + (*this->buffer_index_ptr * (this->buffer_size_bytes + sizeof(eth_channel_sync_t))); ASSERT(num_pages * page_size <= this->buffer_size_bytes); send_chunk(cb_id, num_pages, page_size, buffer_address); - noc_semaphore_inc(edm_semaphore_addr, 1); - *this->buffer_index_ptr = (*this->buffer_index_ptr == this->last_buffer_index) ? 0 : *this->buffer_index_ptr + 1; + noc_semaphore_inc(get_noc_addr(this->edm_noc_x, this->edm_noc_y, this->edm_semaphore_addr), 1); + *this->buffer_index_ptr = + (*this->buffer_index_ptr == this->last_buffer_index) ? 0 : *this->buffer_index_ptr + 1; } }; - -} // namespace tt::fabric +} // namespace tt::fabric diff --git a/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header.hpp b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header.hpp index 37210c2d012..8c4d073aeb8 100644 --- a/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header.hpp @@ -29,7 +29,7 @@ enum CommandType : uint8_t { // 1 bit enum ChipSendType : uint8_t { CHIP_UNICAST = 0, - CHIP_MULTICAST = 1 + CHIP_MULTICAST = 1, }; enum NocSendType : uint8_t { NOC_UNICAST = 0, @@ -53,6 +53,7 @@ union RoutingFields { static_assert(sizeof(RoutingFields) == sizeof(UnicastRoutingCommandHeader), "RoutingFields size is not 1 bytes"); struct NocUnicastCommandHeader { + // TODO: just encode the noc_addr as uint64_t directly uint32_t address; uint32_t size; uint8_t noc_x; @@ -120,7 +121,7 @@ struct PacketHeader { uint8_t reserved : 4; RoutingFields routing_fields; - uint16_t reserved2; + uint16_t reserved2; // can be tagged with src device for debug CommandFields command_fields; // Sort of hack to work-around DRAM read alignment issues that must be 32B aligned @@ -202,6 +203,64 @@ struct PacketHeader { this->command_fields.mcast_seminc = noc_multicast_atomic_inc_command_header; return *this; } + inline volatile PacketHeader* to_write() volatile { this->command_type = WRITE; return this; } + inline volatile PacketHeader* to_atomic_inc() volatile { this->command_type = ATOMIC_INC; return this; } + + inline volatile PacketHeader *to_chip_unicast(UnicastRoutingCommandHeader const &chip_unicast_command_header) volatile { + this->chip_send_type = CHIP_UNICAST; + this->routing_fields.chip_unicast.distance_in_hops = chip_unicast_command_header.distance_in_hops; + return this; + } + inline volatile PacketHeader *to_chip_multicast(MulticastRoutingCommandHeader const &chip_multicast_command_header) volatile { + this->chip_send_type = CHIP_MULTICAST; + this->routing_fields.chip_mcast.range_hops = chip_multicast_command_header.range_hops; + this->routing_fields.chip_mcast.start_distance_in_hops = chip_multicast_command_header.start_distance_in_hops; + return this; + } + inline volatile PacketHeader *to_noc_unicast(NocUnicastCommandHeader const &noc_unicast_command_header) volatile { + this->noc_send_type = NOC_UNICAST; + this->command_fields.unicast_write.address = noc_unicast_command_header.address; + this->command_fields.unicast_write.size = noc_unicast_command_header.size; + this->command_fields.unicast_write.noc_x = noc_unicast_command_header.noc_x; + this->command_fields.unicast_write.noc_y = noc_unicast_command_header.noc_y; + + return this; + } + inline volatile PacketHeader *to_noc_multicast(NocMulticastCommandHeader const &noc_multicast_command_header) volatile { + this->noc_send_type = NOC_MULTICAST; + this->command_fields.mcast_write.mcast_rect_size_x = noc_multicast_command_header.mcast_rect_size_x; + this->command_fields.mcast_write.mcast_rect_size_y = noc_multicast_command_header.mcast_rect_size_y; + this->command_fields.mcast_write.noc_x_start = noc_multicast_command_header.noc_x_start; + this->command_fields.mcast_write.noc_y_start = noc_multicast_command_header.noc_y_start; + this->command_fields.mcast_write.size = noc_multicast_command_header.size; + this->command_fields.mcast_write.address = noc_multicast_command_header.address; + + return this; + } + inline volatile PacketHeader *to_noc_unicast_atomic_inc( + NocUnicastAtomicIncCommandHeader const &noc_unicast_atomic_inc_command_header) volatile { + this->noc_send_type = NOC_UNICAST; + this->command_fields.unicast_seminc.address = noc_unicast_atomic_inc_command_header.address; + this->command_fields.unicast_seminc.noc_x = noc_unicast_atomic_inc_command_header.noc_x; + this->command_fields.unicast_seminc.noc_y = noc_unicast_atomic_inc_command_header.noc_y; + this->command_fields.unicast_seminc.val = noc_unicast_atomic_inc_command_header.val; + this->command_fields.unicast_seminc.wrap = noc_unicast_atomic_inc_command_header.wrap; + + return this; + } + inline volatile PacketHeader *to_noc_multicast_atomic_inc( + NocMulticastAtomicIncCommandHeader const &noc_multicast_atomic_inc_command_header) volatile { + this->noc_send_type = NOC_MULTICAST; + this->command_fields.mcast_seminc.address = noc_multicast_atomic_inc_command_header.address; + this->command_fields.mcast_seminc.noc_x_start = noc_multicast_atomic_inc_command_header.noc_x_start; + this->command_fields.mcast_seminc.noc_y_start = noc_multicast_atomic_inc_command_header.noc_y_start; + this->command_fields.mcast_seminc.size_x = noc_multicast_atomic_inc_command_header.size_x; + this->command_fields.mcast_seminc.size_y = noc_multicast_atomic_inc_command_header.size_y; + this->command_fields.mcast_seminc.val = noc_multicast_atomic_inc_command_header.val; + this->command_fields.mcast_seminc.wrap = noc_multicast_atomic_inc_command_header.wrap; + + return this; + } }; diff --git a/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_transmission.hpp b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_transmission.hpp index 1e25898f003..41f3608559c 100644 --- a/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_transmission.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_transmission.hpp @@ -1,4 +1,3 @@ - // SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. // // SPDX-License-Identifier: Apache-2.0 @@ -11,9 +10,18 @@ #include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_types.hpp" #include +// If the hop/distance counter equals to the below value, it indicates that it has +// arrived at (atleast one of) the intended destination(s) +static constexpr size_t DESTINATION_HOP_COUNT = 1; +// TODO: make 0 and the associated field to num mcast destinations +static constexpr size_t LAST_MCAST_DESTINATION = 1; + + void write_unicast_blocking(uint32_t local_address, uint64_t dest_address, uint32_t size_bytes) { + // TODO - PERF: noc_async_write + // Don't do it yet because we want to sweep perf on buffer size noc_async_write(local_address, dest_address, size_bytes); - noc_async_writes_flushed(); + noc_async_write_barrier(); } void print_pkt_hdr_routing_fields(volatile tt::fabric::PacketHeader *const packet_start) { @@ -60,7 +68,8 @@ void print_pkt_header(volatile tt::fabric::PacketHeader *const packet_start) { auto const& header = *packet_start; DPRINT << "PKT: cmd_t:" << (uint32_t) packet_start->command_type << ", csnd_t:" << (uint32_t) packet_start->chip_send_type << - ", nsnd_t:" << (uint32_t) packet_start->noc_send_type << "\n"; + ", nsnd_t:" << (uint32_t) packet_start->noc_send_type << + ", src_chip:" << (uint32_t) packet_start->reserved2 << "\n"; print_pkt_hdr_routing_fields(packet_start); print_pkt_header_noc_fields(packet_start); } @@ -77,6 +86,8 @@ void execute_chip_unicast_to_local_chip(volatile tt::fabric::PacketHeader *const case tt::fabric::CommandType::WRITE: { switch (noc_send_type) { case tt::fabric::NocSendType::NOC_UNICAST: { + DPRINT << "C_UNI to y|x" << (uint32_t)((header.command_fields.unicast_write.noc_y << 16) | header.command_fields.unicast_write.noc_x) << + ", " << (uint32_t)header.command_fields.unicast_write.address << "\n"; auto const dest_address = get_noc_addr( header.command_fields.unicast_write.noc_x, header.command_fields.unicast_write.noc_y, @@ -96,7 +107,7 @@ void execute_chip_unicast_to_local_chip(volatile tt::fabric::PacketHeader *const auto const num_dests = header.command_fields.mcast_write.mcast_rect_size_x * header.command_fields.mcast_write.mcast_rect_size_y; auto const size = header.command_fields.mcast_write.size - sizeof(tt::fabric::PacketHeader); noc_async_write_multicast_one_packet(payload_start_address, mcast_dest_address, size, num_dests); - noc_async_writes_flushed(); + noc_async_write_barrier(); }break; default: { @@ -106,6 +117,7 @@ void execute_chip_unicast_to_local_chip(volatile tt::fabric::PacketHeader *const break; } case tt::fabric::CommandType::ATOMIC_INC: { + DPRINT << "C_AT_INC\n"; switch (noc_send_type) { case tt::fabric::NocSendType::NOC_UNICAST: { auto const dest_address = get_noc_addr( @@ -113,6 +125,10 @@ void execute_chip_unicast_to_local_chip(volatile tt::fabric::PacketHeader *const header.command_fields.unicast_seminc.noc_y, header.command_fields.unicast_seminc.address); auto const increment = header.command_fields.unicast_seminc.val; + DPRINT << "\tx=" << (uint32_t)header.command_fields.unicast_seminc.noc_x << + ", y=" << (uint32_t)header.command_fields.unicast_seminc.noc_y << + ", addr=" << (uint32_t)header.command_fields.unicast_seminc.address << + ", inc=" << (uint32_t)increment << "\n"; noc_semaphore_inc(dest_address, increment); }break; @@ -140,12 +156,15 @@ void execute_chip_unicast_to_local_chip(volatile tt::fabric::PacketHeader *const void update_packet_header_for_next_hop(volatile tt::fabric::PacketHeader * packet_header) { switch (packet_header->chip_send_type) { case tt::fabric::CHIP_UNICAST: { + ASSERT(packet_header->routing_fields.chip_unicast.distance_in_hops > 0); packet_header->routing_fields.chip_unicast.distance_in_hops--; } break; case tt::fabric::CHIP_MULTICAST: { - if (packet_header->routing_fields.chip_mcast.start_distance_in_hops == 0) { + if (packet_header->routing_fields.chip_mcast.start_distance_in_hops == DESTINATION_HOP_COUNT) { + ASSERT(packet_header->routing_fields.chip_mcast.range_hops > 0); packet_header->routing_fields.chip_mcast.range_hops--; } else { + ASSERT(packet_header->routing_fields.chip_mcast.start_distance_in_hops > 0); packet_header->routing_fields.chip_mcast.start_distance_in_hops--; } } break; @@ -164,14 +183,15 @@ tt::fabric::SendStatus forward_payload_to_downstream_edm( volatile tt::fabric::PacketHeader *packet_header, tt::fabric::WorkerToFabricEdmSender &downstream_edm_interface ) { - // SHOULD BE ABLE TO ASSERT ON THIS SINCE WE CHECK FOR THIS IN THE CALLER - // TODO: PERF + DPRINT << "Fwding pkt to downstream\n"; + // TODO: PERF - this should already be getting checked by the caller so this should be redundant make it an ASSERT bool safe_to_send = downstream_edm_interface.consumer_has_space(); if (!safe_to_send) { return tt::fabric::SendStatus::NOT_SENT; } - // print_pkt_header(packet_header); + // This is a good place to print the packet header for debug if you are trying to inspect packets + // because it is before we start manipulating the header for forwarding update_packet_header_for_next_hop(packet_header); downstream_edm_interface.send_payload_blocking_from_address( @@ -181,22 +201,14 @@ tt::fabric::SendStatus forward_payload_to_downstream_edm( return tt::fabric::SendStatus::SENT_PAYLOAD_AND_SYNC; } -void execute_chip_multicast_to_local_chip(volatile tt::fabric::PacketHeader *const packet_start) { - ASSERT(false); -} -bool packet_must_be_consumed_locally(tt::fabric::PacketHeader const& packet_header) { +bool packet_must_be_consumed_locally(volatile tt::fabric::PacketHeader const& packet_header) { switch (packet_header.chip_send_type) { case tt::fabric::ChipSendType::CHIP_UNICAST: { - // TODO: does it make more sense to have 0 as the terminating distance or 1? - // depends where we want to do the decrement and what the starting value - // is expected to be for worker - // Maybe at API level we just always decrement by 1 under the hood - // so user can call `fabric_send_packet(payload_addr, size, n_hops=1) - return packet_header.routing_fields.chip_unicast.distance_in_hops == 0; + return packet_header.routing_fields.chip_unicast.distance_in_hops == DESTINATION_HOP_COUNT; } case tt::fabric::ChipSendType::CHIP_MULTICAST: { - return packet_header.routing_fields.chip_mcast.start_distance_in_hops == 0; + return packet_header.routing_fields.chip_mcast.start_distance_in_hops == DESTINATION_HOP_COUNT; } default: { ASSERT(false); @@ -206,18 +218,13 @@ bool packet_must_be_consumed_locally(tt::fabric::PacketHeader const& packet_head } -bool packet_must_be_forwarded_to_next_chip(tt::fabric::PacketHeader const& packet_header) { +bool packet_must_be_forwarded_to_next_chip(volatile tt::fabric::PacketHeader const& packet_header) { switch (packet_header.chip_send_type) { case tt::fabric::ChipSendType::CHIP_UNICAST: - // TODO: does it make more sense to have 0 as the terminating distance or 1? - // depends where we want to do the decrement and what the starting value - // is expected to be for worker - // Maybe at API level we just always decrement by 1 under the hood - // so user can call `fabric_send_packet(payload_addr, size, n_hops=1) - return packet_header.routing_fields.chip_unicast.distance_in_hops != 0; + return packet_header.routing_fields.chip_unicast.distance_in_hops != DESTINATION_HOP_COUNT; case tt::fabric::ChipSendType::CHIP_MULTICAST: - return packet_header.routing_fields.chip_mcast.range_hops != 0; + return packet_header.routing_fields.chip_mcast.range_hops != LAST_MCAST_DESTINATION; default: ASSERT(false); diff --git a/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_erisc_datamover.cpp b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_erisc_datamover.cpp index 244b327a7ec..08abace276e 100644 --- a/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_erisc_datamover.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_erisc_datamover.cpp @@ -258,7 +258,7 @@ enum PacketLocalForwardType : uint8_t { PACKET_FORWARD_LOCAL_AND_REMOTE = 0x3 }; -static constexpr uint32_t SWITCH_INTERVAL = 4000; +static constexpr uint32_t SWITCH_INTERVAL = 0; static constexpr size_t ETH_BYTES_TO_WORDS_SHIFT = 4; static constexpr size_t NUM_SENDER_CHANNELS = 2; static constexpr size_t num_workers_ctor = 1; @@ -326,7 +326,7 @@ tt::fabric::SendStatus send_next_data( payload_size >> ETH_BYTES_TO_WORDS_SHIFT); bool sent_payload_and_channel_sync_in_one_shot = - payload_size == sender_buffer_channel.get_channel_buffer_max_size_in_bytes(); + payload_size == sender_buffer_channel.get_current_max_eth_payload_size(); if (!sent_payload_and_channel_sync_in_one_shot) { // We weren't able to send the channel_sync_t in one shot with the payload so we need to send a second // packet @@ -410,6 +410,8 @@ void receiver_send_received_ack( reinterpret_cast(local_receiver_buffer_channel.get_current_bytes_sent_address()) ->receiver_ack == 0); + DPRINT << "EDMR rsa to " << (uint32_t)sender_buffer_channel.get_current_bytes_sent_address() << "\n"; + ASSERT(!eth_txq_is_busy()); internal_::eth_send_packet_unsafe( 0, @@ -431,6 +433,8 @@ FORCE_INLINE void receiver_send_completion_ack( *(local_receiver_buffer_channel.get_current_bytes_acked_address()) = 0; ASSERT(src_sender_channel < NUM_SENDER_CHANNELS); + DPRINT << "EDMR rsc to " << (uint32_t)remote_sender_channels[src_sender_channel].get_current_bytes_sent_address() << "\n"; + ASSERT(!eth_txq_is_busy()); internal_::eth_send_packet_unsafe( 0, @@ -443,7 +447,7 @@ FORCE_INLINE void receiver_send_completion_ack( } -PacketLocalForwardType get_packet_local_forward_type(const tt::fabric::PacketHeader &packet_header) { +PacketLocalForwardType get_packet_local_forward_type(const volatile tt::fabric::PacketHeader &packet_header) { const bool local_chip_is_packet_destination = packet_must_be_consumed_locally(packet_header); const bool packet_needs_forwarding = packet_must_be_forwarded_to_next_chip(packet_header); PacketLocalForwardType forward_type = @@ -452,9 +456,9 @@ PacketLocalForwardType get_packet_local_forward_type(const tt::fabric::PacketHea } FORCE_INLINE bool can_forward_packet_completely( - const tt::fabric::PacketHeader &packet_header, tt::fabric::WorkerToFabricEdmSender &downstream_edm_interface) { + const volatile tt::fabric::PacketHeader &packet_header, tt::fabric::WorkerToFabricEdmSender &downstream_edm_interface) { auto forward_status = get_packet_local_forward_type(packet_header); - bool can_send = true; + switch (forward_status) { case PACKET_FORWARD_INVALID: return false; case PACKET_FORWARD_LOCAL_ONLY: return true; @@ -505,13 +509,19 @@ bool run_sender_channel_state_machine_step( tt::fabric::EthChannelBuffer &local_sender_channel, tt::fabric::EdmChannelWorkerInterface &local_sender_channel_worker_interface, tt::fabric::EthChannelBuffer &remote_receiver_channel, - SenderState *const sender_state_out) { + bool graceful_termination_mode, + SenderState *const sender_state_out, + uint8_t sender_channel_index) { bool incr_sender_channel_index = true; switch (*sender_state_out) { case SenderState::SENDER_WAITING_FOR_WORKER: { bool able_to_send = local_sender_channel_worker_interface.has_payload() && !eth_txq_is_busy() && local_sender_channel.eth_is_receiver_channel_send_done(); if (able_to_send) { + DPRINT << "EDMS " << (uint32_t)sender_channel_index << "\n"; + DPRINT << "\taddress: " << (uint32_t)local_sender_channel.get_current_buffer_address() << "\n"; + DPRINT << "\t1st 8B: " << (uint64_t)*reinterpret_cast(local_sender_channel.get_current_buffer_address()) << "\n"; + DPRINT << "\tsend to " << (uint32_t)remote_receiver_channel.get_current_buffer_address() << "\n"; auto send_status = send_next_data(local_sender_channel, remote_receiver_channel); // TODO: align the enums and state values so I can just do // sender_states[sender_channel_index] += send_status :) @@ -525,8 +535,8 @@ bool run_sender_channel_state_machine_step( // and not be able to send the channel sync for the packet we just sent, which overall negatively // impact latency incr_sender_channel_index = send_status != tt::fabric::SendStatus::SENT_PAYLOAD_ONLY; - } else { - if (local_sender_channel_worker_interface.has_worker_teardown_request()) { + } else if (!graceful_termination_mode) { + if (!local_sender_channel_worker_interface.has_payload() && local_sender_channel_worker_interface.has_worker_teardown_request()) { local_sender_channel_worker_interface.teardown_connection(); *sender_state_out = SenderState::SENDER_WAIT_WORKER_HANDSHAKE; } @@ -538,6 +548,9 @@ bool run_sender_channel_state_machine_step( bool is_safe_to_receive_next_message = local_sender_channel.eth_is_receiver_channel_send_acked() || local_sender_channel.eth_is_receiver_channel_send_done(); if (is_safe_to_receive_next_message) { + DPRINT << "EDM ch " << (uint32_t)sender_channel_index << " wkr con ntfy wrkr\n"; + DPRINT << "\tl1 worker info ptr: " << (uint32_t)local_sender_channel_worker_interface.worker_location_info_ptr << "\n"; + DPRINT << "\tworker.x=" << (uint32_t)local_sender_channel_worker_interface.worker_location_info_ptr->worker_xy.x << ", .y=" << (uint32_t)local_sender_channel_worker_interface.worker_location_info_ptr->worker_xy.y << ", sem_addr=" << (uint32_t)local_sender_channel_worker_interface.worker_location_info_ptr->worker_semaphore_address << "\n"; sender_notify_workers_if_buffer_available_sequence(local_sender_channel_worker_interface); *sender_state_out = SenderState::SENDER_WAITING_FOR_WORKER; } else { @@ -549,6 +562,7 @@ bool run_sender_channel_state_machine_step( case SenderState::SENDER_SEND_CHANNEL_SYNC: { bool can_send_channel_sync_without_blocking = !eth_txq_is_busy(); if (can_send_channel_sync_without_blocking) { + DPRINT << "EDMS send channel sync\n"; send_channel_sync(local_sender_channel, remote_receiver_channel); local_sender_channel.advance_buffer_index(); remote_receiver_channel.advance_buffer_index(); @@ -561,6 +575,7 @@ bool run_sender_channel_state_machine_step( local_sender_channel.eth_is_receiver_channel_send_done(); if (is_safe_to_receive_next_message) { // This also notifies workers in the same call + DPRINT << "EDMS:\n"; sender_eth_check_receiver_ack_sequence(local_sender_channel, local_sender_channel_worker_interface); *sender_state_out = SenderState::SENDER_WAITING_FOR_WORKER; } @@ -584,6 +599,9 @@ void run_receiver_channel_state_machine_step( if (got_payload) { bool can_ack = !eth_txq_is_busy(); if (can_ack) { + DPRINT << "EDMR got pkt @: " << (uint32_t)reinterpret_cast(local_receiver_channel.get_current_packet_header()) << "\n"; + DPRINT << "EDMR got pkt 0 : " << (uint64_t) reinterpret_cast(local_receiver_channel.get_current_packet_header())[0] << "\n"; + DPRINT << "EDMR got pkt 1: " << (uint64_t) reinterpret_cast(local_receiver_channel.get_current_packet_header())[1] << "\n"; ASSERT(tt::fabric::is_valid( *const_cast(local_receiver_channel.get_current_packet_header()))); receiver_send_received_ack(remote_sender_channnels, local_receiver_channel); @@ -601,11 +619,11 @@ void run_receiver_channel_state_machine_step( } break; case ReceiverState::RECEIVER_SENDING_PAYLOAD: { - auto packet_header = - *const_cast(local_receiver_channel.get_current_packet_header()); + auto& packet_header = *local_receiver_channel.get_current_packet_header(); bool can_send_to_all_local_chip_receivers = can_forward_packet_completely(packet_header, downstream_edm_interface); if (can_send_to_all_local_chip_receivers) { + DPRINT << "EDMR writing pkt\n"; receiver_forward_packet(local_receiver_channel.get_current_packet_header(), downstream_edm_interface); *receiver_state_out = ReceiverState::RECEIVER_WAITING_FOR_WRITE_FLUSH; } @@ -641,11 +659,15 @@ FORCE_INLINE bool got_termination_signal(volatile tt::fabric::TerminationSignal template bool all_channels_drained(tt::fabric::EthChannelBuffer &local_receiver_channel, - std::array, NUM_SENDER_CHANNELS> &local_sender_channels) { - // Unfortunately have to do this for now instead of only conditionally checking - // each undrained channel due to code size issues... - return local_sender_channels[0].all_buffers_drained() && local_sender_channels[1].all_buffers_drained() && - local_receiver_channel.all_buffers_drained(); + std::array, NUM_SENDER_CHANNELS> &local_sender_channels, + std::array &local_sender_channel_worker_interfaces) { + + bool eth_buffers_drained = local_sender_channels[0].all_buffers_drained() && local_sender_channels[1].all_buffers_drained() && local_receiver_channel.all_buffers_drained(); + + bool sender0_has_unsent_packets = (local_sender_channel_worker_interfaces[0].has_payload()); + bool sender1_has_unsent_packets = (local_sender_channel_worker_interfaces[1].has_payload()); + + return eth_buffers_drained && !sender0_has_unsent_packets && !sender1_has_unsent_packets; } /* @@ -663,7 +685,6 @@ void run_fabric_edm_main_loop( std::array, NUM_SENDER_CHANNELS> &remote_sender_channels, tt::fabric::EthChannelBuffer &remote_receiver_channel, volatile tt::fabric::TerminationSignal *termination_signal_ptr) { - std::array sender_states = { SenderState::SENDER_WAIT_WORKER_HANDSHAKE, SenderState::SENDER_WAIT_WORKER_HANDSHAKE}; ReceiverState receiver_state = ReceiverState::RECEIVER_WAITING_FOR_ETH; @@ -672,16 +693,22 @@ void run_fabric_edm_main_loop( *termination_signal_ptr = tt::fabric::TerminationSignal::KEEP_RUNNING; while (!got_immediate_termination_signal(termination_signal_ptr)) { - if (got_graceful_termination_signal(termination_signal_ptr)) { + bool got_graceful_termination = got_graceful_termination_signal(termination_signal_ptr); + if (got_graceful_termination) { + DPRINT << "EDM Graceful termination\n"; + DPRINT << "EDMS0 ST: " << (uint32_t)sender_states[0] << "\n"; bool all_drained = all_channels_drained( - local_receiver_channel, local_sender_channels); + local_receiver_channel, local_sender_channels, local_sender_channel_worker_interfaces); if (all_drained) { return; } } - // // TODO + // Capture these to see if we made progress + auto old_send_state = sender_states[sender_channel_index]; + auto old_recv_state = receiver_state; + auto &local_sender_channel = local_sender_channels[sender_channel_index]; auto &local_sender_channel_worker_interface = local_sender_channel_worker_interfaces[sender_channel_index]; // There are some cases, mainly for performance, where we don't want to switch between sender channels @@ -690,7 +717,10 @@ void run_fabric_edm_main_loop( local_sender_channel, local_sender_channel_worker_interface, remote_receiver_channel, - &(sender_states[sender_channel_index])); + got_graceful_termination, + &(sender_states[sender_channel_index]), + sender_channel_index); + bool did_something_sender = old_send_state != sender_states[sender_channel_index]; if (incr_sender_channel_index) { // TODO: this can probably be optimized sender_channel_index = 1 - sender_channel_index; @@ -699,11 +729,18 @@ void run_fabric_edm_main_loop( run_receiver_channel_state_machine_step( local_receiver_channel, remote_sender_channels, downstream_edm_noc_interface, &receiver_state); - if (did_nothing_count++ > SWITCH_INTERVAL) { + bool did_something = did_something_sender || old_recv_state != receiver_state; + + if (did_something) { did_nothing_count = 0; - run_routing(); + } else { + if (did_nothing_count++ > SWITCH_INTERVAL) { + did_nothing_count = 0; + run_routing(); + } } } + DPRINT << "EDM Terminating\n"; } void kernel_main() { @@ -715,9 +752,12 @@ void kernel_main() { *reinterpret_cast(handshake_addr) = 0; auto eth_transaction_ack_word_addr = handshake_addr + sizeof(eth_channel_sync_t); + static constexpr size_t DEFAULT_HANDSHAKE_CONTEXT_SWITCH_TIMEOUT = 0; if constexpr (is_handshake_sender) { - erisc::datamover::handshake::sender_side_start(handshake_addr); + // DPRINT << "EDM Starting handshake as sender\n"; + erisc::datamover::handshake::sender_side_start(handshake_addr, DEFAULT_HANDSHAKE_CONTEXT_SWITCH_TIMEOUT); } else { + // DPRINT << "EDM Starting handshake as receiver\n"; erisc::datamover::handshake::receiver_side_start(handshake_addr); } @@ -740,6 +780,10 @@ void kernel_main() { // TODO: CONVERT TO SEMAPHORE volatile auto termination_signal_ptr = reinterpret_cast(get_compile_time_arg_val(13)); + // In persistent mode, we must rely on static addresses for our local semaphores that are locally + // initialized, rather than metal device APIs. This way different subdevice programs can reliably + // resolve the semaphore addresses on the EDM core + static constexpr bool persistent_mode = get_compile_time_arg_val(14) != 0; static_assert(SENDER_NUM_BUFFERS > 0, "compile time argument [1]: SENDER_NUM_BUFFERS must be > 0"); static_assert(RECEIVER_NUM_BUFFERS > 0, "compile time argument [2]: RECEIVER_NUM_BUFFERS must be > 0"); @@ -750,16 +794,17 @@ void kernel_main() { /////////////////////// const size_t local_sender_channel_0_connection_semaphore_addr = + persistent_mode ? get_arg_val(arg_idx++) : get_semaphore(get_arg_val(arg_idx++)); const size_t local_sender_channel_1_connection_semaphore_addr = get_semaphore(get_arg_val(arg_idx++)); // unused - can later remove const size_t local_sender_channel_0_connection_buffer_index_addr = + persistent_mode ? get_arg_val(arg_idx++) : get_semaphore(get_arg_val(arg_idx++)); - const size_t local_sender_channel_1_connection_buffer_index_addr = - get_semaphore(get_arg_val(arg_idx++)); + const size_t local_sender_channel_1_connection_buffer_index_id = get_arg_val(arg_idx++); // downstream EDM semaphore location @@ -770,8 +815,7 @@ void kernel_main() { // remote address for flow control const auto downstream_edm_semaphore_id = get_arg_val(arg_idx++); // TODO: Convert to semaphore ID - const auto downstream_edm_worker_registration_address = - get_semaphore(get_arg_val(arg_idx++)); + const auto downstream_edm_worker_registration_id = get_arg_val(arg_idx++); const auto downstream_edm_worker_location_info_address = get_arg_val(arg_idx++); const auto downstream_noc_interface_buffer_index_local_addr = get_arg_val(arg_idx++); @@ -785,12 +829,17 @@ void kernel_main() { // Sender runtime args //////////////////////// auto sender0_worker_semaphore_ptr = reinterpret_cast( + persistent_mode ? get_arg_val(arg_idx++) : get_semaphore(get_arg_val(arg_idx++))); auto sender1_worker_semaphore_ptr = reinterpret_cast( get_semaphore(get_arg_val(arg_idx++))); - *sender0_worker_semaphore_ptr = 0; - *sender1_worker_semaphore_ptr = 0; + if constexpr (persistent_mode) { + // initialize the statically allocated "semaphores" + *reinterpret_cast(local_sender_channel_0_connection_semaphore_addr) = 0; + *reinterpret_cast(local_sender_channel_0_connection_buffer_index_addr) = 0; + *sender0_worker_semaphore_ptr = 0; + } ////////////////////////////// ////////////////////////////// // Object Setup @@ -813,15 +862,18 @@ void kernel_main() { auto downstream_edm_noc_interface = has_downstream_edm_buffer_connection ? tt::fabric::WorkerToFabricEdmSender( + //persistent_mode -> hardcode to false because for EDM -> EDM + // connections we must always use semaphore lookup + false, downstream_edm_noc_x, downstream_edm_noc_y, downstream_edm_buffer_base_address, SENDER_NUM_BUFFERS, downstream_edm_semaphore_id, - downstream_edm_worker_registration_address, // edm_connection_handshake_addr, + downstream_edm_worker_registration_id, downstream_edm_worker_location_info_address, channel_buffer_size, - local_sender_channel_1_connection_buffer_index_addr, // our downstream is channel 1 + local_sender_channel_1_connection_buffer_index_id, reinterpret_cast(edm_forwarding_semaphore_address), downstream_noc_interface_buffer_index_local_addr) : tt::fabric::WorkerToFabricEdmSender(); @@ -862,21 +914,23 @@ void kernel_main() { auto connection_worker_info_ptr = reinterpret_cast( local_sender_connection_info_addresses[i]); new (&local_sender_channel_worker_interfaces[i]) tt::fabric::EdmChannelWorkerInterface( - connection_worker_info_ptr, // worker_location_info_ptr, + connection_worker_info_ptr, reinterpret_cast( - local_sender_flow_control_semaphores[i]), // local_semaphore_address, + local_sender_flow_control_semaphores[i]), reinterpret_cast(connection_live_semaphore_ptr)); } + if (has_downstream_edm_buffer_connection) { downstream_edm_noc_interface.open(); } if constexpr (is_handshake_sender) { - erisc::datamover::handshake::sender_side_finish(handshake_addr); + erisc::datamover::handshake::sender_side_finish(handshake_addr, DEFAULT_HANDSHAKE_CONTEXT_SWITCH_TIMEOUT); } else { - erisc::datamover::handshake::receiver_side_finish(handshake_addr); + erisc::datamover::handshake::receiver_side_finish(handshake_addr, DEFAULT_HANDSHAKE_CONTEXT_SWITCH_TIMEOUT); } + DPRINT << "EDM Core y|x " << (uint32_t)((my_y[0] << 16) | my_x[0]) << "\n"; ////////////////////////////// ////////////////////////////// @@ -893,5 +947,14 @@ void kernel_main() { termination_signal_ptr); + if constexpr (persistent_mode) { + // we force these values to a non-zero value so that if we run the fabric back to back, + // and we can reliably probe from host that this kernel has initialized properly. + *reinterpret_cast(local_sender_channel_0_connection_semaphore_addr) = 99; + *reinterpret_cast(local_sender_channel_0_connection_buffer_index_addr) = 99; + *sender0_worker_semaphore_ptr = 99; + } + + DPRINT << "EDM DONE\n"; WAYPOINT("DONE"); } diff --git a/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_erisc_datamover_channels.hpp b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_erisc_datamover_channels.hpp index ae241fb8599..58a509fa1fa 100644 --- a/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_erisc_datamover_channels.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_erisc_datamover_channels.hpp @@ -213,6 +213,7 @@ struct EdmChannelWorkerInterface { uint64_t worker_semaphore_address = get_noc_addr( (uint32_t)worker_info.worker_xy.x, (uint32_t)worker_info.worker_xy.y, worker_info.worker_semaphore_address); + DPRINT << "EDM ntf wrkr sem @" << (uint64_t)worker_semaphore_address << "\n"; noc_semaphore_inc(worker_semaphore_address, 1); } diff --git a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/host/reduce_scatter_worker_builder.cpp b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/host/reduce_scatter_worker_builder.cpp index d0208e20fe2..9e60612239c 100644 --- a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/host/reduce_scatter_worker_builder.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/host/reduce_scatter_worker_builder.cpp @@ -11,7 +11,7 @@ #include "ttnn/cpp/ttnn/operations/ccl/common/uops/ccl_command.hpp" #include "ttnn/operations/ccl/ccl_common.hpp" -using namespace tt::tt_metal; +#include "ttnn/cpp/ttnn/operations/ccl/common/host/ccl_worker_builder.hpp" namespace ttnn { namespace ccl { @@ -414,6 +414,8 @@ static void convert_slices_to_ccl_commands() { } +// Moved to (and updated in) ccl_worker_builder.cpp +/* void emit_ccl_send_slice_sequence_commands(std::vector const& slices, std::vector& args_out) { for (std::size_t i = 0; i < slices.size(); i++) { auto const& slice = slices[i]; @@ -555,6 +557,7 @@ void emit_ccl_send_slice_sequence_commands(std::vector const& slice } } } +*/ std::vector ReduceScatterWorkerArgBuilder::generate_line_start_sender_kernel_rt_args( WorkerEdmInterfaceArgs const& edm_interface, @@ -618,7 +621,7 @@ std::vector ReduceScatterWorkerArgBuilder::generate_line_start_sender_ log_trace(tt::LogOp, "ccl_send arg[{}]: semaphore_id {}", logged_arg_idx, args[logged_arg_idx]);logged_arg_idx++; log_trace(tt::LogOp, "Generating {} ccl send commands", slices.size()); - emit_ccl_send_slice_sequence_commands(slices, args); + ttnn::ccl::worker_detail::emit_ccl_send_slice_sequence_commands(slices, args); log_trace(tt::LogOp, "Reduce Scatter Sender Worker has {} RT Args: {}", args.size(), args); diff --git a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/host/reduce_scatter_worker_builder.hpp b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/host/reduce_scatter_worker_builder.hpp index 0008def47b9..47d6d03434e 100644 --- a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/host/reduce_scatter_worker_builder.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/host/reduce_scatter_worker_builder.hpp @@ -27,7 +27,6 @@ class WorkerEdmInterfaceArgs; namespace reduce_scatter_detail { -void emit_ccl_send_slice_sequence_commands(std::vector const& slices, std::vector& args_out); struct ReduceScatterWorkerArgBuilder { ReduceScatterWorkerArgBuilder ( diff --git a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/reduce_scatter_pybind.cpp b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/reduce_scatter_pybind.cpp index 16af5cb0652..e7f98604275 100644 --- a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/reduce_scatter_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/reduce_scatter_pybind.cpp @@ -103,14 +103,14 @@ void py_bind_reduce_scatter(pybind11::module& module) { Args: input_tensor (ttnn.Tensor): multi-device tensor dim (int): Dimension to perform operation - cluster_axis (int): Provided a MeshTensor, the axis corresponding to MeshDevice to perform the line-all-gather operation on. - mesh_device (MeshDevice): Device mesh to perform the line-all-gather operation on. + cluster_axis (int): Provided a MeshTensor, the axis corresponding to MeshDevice to perform the line-reduce-scatter operation on. + mesh_device (MeshDevice): Device mesh to perform the line-reduce-scatter operation on. * cluster_axis and mesh_device parameters are applicable only for Linear Topology. Mesh Tensor Programming Guide : https://github.com/tenstorrent/tt-metal/blob/main/tech_reports/Programming%20Mesh%20of%20Devices/Programming%20Mesh%20of%20Devices%20with%20TT-NN.md Keyword Args: - num_links (int, optional): Number of links to use for the all-gather operation. Defaults to `1`. + num_links (int, optional): Number of links to use for the reduce0scatter operation. Defaults to `1`. memory_config (ttnn.MemoryConfig, optional): Memory configuration for the operation. Defaults to `input tensor memory config`. num_workers (int, optional): Number of workers to use for the operation. Defaults to `None`. num_buffers_per_channel (int, optional): Number of buffers per channel to use for the operation. Defaults to `None`. diff --git a/ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp b/ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp index 7e11f71d793..02416215d11 100644 --- a/ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp @@ -8,6 +8,7 @@ // #include #include #include +#include "ttnn/cpp/ttnn/operations/ccl/common/types/ccl_types.hpp" /* * ------ ATTENTION ATTENTION ATTENTION ATTENTION ATTENTION ------ @@ -55,6 +56,7 @@ struct WorkerXY { constexpr WorkerXY(uint16_t x, uint16_t y) : x(x), y(y) {} constexpr uint32_t to_uint32() const { return (y << 16) | x; } + static constexpr WorkerXY from_uint32(uint32_t v) { return WorkerXY(v & 0xFFFF, (v >> 16) & 0xFFFF); } constexpr bool operator==(const WorkerXY &rhs) const { return x == rhs.x && y == rhs.y; } constexpr bool operator!=(const WorkerXY &rhs) const { return !(*this == rhs); } @@ -114,7 +116,151 @@ inline coord_t advance_wrapped_slice_row_major( return coord_t(next_offset_x, next_offset_y); } +namespace v2 { +inline size_t flattened_index (const ttnn::ccl::Shape4D& shape, const ttnn::ccl::Shape4D& index) { + std::size_t offset = index.x; + std::size_t inner_volume = shape.x; + offset += index.y * inner_volume; + inner_volume *= shape.y; + offset += index.z * inner_volume; + inner_volume *= shape.z; + offset += index.w * inner_volume; + return offset; +} + +// Increments the index into the input (global) tensor, while respecting the tensor slice, for wrapped worker slice +// that is internal to the tensor slice. +[[nodiscard]] inline bool advance_worker_global_page ( + uint32_t &curr_page_idx, + uint32_t &offset_into_worker_slice, // local to the worker chunk + ttnn::ccl::Shape4D const& offset_worker_slice, // local to the tensor slice + + ttnn::ccl::Shape4D const &worker_slice_shape, // worker chunk shape + ttnn::ccl::Shape4D const &tensor_slice_shape, // tensor slice shape (per device) + + ttnn::ccl::Shape4D const &tensor_shape, // full tensor shape + + const uint32_t stride + ) { + bool may_wrap_multiple_times = stride > tensor_slice_shape.x; + bool outer_dims_gt_1 = tensor_slice_shape.z > 1 || tensor_slice_shape.w > 1; + bool end_of_worker_slice_row = false; + auto next_offset_into_worker_slice = curr_page_idx + stride; + end_of_worker_slice_row = next_offset_into_worker_slice == worker_slice_shape.volume(); + if (may_wrap_multiple_times || !outer_dims_gt_1) { + + uint32_t prev_offset_into_worker_slice = offset_into_worker_slice; + uint32_t flattened_offset_worker_slice = flattened_index(tensor_slice_shape, offset_worker_slice); + + // Calculate the number of wrap arounds (cast to uint32_t to **round down**) + uint32_t prev_num_wrap_around = (flattened_offset_worker_slice + prev_offset_into_worker_slice) / tensor_slice_shape.x; + uint32_t curr_num_wrap_around = (flattened_offset_worker_slice + next_offset_into_worker_slice) / tensor_slice_shape.x; + uint32_t num_wrap_around = curr_num_wrap_around - prev_num_wrap_around; + + // Check for wrap around + if (num_wrap_around > 0) { // wrap around wrt to global tensor + curr_page_idx += num_wrap_around * (tensor_shape.x - tensor_slice_shape.x) + stride; + } else { + curr_page_idx += stride; + } + + } else { + // can wrap at-most one time. For now since we only have the flat index, we are going to brute force + // it. Future work to optimize this - a lot can be done: + // 1) Carry around the 4D index and also carry around subvolumes + // 2) Precompute the "inner"/"outer" volumes for each dimension so they are precomputed - this will save + // on 4 sums + multiplies per call + // 3) possibly update address-generators to support n-dimensional indices which may further reduce the number + // of operations required (otherwise we still need to eventually convert between flat and) + // of each dimension so we can more quickly do the striding + + size_t y_x = tensor_slice_shape.y * tensor_slice_shape.x; + size_t z_y_x = tensor_slice_shape.z * y_x; + + // Calculate the 4D coordinates + size_t index = next_offset_into_worker_slice; + size_t new_w = index / z_y_x; + index -= new_w * z_y_x; + + size_t new_z = index / y_x; + index -= new_z * y_x; + + size_t new_y = index / tensor_slice_shape.x; + size_t new_x = index - new_y * tensor_slice_shape.x; + + curr_page_idx = flattened_index(tensor_shape, tensor_slice_shape + Shape4D{new_x, new_y, new_z, new_w}); + } + + return end_of_worker_slice_row; +} +[[nodiscard]] inline bool advance_worker_global_page ( + uint32_t &curr_page_idx, + uint32_t &offset_into_worker_slice, // local to the worker chunk + ttnn::ccl::Shape4D const& offset_worker_slice, // local to the tensor slice + + size_t const worker_slice_volume, // worker chunk shape + ttnn::ccl::Shape4D const &tensor_slice_shape, // tensor slice shape (per device) + ttnn::ccl::Shape4D const &tensor_slice_base_offset, // tensor slice shape (per device) + + ttnn::ccl::Shape4D const &tensor_shape, // full tensor shape + + const uint32_t stride + ) { + bool may_wrap_multiple_times = stride > tensor_slice_shape.x; + bool outer_dims_gt_1 = tensor_slice_shape.z > 1 || tensor_slice_shape.w > 1; + bool end_of_worker_slice_row = false; + auto next_offset_into_worker_slice = offset_into_worker_slice + stride; + end_of_worker_slice_row = next_offset_into_worker_slice == worker_slice_volume; + if (may_wrap_multiple_times || !outer_dims_gt_1) { + uint32_t prev_offset_into_worker_slice = offset_into_worker_slice; + offset_into_worker_slice += stride; + + uint32_t flattened_offset_worker_slice = flattened_index(tensor_slice_shape, offset_worker_slice); + + // Calculate the number of wrap arounds (cast to uint32_t to **round down**) + uint32_t prev_num_wrap_around = (flattened_offset_worker_slice + prev_offset_into_worker_slice) / tensor_slice_shape.x; + uint32_t curr_num_wrap_around = (flattened_offset_worker_slice + offset_into_worker_slice) / tensor_slice_shape.x; + uint32_t num_wrap_around = curr_num_wrap_around - prev_num_wrap_around; + + bool end_of_worker_slice_row = offset_into_worker_slice == worker_slice_volume; + // Check for wrap around + if (num_wrap_around > 0) { // wrap around wrt to global tensor + curr_page_idx += num_wrap_around * (tensor_shape.x - tensor_slice_shape.x) + stride; + } else { + curr_page_idx += stride; + } + } else { + // can wrap at-most one time. For now since we only have the flat index, we are going to brute force + // it. Future work to optimize this - a lot can be done: + // 1) Carry around the 4D index and also carry around subvolumes + // 2) Precompute the "inner"/"outer" volumes for each dimension so they are precomputed - this will save + // on 4 sums + multiplies per call + // 3) possibly update address-generators to support n-dimensional indices which may further reduce the number + // of operations required (otherwise we still need to eventually convert between flat and) + // of each dimension so we can more quickly do the striding + + offset_into_worker_slice += stride; + uint32_t y_x = tensor_slice_shape.y * tensor_slice_shape.x; + uint32_t z_y_x = tensor_slice_shape.z * y_x; + + // Calculate the 4D coordinates + uint32_t index = next_offset_into_worker_slice; + uint32_t new_w = index / z_y_x; + index -= new_w * z_y_x; + + uint32_t new_z = index / y_x; + index -= new_z * y_x; + + uint32_t new_y = index / tensor_slice_shape.x; + uint32_t new_x = index - new_y * tensor_slice_shape.x; + + curr_page_idx = flattened_index(tensor_shape, tensor_slice_base_offset + Shape4D{new_w, new_z, new_y, new_x}); + } + return end_of_worker_slice_row; +} + +} // Increments the index into the input (global) tensor, while respecting the tensor slice, for wrapped worker slice // that is internal to the tensor slice. inline void advance_worker_global_page_interleaved ( diff --git a/ttnn/cpp/ttnn/operations/ccl/shared_with_host/sharded_tensor_addr_gen.hpp b/ttnn/cpp/ttnn/operations/ccl/shared_with_host/sharded_tensor_addr_gen.hpp index 58fe3f1898a..a51a6eff900 100644 --- a/ttnn/cpp/ttnn/operations/ccl/shared_with_host/sharded_tensor_addr_gen.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/shared_with_host/sharded_tensor_addr_gen.hpp @@ -43,6 +43,26 @@ struct WorkerToNocCoordLookup { }; +struct VirtualCoordWormholeWorkerToNocLookup + : address_generators::WorkerToNocCoordLookup { + VirtualCoordWormholeWorkerToNocLookup() : address_generators::WorkerToNocCoordLookup() {} + noc_grid_index_t get_noc_x_from_worker_x(noc_grid_index_t worker_x) const { + return worker_x + #if defined(KERNEL_BUILD) + + VIRTUAL_TENSIX_START_X + #endif + ; + } + + noc_grid_index_t get_noc_y_from_worker_y(noc_grid_index_t worker_y) const { + return worker_y + #if defined(KERNEL_BUILD) + + VIRTUAL_TENSIX_START_Y + #endif + ; + } +}; + /* A worker coord to noc coord lookup * It is marked "Harvested" in the type name because a non-harvested Wormhole part has a * fixed coordinate mapping, whereas the harvested part has potentially unique mapping per device @@ -51,8 +71,8 @@ struct WorkerToNocCoordLookup { * used on both harvested and non-harvested parts. */ struct HarvestedWormholeWorkerToNocLookup : WorkerToNocCoordLookup{ - HarvestedWormholeWorkerToNocLookup(uint32_t nrows, const uint32_t *const row_map, uint32_t ncols, const uint32_t *const col_map) : - nrows(nrows), row_map(row_map), ncols(ncols), col_map(col_map) {} + HarvestedWormholeWorkerToNocLookup(uint8_t nrows, const uint32_t *const row_map, uint8_t ncols, const uint32_t *const col_map) : + row_map(row_map), col_map(col_map), nrows(nrows), ncols(ncols) {} noc_grid_index_t get_noc_x_from_worker_x(noc_grid_index_t worker_x) const { // ASSERT worker_x < worker_to_routing_x_wormhole.size() @@ -64,10 +84,10 @@ struct HarvestedWormholeWorkerToNocLookup : WorkerToNocCoordLookup struct WidthShardedAddressGenerator { worker_to_noc_lookup_t worker_to_noc_lookup; DEVICE_SHARD_SPEC_T tensor_shard_spec; - uint32_t page_size; uint32_t bank_base_address; + uint16_t page_size; public: - constexpr WidthShardedAddressGenerator(worker_to_noc_lookup_t lookup, DEVICE_SHARD_SPEC_T const& tensor_shard_spec, uint32_t page_size, uint32_t base_address) : worker_to_noc_lookup(lookup), tensor_shard_spec(tensor_shard_spec), page_size(page_size), bank_base_address(base_address) {} + constexpr WidthShardedAddressGenerator(worker_to_noc_lookup_t lookup, DEVICE_SHARD_SPEC_T const& tensor_shard_spec, uint16_t page_size, uint32_t base_address) : worker_to_noc_lookup(lookup), tensor_shard_spec(tensor_shard_spec), bank_base_address(base_address), page_size(page_size) {} /* * This function is an alternative API that allows the caller to implement a more efficient traversal/iteration of their tensor @@ -243,7 +263,7 @@ struct WidthShardedAddressGenerator { noc_grid_index_t noc_x = worker_to_noc_lookup.get_noc_x_from_worker_x(worker_x_logical); noc_grid_index_t noc_y = worker_to_noc_lookup.get_noc_y_from_worker_y(worker_y_logical); - return test_shard_location_with_contig_t{device_core_location_t{noc_y, noc_x}, page_offset_in_shard, tensor_shard_spec.get_pages_per_shard_x() - page_in_shard_x}; + return test_shard_location_with_contig_t{device_core_location_t{noc_y, noc_x}, page_offset_in_shard, static_cast(tensor_shard_spec.get_pages_per_shard_x() - page_in_shard_x)}; } /* @@ -251,7 +271,7 @@ struct WidthShardedAddressGenerator { * iterating through the tensor in a row-major order. */ test_shard_location_t get_page_location(std::uint32_t global_page_id) const { - auto const& result = get_page_location_with_contiguous_pages_in_row_in_bank(global_page_id); + auto const result = get_page_location_with_contiguous_pages_in_row_in_bank(global_page_id); return test_shard_location_t{result.core_location, result.page_offset}; } @@ -315,11 +335,11 @@ template struct HeightShardedAddressGenerator { worker_to_noc_lookup_t worker_to_noc_lookup; DEVICE_SHARD_SPEC_T tensor_shard_spec; - uint32_t page_size; uint32_t bank_base_address; + uint16_t page_size; public: - constexpr HeightShardedAddressGenerator(worker_to_noc_lookup_t lookup, DEVICE_SHARD_SPEC_T const& tensor_shard_spec, uint32_t page_size, uint32_t base_address) : worker_to_noc_lookup(lookup), tensor_shard_spec(tensor_shard_spec), page_size(page_size), bank_base_address(base_address) {} + constexpr HeightShardedAddressGenerator(worker_to_noc_lookup_t lookup, DEVICE_SHARD_SPEC_T const& tensor_shard_spec, uint32_t page_size, uint32_t base_address) : worker_to_noc_lookup(lookup), tensor_shard_spec(tensor_shard_spec), bank_base_address(base_address), page_size(page_size) {} /* * This function is an alternative API that allows the caller to implement a more efficient traversal/iteration of their tensor @@ -367,7 +387,7 @@ struct HeightShardedAddressGenerator { noc_grid_index_t noc_x = worker_to_noc_lookup.get_noc_x_from_worker_x(worker_x_logical); noc_grid_index_t noc_y = worker_to_noc_lookup.get_noc_y_from_worker_y(worker_y_logical); - return test_shard_location_with_contig_t{device_core_location_t{noc_y, noc_x}, page_offset_in_shard, 1};//tensor_shard_spec.get_pages_per_shard_x() - page_in_shard_x}; + return test_shard_location_with_contig_t{device_core_location_t{noc_y, noc_x}, page_offset_in_shard, static_cast(1)};//tensor_shard_spec.get_pages_per_shard_x() - page_in_shard_x}; } /* @@ -375,7 +395,7 @@ struct HeightShardedAddressGenerator { * iterating through the tensor in a row-major order. */ test_shard_location_t get_page_location(std::uint32_t global_page_id) const { - auto const& result = get_page_location_with_contiguous_pages_in_row_in_bank(global_page_id); + auto const result = get_page_location_with_contiguous_pages_in_row_in_bank(global_page_id); return test_shard_location_t{result.core_location, result.page_offset}; } @@ -440,11 +460,11 @@ template struct BlockShardedAddressGenerator { worker_to_noc_lookup_t worker_to_noc_lookup; DEVICE_SHARD_SPEC_T tensor_shard_spec; - uint32_t page_size; uint32_t bank_base_address; + uint16_t page_size; public: - constexpr BlockShardedAddressGenerator(worker_to_noc_lookup_t lookup, DEVICE_SHARD_SPEC_T const& tensor_shard_spec, uint32_t page_size, uint32_t base_address) : worker_to_noc_lookup(lookup), tensor_shard_spec(tensor_shard_spec), page_size(page_size), bank_base_address(base_address) {} + constexpr BlockShardedAddressGenerator(worker_to_noc_lookup_t lookup, DEVICE_SHARD_SPEC_T const& tensor_shard_spec, uint32_t page_size, uint32_t base_address) : worker_to_noc_lookup(lookup), tensor_shard_spec(tensor_shard_spec), bank_base_address(base_address), page_size(page_size) {} /* * This function is an alternative API that allows the caller to implement a more efficient traversal/iteration of their tensor @@ -499,7 +519,7 @@ struct BlockShardedAddressGenerator { noc_grid_index_t noc_x = worker_to_noc_lookup.get_noc_x_from_worker_x(worker_x_logical); noc_grid_index_t noc_y = worker_to_noc_lookup.get_noc_y_from_worker_y(worker_y_logical); - return test_shard_location_with_contig_t{device_core_location_t{noc_y, noc_x}, page_offset_in_shard, tensor_shard_spec.get_pages_per_shard_x() - page_offset_in_shard_x}; + return test_shard_location_with_contig_t{device_core_location_t{noc_y, noc_x}, page_offset_in_shard, static_cast(tensor_shard_spec.get_pages_per_shard_x() - page_offset_in_shard_x)}; } /* @@ -507,7 +527,7 @@ struct BlockShardedAddressGenerator { * iterating through the tensor in a row-major order. */ test_shard_location_t get_page_location(std::uint32_t global_page_id) const { - auto const& result = get_page_location_with_contiguous_pages_in_row_in_bank(global_page_id); + auto const result = get_page_location_with_contiguous_pages_in_row_in_bank(global_page_id); return test_shard_location_t{result.core_location, result.page_offset}; } @@ -586,6 +606,10 @@ using DefaultWidthShardedAddressGenerator = WidthShardedAddressGenerator; using DefaultBlockShardedAddressGenerator = BlockShardedAddressGenerator; +using DefaultVirtualCoordWidthShardedAddressGenerator = WidthShardedAddressGenerator; +using DefaultVirtualCoordHeightShardedAddressGenerator = HeightShardedAddressGenerator; +using DefaultVirtualCoordBlockShardedAddressGenerator = BlockShardedAddressGenerator; + } // namespace address_generators } // namespace tt_metal diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/CMakeLists.txt b/ttnn/cpp/ttnn/operations/experimental/ccl/CMakeLists.txt index 82767c44a09..851de9b13ad 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/CMakeLists.txt +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/CMakeLists.txt @@ -1,5 +1,6 @@ set(CCL_EXPERIMENTAL_TTNN_SRCS #Experimental Ops + ${CMAKE_CURRENT_SOURCE_DIR}/ccl_experimental_pybind.cpp ${CMAKE_CURRENT_SOURCE_DIR}/all_gather_matmul/all_gather_matmul.cpp ${CMAKE_CURRENT_SOURCE_DIR}/all_gather_matmul/all_gather_matmul_pybind.cpp ${CMAKE_CURRENT_SOURCE_DIR}/all_gather_matmul/device/all_gather_matmul_op.cpp @@ -7,6 +8,14 @@ set(CCL_EXPERIMENTAL_TTNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/all_reduce/all_reduce.cpp ${CMAKE_CURRENT_SOURCE_DIR}/all_reduce/all_reduce_pybind.cpp ${CMAKE_CURRENT_SOURCE_DIR}/all_reduce/device/all_reduce_op.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/reduce_scatter_async/device/reduce_scatter_async_op.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/reduce_scatter_async/device/reduce_scatter_async_program.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/reduce_scatter_async/reduce_scatter.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/reduce_scatter_async/reduce_scatter_pybind.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/all_gather_async/all_gather_async.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/all_gather_async/all_gather_async_pybind.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/all_gather_async/device/all_gather_async_op.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/all_gather_async/device/all_gather_async_program.cpp CACHE INTERNAL "CCL Experimental sources to reuse in ttnn build" ) diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/all_gather_async.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/all_gather_async.cpp new file mode 100644 index 00000000000..7ce729ed1b7 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/all_gather_async.cpp @@ -0,0 +1,55 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "all_gather_async.hpp" +#include "ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.hpp" +#include "ttnn/distributed/types.hpp" + +namespace ttnn::operations::experimental::ccl { + +ttnn::Tensor ExecuteAllGatherAsync::invoke( + const ttnn::Tensor& input_tensor, + const int32_t dim, + const uint32_t num_links, + const std::optional& memory_config, + const ttnn::ccl::Topology topology, + std::optional subdevice_id, + bool enable_persistent_fabric_mode, + bool create_semaphore_handles) { + return ttnn::operations::experimental::ccl::all_gather_async( + input_tensor, + dim, + num_links, + memory_config, + topology, + subdevice_id, + enable_persistent_fabric_mode, + create_semaphore_handles); +} + +ttnn::Tensor ExecuteAllGatherAsync::invoke( + const ttnn::Tensor& input_tensor, + const int32_t dim, + const uint32_t cluster_axis, + const MeshDevice& mesh_device, + const ttnn::ccl::Topology topology, + const std::optional& memory_config, + const std::optional num_preferred_links, + std::optional subdevice_id, + bool enable_persistent_fabric_mode, + bool create_semaphore_handles) { + return ttnn::operations::experimental::ccl::all_gather_async( + input_tensor, + dim, + cluster_axis, + mesh_device, + topology, + memory_config, + num_preferred_links, + subdevice_id, + enable_persistent_fabric_mode, + create_semaphore_handles); +} + +} // namespace ttnn::operations::experimental::ccl diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/all_gather_async.hpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/all_gather_async.hpp new file mode 100644 index 00000000000..26f39484078 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/all_gather_async.hpp @@ -0,0 +1,46 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "ttnn/decorators.hpp" +#include "ttnn/operations/ccl/ccl_host_datastructures.hpp" + +namespace ttnn { +namespace operations::experimental::ccl { + +struct ExecuteAllGatherAsync { + static ttnn::Tensor invoke( + const ttnn::Tensor& input_tensor, + const int32_t dim, + const uint32_t num_links = 1, + const std::optional& memory_config = std::nullopt, + const ttnn::ccl::Topology topology = ttnn::ccl::Topology::Ring, + std::optional subdevice_id = std::nullopt, + bool enable_persistent_fabric_mode = false, + bool create_semaphore_handles = true); + + static ttnn::Tensor invoke( + const ttnn::Tensor& input_tensor, + const int32_t dim, + const uint32_t cluster_axis, + const MeshDevice& mesh_device, + const ttnn::ccl::Topology topology, + const std::optional& memory_config = std::nullopt, + const std::optional num_preferred_links = std::nullopt, + std::optional subdevice_id = std::nullopt, + bool enable_persistent_fabric_mode = false, + bool create_semaphore_handles = true); +}; + +} // namespace operations::experimental::ccl + +namespace experimental { + +constexpr auto all_gather_async = ttnn::register_operation< + "ttnn::experimental::all_gather_async", + ttnn::operations::experimental::ccl::ExecuteAllGatherAsync>(); + +} // namespace experimental +} // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/all_gather_async_pybind.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/all_gather_async_pybind.cpp new file mode 100644 index 00000000000..dd670165700 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/all_gather_async_pybind.cpp @@ -0,0 +1,139 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "all_gather_async_pybind.hpp" + +#include +#include + +#include "ttnn/cpp/pybind11/decorators.hpp" +#include "ttnn/operations/experimental/ccl/all_gather_async/all_gather_async.hpp" +#include "ttnn/operations/ccl/ccl_host_datastructures.hpp" +#include "ttnn/distributed/types.hpp" + +namespace ttnn::operations::experimental::ccl { + +namespace detail { + +template +void bind_all_gather_async(pybind11::module& module, const ccl_operation_t& operation, const char* doc) { + // namespace py = pybind11; + + bind_registered_operation( + module, + operation, + doc, + ttnn::pybind_overload_t{ + [](const ccl_operation_t& self, + const ttnn::Tensor& input_tensor, + const int32_t dim, + const uint32_t num_links, + const std::optional& memory_config, + const ttnn::ccl::Topology topology, + std::optional subdevice_id, + bool enable_persistent_fabric_mode, + bool create_semaphore_handles) -> ttnn::Tensor { + return self( + input_tensor, + dim, + num_links, + memory_config, + topology, + subdevice_id, + enable_persistent_fabric_mode, + create_semaphore_handles); + }, + py::arg("input_tensor"), + py::arg("dim"), + py::kw_only(), + py::arg("num_links") = 1, + py::arg("memory_config") = std::nullopt, + py::arg("topology") = ttnn::ccl::Topology::Ring, + py::arg("subdevice_id") = std::nullopt, + py::arg("enable_persistent_fabric_mode") = false, + py::arg("create_semaphore_handles") = true}, + + ttnn::pybind_overload_t{ + [](const ccl_operation_t& self, + const ttnn::Tensor& input_tensor, + const int32_t dim, + const uint32_t cluster_axis, + const MeshDevice& mesh_device, + const ttnn::ccl::Topology topology, + const std::optional& memory_config, + const std::optional num_preferred_links, + std::optional subdevice_id, + bool enable_persistent_fabric_mode, + bool create_semaphore_handles) -> ttnn::Tensor { + return self( + input_tensor, + dim, + cluster_axis, + mesh_device, + topology, + memory_config,// = std::nullopt, + num_preferred_links,// = std::nullopt, + subdevice_id,// = std::nullopt, + enable_persistent_fabric_mode,// = false, + create_semaphore_handles); + }, + py::arg("input_tensor"), + py::arg("dim"), + py::arg("cluster_axis"), + py::arg("mesh_device"), + py::arg("topology"), + py::kw_only(), + py::arg("num_links") = 1, + py::arg("memory_config") = std::nullopt, + py::arg("subdevice_id") = std::nullopt, + py::arg("enable_persistent_fabric_mode") = false, + py::arg("create_semaphore_handles") = true}); + +} + +} // namespace detail + +void py_bind_all_gather_async(pybind11::module& module) { + detail::bind_all_gather_async( + module, + ttnn::experimental::all_gather_async, + R"doc( + + Performs an all-gather operation on multi-device :attr:`input_tensor` across all devices. + + Args: + input_tensor (ttnn.Tensor): multi-device tensor. + dim (int): Dimension to perform operation. + cluster_axis (int): Provided a MeshTensor, the axis corresponding to MeshDevice to perform the line-all-gather operation on. + mesh_device (MeshDevice): Device mesh to perform the line-all-gather operation on. + * cluster_axis and mesh_device parameters are applicable only for Linear Topology. + + Mesh Tensor Programming Guide : https://github.com/tenstorrent/tt-metal/blob/main/tech_reports/Programming%20Mesh%20of%20Devices/Programming%20Mesh%20of%20Devices%20with%20TT-NN.md + + Keyword Args: + num_links (int, optional): Number of links to use for the all-gather operation. Defaults to `1`. + memory_config (ttnn.MemoryConfig, optional): Memory configuration for the operation. Defaults to `input tensor memory config`. + topology (ttnn.Topology, optional): The topology configuration to run the operation in. Valid options are Ring and Linear. Defaults to `ttnn.Topology.Ring`. + + Returns: + ttnn.Tensor: the output tensor. + + Example: + >>> full_tensor = torch.randn([1, 1, 32, 256], dtype=torch.bfloat16) + >>> physical_device_ids = ttnn.get_t3k_physical_device_ids_ring() + >>> mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(1, 8), physical_device_ids=physical_device_ids[:8]) + >>> ttnn_tensor = ttnn.from_torch( + full_tensor, + dtype=input_dtype, + device=mesh_device, + layout=layout, + memory_config=mem_config, + mesh_mapper=ShardTensor2dMesh(mesh_device, mesh_shape=(1, 8), dims=(-1, -2))) + >>> ttnn_tensor = ttnn.to_device(ttnn_tensor, mesh_device) + >>> output = ttnn.all_gather(ttnn_tensor, dim=0, topology=ttnn.Topology.Ring) + + )doc"); +} + +} // namespace ttnn::operations::experimental::ccl diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/all_gather_async_pybind.hpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/all_gather_async_pybind.hpp new file mode 100644 index 00000000000..29bd4ff9f18 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/all_gather_async_pybind.hpp @@ -0,0 +1,13 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "pybind11/pybind_fwd.hpp" + +namespace ttnn::operations::experimental::ccl { + +void py_bind_all_gather_async(pybind11::module& module); + +} // namespace ttnn::operations::experimental::ccl diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.cpp new file mode 100644 index 00000000000..15dcfee3671 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.cpp @@ -0,0 +1,334 @@ +/// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "all_gather_async_op.hpp" +#include "ttnn/operations/math.hpp" +#include "tt_metal/impl/buffers/global_semaphore.hpp" + +#include "tt_metal/host_api.hpp" + +#include "ttnn/tensor/tensor_utils.hpp" + +#include "eth_l1_address_map.h" + +namespace ttnn { +namespace ccl { +namespace all_gather_detail { + +AllGatherAsync create_all_gather_async_struct( + const Tensor& input_tensor, + const uint32_t dim, + const uint32_t num_links, + const std::optional& memory_config, + const std::vector& devices, + const ttnn::ccl::Topology topology, + const std::optional>>& semaphore_handles, + bool enable_persistent_fabric_mode) { + uint32_t num_devices = devices.size(); + + std::optional forward_device = std::nullopt; + std::optional backward_device = std::nullopt; + std::shared_ptr semaphore_handle = nullptr; + uint32_t device_index = 0; // Initialize device index + for (uint32_t i = 0; i < num_devices; ++i) { + if (devices.at(i) == input_tensor.device()) { + device_index = i; + if (semaphore_handles.has_value()) { + semaphore_handle = semaphore_handles.value().at(i); // Get raw pointer + } + if (i != 0) { + backward_device = devices.at(i - 1); + } + if (i != num_devices - 1) { + forward_device = devices.at(i + 1); + } + } + } + + return ttnn::AllGatherAsync{ + forward_device, + backward_device, + dim, + num_links, + num_devices, + device_index, + memory_config.value_or(input_tensor.memory_config()), + topology, + semaphore_handle, + enable_persistent_fabric_mode}; +} + +std::optional>> get_global_semaphores( + const std::vector& devices, + const CoreRange& core_range, + std::optional subdevice_id, + bool create_semaphore_handles) { + std::optional>> semaphore_handles_opt; + if (create_semaphore_handles) { + std::vector> semaphore_handles; + for (const auto& device : devices) { + auto subdevice_span = subdevice_id.has_value() ? tt::stl::Span{subdevice_id.value()} + : tt::stl::Span{}; + + auto handle = GlobalSemaphore::create(device, core_range, 0, BufferType::L1, subdevice_span); + log_trace( + tt::LogOp, "Created semaphore handle at address {} for device {}", handle->address(), device->id()); + semaphore_handles.push_back(handle); + } + // HACK: assert every handle address is the same + TT_FATAL( + std::all_of( + semaphore_handles.begin(), + semaphore_handles.end(), + [&](const auto& handle) { return handle->address() == semaphore_handles.front()->address(); }), + "[Hack] All semaphore handles should have the same address"); + semaphore_handles_opt = semaphore_handles; + } else { + semaphore_handles_opt = std::nullopt; + } + + return semaphore_handles_opt; +} + +} // namespace all_gather_detail +} // namespace ccl + +void AllGatherAsync::validate(const std::vector& input_tensors) const { + TT_FATAL(input_tensors.size() == 1, "Error, Input tensor size should be 1 but has {}", input_tensors.size()); + const auto& input_tensor = input_tensors[0]; + const auto& layout = input_tensors[0].get_layout(); + const auto& dtype = input_tensors[0].get_dtype(); + const auto& page_size = input_tensors[0].buffer()->page_size(); + TT_FATAL(page_size % input_tensors[0].buffer()->alignment() == 0, "All Gather currently requires aligned pages"); + + TT_FATAL(input_tensor.storage_type() == StorageType::DEVICE, "Operands to all_gather need to be on device!"); + TT_FATAL(input_tensor.buffer() != nullptr, "Operands to all_gather need to be allocated in buffers on device!"); + TT_FATAL(this->num_links > 0, "Error, num_links should be more than 0 but has {}", this->num_links); + TT_FATAL( + this->num_links <= input_tensor.device()->compute_with_storage_grid_size().y, + "Worker cores used by links are parallelizaed over rows"); + + TT_FATAL( + input_tensor.memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED || + input_tensor.memory_config().memory_layout == TensorMemoryLayout::WIDTH_SHARDED || + input_tensor.memory_config().memory_layout == TensorMemoryLayout::BLOCK_SHARDED || + input_tensor.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED, + "Unsupported memory layout {}.", + input_tensor.memory_config().memory_layout); +} + +static void validate_output_tensor_allocation(const std::vector& output_tensors) { + for (const auto& output_tensor : output_tensors) { + const auto& buffers = output_tensor.buffers(); + const auto first_address = buffers.front()->address(); + TT_FATAL( + std::all_of( + buffers.begin(), + buffers.end(), + [&first_address](const auto& buffer) { + return buffer != nullptr && buffer->address() == first_address; + }), + "Output buffers for all_gather async must be lock-step allocated but some of the tensors were allocated at " + "different addresses across devices."); + } +} + +std::vector AllGatherAsync::compute_output_shapes(const std::vector& input_tensors) const { + auto shape = input_tensors[0].get_padded_shape(); // TODO: Replace with get_logical_shape() + shape[this->dim] *= this->ring_size; + return std::vector(input_tensors.size(), shape); +} + +std::vector AllGatherAsync::create_output_tensors(const std::vector& input_tensors) const { + const auto& input_tensor = input_tensors[0]; + auto output_tensors = std::vector(); + output_tensors.reserve(1); + auto tile = input_tensor.get_tensor_spec().tile(); + if (this->output_mem_config.is_sharded()) { + output_tensors.push_back(create_device_tensor( + this->compute_output_shapes(input_tensors).at(0), + input_tensor.get_dtype(), + input_tensor.get_layout(), + input_tensor.device(), + this->output_mem_config, + tile)); + } else { + output_tensors = operation::generic_create_output_tensors( + *this, input_tensors, input_tensor.get_dtype(), input_tensor.get_layout(), this->output_mem_config, tile); + } + log_debug(tt::LogOp, "DEBUG: output_tensors[0] address: {}", output_tensors.at(0).buffer()->address()); + return output_tensors; +} + +operation::ProgramWithCallbacks AllGatherAsync::create_program( + const std::vector& input_tensors, std::vector& output_tensors) const { + tt::log_debug(tt::LogOp, "DEBUG: create_program is called"); + return all_gather_async_multi_core_with_workers( + input_tensors[0], + this->forward_device, + this->backward_device, + output_tensors[0], + this->dim, + this->num_links, + this->ring_size, + this->ring_index, + this->topology, + this->semaphore_handle, + this->enable_persistent_fabric_mode); +} + +const operation::Hash AllGatherAsync::compute_program_hash(const std::vector& input_tensors) const { + return operation::hash_operation( + this->dim, this->num_links, this->ring_size, this->ring_index, this->output_mem_config, this->topology); +} + + + +namespace operations { +namespace experimental { +namespace ccl { + +Tensor all_gather_async( + const Tensor& input_tensor, + const uint32_t dim, + const uint32_t num_links, + const std::optional& memory_config, + const ttnn::ccl::Topology topology, + std::optional subdevice_id, + bool enable_persistent_fabric_mode, + bool create_semaphore_handles) { + TT_FATAL( + std::getenv("TT_METAL_SLOW_DISPATCH_MODE") == nullptr, + "all_gather_async op is only supported for Fast Dispatch"); + auto devices = input_tensor.get_workers(); + uint32_t num_devices = devices.size(); + TT_FATAL(num_devices > 1, "all_gather_async op will only work for num_devices > 1, but has {}", num_devices); + ttnn::ccl::Topology ccl_topology = topology; + + if (num_devices == 2) { + ccl_topology = ttnn::ccl::Topology::Linear; + } + std::vector output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor}))}; + + tt::log_debug( + tt::LogOp, "DEBUG: creating line_fabric with num devices: {}, num links: {}", devices.size(), num_links); + tt::log_debug(tt::LogOp, "DEBUG: line_fabric is created"); + + // create this semaphore for all cores since we don't know which core will be used for teardown draining + CoreCoord grid_size = devices[0]->compute_with_storage_grid_size(); + auto core_grid = CoreRange({0, 0}, {grid_size.x - 1, grid_size.y - 1}); + + std::optional>> semaphore_handles_opt = + ttnn::ccl::all_gather_detail::get_global_semaphores(devices, core_grid, subdevice_id, create_semaphore_handles); + + operation::launch_op( + [dim, + num_links, + num_devices, + memory_config, + devices, + ccl_topology, + semaphore_handles_opt, + enable_persistent_fabric_mode]( + const std::vector& input_tensors, + const std::vector>& optional_input_tensors, + const std::vector>& optional_output_tensors) mutable -> std::vector { + const auto& input_tensor = input_tensors.at(0); + + return operation::run( + ttnn::ccl::all_gather_detail::create_all_gather_async_struct( + input_tensor, + dim, + num_links, + memory_config, + devices, + ccl_topology, + semaphore_handles_opt, + enable_persistent_fabric_mode), + {input_tensor}); + }, + {input_tensor}, + output_tensors); + return output_tensors.at(0); +} + +Tensor all_gather_async( + const Tensor& input_tensor, + const int32_t dim, + const uint32_t cluster_axis, + const MeshDevice& mesh_device, + const ttnn::ccl::Topology topology, + const std::optional& memory_config, + const std::optional num_preferred_links, + std::optional subdevice_id, + bool enable_persistent_fabric_mode, + bool create_semaphore_handles) { + TT_FATAL( + topology == ttnn::ccl::Topology::Linear, + "This all_gather API with cluster_axis is currently supported only for the Linear topology"); + const auto mesh_view = mesh_device.get_view(); + auto devices = input_tensor.get_workers(); + std::size_t num_devices = (cluster_axis == 0) ? mesh_view.num_rows() : mesh_view.num_cols(); + + int32_t rank = input_tensor.get_logical_shape().rank(); + + int32_t gather_dim = (dim < 0) ? rank + dim : dim; + + TT_FATAL( + gather_dim >= -rank && gather_dim <= rank - 1, + "Dimension input should be in between -{} and {}, but has {}", + rank, + rank - 1, + dim); + + std::vector output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor}))}; + CoreCoord grid_size = devices[0]->compute_with_storage_grid_size(); + auto core_grid = CoreRange({0, 0}, {grid_size.x - 1, grid_size.y - 1}); + std::optional>> semaphore_handles_opt = + ttnn::ccl::all_gather_detail::get_global_semaphores(devices, core_grid, subdevice_id, create_semaphore_handles); + + operation::launch_op( + [gather_dim, + num_preferred_links, + memory_config, + mesh_view, + cluster_axis, + num_devices, + topology, + semaphore_handles_opt, + enable_persistent_fabric_mode]( + const std::vector& input_tensors, + const std::vector>& optional_input_tensors, + const std::vector>& optional_output_tensors) mutable -> std::vector { + const auto& input_device_tensor = input_tensors.at(0); + + const auto coordinate = mesh_view.find_device(input_device_tensor.device()->id()); + std::vector devices = (cluster_axis == 0) ? mesh_view.get_devices_on_column(coordinate.col) + : mesh_view.get_devices_on_row(coordinate.row); + + const auto& input_tensor = input_tensors.at(0); + + return operation::run( + ttnn::ccl::all_gather_detail::create_all_gather_async_struct( + input_device_tensor, + gather_dim, + num_preferred_links.has_value() ? num_preferred_links.value() : 1, + memory_config, + devices, + topology, + semaphore_handles_opt, + enable_persistent_fabric_mode), + {input_tensor}); + + }, + {input_tensor}, + output_tensors); + return output_tensors.at(0); +} + +} // namespace ccl +} // namespace experimental +} // namespace operations + +} // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.hpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.hpp new file mode 100644 index 00000000000..b5bc4095f2f --- /dev/null +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.hpp @@ -0,0 +1,144 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include "common/core_coord.hpp" +#include "impl/buffers/buffer.hpp" +#include "ttnn/tensor/tensor.hpp" +#include "ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" +#include "tt_metal/common/constants.hpp" +#include "tt_metal/host_api.hpp" +#include "ttnn/operations/ccl/ccl_host_datastructures.hpp" +#include "ttnn/operations/ccl/ccl_common.hpp" +#include "ttnn/operations/ccl/ccl_op_fusion.hpp" +#include "tt_metal/impl/buffers/global_semaphore.hpp" + +#include "ttnn/run_operation.hpp" + +#include +#include + +namespace ttnn { + +using ccl::EriscDatamoverBuilder; + +struct AllGatherAsync { + std::optional forward_device; + std::optional backward_device; + const uint32_t dim; + const uint32_t num_links; + const uint32_t ring_size; + const uint32_t ring_index; + const MemoryConfig output_mem_config; + const ccl::Topology topology; + std::optional> semaphore_handle; + bool enable_persistent_fabric_mode; + + AllGatherAsync( + std::optional forward_device, + std::optional backward_device, + uint32_t dim, + uint32_t num_links, + uint32_t ring_size, + uint32_t ring_index, + MemoryConfig output_mem_config, + ccl::Topology topology, + std::optional> semaphore_handle, + bool enable_persistent_fabric_mode) : + forward_device(forward_device), + backward_device(backward_device), + dim(dim), + num_links(num_links), + ring_size(ring_size), + ring_index(ring_index), + output_mem_config(output_mem_config), + topology(topology), + semaphore_handle(semaphore_handle), + enable_persistent_fabric_mode(enable_persistent_fabric_mode) {} + + // Add attributes method for reflection + auto attributes() const { + using tt::stl::reflection::Attribute; + std::vector> attrs; + + attrs.emplace_back("dim", dim); + attrs.emplace_back("num_links", num_links); + attrs.emplace_back("ring_size", ring_size); + attrs.emplace_back("ring_index", ring_index); + attrs.emplace_back("output_mem_config", output_mem_config); + attrs.emplace_back("topology", topology); + attrs.emplace_back("semaphore_handle", semaphore_handle.has_value() ? semaphore_handle.value().get() : nullptr); + + return attrs; + } + + void validate(const std::vector& input_tensors) const; + std::vector compute_output_shapes(const std::vector& input_tensors) const; + std::vector create_output_tensors(const std::vector& input_tensors) const; + operation::ProgramWithCallbacks create_program( + const std::vector& input_tensors, std::vector& output_tensors) const; + const operation::Hash compute_program_hash(const std::vector& input_tensors) const; +}; + +namespace ccl { +namespace all_gather_async_detail { +AllGatherAsync create_all_gather_async_struct( + const Tensor& input_tensor, + const uint32_t dim, + const uint32_t num_links, + const std::optional& memory_config, + const std::vector& devices, + const ccl::Topology topology, + const std::optional>& semaphore_handles, + bool enable_persistent_fabric_mode); +} // namespace all_gather_async_detail +} // namespace ccl + +// All Gather Variants +operation::ProgramWithCallbacks all_gather_async_multi_core_with_workers( + const Tensor& input_tensor, + std::optional forward_device, + std::optional backward_device, + Tensor& output_tensor, + const uint32_t dim, + const uint32_t num_links, + const uint32_t ring_size, + const uint32_t ring_index, + ccl::Topology topology, + const std::optional>& semaphore_handle_opt, + bool enable_persistent_fabric_mode); + +namespace operations { +namespace experimental { +namespace ccl { + +Tensor all_gather_async( + const Tensor& input_tensor, + const uint32_t dim, + const uint32_t num_links = 1, + const std::optional& memory_config = std::nullopt, + const ttnn::ccl::Topology topology = ttnn::ccl::Topology::Ring, + std::optional subdevice_id = std::nullopt, + bool enable_persistent_fabric_mode = false, + bool create_semaphore_handles = true); // TODO make reference + +Tensor all_gather_async( + const Tensor& input_tensor, + const int32_t dim, + const uint32_t cluster_axis, + const MeshDevice& mesh_device, + const ttnn::ccl::Topology topology, + const std::optional& memory_config = std::nullopt, + const std::optional num_preferred_links = std::nullopt, + std::optional subdevice_id = std::nullopt, + bool enable_persistent_fabric_mode = false, + bool create_semaphore_handles = true); + +} // namespace ccl +} // namespace experimental +} // namespace operations + +} // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_program.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_program.cpp new file mode 100644 index 00000000000..dc83794cdb8 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_program.cpp @@ -0,0 +1,410 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 +/// +#include + +#include "tt_metal/common/core_coord.hpp" +#include "eth_l1_address_map.h" +#include "impl/buffers/buffer.hpp" +#include "ttnn/tensor/tensor_impl.hpp" +#include "ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.hpp" +#include "ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" +#include "ttnn/operations/ccl/ccl_host_datastructures.hpp" +#include "ttnn/operations/ccl/ccl_common.hpp" +#include "ttnn/operations/math.hpp" +#include "tt_metal/common/work_split.hpp" +#include "tt_metal/common/constants.hpp" +#include "tt_metal/detail/util.hpp" +#include "tt_metal/host_api.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/common/types/ccl_types_args_emitters.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/common/host/ccl_command_stream_builders.hpp" + +#include +#include +#include + +#include "ttnn/cpp/ttnn/operations/ccl/common/host/ccl_worker_builder.hpp" + +#include +using namespace tt::constants; + +namespace ttnn { + +using namespace ccl; + +static void print_tensor_slice(const ttnn::ccl::v2::TensorSlice& slice_v2) { + log_trace(tt::LogOp, "TensorSlice:"); + log_trace( + tt::LogOp, + " tensor_shape: [w={}, z={}, y={}, x={}]", + slice_v2.tensor_shape.w, + slice_v2.tensor_shape.z, + slice_v2.tensor_shape.y, + slice_v2.tensor_shape.x); + log_trace( + tt::LogOp, + " tensor_slice_shape: [w={}, z={}, y={}, x={}]", + slice_v2.tensor_slice_shape.w, + slice_v2.tensor_slice_shape.z, + slice_v2.tensor_slice_shape.y, + slice_v2.tensor_slice_shape.x); + log_trace( + tt::LogOp, + " tensor_slice_offset: [w={}, z={}, y={}, x={}]", + slice_v2.tensor_slice_offset.w, + slice_v2.tensor_slice_offset.z, + slice_v2.tensor_slice_offset.y, + slice_v2.tensor_slice_offset.x); + log_trace( + tt::LogOp, + " worker_slice_shape: [w={}, z={}, y={}, x={}]", + slice_v2.worker_slice_shape.w, + slice_v2.worker_slice_shape.z, + slice_v2.worker_slice_shape.y, + slice_v2.worker_slice_shape.x); + log_trace( + tt::LogOp, + " worker_slice_offset: [w={}, z={}, y={}, x={}]", + slice_v2.worker_slice_offset.w, + slice_v2.worker_slice_offset.z, + slice_v2.worker_slice_offset.y, + slice_v2.worker_slice_offset.x); +} + +std::tuple> choose_worker_cores( + size_t num_links, size_t num_workers_per_link, bool persistent_fabric_mode, Device* device) { + std::tuple> result; + CoreRangeSet sender_worker_core_range; + if (persistent_fabric_mode) { + const size_t num_workers_preferred = num_workers_per_link * num_links; + const auto available_cores = + device->worker_cores(HalProgrammableCoreType::TENSIX, device->get_sub_device_ids().at(0)); + if (available_cores.num_cores() < num_workers_preferred) { + log_warning( + tt::LogOp, + "AllGather is being launched on a subdevice with fewer worker cores available than ideal. Ideally {} " + "cores ({} per link and {} links) are made available but only {} are available. This may lead to " + "performance loss.", + num_workers_preferred, + num_workers_per_link, + num_links, + available_cores.num_cores()); + } + for (const auto& cr : available_cores.ranges()) { + auto start = cr.start_coord; + auto end = cr.end_coord; + for (size_t y = start.y; y <= end.y; y++) { + for (size_t x = start.x; x <= end.x; x++) { + sender_worker_core_range = + sender_worker_core_range.merge(CoreRangeSet(CoreRange(CoreCoord(x, y), CoreCoord(x, y)))); + if (sender_worker_core_range.num_cores() == num_workers_preferred) { + break; + } + } + if (sender_worker_core_range.num_cores() == num_workers_preferred) { + break; + } + } + if (sender_worker_core_range.num_cores() == num_workers_preferred) { + break; + } + } + } else { + sender_worker_core_range = + CoreRangeSet(CoreRange(CoreCoord(0, 0), CoreCoord(num_workers_per_link - 1, num_links - 1))); + } + return {sender_worker_core_range, corerange_to_cores(sender_worker_core_range, std::nullopt, true)}; +} + +// For ring all-gather, we can send sub-sections of input tensor in opposite directions +// For linear all-gather though, we must ensure we send full tensors in BOTH directions +// (in other words, disable the "bidirectional" send flag) +operation::ProgramWithCallbacks all_gather_async_multi_core_with_workers( + const Tensor& input_tensor, + std::optional forward_device, + std::optional backward_device, + Tensor& output_tensor, + const uint32_t dim, + const uint32_t num_links, + const uint32_t ring_size, + const uint32_t ring_index, + ccl::Topology topology, + const std::optional>& semaphore_handle_opt, + bool enable_persistent_fabric_mode) { + tt::tt_metal::Program program{}; + const bool enable_async_output_tensor = false; + + TT_FATAL(semaphore_handle_opt.has_value(), "Semaphore handle is required for compile time"); + + auto semaphore_handle = semaphore_handle_opt.value(); + + Device* device = input_tensor.device(); + bool is_first_chip = ring_index == 0; + bool is_last_chip = ring_index == ring_size - 1; + log_trace( + tt::LogOp, + "DEBUG: device: {}, is_first_chip: {}, is_last_chip: {}", + input_tensor.device()->id(), + is_first_chip, + is_last_chip); + + std::optional local_fabric_handle = + enable_persistent_fabric_mode + ? ttnn::ccl::EdmLineFabricOpInterface::build_program_builder_worker_connection_fabric( + device, forward_device, backward_device, &program, enable_persistent_fabric_mode, num_links) + : ccl::EdmLineFabricOpInterface( + device, forward_device, backward_device, &program, enable_persistent_fabric_mode, num_links); + + LineTopology line_topology(ring_size, ring_index); + + std::unique_ptr input_tensor_config = + ttnn::ccl::CclOpTensorConfig::build_all_gather_tensor_config(input_tensor); + std::unique_ptr output_tensor_config = + ttnn::ccl::CclOpTensorConfig::build_all_gather_tensor_config(output_tensor); + + bool is_sharded = input_tensor.is_sharded(); + + const auto input_buffer = input_tensor.buffer(); + const auto output_buffer = output_tensor.buffer(); + + // Get OP Config, topology config + std::vector input_tensors = {input_tensor}; + std::vector output_tensors = {output_tensor}; + const auto& op_config = ttnn::ccl::CCLOpConfig(input_tensors, output_tensors, topology); + + // Get worker cores, assuming 1 worker per link + uint32_t num_workers_per_link = 1; + const auto [sender_worker_core_range, sender_worker_cores] = + choose_worker_cores(num_links, num_workers_per_link, enable_persistent_fabric_mode, device); + + // L1 Scratch CB Creation + const size_t packet_size_bytes = local_fabric_handle->get_edm_buffer_size_bytes(); + uint32_t l1_scratch_cb_page_size_bytes = op_config.get_page_size(); + uint32_t num_pages_per_packet = packet_size_bytes / l1_scratch_cb_page_size_bytes; + uint32_t cb_num_pages = 3 * num_pages_per_packet; // tripple buffering + uint32_t src0_cb_index = tt::CB::c_in0; + tt::DataFormat df = tt::tt_metal::datatype_to_dataformat_converter(input_tensor.get_dtype()); + tt::tt_metal::CircularBufferConfig cb_src0_config = + tt::tt_metal::CircularBufferConfig(cb_num_pages * l1_scratch_cb_page_size_bytes, {{src0_cb_index, df}}) + .set_page_size(src0_cb_index, l1_scratch_cb_page_size_bytes); + CBHandle cb_src0_workers = CreateCircularBuffer(program, sender_worker_core_range, cb_src0_config); + + // Create Tensor slicer + // read the entire input tensor (partition size = 1, partition index = 0) + // write to the output tensor on its corresponding partition (partition size = ring_size, partition index = + // ring_index) + auto input_tensor_slicer = ttnn::ccl::GenericWrappedTensorSlicerV2( + input_tensor, + dim, + 0, // partition index + 1, // partition size + num_links // num_workers_per_slicer, set 1 per link for now + ); + auto output_tensor_slicer = ttnn::ccl::GenericWrappedTensorSlicerV2( + output_tensor, + dim, + ring_index, // partition index + ring_size, // partition size + num_links // num_workers_per_slicer, set 1 per link for now + ); + + // KERNEL CREATION + const auto& worker_defines = op_config.emit_worker_defines(); + static const std::string& sender_kernel_reader_path = + "ttnn/cpp/ttnn/operations/ccl/common/kernels/ccl_send_reader.cpp"; + static const std::string& sender_kernel_writer_path = + "ttnn/cpp/ttnn/operations/ccl/common/kernels/ccl_send_writer.cpp"; + + KernelHandle worker_sender_reader_kernel_id = + ttnn::ccl::worker_detail::generate_multi_command_stream_kernel_ct_args( + program, + {src0_cb_index}, + {&input_tensor}, + sender_worker_core_range, + tt::tt_metal::ReaderDataMovementConfig{}, + 1, // num_command_streams + device->id()); + + KernelHandle worker_sender_writer_kernel_id = + ttnn::ccl::worker_detail::generate_multi_command_stream_kernel_ct_args( + program, + {src0_cb_index}, + {&output_tensor}, + sender_worker_core_range, + tt::tt_metal::WriterDataMovementConfig{}, + 1, // num_command_streams + device->id()); + + const size_t forward_direction_distance_to_end_of_line = + line_topology.get_distance_to_end_of_line(ttnn::ccl::EdmLineFabricOpInterface::Direction::FORWARD); + const size_t backward_direction_distance_to_end_of_line = + line_topology.get_distance_to_end_of_line(ttnn::ccl::EdmLineFabricOpInterface::Direction::BACKWARD); + + ttnn::ccl::cmd::MulticastCommandDestArgs mcast_dest_args = { + forward_direction_distance_to_end_of_line, backward_direction_distance_to_end_of_line}; + log_trace( + tt::LogOp, + "[mcast_dest_args] num target forward: {}, num target backward: {}", + mcast_dest_args.num_targets_forward_direction, + mcast_dest_args.num_targets_backward_direction); + + auto reader_tensor_slices = + ttnn::ccl::cmd::builder::generate_worker_tensor_slices(1, input_tensor, num_workers_per_link * num_links, dim); + log_trace(tt::LogOp, "reader_tensor_slices size: {}", reader_tensor_slices.size()); + log_trace(tt::LogOp, "reader_tensor_slices[0] size: {}", reader_tensor_slices[0].size()); + + CoreCoord drain_sync_core; + for (std::size_t link = 0; link < num_links; link++) { + CoreCoord core = {num_workers_per_link - 1, link}; + if (link == 0) { + // drain sync core is the first worker core + drain_sync_core = device->worker_core_from_logical_core(core); + } + std::size_t worker_tensor_slice_index = link; + + const auto& input_worker_slice_v2 = input_tensor_slicer.get_worker_slice_v2(worker_tensor_slice_index); + const auto& output_worker_slice_v2 = output_tensor_slicer.get_worker_slice_v2(worker_tensor_slice_index); + + log_trace(tt::LogOp, "DEBUG: input tensor slice v2:"); + print_tensor_slice(input_worker_slice_v2); + log_trace(tt::LogOp, "DEBUG: output tensor slice v2:"); + print_tensor_slice(output_worker_slice_v2); + + std::optional forward_fabric_connection = + line_topology.is_first_device_in_line(ttnn::ccl::EdmLineFabricOpInterface::Direction::BACKWARD) + ? std::nullopt + : std::optional(local_fabric_handle->uniquely_connect_worker( + device, ttnn::ccl::EdmLineFabricOpInterface::FORWARD)); + std::optional backward_fabric_connection = + line_topology.is_last_device_in_line(ttnn::ccl::EdmLineFabricOpInterface::Direction::BACKWARD) + ? std::nullopt + : std::optional(local_fabric_handle->uniquely_connect_worker( + device, ttnn::ccl::EdmLineFabricOpInterface::BACKWARD)); + + log_trace( + tt::LogOp, + "DEBUG: line_index: {}, line_size: {}, forward_fabric_connection: {}", + line_topology.line_index(), + line_topology.line_size(), + forward_fabric_connection.has_value()); + log_trace( + tt::LogOp, + "DEBUG: line_index: {}, line_size: {}, backward_fabric_connection: {}", + line_topology.line_index(), + line_topology.line_size(), + backward_fabric_connection.has_value()); + + // READER COMMAND STREAM and RT ARGS + std::vector reader_cmd_stream; + reader_cmd_stream.push_back( // use the reader_tensor_slices after the bug is fixed + ttnn::ccl::cmd::uops::read_tensor_slice_to_cb_for_eventual_fabric_write( + input_worker_slice_v2, src0_cb_index)); + + ttnn::ccl::worker_detail::generate_multi_input_command_stream_kernel_rt_args( + program, + worker_sender_reader_kernel_id, + {&input_tensor}, + {op_config.get_page_size()}, + input_tensor.device(), + num_pages_per_packet, + {core}, + reader_cmd_stream, + std::nullopt, + std::nullopt, + std::nullopt); + + // WRITER COMMAND STREAM and RT ARGS + std::vector writer_cmd_stream; + // 1, do mcast of the tensor slice to all the destinations + writer_cmd_stream.push_back(ttnn::ccl::cmd::uops::fabric_write_cb_to_tensor_slice( + output_worker_slice_v2, src0_cb_index, mcast_dest_args)); + // 2, mcast the semaphore to all dest for teardown + TT_FATAL( + semaphore_handle != nullptr, + "Internal error during all-=gather fatcory. Global semaphore for fabric teardown not properly " + "initialized for non-persistent fabric mode"); + writer_cmd_stream.push_back(ttnn::ccl::cmd::uops::fabric_multicast_semaphore_inc( + semaphore_handle.get(), + ttnn::ccl::cmd::CclCommandAtomicInc{1}, + drain_sync_core.x, + drain_sync_core.y, + mcast_dest_args)); + if (!enable_async_output_tensor) { + // 3, wait for n_chip*num_links number of semaphore at teardown semaphore address for first chip, and + // n_chip*num_links+1 for other chips + writer_cmd_stream.push_back(ttnn::ccl::cmd::uops::local_semaphore_wait( + semaphore_handle.get(), + is_first_chip ? ring_size * num_links : ring_size * num_links + !enable_persistent_fabric_mode)); + } + + bool generate_teardown_commands = !enable_persistent_fabric_mode && link == 0; + if (generate_teardown_commands) { + // 4, send semaphore unicast to forward device except for the last chip + if (!is_last_chip) { + writer_cmd_stream.push_back(ttnn::ccl::cmd::uops::fabric_unicast_semaphore_inc( + semaphore_handle.get(), + ttnn::ccl::cmd::CclCommandAtomicInc{1}, + drain_sync_core.x, + drain_sync_core.y, + ttnn::ccl::cmd::UnicastCommandDestArgs{1, true})); + } + // 5, increment the termination semaphore for local device for local teardown only for the drain sync core + auto termination_infos = local_fabric_handle->generate_local_chip_fabric_termination_infos(device); + for (auto& info : termination_infos) { + if (info.distance != 0) { + continue; + } + writer_cmd_stream.push_back(ttnn::ccl::cmd::uops::local_chip_noc_absolute_address_semaphore_inc( + info.edm_noc_x, info.edm_noc_y, info.termination_addr, 1)); + } + // 6. (drain sync core) reset semaphore to 0 + writer_cmd_stream.push_back(ttnn::ccl::cmd::uops::local_core_semaphore_set(semaphore_handle.get(), 0)); + } + + // set the rt args + ttnn::ccl::worker_detail::generate_multi_input_command_stream_kernel_rt_args( + program, + worker_sender_writer_kernel_id, + {&output_tensor}, + {op_config.get_page_size()}, + output_tensor.device(), + num_pages_per_packet, // num_pages_per_edm_buffer + {core}, + writer_cmd_stream, + std::nullopt, + {forward_fabric_connection}, + {backward_fabric_connection}); + } + + if (!enable_persistent_fabric_mode) { + local_fabric_handle->build_kernels(); + } + + auto override_runtime_arguments_callback = + [worker_sender_reader_kernel_id, worker_sender_writer_kernel_id, semaphore_handle, sender_worker_cores]( + const void* operation, + Program& program, + const std::vector& input_tensors, + const std::vector>& optional_input_tensors, + const std::vector& output_tensors) { + const auto& input = input_tensors[0]; + const auto& output = output_tensors[0]; + + // update senders + auto& worker_reader_sender_runtime_args_by_core = GetRuntimeArgs(program, worker_sender_reader_kernel_id); + auto& worker_writer_sender_runtime_args_by_core = GetRuntimeArgs(program, worker_sender_writer_kernel_id); + for (const auto& core : sender_worker_cores) { + // reader + auto& worker_reader_sender_runtime_args = worker_reader_sender_runtime_args_by_core[core.x][core.y]; + worker_reader_sender_runtime_args.at(0) = input.buffer()->address(); + // writer + auto& worker_writer_sender_runtime_args = worker_writer_sender_runtime_args_by_core[core.x][core.y]; + worker_writer_sender_runtime_args.at(0) = output.buffer()->address(); + } + }; + + return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_arguments_callback}; +} + +} // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/ccl_experimental_pybind.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/ccl_experimental_pybind.cpp new file mode 100644 index 00000000000..02d57b213fd --- /dev/null +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/ccl_experimental_pybind.cpp @@ -0,0 +1,22 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "ttnn/cpp/ttnn/operations/experimental/ccl/ccl_experimental_pybind.hpp" +#include "ttnn/operations/experimental/ccl/all_gather_matmul/all_gather_matmul_pybind.hpp" +#include "ttnn/operations/experimental/ccl/all_reduce/all_reduce_pybind.hpp" +#include "ttnn/operations/experimental/ccl/all_gather_async/all_gather_async_pybind.hpp" +#include "ttnn/operations/experimental/ccl/reduce_scatter_async/reduce_scatter_pybind.hpp" + +namespace py = pybind11; + +namespace ttnn::operations::experimental::ccl { + +void py_module(py::module& module) { + ccl::py_bind_all_gather_matmul(module); + ccl::py_bind_all_reduce(module); + ccl::py_bind_all_gather_async(module); + ccl::py_bind_reduce_scatter_async(module); +} + +} // namespace ttnn::operations::experimental::ccl diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/ccl_experimental_pybind.hpp b/ttnn/cpp/ttnn/operations/experimental/ccl/ccl_experimental_pybind.hpp new file mode 100644 index 00000000000..91c022ea0fd --- /dev/null +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/ccl_experimental_pybind.hpp @@ -0,0 +1,16 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include + +namespace py = pybind11; + +namespace ttnn::operations::experimental::ccl { + +void py_module(py::module& module); + +} // namespace ttnn::operations::experimental::ccl diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/device/reduce_scatter_async_op.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/device/reduce_scatter_async_op.cpp new file mode 100644 index 00000000000..27053e75455 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/device/reduce_scatter_async_op.cpp @@ -0,0 +1,374 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/device/reduce_scatter_async_op.hpp" +#include "sub_device/sub_device_types.hpp" +#include "tt_metal/host_api.hpp" +#include "tt_metal/impl/buffers/global_semaphore.hpp" + +#include +#include +#include +#include + +namespace ttnn { +namespace ccl { +namespace reduce_scatter_detail { + +ReduceScatterAsync create_reduce_scatter_struct( + const Tensor& input_tensor, + const ttnn::operations::binary::BinaryOpType binary_op_type, + const uint32_t scatter_dim, + const MemoryConfig& output_mem_config, + const std::vector& devices, + const ttnn::ccl::Topology topology, + std::optional> forward_output_tensors, + std::optional> backward_output_tensors, + std::optional num_links_preferred, + const std::optional>>& from_remote_sems, + const std::optional>>& to_remote_sems, + std::optional sub_device_id, + std::optional& fabric_handle) { + uint32_t num_devices = devices.size(); + + auto [device_index, sender_device_id, receiver_device_id] = + get_device_index_and_sender_receiver_ids(input_tensor, devices, topology); + + TT_FATAL( + receiver_device_id != std::nullopt || sender_device_id != std::nullopt, + "Error, Reduce-scatter was unable to identify either a sender or receiver device ID and atleast one must be " + "identified for a valid Reduce-scatter configuration. The input mesh tensor or Reduce-scatter arguments may be " + "incorrect"); + + auto find_device = [](const std::vector& devices, std::optional id) -> std::optional { + if (id == std::nullopt) { + return std::nullopt; + } + auto device = std::find_if( + devices.begin(), devices.end(), [id_ = id.value()](Device const* d) { return d->id() == id_; }); + TT_FATAL( + device != devices.end(), + "Device with ID {} not found in the list of devices, but it should be here since it was provided " + "previously", + id.value()); + return *device; + }; + + std::optional> from_remote_sem = std::nullopt; + std::optional> to_remote_sem = std::nullopt; + if (from_remote_sems.has_value()) { + from_remote_sem = from_remote_sems.value().at(device_index); + } + if (to_remote_sems.has_value()) { + to_remote_sem = to_remote_sems.value().at(device_index); + } + + return ttnn::ReduceScatterAsync{ + binary_op_type, + scatter_dim, + num_devices, + device_index, + find_device(devices, receiver_device_id), + find_device(devices, sender_device_id), + output_mem_config, + topology, + forward_output_tensors, + backward_output_tensors, + num_links_preferred, + from_remote_sem, + to_remote_sem, + sub_device_id, + fabric_handle}; +} +} // namespace reduce_scatter_detail +} // namespace ccl + +void ReduceScatterAsync::validate(const std::vector& input_tensors) const { + for (auto const& t : input_tensors) { + TT_FATAL( + t.get_legacy_shape()[this->scatter_dim] / this->ring_size > 0, + "Reduce scatter input tensor shape on dim {} must be divisible by ring size", + this->scatter_dim); + TT_FATAL( + t.get_legacy_shape()[this->scatter_dim] % this->ring_size == 0, + "Reduce scatter input tensor shape on dim {} must be divisible by ring size", + this->scatter_dim); + } +} + +std::vector ReduceScatterAsync::compute_output_shapes( + const std::vector& input_tensors) const { + auto shape = input_tensors[0].get_logical_shape(); + TT_FATAL( + shape[this->scatter_dim] % this->ring_size == 0, + "The size of the scatter dimension must be a multiple of the ring size. Dimension size: {}, ring Size: {}", + shape[this->scatter_dim], + this->ring_size); + shape[this->scatter_dim] /= this->ring_size; + return std::vector(input_tensors.size(), shape); +} + +std::vector ReduceScatterAsync::create_output_tensors(const std::vector& input_tensors) const { + const auto& input_tensor = input_tensors.at(0); + // output tensors + // 0. final (real) output_tensor + // 1. input_tensor_from_remote_forward_direction (shape of input tensor) + // 2. input_tensor_from_remote_backward_direction (shape of input tensor) + // 3. partial_output_tensor_forward_direction (shape of output tensor) + // 4. partial_output_tensor_backward_direction (shape of output tensor) + + bool is_tile_layout = input_tensor.get_layout() == Layout::TILE; + std::optional tile = + is_tile_layout ? input_tensor.get_tensor_spec().tile() : std::optional(std::nullopt); + + std::vector output_tensors; + output_tensors.reserve(5); + // real_output_tensor + output_tensors.emplace_back(create_device_tensor( + this->compute_output_shapes(input_tensors).at(0), + input_tensor.get_dtype(), + input_tensor.get_layout(), + input_tensor.device(), + this->output_mem_config, + tile)); + // temporary_input_from_remote_tensor_for_forward_direction + output_tensors.emplace_back(create_device_tensor( + input_tensor.shape(), + input_tensor.get_dtype(), + input_tensor.get_layout(), + input_tensor.device(), + input_tensor.memory_config(), + tile)); + // temporary_input_from_remote_tensor_for_backward_direction + output_tensors.emplace_back(create_device_tensor( + input_tensor.shape(), + input_tensor.get_dtype(), + input_tensor.get_layout(), + input_tensor.device(), + input_tensor.memory_config(), + tile)); + // temporary_partial_output_tensor_for_forward_direction + output_tensors.emplace_back(create_device_tensor( + this->compute_output_shapes(input_tensors).at(0), + input_tensor.get_dtype(), + input_tensor.get_layout(), + input_tensor.device(), + this->output_mem_config, + tile)); + // temporary_partial_output_tensor_for_backward_direction + output_tensors.emplace_back(create_device_tensor( + this->compute_output_shapes(input_tensors).at(0), + input_tensor.get_dtype(), + input_tensor.get_layout(), + input_tensor.device(), + this->output_mem_config, + tile)); + + return output_tensors; +} + +operation::ProgramWithCallbacks ReduceScatterAsync::create_program( + const std::vector& input_tensors, std::vector& output_tensors) const { + std::optional foreward_direction_remote_output_tensor = std::nullopt; + std::optional backward_direction_remote_output_tensor = std::nullopt; + return ccl::reduce_scatter_detail::build_reduce_scatter_async_program( + input_tensors.at(0), // true input_tensor + output_tensors.at(0), // final output_tensor + output_tensors.at(1), // input_tensor_from_remote_forward_direction + output_tensors.at(2), // input_tensor_from_remote_backward_direction + output_tensors.at(3), // partial_output_tensor_forward_direction + output_tensors.at(4), // partial_output_tensor_backward_direction + foreward_direction_remote_output_tensor, + backward_direction_remote_output_tensor, + this->forward_device, + this->backward_device, + this->binary_op_type, + this->scatter_dim, + this->ring_size, + this->ring_index, + this->topology, + this->num_links_preferred, + this->from_remote_sem, + this->to_remote_sem, + this->fabric_handle); +} + +operation::Hash ReduceScatterAsync::compute_program_hash(const std::vector& input_tensors) const { + return operation::hash_operation( + this->binary_op_type, + this->scatter_dim, + this->ring_size, + this->ring_index, + this->topology, + this->from_remote_sem.has_value() ? this->from_remote_sem.value().get() : nullptr, + this->to_remote_sem.has_value() ? this->to_remote_sem.value().get() : nullptr); +} + +namespace { +namespace CMAKE_UNIQUE_NAMESPACE { +ttnn::operations::binary::BinaryOpType convert_reduce_type_to_eltwise_type( + ttnn::operations::reduction::ReduceType reduce_op) { + // Leaving switch statement for future support of additional types. + switch (reduce_op) { + case ttnn::operations::reduction::ReduceType::Sum: return ttnn::operations::binary::BinaryOpType::ADD; + default: + TT_THROW("Reduce scatter only supports reduce_type Sum. Op type {} not supported.", reduce_op); + return ttnn::operations::binary::BinaryOpType::ADD; + } +} +} // namespace CMAKE_UNIQUE_NAMESPACE +} // namespace + +std::vector> create_global_semaphores( + const std::vector& devices, std::optional worker_subdevice_id_opt = std::nullopt) { + std::vector> semaphores; + auto worker_cores = CoreRangeSet(CoreRange(CoreCoord(0, 0), CoreCoord(6, 6))); + for (Device* d : devices) { + CoreCoord grid_size = devices[0]->compute_with_storage_grid_size(); + auto core_grid = CoreRange({0, 0}, {grid_size.x - 1, grid_size.y - 1}); + auto worker_subdevice_id = worker_subdevice_id_opt.has_value() + ? tt::stl::Span{{worker_subdevice_id_opt.value()}} + : tt::stl::Span{}; + auto sem = CreateGlobalSemaphore(d, core_grid, 0, BufferType::L1, worker_subdevice_id); + semaphores.push_back(sem); + } + + auto first_addr = semaphores.front()->address(); + bool all_same = std::all_of( + semaphores.begin(), semaphores.end(), [first_addr](const auto& sem) { return sem->address() == first_addr; }); + + if (!all_same) { + DeviceAddr highest_addr = semaphores.front()->address(); + for (auto i = 1; i < semaphores.size(); i++) { + if (semaphores[i]->address() > highest_addr) { + highest_addr = semaphores[i]->address(); + } + }; + for (auto i = 0; i < semaphores.size(); i++) { + size_t attempts = 1000; + size_t attempt = 0; + std::vector> garbage; + CoreCoord grid_size = devices[0]->compute_with_storage_grid_size(); + auto core_grid = CoreRange({0, 0}, {grid_size.x - 1, grid_size.y - 1}); + while (semaphores[i]->address() != highest_addr) { + auto worker_subdevice_id = worker_subdevice_id_opt.has_value() + ? tt::stl::Span{worker_subdevice_id_opt.value()} + : tt::stl::Span{}; + auto sem = CreateGlobalSemaphore(devices[i], core_grid, 0, BufferType::L1, worker_subdevice_id); + if (sem->address() == highest_addr) { + semaphores[i] = sem; + } else { + garbage.push_back(std::move(sem)); + attempt++; + } + + if (attempt > attempts) { + TT_THROW("Failed to create global semaphores with the same address"); + } + } + } + } + return semaphores; +} + +namespace operations { +namespace experimental { +namespace ccl { +Tensor reduce_scatter( + const Tensor& input_tensor, + const int32_t dim, + ttnn::operations::reduction::ReduceType math_op, + const MemoryConfig& output_mem_config, + ttnn::ccl::Topology topology, + const std::optional num_links_preferred, + std::optional worker_subdevice_id_opt, + bool create_semaphore_handles, + std::optional fabric_handle) { + using namespace CMAKE_UNIQUE_NAMESPACE; + ttnn::operations::binary::BinaryOpType binary_op_type = convert_reduce_type_to_eltwise_type(math_op); + TT_FATAL( + std::getenv("TT_METAL_SLOW_DISPATCH_MODE") == nullptr, "reduce_scatter op is only supported for Fast Dispatch"); + + ttnn::ccl::Topology ccl_topology = topology; + auto devices = input_tensor.get_workers(); + uint32_t num_devices = devices.size(); + TT_FATAL(num_devices > 1, "reduce_scatter op will only work for num_devices > 1, but has {}", num_devices); + if (num_devices == 2) { + ccl_topology = ttnn::ccl::Topology::Linear; + } + + int16_t rank = input_tensor.get_logical_shape().rank(); + int16_t scatter_dim = (dim < 0) ? rank + dim : dim; + TT_FATAL( + scatter_dim >= -rank && scatter_dim <= rank - 1, + "Dimension input should be in between -{} and {}, but has {}", + rank, + rank - 1, + dim); + + std::optional>> from_remote_inputs_semaphores_opt; + std::optional>> to_remote_inputs_semaphores_opt; + if (create_semaphore_handles) { + const auto from_remote_inputs_semaphores = create_global_semaphores(devices, worker_subdevice_id_opt); + const auto to_remote_inputs_semaphores = create_global_semaphores(devices, worker_subdevice_id_opt); + from_remote_inputs_semaphores_opt = from_remote_inputs_semaphores; + to_remote_inputs_semaphores_opt = to_remote_inputs_semaphores; + } else { + from_remote_inputs_semaphores_opt = std::nullopt; + to_remote_inputs_semaphores_opt = std::nullopt; + } + + std::vector output_tensors = { + Tensor(operation::get_workers_for_op_output({input_tensor})), + Tensor(operation::get_workers_for_op_output({input_tensor})), + Tensor(operation::get_workers_for_op_output({input_tensor})), + Tensor(operation::get_workers_for_op_output({input_tensor})), + Tensor(operation::get_workers_for_op_output({input_tensor}))}; + TT_FATAL( + output_tensors.size() == 5, + "Reduce scatter requires 5 output tensors. 1 is real and the others are temporaries"); + operation::launch_op( + [binary_op_type, + from_remote_inputs_semaphores_opt, + to_remote_inputs_semaphores_opt, + scatter_dim, + output_mem_config, + ccl_topology, + devices, + num_links_preferred, + output_tensors, + worker_subdevice_id_opt, + fabric_handle]( + const std::vector& input_tensors, + const std::vector>& optional_input_tensors, + const std::vector>& optional_output_tensors) mutable -> std::vector { + const auto& input_tensor = input_tensors.at(0); + + return operation::run( + ttnn::ccl::reduce_scatter_detail::create_reduce_scatter_struct( + input_tensor, + binary_op_type, + scatter_dim, + output_mem_config, + devices, + ccl_topology, + std::nullopt, + std::nullopt, + num_links_preferred, + from_remote_inputs_semaphores_opt, + to_remote_inputs_semaphores_opt, + worker_subdevice_id_opt, + fabric_handle), + {input_tensor}); + }, + {input_tensor}, + output_tensors); + return output_tensors.at(0); +} + +} // namespace ccl +} // namespace experimental +} // namespace operations + +}; // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/device/reduce_scatter_async_op.hpp b/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/device/reduce_scatter_async_op.hpp new file mode 100644 index 00000000000..5d60c264b2a --- /dev/null +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/device/reduce_scatter_async_op.hpp @@ -0,0 +1,150 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "sub_device/sub_device_types.hpp" +#include "ttnn/operations/ccl/ccl_common.hpp" +#include "ttnn/operations/reduction/generic/generic_reductions.hpp" +#include "ttnn/operations/eltwise/binary/binary.hpp" + +namespace ttnn { +struct ReduceScatterAsync { + ReduceScatterAsync( + const ttnn::operations::binary::BinaryOpType binary_op_type, + const uint32_t scatter_dim, + const uint32_t ring_size, + const uint32_t ring_index, + const std::optional forward_device, + const std::optional backward_device, + const MemoryConfig& output_mem_config, + const ttnn::ccl::Topology topology, + std::optional>& foreward_output_tensors, + std::optional>& backward_output_tensors, + std::optional num_links_preferred, + const std::optional>& from_remote_sem, + const std::optional>& to_remote_sem, + std::optional& sub_device_id, + std::optional& fabric_handle) : + binary_op_type(binary_op_type), + scatter_dim(scatter_dim), + ring_size(ring_size), + ring_index(ring_index), + forward_device(forward_device), + backward_device(backward_device), + output_mem_config(output_mem_config), + topology(topology), + foreward_output_tensors(foreward_output_tensors), + backward_output_tensors(backward_output_tensors), + num_links_preferred(num_links_preferred), + from_remote_sem(from_remote_sem), + to_remote_sem(to_remote_sem), + fabric_handle(fabric_handle), + sub_device_id(sub_device_id) {} + + const ttnn::operations::binary::BinaryOpType binary_op_type; + const uint32_t scatter_dim; + const uint32_t ring_size; + const uint32_t ring_index; + const std::optional forward_device; + const std::optional backward_device; + const MemoryConfig output_mem_config; + const ttnn::ccl::Topology topology; + // const + std::optional> foreward_output_tensors; + std::optional> backward_output_tensors; + std::optional num_links_preferred; + std::optional> from_remote_sem; + std::optional> to_remote_sem; + std::optional& fabric_handle; + std::optional sub_device_id; + + auto attributes() const { + using tt::stl::reflection::Attribute; + std::vector> attrs; + + attrs.emplace_back("binary_op_type", binary_op_type); + attrs.emplace_back("dim", scatter_dim); + attrs.emplace_back("ring_size", ring_size); + attrs.emplace_back("ring_index", ring_index); + attrs.emplace_back("forward_device", forward_device); + attrs.emplace_back("backward_device", backward_device); + attrs.emplace_back("num_links_preferred", num_links_preferred); + attrs.emplace_back("output_mem_config", output_mem_config); + attrs.emplace_back("topology", topology); + + return attrs; + } + + void validate(const std::vector& input_tensors) const; + std::vector compute_output_shapes(const std::vector& input_tensors) const; + std::vector create_output_tensors(const std::vector& input_tensors) const; + operation::ProgramWithCallbacks create_program( + const std::vector& input_tensors, std::vector& output_tensors) const; + operation::Hash compute_program_hash(const std::vector& input_tensors) const; +}; + +namespace ccl { +namespace reduce_scatter_detail { +operation::ProgramWithCallbacks build_reduce_scatter_async_program( + const Tensor& input_tensor, + Tensor& local_output_tensor, + Tensor& input_tensor_from_remote_forward_direction, + Tensor& input_tensor_from_remote_backward_direction, + Tensor& partial_output_tensor_to_forward_direction, + Tensor& partial_output_tensor_to_backward_direction, + std::optional& foreward_direction_remote_output_tensor, + std::optional& backward_direction_remote_output_tensor, + std::optional forward_device, + std::optional backward_device, + ttnn::operations::binary::BinaryOpType reduce_op, + const uint32_t dim, + const uint32_t line_size, + const uint32_t line_index, + ttnn::ccl::Topology topology, + std::optional num_links_preferred, + const std::optional>& from_remote_sem_opt, + const std::optional>& to_remote_sem_opt, + std::optional& fabric_handle); +} +}; // namespace ccl + +namespace ccl { +namespace reduce_scatter_detail { +ReduceScatterAsync create_reduce_scatter_struct( + const Tensor& input_tensor, + const ttnn::operations::binary::BinaryOpType binary_op_type, + const uint32_t dim, + const MemoryConfig& output_mem_config, + const std::vector& devices, + const ttnn::ccl::Topology topology, + std::optional> foreward_output_tensors, + std::optional> backward_output_tensors, + std::optional num_links_preferred, + const std::optional>>& from_remote_sems, + const std::optional>>& to_remote_sems, + std::unordered_map& sub_device_id_map, + std::optional& fabric_handle); +} // namespace reduce_scatter_detail +} // namespace ccl + +namespace operations { +namespace experimental { +namespace ccl { +Tensor reduce_scatter( + const Tensor& input_tensor, + const int32_t dim, + ttnn::operations::reduction::ReduceType reduce_op = ttnn::operations::reduction::ReduceType::Sum, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, + ttnn::ccl::Topology topology = ttnn::ccl::Topology::Linear, + const std::optional num_preferred_links = std::nullopt, + std::optional worker_subdevice_id_opt = std::nullopt, // TODO make reference + bool create_semaphore_handles = true, + std::optional fabric_handle = std::nullopt); // TODO make reference + +} // namespace ccl +} // namespace experimental +} // namespace operations + +} // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/device/reduce_scatter_async_program.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/device/reduce_scatter_async_program.cpp new file mode 100644 index 00000000000..b3464e21ad9 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/device/reduce_scatter_async_program.cpp @@ -0,0 +1,2172 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC. +// +// SPDX-License-Identifier: Apache-2.0 +/// + +#include +#include +#include +#include +#include +#include +#include "common/core_coord.hpp" +#include "common/logger.hpp" +#include "device/device.hpp" +#include "kernels/kernel_types.hpp" +#include "span.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/erisc_datamover_builder.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/common/host/ccl_worker_builder.hpp" +#include "tt_metal/host_api.hpp" +#include "ttnn/operation.hpp" + +// For reduction op +#include "ttnn/operations/ccl/common/uops/ccl_host_commands.hpp" +#include "ttnn/operations/eltwise/binary/common/binary_op_types.hpp" +#include "ttnn/operations/eltwise/binary/common/binary_op_utils.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/common/uops/ccl_command.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/common/types/ccl_types_args_emitters.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/common/host/ccl_command_stream_builders.hpp" +#include "tt_metal/impl/buffers/global_semaphore.hpp" +#include "tt_metal/tt_stl/overloaded.hpp" + +/* + * This file contains the program factory for reduce scatter operation implemented on line (and soon, ring) topologies. + * The current implementation is fairly memory inefficient, however, even when optimized the general approach is as + follows: + * + * Lo + * + * IN 0 IN 1 IN 2 IN 3 OUT 0 OUT 1 OUT 2 OUT 3 + * C0 C1 C2 C3 C0 C1 C2 C3 + * ┌────┐ ┌────┐ ┌────┐ ┌────┐ ┌────┐ ...... ...... ...... + * │ │ │ │ │ │ │ │ │////│ . . . . . . + * │ │ │ │ │ │ │ │ │////│ . . . . . . + * │ │ │ │ │ │ │ │ │////│ . . . . . . + * ├────┤ ├────┤ ├────┤ ├────┤ └────┘ ┌────┐ ...... ...... + * │ │ │ │ │ │ │ │ . . │////│ . . . . + * │ │ │ │ │ │ │ │ . . │////│ . . . . + * │ │ │ │ │ │ │ │ . . │////│ . . . . + * ├────┤ ├────┤ ├────┤ ├────┤ ────► ...... └────┘ ┌────┐ ...... + * │ │ │ │ │ │ │ │ . . . . │////│ . . + * │ │ │ │ │ │ │ │ . . . . │////│ . . + * │ │ │ │ │ │ │ │ . . . . │////│ . . + * ├────┤ ├────┤ ├────┤ ├────┤ ...... ...... └────┘ ┌────┐ + * │ │ │ │ │ │ │ │ . . . . . . │////│ + * │ │ │ │ │ │ │ │ . . . . . . │////│ + * │ │ │ │ │ │ │ │ . . . . . . │////│ + * └────┘ └────┘ └────┘ └────┘ ...... ...... ...... └────┘ + * + * + * + * + + + + * + * ┌────┐ ┌────┐ ┌────┐ ┌────┐ + * ├─►+◄┼──────┼─ +◄┼─────┼──+◄┼─────┼── │ + * │ │ │ ▲ │ │ ▲ │ │ │ + * │ │ │ └ │ │ └ │ │ │ + * ┼────┼ ┼────┼ ┼────┼ ┼────┼ + * │ │ │ ┌ │ │ ┌ │ │ │ + * │ │ │ ▼ │ │ ▼ │ │ │ + * │ ───┼──────┼►+◄─┼─────┼──+◄┼─────┼── │ + * ┼────┼ ┼────┼ ┼────┼ ┼────┼ + * │ │ │ ┌ │ │ ┌ │ │ │ + * │ │ │ ▼ │ │ ▼ │ │ │ + * │ ──┼──────┼►+──┼─────┼─►+◄┼─────┼── │ + * ┼────┼ ┼────┼ ┼────┼ ┼────┼ + * │ │ │ ┌ │ │ ┌ │ │ ┌ │ + * │ │ │ ▼ │ │ ▼ │ │ ▼ │ + * │ ──┼──────┼►+──┼─────┼►+ ─┼─────┼►+ │ + * └────┘ └────┘ └────┘ └────┘ + * + * + */ + +namespace ttnn::ccl::reduce_scatter_detail { + +using ttnn::ccl::Shape4D; +using ttnn::ccl::cmd::CclCommandTensor; + +enum fabric_lifetime_mode { + // The fabric's lifetime exceed (before and after) the lifetime of the op + // so the op should not in any way manage fabric lifetime + PERSISTENT, + + // The fabric is brought up and torn down for each op invocation + TRANSIENT +}; + +enum LineDirection { FORWARD, BACKWARD }; +static_assert( + static_cast(LineDirection::FORWARD) == + static_cast(ttnn::ccl::EdmLineFabricOpInterface::Direction::FORWARD)); +static_assert( + static_cast(LineDirection::BACKWARD) == + static_cast(ttnn::ccl::EdmLineFabricOpInterface::Direction::BACKWARD)); + +constexpr LineDirection relay_to_final_output_dir = LineDirection::FORWARD; +// TODO: promote to header + +struct ReduceScatterCircularBuffers { + uint32_t reader_to_writer_shortcut_cb = -1; + uint32_t reader_to_math_operand0_cb = -1; + uint32_t reader_to_math_operand1_cb = -1; + uint32_t math_to_writer_cb = -1; + CBHandle reader_to_writer_shortcut_cb_handle = -1; + CBHandle reader_to_math_operand0_cb_handle = -1; + CBHandle reader_to_math_operand1_cb_handle = -1; + CBHandle math_to_writer_cb_handle = -1; +}; + +struct CircularBufferSpec { + size_t cb_size = 0; + size_t page_size = 0; + uint32_t cb_index = 0; + tt::DataFormat df = tt::DataFormat::Invalid; +}; + +struct ReduceScatterKernelHandles { + KernelHandle reader = -1; + KernelHandle math = -1; + KernelHandle writer = -1; +}; + +// We really need something like a graph here to describe the dependencies generically but for +// now we keep it very simple and constrained +struct TensorSyncSpec { + static constexpr int UNINITIALIZED_DEST_NOC = -1; + struct target_rect { + int dest_noc0_x_start = UNINITIALIZED_DEST_NOC; + int dest_noc0_y_start = UNINITIALIZED_DEST_NOC; + int dest_noc0_x_end = UNINITIALIZED_DEST_NOC; + int dest_noc0_y_end = UNINITIALIZED_DEST_NOC; + }; + // always equal to number of slices for now + std::vector> semaphore_ids; + std::vector completion_target_value_per_semaphore; + std::vector targets; + + ttnn::ccl::cmd::CclCommandCoreDescriptorTypeMcast get_target(size_t i) const { + return ttnn::ccl::cmd::CclCommandCoreDescriptorTypeMcast{ + static_cast(targets.at(i).dest_noc0_x_start), + static_cast(targets.at(i).dest_noc0_y_start), + static_cast(targets.at(i).dest_noc0_x_end), + static_cast(targets.at(i).dest_noc0_y_end)}; + } + + size_t num_semaphores() const { return semaphore_ids.size(); } + + std::variant const& get_tensor_sync_semaphore(size_t slice_index) const { + TT_FATAL( + slice_index < semaphore_ids.size(), + "Internal error. Requested semaphore id for slice index that does not exist"); + return semaphore_ids.at(slice_index); + } +}; + +struct WorkerCoreBundle { + CoreRangeSet all_worker_cores; + CoreRangeSet final_reducers; + std::array partial_reducers; + + std::vector all_worker_cores_vec; + std::vector final_reducers_vec; + std::array, 2> partial_reducers_vec; +}; + +struct ProgramTensorsBundle { + Tensor const* input_tensor = nullptr; + std::optional input_tensor_sync; + Tensor* local_output_tensor = nullptr; + std::optional local_output_sync; + std::array input_tensor_from_remote = {nullptr, nullptr}; + std::array input_tensor_from_remote_sync; + std::array remote_output = {nullptr, nullptr}; + std::array remote_output_sync; + std::array local_output_partial = {nullptr, nullptr}; + std::array local_output_partial_sync; + + static Tensor* build_handle(Tensor& tensor) { return &tensor; } + static Tensor const* build_handle(Tensor const& tensor) { return &tensor; } + static Tensor* build_handle(std::optional& tensor) { + return tensor.has_value() ? &tensor.value() : nullptr; + } +}; + +static ReduceScatterCircularBuffers create_worker_circular_buffers( + tt::tt_metal::Program& program, + CoreRangeSet const& worker_core_range, + CircularBufferSpec const& shortcut_cb_spec, + CircularBufferSpec const& reader_to_math0_cb_spec, + CircularBufferSpec const& reader_to_math1_cb_spec, + CircularBufferSpec const& math_to_writer_cb_spec) { + TT_FATAL( + shortcut_cb_spec.cb_size % shortcut_cb_spec.page_size == 0, + "Shortcut circular buffer size must be a multiple of the page size"); + TT_FATAL( + reader_to_math0_cb_spec.cb_size % reader_to_math0_cb_spec.page_size == 0, + "Reader to math circular buffer size must be a multiple of the page size"); + TT_FATAL( + reader_to_math1_cb_spec.cb_size % reader_to_math1_cb_spec.page_size == 0, + "Reader to math circular buffer size must be a multiple of the page size"); + TT_FATAL( + math_to_writer_cb_spec.cb_size % math_to_writer_cb_spec.page_size == 0, + "Math to writer circular buffer size must be a multiple of the page size"); + + auto generate_circular_buffer = [&program, &worker_core_range](CircularBufferSpec const& cb_spec) -> CBHandle { + tt::tt_metal::CircularBufferConfig cb_config = + tt::tt_metal::CircularBufferConfig(cb_spec.cb_size, {{cb_spec.cb_index, cb_spec.df}}) + .set_page_size(cb_spec.cb_index, cb_spec.page_size); + CBHandle cb_handle = CreateCircularBuffer(program, worker_core_range, cb_config); + return cb_handle; + }; + + return ReduceScatterCircularBuffers{ + shortcut_cb_spec.cb_index, + reader_to_math0_cb_spec.cb_index, + reader_to_math1_cb_spec.cb_index, + math_to_writer_cb_spec.cb_index, + generate_circular_buffer(shortcut_cb_spec), + generate_circular_buffer(reader_to_math0_cb_spec), + generate_circular_buffer(reader_to_math1_cb_spec), + generate_circular_buffer(math_to_writer_cb_spec)}; +} + +static ReduceScatterCircularBuffers create_worker_circular_buffers( + tt::tt_metal::Program& program, + CoreRangeSet const& worker_core_range, + tt::DataFormat df, + + const tt::CBIndex math_in0_cb, + const tt::CBIndex math_in1_cb, + const tt::CBIndex math_out_cb, + const tt::CBIndex pass_through_cb, + size_t fabric_buffer_size_pages, + size_t page_size) { + size_t buffer_depth_multiplier = 3; + auto cb_handles = create_worker_circular_buffers( + program, + worker_core_range, + CircularBufferSpec{ + buffer_depth_multiplier * fabric_buffer_size_pages * page_size, + page_size, + pass_through_cb, + df}, + CircularBufferSpec{ + buffer_depth_multiplier * fabric_buffer_size_pages * page_size, + page_size, + math_in0_cb, + df}, + CircularBufferSpec{ + buffer_depth_multiplier * fabric_buffer_size_pages * page_size, + page_size, + math_in1_cb, + df}, + CircularBufferSpec{ + buffer_depth_multiplier * fabric_buffer_size_pages * page_size, + page_size, + math_out_cb, + df}); + + TT_FATAL(cb_handles.math_to_writer_cb != -1, "Math to writer circular buffer handle is invalid"); + TT_FATAL(cb_handles.reader_to_math_operand0_cb != -1, "Reader to math0 circular buffer handle is invalid"); + TT_FATAL(cb_handles.reader_to_math_operand1_cb != -1, "Reader to math1 circular buffer handle is invalid"); + TT_FATAL( + cb_handles.reader_to_writer_shortcut_cb != -1, "Reader to writer shortcut circular buffer handle is invalid"); + return cb_handles; +} + +template +static std::vector vslice(std::vector const& vec, std::size_t start, std::size_t end_inclusive) { + TT_FATAL(end_inclusive < vec.size(), "Out of bounds access in vslice for vector of size {}. Requested end_inclusive index {}.", vec.size(), end_inclusive); + TT_FATAL(start < vec.size(), "Out of bounds access in vslice for vector of size {}. Requested start index {}.", vec.size(), start); + std::vector output; + if (start > end_inclusive) { + size_t n_elem = start - end_inclusive + 1; + output.reserve(n_elem); + std::copy( + vec.rbegin() + (vec.size() - 1 - start), + vec.rbegin() + (vec.size() - 1 - start + n_elem), + std::back_inserter(output)); + + } else { + output.reserve(end_inclusive - start + 1); + std::copy(vec.begin() + start, vec.begin() + end_inclusive + 1, std::back_inserter(output)); + } + return output; +} + +class LineTopology { +public: + LineTopology(size_t line_size, size_t line_index) : _line_size(line_size), _line_index(line_index) {} + + bool is_first_device_in_line(LineDirection direction) const { + if (direction == LineDirection::FORWARD) { + return _line_index == 0; + } else { + TT_ASSERT(direction == LineDirection::BACKWARD); + return _line_index == _line_size - 1; + } + } + bool is_last_device_in_line(LineDirection direction) const { + if (direction == LineDirection::BACKWARD) { + return _line_index == 0; + } else { + TT_ASSERT(direction == LineDirection::FORWARD); + return _line_index == _line_size - 1; + } + } + + bool is_at_end_of_line() const { return _line_index == 0 || _line_index == _line_size - 1; } + + size_t line_size() const { return _line_size; } + + size_t line_index() const { return _line_index; } + + ttnn::ccl::Topology topology() const { return ttnn::ccl::Topology::Linear; } + +private: + size_t _line_size; + size_t _line_index; +}; + +struct TensorSyncBundle { + const Tensor* tensor; + std::optional sync_spec; +}; + +struct ReaderCircularBufferIds { + uint32_t pass_through; + uint32_t math_in0; + uint32_t math_in1; +}; + +struct WriterCircularBufferIds { + uint32_t pass_through; + uint32_t math_out; +}; +struct FinalReducerReaderCircularBufferIds { + uint32_t math_in0; + uint32_t math_in1; +}; +struct FinalReducerWriterCircularBufferIds { + uint32_t math_out; +}; +struct LineStartReaderCircularBufferIds { + uint32_t pass_through; +}; +struct LineStartWriterCircularBufferIds { + uint32_t pass_through; +}; +struct LineEndReaderCircularBufferIds { + uint32_t math_in0; + uint32_t math_in1; +}; +struct LineEndWriterCircularBufferIds { + uint32_t math_out; +}; + +struct AllReduceScatterCircularBufferIds { + ReaderCircularBufferIds partial_reducer_reader; + WriterCircularBufferIds partial_reducer_writer; + FinalReducerReaderCircularBufferIds final_reducer_reader; + FinalReducerWriterCircularBufferIds final_reducer_writer; + LineStartReaderCircularBufferIds line_start_reader; + LineStartWriterCircularBufferIds line_start_writer; + LineEndReaderCircularBufferIds line_end_reader; + LineEndWriterCircularBufferIds line_end_writer; +}; + +struct WorkerCommandStreams { + std::unordered_map reader_cmds0; + std::unordered_map reader_cmds1; + std::unordered_map writer_cmds0; + std::unordered_map writer_cmds1; +}; + +struct ReduceScatterBuilderConfig { + std::reference_wrapper program; + Device* device; + Device* forward_device; + Device* backward_device; + std::reference_wrapper fabric; + std::reference_wrapper all_tensors; + std::reference_wrapper kernel_ids; + std::reference_wrapper all_cbs; + std::reference_wrapper topology_config; + std::reference_wrapper worker_cores; + size_t page_size = std::numeric_limits::max(); + size_t pages_per_cb_packet = std::numeric_limits::max(); + size_t dim = std::numeric_limits::max(); +}; + + +static WorkerCoreBundle select_worker_cores_for_line_topology(size_t num_links, Device *device) { + + auto build_all_workers_list = [](CoreRangeSet const& available_cores, size_t total_cores_needed, std::vector &all_cores_out) { + + for (const auto& cr : available_cores.ranges()) { + auto start = cr.start_coord; + auto end = cr.end_coord; + for (size_t y = start.y; y <= end.y; y++) { + for (size_t x = start.x; x <= end.x; x++) { + all_cores_out.push_back(CoreCoord(x, y)); + if (all_cores_out.size() == total_cores_needed) { + return; + } + } + } + } + }; + + static constexpr std::size_t num_directions_per_line = 2; + WorkerCoreBundle worker_cores; + size_t current_chunk = 0; + + constexpr size_t num_final_reducers_per_link = 1; + constexpr size_t per_link_num_workers_needed = num_directions_per_line + num_final_reducers_per_link; + const size_t total_cores_needed = per_link_num_workers_needed * num_links; + const auto available_cores = + device->worker_cores(HalProgrammableCoreType::TENSIX, device->get_sub_device_ids().at(0)); + if (available_cores.num_cores() < total_cores_needed) { + log_warning( + tt::LogOp, + "AllGather is being launched on a subdevice with fewer worker cores available than ideal. Ideally {} " + "cores are available ({} per link and {} links) are made available but only {} are available. This may " + "lead to performance loss.", + total_cores_needed, + per_link_num_workers_needed, + num_links, + available_cores.num_cores()); + TT_THROW("Reduce scatter async currently doesn't support running with fewer than preferred number of workers"); + } + std::vector all_cores; + all_cores.reserve(total_cores_needed); + build_all_workers_list(available_cores, total_cores_needed, all_cores); + + auto add_workers = [&num_links](std::vector::iterator &worker_iter, CoreRangeSet& cores) { + for (size_t l = 0; l < num_links; l++) { + cores = cores.merge(CoreRangeSet(CoreRange(*worker_iter))); + worker_iter++; + } + }; + + auto worker_coord_iter = all_cores.begin(); + for (size_t d = 0; d < num_directions_per_line; d++) { + add_workers(worker_coord_iter, worker_cores.partial_reducers[d]); + } + add_workers(worker_coord_iter, worker_cores.final_reducers); + + // Merge them all into the global set for convenience anywhere we want to access all worker cores easily + for (size_t d = 0; d < num_directions_per_line; d++) { + worker_cores.all_worker_cores = worker_cores.all_worker_cores.merge(worker_cores.partial_reducers[d]); + } + worker_cores.all_worker_cores = worker_cores.all_worker_cores.merge(worker_cores.final_reducers); + log_trace(tt::LogOp, "Worker cores: ", worker_cores.all_worker_cores); + + worker_cores.all_worker_cores_vec = corerange_to_cores(worker_cores.all_worker_cores, std::nullopt, true); + worker_cores.final_reducers_vec = corerange_to_cores(worker_cores.final_reducers, std::nullopt, true); + for (size_t d = 0; d < num_directions_per_line; d++) { + worker_cores.partial_reducers_vec[d] = corerange_to_cores(worker_cores.partial_reducers[d], std::nullopt, true); + } + + return worker_cores; +} + + + +/* + * Core range sets for line topology + * BORROWED FROM REDUCE SCATTER but modified a fair bit + * TODO: COMMONIZE + */ +/* +static WorkerCoreBundle select_worker_cores_for_line_topology(size_t num_links) { + static constexpr std::size_t num_directions_per_line = 2; + WorkerCoreBundle worker_cores; + size_t current_chunk = 0; + for (size_t d = 0; d < num_directions_per_line; d++) { + worker_cores.partial_reducers[d] = + CoreRangeSet(CoreRange(CoreCoord(0, current_chunk), CoreCoord(num_links - 1, current_chunk))); + current_chunk++; + } + worker_cores.final_reducers = + CoreRangeSet(CoreRange(CoreCoord(0, current_chunk), CoreCoord(num_links - 1, current_chunk))); + current_chunk++; + + // Merge them all into the global set for convenience anywhere we want to access all worker cores easily + for (size_t d = 0; d < num_directions_per_line; d++) { + worker_cores.all_worker_cores = worker_cores.all_worker_cores.merge(worker_cores.partial_reducers[d]); + } + worker_cores.all_worker_cores = worker_cores.all_worker_cores.merge(worker_cores.final_reducers); + log_trace(tt::LogOp, "Worker cores: ", worker_cores.all_worker_cores); + + worker_cores.all_worker_cores_vec = corerange_to_cores(worker_cores.all_worker_cores, std::nullopt, true); + worker_cores.final_reducers_vec = corerange_to_cores(worker_cores.final_reducers, std::nullopt, true); + for (size_t d = 0; d < num_directions_per_line; d++) { + worker_cores.partial_reducers_vec[d] = corerange_to_cores(worker_cores.partial_reducers[d], std::nullopt, true); + } + + return worker_cores; +} +*/ + +/* + * Returns 1 or 2 core range sets. Typically returns only one but in the case of a line reduce scatter where we are at + * the end of the line, then we must split the core range in half (and return 2), one for each direction where half the + * cores will invoke the ccl::send kernel to implement the start of the line and the others will invoke the typical + * reduce scatter worker kernels. BORROWED FROM REDUCE SCATTER + * TODO: COMMONIZE + */ +static WorkerCoreBundle select_worker_cores(ttnn::ccl::Topology const topology, size_t num_links, Device *device) { + switch (topology) { + case ttnn::ccl::Topology::Linear: return select_worker_cores_for_line_topology(num_links, device); + + case ttnn::ccl::Topology::Ring: + TT_THROW("Ring topology support not yet added to async reduce scatter"); + return WorkerCoreBundle{}; + + default: TT_ASSERT(false, "Unsupported topology"); return WorkerCoreBundle{}; + }; +} + +static size_t compute_math_pages_from_tensor_slices( + std::vector const& tensor_slices, size_t pages_per_cb_packet) { + using namespace ttnn::ccl::cmd; + + auto get_slice_vol = [pages_per_cb_packet](ttnn::ccl::v2::TensorSlice const& slice) { + return round_up(slice.worker_slice_shape.volume(), pages_per_cb_packet); + }; + + size_t total_num_pages = 0; + for (auto const& s : tensor_slices) { + total_num_pages += get_slice_vol(s); + } + + return total_num_pages; +} + +/* + * Returns the reader, math, and writer kernels, respectively + */ +static ReduceScatterKernelHandles build_line_reduce_scatter_worker_ct( + Program& program, + ProgramTensorsBundle const& all_tensors, + ReduceScatterCircularBuffers const& cb_handles, + CoreRangeSet const& worker_core_range, + ttnn::operations::binary::BinaryOpType reduce_op) { + using namespace ttnn::ccl::worker_detail; + + // Summary: + // == READER == + // - First CB: shortcut to writer + // - Second CB: to math (local input) + // - Third CB: to math (remote input) + + static std::string const& receiver_kernel_path = + "ttnn/cpp/ttnn/operations/ccl/common/kernels/ccl_send_reader_two_input.cpp"; + static std::string const& forward_sender_kernel_path = receiver_kernel_path; + static std::string const& reduce_kernel_path = + "ttnn/cpp/ttnn/operations/eltwise/binary/device/kernels/compute/eltwise_binary_kernel.cpp"; + + // Generate the reader kernel + auto input_tensor_ptrs = std::vector{ + all_tensors.input_tensor, + all_tensors.input_tensor_from_remote[0] != nullptr ? all_tensors.input_tensor_from_remote[0] + : all_tensors.input_tensor_from_remote[1]}; + TT_FATAL( + input_tensor_ptrs[0] != nullptr && input_tensor_ptrs[1] != nullptr, + "Internal error. Input tensor pointers are null"); + auto reader_kernel_id = generate_multi_command_stream_kernel_ct_args( + program, + // the CBs don't actuall matter for CT args - they will be removed as CT args in the near future + {cb_handles.reader_to_math_operand1_cb /*cb_handles.reader_to_writer_shortcut_cb*/, + cb_handles.reader_to_math_operand0_cb}, + input_tensor_ptrs, + worker_core_range, + ReaderDataMovementConfig{}); + + // Generate the math/reducer kernel + std::vector compute_kernel_args = {}; + constexpr bool fp32_dest_acc_en = false; + constexpr bool math_approx_mode = false; + std::map eltwise_defines = ttnn::operations::binary::utils::get_defines(reduce_op); + auto math_kernel_id = tt::tt_metal::CreateKernel( + program, + reduce_kernel_path, + worker_core_range, + tt::tt_metal::ComputeConfig{ + .math_fidelity = MathFidelity::HiFi4, + .fp32_dest_acc_en = fp32_dest_acc_en, + .math_approx_mode = math_approx_mode, + .compile_args = compute_kernel_args, + .defines = eltwise_defines}); + + // Generate the sender kernel + auto const output_tensor_ptrs = std::vector{ + all_tensors.remote_output[0] != nullptr ? all_tensors.remote_output[0] : all_tensors.remote_output[1], + all_tensors.local_output_tensor}; + auto sender_kernel_id = generate_multi_command_stream_kernel_ct_args( + program, + {cb_handles.reader_to_writer_shortcut_cb, cb_handles.math_to_writer_cb}, + output_tensor_ptrs, + worker_core_range, + WriterDataMovementConfig{}); + + return ReduceScatterKernelHandles{reader_kernel_id, math_kernel_id, sender_kernel_id}; +} + +static size_t get_page_size(const Tensor& tensor) { + if (tensor.get_layout() == Layout::TILE) { + auto dtype = tt::tt_metal::datatype_to_dataformat_converter(tensor.get_dtype()); + return tensor.get_tensor_spec().tile().get_tile_size(dtype); + } else { + return tensor.buffer()->page_size(); + } +} + +static void validate_final_reducer_reader_worker_slices( + std::vector> const& in0_worker_slices, + std::vector> const& in1_worker_slices, + std::optional const& in0_sync, + std::optional const& in1_sync, + size_t num_workers) { + TT_FATAL(in0_sync.has_value(), "Internal error. Final reducer saw that in0 had not tensor synchronization info"); + TT_FATAL(in1_sync.has_value(), "Internal error. Final reducer saw that in1 had not tensor synchronization info"); + TT_FATAL( + in0_worker_slices.size() == num_workers, + "Internal error. Expected number of worker slices to match number of workers"); + TT_FATAL( + in1_worker_slices.size() == num_workers, + "Internal error. Expected number of worker slices to match number of workers"); + for (size_t w = 0; w < num_workers; w++) { + TT_FATAL(in0_worker_slices[w].size() == 1, "Internal error. Expected only one slice per worker"); + TT_FATAL(in1_worker_slices[w].size() == 1, "Internal error. Expected only one slice per worker"); + } +} + +static void generate_final_reducer_reader_worker_command_streams( + ReduceScatterBuilderConfig& builder_config, + TensorSyncBundle const& partial_output0_tensor_sync_bundle, + TensorSyncBundle const& partial_output1_tensor_sync_bundle, + WorkerCommandStreams& worker_command_streams_out, + std::unordered_map& math_page_counts_out) { + using namespace ttnn::ccl::cmd; + using namespace ttnn::ccl::cmd::uops; + using namespace ttnn::ccl::cmd::builder; + + auto const& all_tensors = builder_config.all_tensors.get(); + auto const& reader_cbs = builder_config.all_cbs.get().final_reducer_reader; + size_t num_partial_reducer_workers = + builder_config.worker_cores.get().partial_reducers[LineDirection::FORWARD].size(); + auto const& worker_cores = builder_config.worker_cores.get().final_reducers_vec; + size_t num_workers = worker_cores.size(); + size_t pages_per_cb_packet = builder_config.pages_per_cb_packet; + + auto const in0_tensor_slice = generate_tensor_slices(1, *partial_output0_tensor_sync_bundle.tensor, 0).at(0); + auto in0_worker_slices = split_tensor_slices_across_workers_page_aligned(num_workers, {in0_tensor_slice}); + auto const in1_tensor_slice = generate_tensor_slices(1, *partial_output1_tensor_sync_bundle.tensor, 0).at(0); + auto in1_worker_slices = split_tensor_slices_across_workers_page_aligned(num_workers, {in1_tensor_slice}); + + auto const& in0_sync = partial_output0_tensor_sync_bundle.sync_spec; + auto const& in1_sync = partial_output1_tensor_sync_bundle.sync_spec; + + validate_final_reducer_reader_worker_slices(in0_worker_slices, in1_worker_slices, in0_sync, in1_sync, num_workers); + for (size_t w = 0; w < num_workers; w++) { + auto const& w_logical = worker_cores[w]; + auto& worker_command_stream0 = worker_command_streams_out.reader_cmds0[w_logical]; + // TODO: Semaphore inc/wait optimization + worker_command_stream0 = { + local_semaphore_wait(in0_sync.value().get_tensor_sync_semaphore(0), num_partial_reducer_workers), + read_tensor_slice_to_cb(in0_worker_slices[w][0], reader_cbs.math_in0)}; + + auto& worker_command_stream1 = worker_command_streams_out.reader_cmds1[w_logical]; + worker_command_stream1 = { + local_semaphore_wait(in1_sync.value().get_tensor_sync_semaphore(0), num_partial_reducer_workers), + read_tensor_slice_to_cb(in1_worker_slices[w][0], reader_cbs.math_in1)}; + + math_page_counts_out[w_logical] = + compute_math_pages_from_tensor_slices(in0_worker_slices[w], pages_per_cb_packet); + } +} + +static void generate_final_reducer_writer_worker_command_streams( + ReduceScatterBuilderConfig& builder_config, + // Should only have populated sync info if fused + TensorSyncBundle const& output_tensor_sync_bundle, + WorkerCommandStreams& worker_command_streams_out) { + using namespace ttnn::ccl::cmd; + using namespace ttnn::ccl::cmd::uops; + using namespace ttnn::ccl::cmd::builder; + + auto from_math_cb = builder_config.all_cbs.get().final_reducer_writer.math_out; + auto const& worker_cores = builder_config.worker_cores.get().final_reducers_vec; + size_t num_workers = worker_cores.size(); + + auto const tensor_slice = generate_tensor_slices(1, *output_tensor_sync_bundle.tensor, 0).at(0); + auto worker_slices = split_tensor_slices_across_workers_page_aligned(num_workers, {tensor_slice}); + + auto const& sync = output_tensor_sync_bundle.sync_spec; + TT_FATAL( + worker_slices.size() == num_workers, + "Internal error. Expected number of worker slices to match number of workers"); + auto& writer_cmds = worker_command_streams_out.writer_cmds0; + for (size_t w = 0; w < num_workers; w++) { + auto const& w_logical = worker_cores[w]; + TT_FATAL(worker_slices[w].size() == 1, "Internal error. Expected only one slice per worker"); + writer_cmds[w_logical].push_back({local_write_cb_to_tensor_slice(worker_slices[w][0], from_math_cb)}); + } +} + +static void compute_math_pages_from_per_worker_tensor_slices( + std::vector> const& worker_slices, + size_t pages_per_cb_packet, + std::vector const& worker_cores, + std::unordered_map& math_page_counts_out) { + for (size_t w = 0; w < worker_slices.size(); w++) { + auto const& w_logical = worker_cores[w]; + auto const& slices = worker_slices[w]; + math_page_counts_out[w_logical] = compute_math_pages_from_tensor_slices(slices, pages_per_cb_packet); + } +} + +// More efficient implementation is to do the splitting outside but we'll do that after we have something working +// Outer index is per worker, inner is each command stream (0 and 1 respectively for that worker) +// second result is total number of pages cycled through the CBs +static void generate_partial_reducer_reader_worker_command_streams( + ReduceScatterBuilderConfig& builder_config, + std::optional const& in0_tensor_sync, + std::optional const& in1_tensor_sync, + // Same for both operands + std::vector> const& worker_tensor_slices, + std::vector const& worker_cores, + WorkerCommandStreams& worker_command_streams_out, + bool skip_math_for_last_slice) { + using namespace ttnn::ccl::cmd; + using namespace ttnn::ccl::cmd::uops; + using namespace ttnn::ccl::cmd::builder; + + auto const& reader_cbs = builder_config.all_cbs.get().partial_reducer_reader; + auto const& topology_config = builder_config.topology_config.get(); + + const size_t num_workers = worker_cores.size(); + log_trace( + tt::LogOp, "generate_partial_reducer_reader_worker_command_streams. topologyu: {}", topology_config.topology()); + + bool in0_async_mode_specified = in0_tensor_sync.has_value(); + bool in1_async_mode_specified = in1_tensor_sync.has_value(); + TT_FATAL(in1_async_mode_specified, "Internal error. Expected input tensor sync to be populated"); + auto const& from_remote_input_tensor_sync = in1_tensor_sync; + TT_FATAL( + worker_tensor_slices.size() == num_workers, + "Internal error. Expected number of worker slices to match number of workers"); + auto get_cb_base = [](size_t slice_index, ttnn::ccl::Topology topology, uint32_t idx0_cb, uint32_t default_cb) { + if (topology == ttnn::ccl::Topology::Linear) { + return default_cb; + } else { + return slice_index == 0 ? idx0_cb : default_cb; + } + }; + auto get_cb = std::bind( + get_cb_base, std::placeholders::_1, topology_config.topology(), reader_cbs.pass_through, reader_cbs.math_in0); + + for (size_t w = 0; w < num_workers; w++) { + auto const& w_logical = worker_cores[w]; + { + auto& worker_command_stream0 = worker_command_streams_out.reader_cmds0[w_logical]; + for (size_t i = 0; i < worker_tensor_slices[w].size(); i++) { + bool last_slice = i == worker_tensor_slices[w].size() - 1; + auto const& s = worker_tensor_slices[w][i]; + if (in0_tensor_sync.has_value()) { + // NOTE: per-worker sync + worker_command_stream0.push_back( + local_semaphore_wait(in0_tensor_sync.value().get_tensor_sync_semaphore(w), i + 1)); + } + if (last_slice) { + // Make sure not to add the space at the beginning of the CB chunk for packet header + // so when we write out from the other side, we maintain proper alignment + if (!skip_math_for_last_slice) { + // for linear topology, one of the direction mustn't do a partial reduce for it's last + // input chunk with the `input_tensor` otherwise `input_tensor` for that chunk will be accumulated + // twice. We arbitrarily choose the forward direction as the one that will not partial reduce + // for the last input chunk + worker_command_stream0.push_back(read_tensor_slice_to_cb(s, get_cb(i))); + } + } else { + worker_command_stream0.push_back(read_tensor_slice_to_cb_for_eventual_fabric_write(s, get_cb(i))); + } + } + } + { + auto& worker_command_stream1 = worker_command_streams_out.reader_cmds1[w_logical]; + for (size_t i = 0; i < worker_tensor_slices[w].size(); i++) { + bool last_slice = i == worker_tensor_slices[w].size() - 1; + auto const& s = worker_tensor_slices[w][i]; + worker_command_stream1.push_back( + local_semaphore_wait(from_remote_input_tensor_sync.value().get_tensor_sync_semaphore(w), i + 1)); + if (last_slice) { + worker_command_stream1.push_back(read_tensor_slice_to_cb(s, skip_math_for_last_slice ? reader_cbs.pass_through : reader_cbs.math_in1)); + } else { + worker_command_stream1.push_back( + read_tensor_slice_to_cb_for_eventual_fabric_write(s, reader_cbs.math_in1)); + } + } + } + } +} + +static void generate_partial_reducer_writer_worker_command_streams( + ReduceScatterBuilderConfig& builder_config, + TensorSyncBundle const& remote_output_tensor_sync_bundle, + TensorSyncBundle const& local_partial_output_tensor_sync_bundle, + std::vector> const& remote_out_worker_tensor_slices, + LineDirection direction, + WorkerCommandStreams& worker_command_streams, + bool skip_math_for_last_slice) { + auto const& topology_config = builder_config.topology_config.get(); + auto const& worker_cores = builder_config.worker_cores.get().partial_reducers[direction]; + auto const& worker_cores_vec = builder_config.worker_cores.get().partial_reducers_vec[direction]; + size_t num_devices = topology_config.line_size(); + bool is_forward_direction = direction == LineDirection::FORWARD; + + log_trace( + tt::LogOp, + "generate_partial_reducer_writer_worker_command_streams. topologyu: {}, num_devices: {}", + topology_config.topology(), + num_devices); + + using namespace ttnn::ccl::cmd; + using namespace ttnn::ccl::cmd::uops; + using namespace ttnn::ccl::cmd::builder; + + auto const& writer_cbs = builder_config.all_cbs.get().partial_reducer_writer; + TT_FATAL( + local_partial_output_tensor_sync_bundle.sync_spec.has_value(), + "Internal error. Expected local partial output tensor to have synchronization info"); + // Since Command processor currently doesn't support switching between tensors within a single command stream + // (future work), we split into two command streams, with each one assigned to one of the two output tensors: + // 0. Remote output tensor + // 1. Local output tensor + // + // After all slices have been forwarded to the remote chip, then the command streams synchronize with each other + // to indicate that the "from math" CB can be read from + + const size_t num_workers = worker_cores.num_cores(); + + auto const local_partial_output_tensor_slice = + convert_to_whole_tensor_slice(*local_partial_output_tensor_sync_bundle.tensor); + auto const local_output_tensor_slices_per_worker = + split_tensor_slices_across_workers_page_aligned(num_workers, {local_partial_output_tensor_slice}); + TT_FATAL( + local_output_tensor_slices_per_worker.size() == num_workers, + "Local output tensor slices per worker size mismatch"); + TT_FATAL( + remote_out_worker_tensor_slices.size() == num_workers, "Remote output tensor slices per worker size mismatch"); + + auto get_cb_base = [](size_t slice_index, ttnn::ccl::Topology topology, uint32_t idx0_cb, uint32_t default_cb) { + if (topology == ttnn::ccl::Topology::Linear) { + return default_cb; + } else { + return slice_index == 0 ? idx0_cb : default_cb; + } + }; + log_trace( + tt::LogOp, + "\t\t\twriter_cbs.pass_through: {}, writer_cbs.math_out: {}", + writer_cbs.pass_through, + writer_cbs.math_out); + auto get_cb = std::bind( + get_cb_base, std::placeholders::_1, topology_config.topology(), writer_cbs.pass_through, writer_cbs.math_out); + + TT_FATAL( + remote_output_tensor_sync_bundle.sync_spec.has_value(), + "Internal error. Expected remote output tensor to have synchronization info"); + auto const& remote_out_tensor_sync = remote_output_tensor_sync_bundle.sync_spec.value(); + + std::vector> writer_command_streams_per_worker; + auto const next_chip_fabric_unicast = UnicastCommandDestArgs{1, is_forward_direction}; + auto internal_command_stream_sync_sem_id = CreateSemaphore(builder_config.program.get(), worker_cores, 0); + for (size_t w = 0; w < num_workers; w++) { + { // Command stream 0 + const size_t operand_index = 0; + auto& worker_command_stream0 = worker_command_streams.writer_cmds0[worker_cores_vec[w]]; + for (size_t i = 0; i < remote_out_worker_tensor_slices[w].size(); i++) { + auto const& s = remote_out_worker_tensor_slices[w][i]; + log_debug( + tt::LogOp, + "Worker {} Writer Kernel cmds0[{}]: tensor_slice: (.shape=(w={},z={},y={},x={}), " + ".slice_shape=(w={},z={},y={},x={})), .slice_offset=(w={},z={},y={},x={}), " + ".worker_slice_shape=(w={},z={},y={},x={}), .worker_slice_offset=(w={},z={},y={},x={}), cb_id={}", + w, + 2 * i, + s.tensor_slice_shape.w, + s.tensor_slice_shape.z, + s.tensor_slice_shape.y, + s.tensor_slice_shape.x, + s.tensor_slice_shape.w, + s.tensor_slice_shape.z, + s.tensor_slice_shape.y, + s.tensor_slice_shape.x, + s.tensor_slice_offset.w, + s.tensor_slice_offset.z, + s.tensor_slice_offset.y, + s.tensor_slice_offset.x, + s.worker_slice_shape.w, + s.worker_slice_shape.z, + s.worker_slice_shape.y, + s.worker_slice_shape.x, + s.worker_slice_offset.w, + s.worker_slice_offset.z, + s.worker_slice_offset.y, + s.worker_slice_offset.x, + get_cb(i)); + + worker_command_stream0.push_back( + fabric_write_cb_to_tensor_slice(s, get_cb(i), next_chip_fabric_unicast)); + + // remote_out_tensor_sync + worker_command_stream0.push_back(fabric_unicast_semaphore_inc_mcast( + // For now we assume the semaphores are consistent across chips + // though this may not be generally safe - it should be for the initial + // cases we care about + // NOTE: per worker semaphore + remote_out_tensor_sync.get_tensor_sync_semaphore(w), + CclCommandAtomicInc{1}, + remote_out_tensor_sync.get_target(w), + next_chip_fabric_unicast) + + ); + } + // Finish off by notifying the other command stream that it's safe for it to pull from the + // "from math" CB + worker_command_stream0.push_back(local_core_semaphore_inc(internal_command_stream_sync_sem_id, 1)); + } + { // Command stream 1 + const size_t operand_index = 1; + auto& worker_command_stream1 = worker_command_streams.writer_cmds1[worker_cores_vec[w]]; + + TT_FATAL( + local_output_tensor_slices_per_worker[w].size() == 1, + "Local output tensor expected only to have a single tensor slice"); + // Wait for all-clear from first command stream that "from math" CB is no longer being pulled from + // Then it's safe to forward to fabric from CB + + std::ranges::copy( + CclHostLowLevelCommandSequence{ + local_semaphore_wait(internal_command_stream_sync_sem_id, 1), + local_write_cb_to_tensor_slice(local_output_tensor_slices_per_worker[w][0], skip_math_for_last_slice ? writer_cbs.pass_through : writer_cbs.math_out), + local_chip_semaphore_inc_mcast( + // NOTE: Per worker semaphores + local_partial_output_tensor_sync_bundle.sync_spec.value().get_tensor_sync_semaphore(w), + CclCommandAtomicInc{1}, + local_partial_output_tensor_sync_bundle.sync_spec.value().get_target(w))}, + std::back_inserter(worker_command_stream1)); + } + } +} + +// TODO: optimize to have set block_size == packet_size +static std::vector generate_reduce_op_kernel_rt_args(size_t total_num_math_pages) { + auto const& args = std::vector{total_num_math_pages, 1}; + + std::size_t i = 0; + log_trace(tt::LogOp, "\tReduce Scatter Worker RT Args:"); + log_trace(tt::LogOp, "\t\tblock_size: {}", args.at(i++)); + log_trace(tt::LogOp, "\t\ttotal_num_math_pages: {}", args.at(i++)); + TT_ASSERT(args.size() == i, "Missed some args"); + + return args; +} + +static void set_math_runtime_args( + Program& program, KernelHandle math_kernel_id, CoreCoord const& worker_logical, size_t total_num_math_pages) { + log_trace(tt::LogOp, "Setting math kernel RT args"); + auto rt_args = generate_reduce_op_kernel_rt_args(total_num_math_pages); + tt::tt_metal::SetRuntimeArgs(program, math_kernel_id, worker_logical, rt_args); +} + + +static void create_non_end_of_line_final_reducer_worker_commands( + ReduceScatterBuilderConfig& builder_config, + WorkerCommandStreams& worker_command_streams_out, + std::unordered_map& math_page_counts_out) { + auto const& final_reducer_worker_cores = builder_config.worker_cores.get().final_reducers_vec; + auto const& all_program_tensors = builder_config.all_tensors.get(); + auto const& all_cbs = builder_config.all_cbs.get(); + log_trace(tt::LogOp, "--------------------------------------"); + log_trace(tt::LogOp, "CREATE WORKER (final reducer - not end. Device={})", builder_config.device->id()); + + size_t const num_partial_reducer_workers_per_direction = + builder_config.worker_cores.get().partial_reducers[LineDirection::FORWARD].size(); + + std::array const& partial_output_tensor_sync_bundles = { + TensorSyncBundle{ + all_program_tensors.local_output_partial[LineDirection::FORWARD], + all_program_tensors.local_output_partial_sync[LineDirection::FORWARD]}, + TensorSyncBundle{ + all_program_tensors.local_output_partial[LineDirection::BACKWARD], + all_program_tensors.local_output_partial_sync[LineDirection::BACKWARD]}, + }; + + generate_final_reducer_reader_worker_command_streams( + builder_config, + partial_output_tensor_sync_bundles[LineDirection::FORWARD], + partial_output_tensor_sync_bundles[LineDirection::BACKWARD], + worker_command_streams_out, + math_page_counts_out); + + generate_final_reducer_writer_worker_command_streams( + builder_config, + TensorSyncBundle{all_program_tensors.local_output_tensor, all_program_tensors.local_output_sync}, + worker_command_streams_out); + + TT_FATAL(final_reducer_worker_cores.size() > 0, "Internal error. No final reducer cores were created"); +} + +static void populate_partial_reduce_worker_commands( + ReduceScatterBuilderConfig& builder_config, + + std::array>, 2> const& reader_worker_slices_by_direction, + std::array>, 2> const& writer_worker_slices_by_direction, + + WorkerCommandStreams& worker_command_streams_out, + std::unordered_map& math_page_counts_out) { + auto const& partial_reducer_worker_cores = builder_config.worker_cores.get().partial_reducers_vec; + auto const& all_tensors = builder_config.all_tensors.get(); + auto const& all_cbs = builder_config.all_cbs.get(); + auto const& topology_config = builder_config.topology_config.get(); + auto const& kernel_ids = builder_config.kernel_ids.get(); + log_trace(tt::LogOp, "--------------------------------------"); + log_trace(tt::LogOp, "CREATE WORKER (partial reducer - not end. Device={})", builder_config.device->id()); + + std::array, 2> partial_reducer_worker_cores_vec = { + partial_reducer_worker_cores[LineDirection::FORWARD], partial_reducer_worker_cores[LineDirection::BACKWARD]}; + + std::vector> slices_through_math_forward_direction; + slices_through_math_forward_direction.reserve(reader_worker_slices_by_direction[LineDirection::FORWARD].size()); + // For line topology, we don't want to partial reduce for input_tensor for the last input chunk, otherwise we will + // end up reducing with that chunk of `input_tensor` twice. + for (auto const& slices : reader_worker_slices_by_direction[LineDirection::FORWARD]) { + TT_FATAL(slices.size() > 1, "Internal error. Expected at least two slices"); + slices_through_math_forward_direction.push_back(vslice(slices, 0, slices.size() - 2)); + } + + compute_math_pages_from_per_worker_tensor_slices( + slices_through_math_forward_direction,//reader_worker_slices_by_direction[LineDirection::FORWARD], + builder_config.pages_per_cb_packet, + partial_reducer_worker_cores_vec[LineDirection::FORWARD], + math_page_counts_out); + compute_math_pages_from_per_worker_tensor_slices( + reader_worker_slices_by_direction[LineDirection::BACKWARD], + builder_config.pages_per_cb_packet, + partial_reducer_worker_cores_vec[LineDirection::BACKWARD], + math_page_counts_out); + + for (auto line_direction : {LineDirection::FORWARD, LineDirection::BACKWARD}) { + // Logic for any chip in the "middle" of the line + bool is_forward_direction = line_direction == LineDirection::FORWARD; + generate_partial_reducer_reader_worker_command_streams( + builder_config, + all_tensors.input_tensor_sync, + all_tensors.input_tensor_from_remote_sync[line_direction], + reader_worker_slices_by_direction[line_direction], + partial_reducer_worker_cores_vec[line_direction], + worker_command_streams_out, + is_forward_direction); + + generate_partial_reducer_writer_worker_command_streams( + builder_config, + TensorSyncBundle{all_tensors.remote_output[line_direction], all_tensors.remote_output_sync[line_direction]}, + TensorSyncBundle{ + all_tensors.local_output_partial[line_direction], + all_tensors.local_output_partial_sync[line_direction]}, + writer_worker_slices_by_direction[line_direction], + line_direction, + worker_command_streams_out, + is_forward_direction); + } +} + +static void create_final_reducer_worker_rt_args_not_end_of_line( + ReduceScatterBuilderConfig& builder_config, + fabric_lifetime_mode fabric_mode, + WorkerCommandStreams& worker_command_streams_out, + std::unordered_map& math_page_counts_out) { + using namespace ttnn::ccl::worker_detail; + + auto const& final_reducer_worker_cores = builder_config.worker_cores.get().final_reducers_vec; + auto const& all_program_tensors = builder_config.all_tensors.get(); + + for (size_t i = 0; i < final_reducer_worker_cores.size(); i++) { + auto const& w_logical = final_reducer_worker_cores[i]; + generate_multi_input_command_stream_kernel_rt_args( + builder_config.program, + builder_config.kernel_ids.get().reader, + {all_program_tensors.local_output_partial[LineDirection::FORWARD], + all_program_tensors.local_output_partial[LineDirection::BACKWARD]}, + {builder_config.page_size, builder_config.page_size}, + builder_config.device, + builder_config.pages_per_cb_packet, + {w_logical}, + worker_command_streams_out.reader_cmds0.at(w_logical), + worker_command_streams_out.reader_cmds1.at(w_logical), + std::nullopt, + std::nullopt); + set_math_runtime_args( + builder_config.program, + builder_config.kernel_ids.get().math, + w_logical, + math_page_counts_out.at(w_logical)); + generate_multi_input_command_stream_kernel_rt_args( + builder_config.program, + builder_config.kernel_ids.get().writer, + {all_program_tensors.local_output_tensor, nullptr}, + {builder_config.page_size, builder_config.page_size}, + builder_config.device, + builder_config.pages_per_cb_packet, + {w_logical}, + worker_command_streams_out.writer_cmds0.at(w_logical), + ttnn::ccl::cmd::CclHostLowLevelCommandSequence{}, + std::nullopt, + std::nullopt); + } +} + +static void populate_partial_reduce_rt_args( + ReduceScatterBuilderConfig& builder_config, + + WorkerCommandStreams& worker_command_streams_out, + std::unordered_map& math_page_counts_out) { + using namespace ttnn::ccl::worker_detail; + using Direction = ttnn::ccl::EdmLineFabricOpInterface::Direction; + + auto& fabric = builder_config.fabric.get(); + auto const& all_tensors = builder_config.all_tensors.get(); + auto const& kernel_ids = builder_config.kernel_ids.get(); + auto device = builder_config.device; + + auto get_fabric_connection = [&device, &fabric](bool is_connected, Direction dir) { + return is_connected + ? std::make_optional(fabric.uniquely_connect_worker(device, dir)) + : std::nullopt; + }; + + auto const& partial_reducer_worker_cores = builder_config.worker_cores.get().partial_reducers_vec; + std::array, 2> partial_reducer_worker_cores_vec = { + partial_reducer_worker_cores[LineDirection::FORWARD], partial_reducer_worker_cores[LineDirection::BACKWARD]}; + + for (auto line_direction : {LineDirection::FORWARD, LineDirection::BACKWARD}) { + bool is_forward_direction = line_direction == LineDirection::FORWARD; + auto fwd_fabric_connection = get_fabric_connection(is_forward_direction, Direction::FORWARD); + auto bwd_fabric_connection = get_fabric_connection(!is_forward_direction, Direction::BACKWARD); + + for (size_t i = 0; i < partial_reducer_worker_cores_vec[line_direction].size(); i++) { + auto const& w_logical = partial_reducer_worker_cores_vec[line_direction][i]; + // Reader kernel RT args + generate_multi_input_command_stream_kernel_rt_args( + builder_config.program.get(), + kernel_ids.reader, + std::vector{ + all_tensors.input_tensor, all_tensors.input_tensor_from_remote[line_direction]}, + {builder_config.page_size, builder_config.page_size}, + builder_config.device, + builder_config.pages_per_cb_packet, // TODO: get from fabric + {w_logical}, + worker_command_streams_out.reader_cmds0.at(w_logical), + worker_command_streams_out.reader_cmds1.at(w_logical), + std::nullopt, + std::nullopt); + set_math_runtime_args( + builder_config.program.get(), kernel_ids.math, w_logical, math_page_counts_out[w_logical]); + auto output_tensor_ptrs = std::vector{ + all_tensors.remote_output[line_direction], all_tensors.local_output_partial[line_direction]}; + generate_multi_input_command_stream_kernel_rt_args( + builder_config.program.get(), + kernel_ids.writer, + output_tensor_ptrs, + {builder_config.page_size, builder_config.page_size}, + builder_config.device, + builder_config.pages_per_cb_packet, // TODO: get from fabric + {w_logical}, + worker_command_streams_out.writer_cmds0.at(w_logical), + worker_command_streams_out.writer_cmds1.at(w_logical), + fwd_fabric_connection, + bwd_fabric_connection, + std::unordered_map{ + {output_tensor_ptrs[0], + line_direction == LineDirection::FORWARD ? builder_config.forward_device + : builder_config.backward_device}}); + } + ////////////// + } +} + +static void create_worker_runtime_args_for_inactive_workers(ReduceScatterBuilderConfig& builder_config) { + auto const& inactive_cores = builder_config.worker_cores.get().final_reducers; + using namespace ttnn::ccl::worker_detail; + log_trace(tt::LogOp, "--------------------------------------"); + log_trace(tt::LogOp, "CREATE WORKER (inactive - not end. Device={})", builder_config.device->id()); + + generate_multi_input_command_stream_kernel_rt_args( + builder_config.program.get(), + builder_config.kernel_ids.get().reader, + {nullptr, nullptr}, + {0, 0}, + builder_config.device, + 0, // TODO: get from fabric + inactive_cores, + ttnn::ccl::cmd::CclHostLowLevelCommandSequence{}, + ttnn::ccl::cmd::CclHostLowLevelCommandSequence{}, + std::nullopt, + std::nullopt); + + tt::tt_metal::SetRuntimeArgs( + builder_config.program.get(), + builder_config.kernel_ids.get().math, + inactive_cores, + generate_reduce_op_kernel_rt_args(0)); + + generate_multi_input_command_stream_kernel_rt_args( + builder_config.program.get(), + builder_config.kernel_ids.get().writer, + {nullptr, nullptr}, + {0, 0}, + builder_config.device, + 0, // TODO: get from fabric + inactive_cores, + ttnn::ccl::cmd::CclHostLowLevelCommandSequence{}, + ttnn::ccl::cmd::CclHostLowLevelCommandSequence{}, + std::nullopt, + std::nullopt); +} + +static void validate_end_of_line_worker_tensors( + ReduceScatterBuilderConfig& builder_config, fabric_lifetime_mode fabric_mode) { + ProgramTensorsBundle const& all_tensors = builder_config.all_tensors.get(); + LineTopology const& line_topology = builder_config.topology_config.get(); + bool teardown_fabric = fabric_mode == fabric_lifetime_mode::TRANSIENT; + + TT_FATAL(all_tensors.input_tensor != nullptr, "Input tensor must be populated"); + TT_FATAL(all_tensors.local_output_tensor != nullptr, "Output tensor must be populated"); + if (line_topology.is_first_device_in_line(LineDirection::FORWARD)) { + TT_FATAL( + all_tensors.input_tensor_from_remote[LineDirection::FORWARD] == nullptr, + "Input tensor from remote must be populated"); + TT_FATAL( + all_tensors.input_tensor_from_remote[LineDirection::BACKWARD] != nullptr, + "Input tensor from remote must be populated"); + TT_FATAL( + all_tensors.input_tensor->shape() == all_tensors.input_tensor_from_remote[LineDirection::BACKWARD]->shape(), + "Input tensor and input from remote tensor must have the same shape"); + } + if (line_topology.is_first_device_in_line(LineDirection::BACKWARD)) { + TT_FATAL( + all_tensors.input_tensor_from_remote[LineDirection::BACKWARD] == nullptr, + "Input tensor from remote must be populated"); + TT_FATAL( + all_tensors.input_tensor_from_remote[LineDirection::FORWARD] != nullptr, + "Input tensor from remote must be populated"); + TT_FATAL( + all_tensors.input_tensor->shape() == all_tensors.input_tensor_from_remote[LineDirection::FORWARD]->shape(), + "Input tensor and input from remote tensor must have the same shape"); + } +} + +static void create_end_of_line_worker_commands( + ReduceScatterBuilderConfig& builder_config, + std::unordered_map& worker_math_page_counts_out, + WorkerCommandStreams& worker_command_streams_out) { + // using namespace ttnn::ccl::worker_detail; + using namespace ttnn::ccl::cmd; + using namespace ttnn::ccl::cmd::uops; + using namespace ttnn::ccl::cmd::builder; + auto const& topology_config = builder_config.topology_config.get(); + auto const& worker_cores = builder_config.worker_cores.get(); + auto const& all_tensors = builder_config.all_tensors.get(); + auto const& all_cbs = builder_config.all_cbs.get(); + + size_t nchips = builder_config.topology_config.get().line_size(); + size_t curr_chip = builder_config.topology_config.get().line_index(); + auto num_workers = worker_cores.partial_reducers_vec[LineDirection::FORWARD].size(); + + TT_FATAL( + worker_cores.partial_reducers_vec[LineDirection::BACKWARD].size() == num_workers, + "Internal error. Expected number of workers to match"); + // out_slices = partial_out_tensor.chunk(n=line_size,dim=dim) + // out_slices_fwd = reverse(out_slices[line_topology.line_index() + 1:]) + // worker_out_slices_fwd = distribute_across_workers(out_slices_fwd) + // out_slices_bwd = out_slices[:line_topology.line_index() + 1] // assuming exclusive end + // worker_out_slices_bwd = distribute_across_workers(out_slices_bwd, n_workers) + auto const reader_in_slices = + generate_tensor_slices(nchips, *builder_config.all_tensors.get().input_tensor, builder_config.dim); + + auto reader_slices_fwd = + vslice(reader_in_slices, reader_in_slices.size() - 1, std::min(curr_chip + 1, reader_in_slices.size() - 1)); + auto reader_slices_bwd = + vslice(reader_in_slices, 0, curr_chip - !topology_config.is_first_device_in_line(LineDirection::FORWARD)); + auto remote_writer_slices_fwd = + vslice(reader_in_slices, reader_in_slices.size() - 1, std::min(curr_chip + 1, reader_in_slices.size() - 1)); + auto remote_writer_slices_bwd = + vslice(reader_in_slices, 0, curr_chip - !topology_config.is_first_device_in_line(LineDirection::FORWARD)); + + auto reader_worker_sliced_fwd = split_tensor_slices_across_workers_page_aligned(num_workers, reader_slices_fwd); + auto reader_worker_sliced_bwd = split_tensor_slices_across_workers_page_aligned(num_workers, reader_slices_bwd); + auto remote_writer_worker_sliced_fwd = + split_tensor_slices_across_workers_page_aligned(num_workers, remote_writer_slices_fwd); + auto remote_writer_worker_sliced_bwd = + split_tensor_slices_across_workers_page_aligned(num_workers, remote_writer_slices_bwd); + + std::array reader_worker_slices = { + reader_worker_sliced_fwd, reader_worker_sliced_bwd}; + std::array remote_writer_worker_slices = { + remote_writer_worker_sliced_fwd, remote_writer_worker_sliced_bwd}; + + std::array, 2> const reader_worker_cores_per_direction = worker_cores.partial_reducers_vec; + std::array, 2> const& writer_worker_cores_per_direction = reader_worker_cores_per_direction; + + auto const local_partial_output_tensor_slice = convert_to_whole_tensor_slice(*all_tensors.local_output_tensor); + auto writer_end_of_line_output_worker_slices = + split_tensor_slices_across_workers_page_aligned(num_workers, {local_partial_output_tensor_slice}); + TT_FATAL( + writer_end_of_line_output_worker_slices.size() == num_workers, + "Internal error. Expected number of end of line worker slices to match number of workers. Got {} but expected " + "{}", + writer_end_of_line_output_worker_slices.size(), + num_workers); + + for (auto direction : {LineDirection::FORWARD, LineDirection::BACKWARD}) { + bool is_forward_direction = direction == LineDirection::FORWARD; + bool is_start_of_line = topology_config.is_first_device_in_line(direction); + + auto const& reader_worker_cores = reader_worker_cores_per_direction[direction]; + TT_FATAL( + reader_worker_cores.size() == num_workers, + "Internal error. Expected number of reader worker cores to match number of workers. Got {} but expected {}", + reader_worker_cores.size(), + num_workers); + + std::vector> worker_in0_cmd_stream(num_workers); + std::optional>> worker_in1_cmd_stream; + std::vector> worker_out0_cmd_stream(num_workers); + TT_FATAL( + reader_worker_slices[direction].size() == num_workers, + "Internal error. Expected number of reader worker slices to match number of workers. Got {} but expected " + "{}", + reader_worker_slices[direction].size(), + num_workers); + TT_FATAL( + reader_worker_slices[direction].size() == num_workers, + "Internal error. Expected number of writer worker slices to match number of workers. Got {} but expected " + "{}", + reader_worker_slices[direction].size(), + num_workers); + if (!is_start_of_line) { + worker_in1_cmd_stream = std::vector>(num_workers); + } + for (size_t i = 0; i < num_workers; i++) { + auto const& w_logical = reader_worker_cores[i]; + auto& in0_cmd_stream = worker_command_streams_out.reader_cmds0[w_logical]; // worker_in0_cmd_stream[i]; + auto& out0_cmd_stream = worker_command_streams_out.writer_cmds0[w_logical]; // worker_out0_cmd_stream[i]; + auto& in1_cmd_stream = worker_command_streams_out.reader_cmds1[w_logical]; + + size_t num_math_pages = 0; + if (is_start_of_line) { + for (auto const& slice : reader_worker_slices[direction][i]) { + in0_cmd_stream.push_back(read_tensor_slice_to_cb_for_eventual_fabric_write( + slice, all_cbs.line_start_reader.pass_through)); + } + + for (size_t s = 0; s < remote_writer_worker_slices[direction][i].size(); s++) { + auto const& slice = remote_writer_worker_slices[direction][i][s]; + out0_cmd_stream.push_back(fabric_write_cb_to_tensor_slice( + slice, + all_cbs.line_start_writer.pass_through, + UnicastCommandDestArgs{1, direction == LineDirection::FORWARD})); + out0_cmd_stream.push_back(fabric_unicast_semaphore_inc_mcast( + // NOTE: per worker semaphores + all_tensors.remote_output_sync.at(direction).get_tensor_sync_semaphore(i), + CclCommandAtomicInc{1}, + all_tensors.remote_output_sync.at(direction).get_target(i), + UnicastCommandDestArgs{1, direction == LineDirection::FORWARD})); + } + } else { + auto const& worker_in_slices = reader_worker_slices.at(direction).at(i); + // READER COMMANDS + auto const& from_remote_sync = direction == LineDirection::FORWARD + ? all_tensors.input_tensor_from_remote_sync[LineDirection::FORWARD] + : all_tensors.input_tensor_from_remote_sync[LineDirection::BACKWARD]; + TT_FATAL(worker_in_slices.size() == 1, "Internal error. Expected only one slice per worker"); + in0_cmd_stream.push_back( + read_tensor_slice_to_cb(worker_in_slices[0], all_cbs.line_end_reader.math_in0)); + // NOTE: per worker semaphore + in1_cmd_stream.push_back(local_semaphore_wait(from_remote_sync.get_tensor_sync_semaphore(0), 1)); + in1_cmd_stream.push_back( + read_tensor_slice_to_cb(worker_in_slices.at(0), all_cbs.line_end_reader.math_in1)); + + // MATH PAGE COUNTS + num_math_pages = + compute_math_pages_from_tensor_slices(worker_in_slices, builder_config.pages_per_cb_packet); + + // WRITER COMMANDS + TT_FATAL( + writer_end_of_line_output_worker_slices[i].size() == 1, + "Internal error. Expected only one slice per worker"); + out0_cmd_stream.push_back(local_write_cb_to_tensor_slice( + writer_end_of_line_output_worker_slices[i][0], all_cbs.line_end_writer.math_out)); + } + + worker_math_page_counts_out[w_logical] = num_math_pages; + } + } +} + +// Maybe reusable for all configurations +static void create_end_of_line_worker_runtime_args( + ReduceScatterBuilderConfig& builder_config, + WorkerCommandStreams& worker_command_streams, + std::unordered_map const& worker_math_page_counts) { + using namespace ttnn::ccl::worker_detail; + using namespace ttnn::ccl::cmd; + using Direction = ttnn::ccl::EdmLineFabricOpInterface::Direction; + Program& program = builder_config.program.get(); + Device* device = builder_config.device; + ttnn::ccl::EdmLineFabricOpInterface& fabric = builder_config.fabric.get(); + ProgramTensorsBundle const& all_tensors = builder_config.all_tensors.get(); + ReduceScatterKernelHandles const& kernel_ids = builder_config.kernel_ids.get(); + WorkerCoreBundle const& worker_cores = builder_config.worker_cores.get(); + + auto get_fabric_connection = [&device, &fabric](bool is_connected, Direction dir) { + return is_connected + ? std::make_optional(fabric.uniquely_connect_worker(device, dir)) + : std::nullopt; + }; + + std::array, 2> const reader_worker_cores_per_direction = worker_cores.partial_reducers_vec; + std::array, 2> const& writer_worker_cores_per_direction = reader_worker_cores_per_direction; + auto num_workers = worker_cores.partial_reducers_vec[LineDirection::FORWARD].size(); + + // Generate the kernels themselves + for (auto direction : {LineDirection::FORWARD, LineDirection::BACKWARD}) { + bool is_start_of_line = builder_config.topology_config.get().is_first_device_in_line(direction); + auto const& reader_worker_cores = reader_worker_cores_per_direction[direction]; + bool is_forward_direction = direction == LineDirection::FORWARD; + + auto fwd_fabric_connection = + get_fabric_connection(is_forward_direction && is_start_of_line, Direction::FORWARD); + auto bwd_fabric_connection = + get_fabric_connection(!is_forward_direction && is_start_of_line, Direction::BACKWARD); + + Tensor* output_tensor_ptr = nullptr; + auto input_tensor_ptrs = std::vector{nullptr, nullptr}; + input_tensor_ptrs[0] = all_tensors.input_tensor; + + if (is_start_of_line) { + output_tensor_ptr = all_tensors.remote_output[direction]; + } else { + output_tensor_ptr = all_tensors.local_output_tensor; + input_tensor_ptrs[1] = all_tensors.input_tensor_from_remote.at(direction); + TT_FATAL(input_tensor_ptrs[1] != nullptr, "Internal error. Expected input tensor to be populated"); + } + + for (size_t i = 0; i < num_workers; i++) { + auto const& w_logical = reader_worker_cores[i]; + size_t num_math_pages = is_start_of_line ? 0 : worker_math_page_counts.at(w_logical); + + TT_FATAL(output_tensor_ptr != nullptr, "Internal error. Expected output tensor to be populated"); + TT_FATAL(input_tensor_ptrs[0] != nullptr, "Internal error. Expected input tensor to be populated"); + TT_FATAL( + worker_command_streams.reader_cmds0.find(w_logical) != worker_command_streams.reader_cmds0.end(), + "Internal error. Expected reader command stream to be populated"); + bool has_in1_commands = + worker_command_streams.reader_cmds1.find(w_logical) != worker_command_streams.reader_cmds1.end(); + generate_multi_input_command_stream_kernel_rt_args( + program, + kernel_ids.reader, + input_tensor_ptrs, + {builder_config.page_size, builder_config.page_size}, + device, + builder_config.pages_per_cb_packet, + {w_logical}, + worker_command_streams.reader_cmds0.at(w_logical), + has_in1_commands ? worker_command_streams.reader_cmds1.at(w_logical) + : std::vector{}, + std::nullopt, + std::nullopt); + set_math_runtime_args(program, kernel_ids.math, w_logical, num_math_pages); + generate_multi_input_command_stream_kernel_rt_args( + program, + kernel_ids.writer, + {output_tensor_ptr, nullptr}, + {builder_config.page_size, builder_config.page_size}, + device, + builder_config.pages_per_cb_packet, + {w_logical}, + worker_command_streams.writer_cmds0.at(w_logical), + std::vector{}, + fwd_fabric_connection, + bwd_fabric_connection); + } + } +} + + + +static void create_end_of_line_worker_commands( + ReduceScatterBuilderConfig& builder_config, + fabric_lifetime_mode fabric_mode, + WorkerCommandStreams& worker_command_streams, + std::unordered_map& worker_math_page_counts) { + using namespace ttnn::ccl::worker_detail; + using namespace ttnn::ccl::cmd; + using namespace ttnn::ccl::cmd::uops; + using namespace ttnn::ccl::cmd::builder; + + validate_end_of_line_worker_tensors(builder_config, fabric_mode); + + log_trace(tt::LogOp, "--------------------------------------"); + log_trace(tt::LogOp, "CREATE WORKER (end of line Device={})", builder_config.device->id()); + + create_end_of_line_worker_commands(builder_config, worker_math_page_counts, worker_command_streams); +} + +static void validate_non_end_of_line_tensors(ReduceScatterBuilderConfig& builder_config) { + auto const& all_program_tensors = builder_config.all_tensors.get(); + auto const& partial_reducer_worker_cores_per_direction = builder_config.worker_cores.get().partial_reducers; + for (auto direction : {LineDirection::FORWARD, LineDirection::BACKWARD}) { + TT_FATAL( + all_program_tensors.remote_output[direction] != nullptr, + "Internal error. Expected remote output tensor from direction {} to be populated", + direction); + TT_FATAL( + all_program_tensors.input_tensor_from_remote[direction] != nullptr, + "Internal error. Expected input tensor from remote direction {} to be populated", + direction); + TT_ASSERT( + all_program_tensors.input_tensor->shape() == all_program_tensors.remote_output[direction]->shape(), + "Input tensor and remote output tensor - direction {} must have the same shape", + direction); + TT_ASSERT( + all_program_tensors.input_tensor->shape() == + all_program_tensors.input_tensor_from_remote[direction]->shape(), + "Input tensor and input from remote tensor from direction {} must have the same shape", + direction); + } + TT_FATAL( + partial_reducer_worker_cores_per_direction[LineDirection::FORWARD].num_cores() == + partial_reducer_worker_cores_per_direction[LineDirection::BACKWARD].num_cores(), + "Internal error. Expected number of partial reducer workers to be the same for both directions"); +} + +static void create_non_end_of_line_worker_commands( + ReduceScatterBuilderConfig& builder_config, + WorkerCommandStreams& worker_command_streams_out, + std::unordered_map& math_page_counts_out) { + validate_non_end_of_line_tensors(builder_config); + + auto const& all_program_tensors = builder_config.all_tensors.get(); + auto const& partial_reducer_worker_cores_per_direction = builder_config.worker_cores.get().partial_reducers; + auto const& topology_config = builder_config.topology_config.get(); + + using namespace ttnn::ccl::worker_detail; + using namespace ttnn::ccl::cmd; + using namespace ttnn::ccl::cmd::builder; + + auto const num_workers = partial_reducer_worker_cores_per_direction[LineDirection::FORWARD].num_cores(); + auto const nchips = topology_config.line_size(); + auto const last_chip = topology_config.line_size() - 1; + // in_tensor_slices = input_tensor.shape.chunk(n=line_size, dim=dim) + // in_slices_fwd = reverse(in_tensor_slices[topology_config.line_index():]) --> For chip 1, of 4 chip line we want + // slices 3, 2, 1 in_slices_bwd = in_tensor_slices[:line_toptopology_configology.line_index() + 1] // assuming + // exclusive end --> For chip 1, of 4 chip line we want slices 0, 1 out_remote_slices_fwd = + // reverse(in_tensor_slices[topology_config.line_index() + 1:]) --> For chip 1, of 4 chip line we want slices 3, 2 + // out_remote_slices_bwd = in_tensor_slices[topology_config.line_index():]) --> For chip 1, of 4 chip line we want + // slices 0 (we are only forwarding one slice) Note those that vslice uses inclusive ends so the end values below + // are off-by-one from the examples above + auto const input_tensor_slices = + generate_tensor_slices(nchips, *all_program_tensors.input_tensor, builder_config.dim); + TT_FATAL(input_tensor_slices.size() == nchips, "Internal error. Expected number of slices to match line size"); + + auto const in_slices_fwd = vslice(input_tensor_slices, last_chip, topology_config.line_index()); + auto const in_slices_bwd = vslice(input_tensor_slices, 0, topology_config.line_index()); + auto const out_remote_slices_fwd = vslice(input_tensor_slices, last_chip, topology_config.line_index() + 1); + auto const out_remote_slices_bwd = vslice(input_tensor_slices, 0, topology_config.line_index() - 1); + + std::array>, 2> reader_worker_slices_by_direction = { + split_tensor_slices_across_workers_page_aligned(num_workers, in_slices_fwd), + split_tensor_slices_across_workers_page_aligned(num_workers, in_slices_bwd)}; + std::array>, 2> writer_worker_slices_by_direction = { + split_tensor_slices_across_workers_page_aligned(num_workers, out_remote_slices_fwd), + split_tensor_slices_across_workers_page_aligned(num_workers, out_remote_slices_bwd)}; + + // Command stream creation + populate_partial_reduce_worker_commands( + builder_config, + reader_worker_slices_by_direction, + writer_worker_slices_by_direction, + worker_command_streams_out, + math_page_counts_out); + + create_non_end_of_line_final_reducer_worker_commands( + builder_config, worker_command_streams_out, math_page_counts_out); +} + +static void create_worker_runtime_args_not_end_of_line( + ReduceScatterBuilderConfig& builder_config, + fabric_lifetime_mode fabric_mode, + WorkerCommandStreams& worker_command_streams_out, + std::unordered_map& math_page_counts_out) { + // Kernel Creation + create_final_reducer_worker_rt_args_not_end_of_line( + builder_config, fabric_mode, worker_command_streams_out, math_page_counts_out); + + populate_partial_reduce_rt_args(builder_config, worker_command_streams_out, math_page_counts_out); +} + +static void validate_tensors(ProgramTensorsBundle const& all_tensors, LineTopology topology_config) { + if (topology_config.topology() == ttnn::ccl::Topology::Linear) { + const size_t page_size = get_page_size(*all_tensors.input_tensor); + for (auto direction : {LineDirection::FORWARD, LineDirection::BACKWARD}) { + if (!topology_config.is_at_end_of_line()) { + TT_FATAL(all_tensors.remote_output[direction] != nullptr, "Remote output tensor must be populated"); + TT_FATAL( + page_size == get_page_size(*all_tensors.remote_output[direction]), + "Remote output tensor must have the same page size as input tensor"); + } + if (topology_config.is_first_device_in_line(direction)) { + TT_FATAL( + all_tensors.local_output_partial[direction] != nullptr, + "Local output partial tensor must be populated"); + TT_FATAL( + all_tensors.input_tensor_from_remote[direction] == nullptr, + "Input tensor from remote must be populated"); + TT_FATAL(all_tensors.remote_output[direction] != nullptr, "Remote output tensor must be populated"); + TT_FATAL( + page_size == get_page_size(*all_tensors.remote_output[direction]), + "Remote output tensor must have the same page size as input tensor"); + } else if (topology_config.is_last_device_in_line(direction)) { + TT_FATAL( + all_tensors.input_tensor_from_remote[direction] != nullptr, + "Input tensor from remote must be populated"); + TT_FATAL(all_tensors.remote_output[direction] == nullptr, "Remote output tensor must be populated"); + TT_FATAL( + page_size == get_page_size(*all_tensors.input_tensor_from_remote[direction]), + "Input tensor from remote must have the same page size as input tensor"); + } + if (all_tensors.local_output_partial[direction] != nullptr) { + TT_FATAL( + all_tensors.local_output_partial[direction]->shape() == all_tensors.local_output_tensor->shape(), + "Partial output tensor and local output tensor must have the same shape"); + } + if (all_tensors.input_tensor_from_remote[direction] != nullptr) { + TT_FATAL( + all_tensors.input_tensor_from_remote[direction]->shape() == all_tensors.input_tensor->shape(), + "Input tensor from remote and input tensor must have the same shape"); + } + if (all_tensors.remote_output[direction] != nullptr) { + TT_FATAL( + all_tensors.remote_output[direction]->shape() == all_tensors.input_tensor->shape(), + "Remote output tensor and input tensor must have the same shape"); + } + } + } else { + return; + } +} + +static void initialize_op_internal_tensor_syncs( + Program& program, + Device* device, + std::array const& neighbour_devices, + ProgramTensorsBundle& all_tensors, + WorkerCoreBundle const& worker_cores, + std::shared_ptr const& from_remote_sem, + std::shared_ptr const& to_remote_sem) { + auto core_coord_lt = [](CoreCoord const& a, CoreCoord const& b) { return a.y < b.y || (a.y == b.y && a.x < b.x); }; + + TT_FATAL( + worker_cores.partial_reducers_vec[LineDirection::BACKWARD].size() > 0, + "Internal error. Expected at least one partial reducer worker"); + std::array, 2> partial_reducer_cores = { + worker_cores.partial_reducers_vec[LineDirection::FORWARD], + worker_cores.partial_reducers_vec[LineDirection::BACKWARD]}; + auto all_partial_reducer_cores = worker_cores.partial_reducers[LineDirection::FORWARD]; + all_partial_reducer_cores = all_partial_reducer_cores.merge(worker_cores.partial_reducers[LineDirection::BACKWARD]); + + auto partial_reducers_in1_sem_id = CreateSemaphore(program, all_partial_reducer_cores, 0, CoreType::WORKER); + for (auto direction : {LineDirection::FORWARD, LineDirection::BACKWARD}) { + all_tensors.input_tensor_from_remote_sync[direction] = TensorSyncSpec{}; + for (auto const& worker_core : partial_reducer_cores[direction]) { + all_tensors.input_tensor_from_remote_sync[direction].targets.push_back(TensorSyncSpec::target_rect{ + device->worker_core_from_logical_core(worker_core).x, + device->worker_core_from_logical_core(worker_core).y, + device->worker_core_from_logical_core(worker_core).x, + device->worker_core_from_logical_core(worker_core).y, + }); + all_tensors.input_tensor_from_remote_sync[direction].semaphore_ids.push_back(from_remote_sem.get()); + all_tensors.input_tensor_from_remote_sync[direction].completion_target_value_per_semaphore.push_back(1); + + // remote output sync + if (neighbour_devices[direction] != nullptr) { + all_tensors.remote_output_sync[direction].semaphore_ids.push_back(to_remote_sem.get());// = all_tensors.input_tensor_from_remote_sync[direction]; + all_tensors.remote_output_sync[direction].completion_target_value_per_semaphore.push_back(1);// = all_tensors.input_tensor_from_remote_sync[direction]; + all_tensors.remote_output_sync[direction] = all_tensors.input_tensor_from_remote_sync[direction]; + all_tensors.remote_output_sync[direction].targets.back() = TensorSyncSpec::target_rect{ + neighbour_devices[direction]->worker_core_from_logical_core(worker_core).x, + neighbour_devices[direction]->worker_core_from_logical_core(worker_core).y, + neighbour_devices[direction]->worker_core_from_logical_core(worker_core).x, + neighbour_devices[direction]->worker_core_from_logical_core(worker_core).y, + }; + } + } + } + + auto final_reducer_cores = corerange_to_cores(worker_cores.final_reducers, std::nullopt, true); + std::array final_reducer_partial_input_sem_ids = { + CreateSemaphore(program, worker_cores.final_reducers, 0, CoreType::WORKER), + CreateSemaphore(program, worker_cores.final_reducers, 0, CoreType::WORKER)}; + for (auto const& worker_core : final_reducer_cores) { + auto worker_target = TensorSyncSpec::target_rect{ + device->worker_core_from_logical_core(worker_core).x, + device->worker_core_from_logical_core(worker_core).y, + device->worker_core_from_logical_core(worker_core).x, + device->worker_core_from_logical_core(worker_core).y, + }; + all_tensors.local_output_partial_sync[LineDirection::FORWARD].targets.push_back(worker_target); + all_tensors.local_output_partial_sync[LineDirection::FORWARD].completion_target_value_per_semaphore.push_back( + 1); + all_tensors.local_output_partial_sync[LineDirection::FORWARD].semaphore_ids.push_back( + final_reducer_partial_input_sem_ids[LineDirection::FORWARD]); + all_tensors.local_output_partial_sync[LineDirection::BACKWARD].targets.push_back(worker_target); + all_tensors.local_output_partial_sync[LineDirection::BACKWARD].completion_target_value_per_semaphore.push_back( + 1); + all_tensors.local_output_partial_sync[LineDirection::BACKWARD].semaphore_ids.push_back( + final_reducer_partial_input_sem_ids[LineDirection::BACKWARD]); + } + + for (auto direction : {LineDirection::FORWARD, LineDirection::BACKWARD}) { + TT_FATAL( + all_tensors.input_tensor_from_remote_sync[direction].targets.size() > 0, + "Input tensor from remote sync must be populated"); + TT_FATAL( + all_tensors.input_tensor_from_remote_sync[direction].semaphore_ids.size() > 0, + "Input tensor from remote sync must be populated"); + TT_FATAL( + all_tensors.input_tensor_from_remote_sync[direction].completion_target_value_per_semaphore.size() > 0, + "Input tensor from remote sync must be populated"); + TT_FATAL( + all_tensors.input_tensor_from_remote_sync[direction].completion_target_value_per_semaphore.size() == + all_tensors.input_tensor_from_remote_sync[direction].semaphore_ids.size(), + "Input tensor from remote sync must be populated"); + + TT_FATAL( + all_tensors.remote_output_sync[direction].completion_target_value_per_semaphore.size() == + all_tensors.remote_output_sync[direction].semaphore_ids.size(), + "Remote output sync must be populated"); + + TT_FATAL( + all_tensors.local_output_partial_sync[direction].targets.size() > 0, + "Local output partial sync must be populated"); + TT_FATAL( + all_tensors.local_output_partial_sync[direction].semaphore_ids.size() > 0, + "Local output partial sync must be populated"); + TT_FATAL( + all_tensors.local_output_partial_sync[direction].completion_target_value_per_semaphore.size() > 0, + "Local output partial sync must be populated"); + TT_FATAL( + all_tensors.local_output_partial_sync[direction].completion_target_value_per_semaphore.size() == + all_tensors.local_output_partial_sync[direction].semaphore_ids.size(), + "Local output partial sync must be populated"); + } + TT_FATAL( + all_tensors.remote_output_sync[LineDirection::FORWARD].targets.size() > 0 || + all_tensors.remote_output_sync[LineDirection::BACKWARD].targets.size() > 0, + "Remote output sync must be populated"); + TT_FATAL( + all_tensors.remote_output_sync[LineDirection::FORWARD].semaphore_ids.size() > 0 || + all_tensors.remote_output_sync[LineDirection::BACKWARD].semaphore_ids.size() > 0, + "Remote output sync must be populated"); + TT_FATAL( + all_tensors.remote_output_sync[LineDirection::FORWARD].completion_target_value_per_semaphore.size() > 0 || + all_tensors.remote_output_sync[LineDirection::BACKWARD].completion_target_value_per_semaphore.size() > 0, + "Remote output sync must be populated"); +} + +static void generate_worker_command_streams( + ReduceScatterBuilderConfig& builder_config, + fabric_lifetime_mode fabric_mode, + WorkerCommandStreams& command_streams, + std::unordered_map& math_page_counts) { + bool is_end_of_line = builder_config.topology_config.get().is_at_end_of_line(); + if (is_end_of_line) { + create_end_of_line_worker_commands(builder_config, fabric_mode, command_streams, math_page_counts); + } else { + create_non_end_of_line_worker_commands(builder_config, command_streams, math_page_counts); + } +} + + + +static void populate_worker_runtime_args( + ReduceScatterBuilderConfig& builder_config, + fabric_lifetime_mode fabric_mode, + WorkerCommandStreams& command_streams, + std::unordered_map& math_page_counts) { + bool is_end_of_line = builder_config.topology_config.get().is_at_end_of_line(); + if (is_end_of_line) { + create_worker_runtime_args_for_inactive_workers(builder_config); + create_end_of_line_worker_runtime_args(builder_config, command_streams, math_page_counts); + } else { + create_worker_runtime_args_not_end_of_line(builder_config, fabric_mode, command_streams, math_page_counts); + } +} + + +static void log_worker_command_streams(WorkerCommandStreams const& command_streams, Device *device) { + std::set cores; + for (auto const&[core, cmd_stream] : command_streams.reader_cmds0) { cores.insert(core); } + for (auto const&[core, cmd_stream] : command_streams.reader_cmds1) { cores.insert(core); } + for (auto const&[core, cmd_stream] : command_streams.writer_cmds0) { cores.insert(core); } + for (auto const&[core, cmd_stream] : command_streams.writer_cmds1) { cores.insert(core); } + + auto get_cmd_str = [device](ttnn::ccl::cmd::CclHostLowLevelWorkerCommand const& cmd) -> std::string { + auto print_core = [](ttnn::ccl::cmd::CclCommandCoreDescriptorArgs const& core) { + return std::visit( + tt::stl::overloaded{ + [](ttnn::ccl::cmd::CclCommandCoreDescriptorTypeAddrgen const& core) { return fmt::format("addrgen"); }, + [](ttnn::ccl::cmd::CclCommandCoreDescriptorTypeLocal const& core) { return fmt::format("local"); }, + [](ttnn::ccl::cmd::CclCommandCoreDescriptorTypeNocXY const& core) { return fmt::format("(x={},y={})", core.x, core.y); }, + [](ttnn::ccl::cmd::CclCommandCoreDescriptorTypeMcast const& core) { return fmt::format("mcast"); }, + }, + core); + }; + + auto print_addr = [](ttnn::ccl::cmd::CclCommandAddrArgs const& addr) { + return std::visit( + tt::stl::overloaded{ + [](ttnn::ccl::cmd::CclCommandAddrSemaphoreId const& addr) { return fmt::format("sem: {}", addr.semaphore_id); }, + [](ttnn::ccl::cmd::CclCommandAddrCircularBufferId const& addr) { return fmt::format("cb: {}", addr.circular_buffer_id); }, + [](ttnn::ccl::cmd::CclCommandAddrAbsoluteAddress const& addr) { return fmt::format("abs_addr: {}", addr.absolute_address); }, + [](ttnn::ccl::cmd::CclCommandAddrRelativeAddress const& addr) { return fmt::format("rel_addr: {}", addr.relative_address); }, + [](ttnn::ccl::cmd::CclCommandAddrNone const& addr) { return fmt::format("NONE"); }, + }, + addr); + }; + + auto tslice_str = [](ttnn::ccl::cmd::CclCommandStreamTensorSlice const& slice) { + return fmt::format("t({},{},{},{})s({},{},{},{})o({},{},{},{})w({},{},{},{})", + slice.tensor_shape.w, slice.tensor_shape.z, slice.tensor_shape.y, slice.tensor_shape.x, + slice.tensor_slice_shape.w, slice.tensor_slice_shape.z, slice.tensor_slice_shape.y, slice.tensor_slice_shape.x, + slice.tensor_slice_offset.w, slice.tensor_slice_offset.z, slice.tensor_slice_offset.y, slice.tensor_slice_offset.x, + slice.worker_slice_offset.w, slice.worker_slice_offset.z, slice.worker_slice_offset.y, slice.worker_slice_offset.x); + }; + + switch (cmd.command_code) { + case ttnn::ccl::cmd::CclCommandCode::STREAM_TENSOR_TO_CB: + return fmt::format("T->CB {}", tslice_str(std::get(cmd.command_args))); + case ttnn::ccl::cmd::CclCommandCode::STREAM_CB_TO_TENSOR: + return fmt::format("CB->T {}", tslice_str(std::get(cmd.command_args))); + + case ttnn::ccl::cmd::CclCommandCode::WAIT_VALUE: + return fmt::format("WAIT val: {}, {}", std::get(cmd.command_args).target_value, print_addr(cmd.source_addr_args)); + + case ttnn::ccl::cmd::CclCommandCode::ATOMIC_INC: + return fmt::format("AT_INC val: {}, {}, {}", std::get(cmd.command_args).value, print_addr(cmd.dest_addr_args), print_core(cmd.core_desc_args)); + + case ttnn::ccl::cmd::CclCommandCode::RAW_INLINE_WRITE_BYTES: + return "RAW_INL_WR"; + + case ttnn::ccl::cmd::CclCommandCode::RAW_READ_BYTES: + return "RAW_RD"; + case ttnn::ccl::cmd::CclCommandCode::RAW_WRITE_BYTES: + return "RAW_WR"; + + case ttnn::ccl::cmd::CclCommandCode::STREAM_EDM_TO_TENSOR: + TT_THROW("Got an unsupported command in a command stream (STREAM_EDM_TO_TENSOR). This command is deprecated and unsupported by this infrastructure. This will lead to undefined and invalid behaviour"); + case ttnn::ccl::cmd::CclCommandCode::INVALID: + default: + TT_THROW("Got an invalid command in a command stream. This will lead to undefined and invalid behaviour"); + return ""; + } + }; + + for (auto const& core : cores) { + std::stringstream ss; + ss << fmt::format("\n____________________________________________________________________ CORE(chip={},x={},y={}) ____________________________________________________________________\n", device->id(), core.x, core.y); + ss << fmt::format("{:<50} {:<50} {:<50} {:<50}\n", "READER STREAM 0","READER STREAM 1","WRITER STREAM 0","WRITER STREAM 1"); + size_t max_seq_len = 0; + bool reader0_populated = command_streams.reader_cmds0.find(core) != command_streams.reader_cmds0.end(); + bool reader1_populated = command_streams.reader_cmds1.find(core) != command_streams.reader_cmds1.end(); + bool writer0_populated = command_streams.writer_cmds0.find(core) != command_streams.writer_cmds0.end(); + bool writer1_populated = command_streams.writer_cmds1.find(core) != command_streams.writer_cmds1.end(); + + if (reader0_populated) { + max_seq_len = std::max(max_seq_len, command_streams.reader_cmds0.at(core).size()); + } + if (reader1_populated) { + max_seq_len = std::max(max_seq_len, command_streams.reader_cmds1.at(core).size()); + } + if (writer0_populated) { + max_seq_len = std::max(max_seq_len, command_streams.writer_cmds0.at(core).size()); + } + if (writer1_populated) { + max_seq_len = std::max(max_seq_len, command_streams.writer_cmds1.at(core).size()); + } + + for (size_t i = 0; i < max_seq_len; i++) { + auto reader0_has = reader0_populated && i < command_streams.reader_cmds0.at(core).size(); + auto reader1_has = reader1_populated && i < command_streams.reader_cmds1.at(core).size(); + auto writer0_has = writer0_populated && i < command_streams.writer_cmds0.at(core).size(); + auto writer1_has = writer1_populated && i < command_streams.writer_cmds1.at(core).size(); + + ss << fmt::format("{:<50} {:<50} {:<50} {:<50}\n", + reader0_has ? get_cmd_str(command_streams.reader_cmds0.at(core)[i]) : "", + reader1_has ? get_cmd_str(command_streams.reader_cmds1.at(core)[i]) : "", + writer0_has ? get_cmd_str(command_streams.writer_cmds0.at(core)[i]) : "", + writer1_has ? get_cmd_str(command_streams.writer_cmds1.at(core)[i]) : ""); + } + log_debug(tt::LogOp, "{}", ss.str()); + } +} + + +operation::ProgramWithCallbacks reduce_scatter_async_on_instantiated_edm_fabric( + Program& program, + ttnn::ccl::EdmLineFabricOpInterface& fabric, + std::optional forward_device, + std::optional backward_device, + Tensor const& input_tensor, + Tensor& local_output_tensor, + Tensor& input_tensor_from_remote_forward_direction, + Tensor& input_tensor_from_remote_backward_direction, + Tensor& local_partial_output_tensor_from_forward_direction, + Tensor& local_partial_output_tensor_from_backward_direction, + std::optional& foreward_direction_remote_output_tensor, + std::optional& backward_direction_remote_output_tensor, + ttnn::operations::binary::BinaryOpType reduce_op, + size_t line_size, + size_t line_index, + const uint32_t dim, + const size_t num_links, + ttnn::ccl::Topology topology, + + fabric_lifetime_mode fabric_mode, + std::shared_ptr const& from_remote_sems, + std::shared_ptr const& to_remote_sem) { + using namespace ttnn::ccl::worker_detail; + bool do_dynamic_fabric_bringup_and_teardown = fabric_mode == fabric_lifetime_mode::TRANSIENT; + + // Constants/ "Globals" + constexpr auto math_in0_cb = tt::CBIndex::c_0; + constexpr auto math_in1_cb = tt::CBIndex::c_1; + constexpr auto math_out_cb = tt::CBIndex::c_2; + constexpr auto pass_through_cb = tt::CBIndex::c_3; + AllReduceScatterCircularBufferIds all_cbs = { + {pass_through_cb, math_in0_cb, math_in1_cb}, + {pass_through_cb, math_out_cb}, + {math_in0_cb, math_in1_cb}, + {math_out_cb}, + {pass_through_cb}, + {pass_through_cb}, + {math_in0_cb, math_in1_cb}, + {math_out_cb}}; + + const size_t page_size = get_page_size(input_tensor); + Device* device = input_tensor.device(); + std::array neighbour_devices = {forward_device.value_or(nullptr), backward_device.value_or(nullptr)}; + size_t fabric_buffer_size_pages = fabric.get_edm_buffer_size_bytes() / get_page_size(input_tensor); + auto const& topology_config = LineTopology(line_size, line_index); + + auto const& worker_cores = select_worker_cores(topology, num_links, device); + ProgramTensorsBundle all_tensors = { + // local input tensor + ProgramTensorsBundle::build_handle(input_tensor), + {}, + + // local output tensor + ProgramTensorsBundle::build_handle(local_output_tensor), + {}, + + // input tensor from remote + {topology_config.is_first_device_in_line(LineDirection::FORWARD) + ? nullptr + : ProgramTensorsBundle::build_handle(input_tensor_from_remote_forward_direction), + topology_config.is_first_device_in_line(LineDirection::BACKWARD) + ? nullptr + : ProgramTensorsBundle::build_handle(input_tensor_from_remote_backward_direction)}, + {}, + + // output partial tensor on remote chip + {topology_config.is_last_device_in_line(LineDirection::FORWARD) + ? nullptr + : ProgramTensorsBundle::build_handle(input_tensor_from_remote_forward_direction), + topology_config.is_last_device_in_line(LineDirection::BACKWARD) + ? nullptr + : ProgramTensorsBundle::build_handle(input_tensor_from_remote_backward_direction)}, + {}, + + // local partial output tensor for final reducer + {ProgramTensorsBundle::build_handle(local_partial_output_tensor_from_forward_direction), + ProgramTensorsBundle::build_handle(local_partial_output_tensor_from_backward_direction)}, + {}}; + + log_debug(tt::LogOp, + "input_tensor.addr: {}, \n" + "local_output_tensor.addr: {}, \n" + "input_tensor_from_remote_forward_direction.addr: {}, \n" + "input_tensor_from_remote_backward_direction.addr: {}, \n" + "output_tensor_on_remote_chip_forward_direction.addr: {}, \n" + "output_tensor_on_remote_chip_backward_direction.addr: {}, \n" + "local_partial_output_tensor_from_forward_direction.addr: {}, \n" + "local_partial_output_tensor_from_backward_direction.addr: {} \n", + all_tensors.input_tensor != nullptr ? (void*)all_tensors.input_tensor->buffer()->address() : nullptr, + all_tensors.local_output_tensor != nullptr ? (void*)all_tensors.local_output_tensor->buffer()->address() : nullptr, + all_tensors.input_tensor_from_remote[LineDirection::FORWARD] != nullptr ? (void*)all_tensors.input_tensor_from_remote[LineDirection::FORWARD]->buffer()->address() : nullptr, + all_tensors.input_tensor_from_remote[LineDirection::BACKWARD] != nullptr ? (void*)all_tensors.input_tensor_from_remote[LineDirection::BACKWARD]->buffer()->address() : nullptr, + all_tensors.remote_output[LineDirection::FORWARD] != nullptr ? (void*)all_tensors.remote_output[LineDirection::FORWARD]->buffer()->address() : nullptr, + all_tensors.remote_output[LineDirection::BACKWARD] != nullptr ? (void*)all_tensors.remote_output[LineDirection::BACKWARD]->buffer()->address() : nullptr, + all_tensors.local_output_partial[LineDirection::FORWARD] != nullptr ? (void*)all_tensors.local_output_partial[LineDirection::FORWARD]->buffer()->address() : nullptr, + all_tensors.local_output_partial[LineDirection::BACKWARD] != nullptr ? (void*)all_tensors.local_output_partial[LineDirection::BACKWARD]->buffer()->address() : nullptr); + + + initialize_op_internal_tensor_syncs( + program, device, neighbour_devices, all_tensors, worker_cores, from_remote_sems, to_remote_sem); + + validate_tensors(all_tensors, topology_config); + + // Circular Buffer Creation + size_t const cb_page_size = page_size; + auto const cb_handles = create_worker_circular_buffers( + program, + worker_cores.all_worker_cores, + tt::tt_metal::datatype_to_dataformat_converter(input_tensor.get_dtype()), + math_in0_cb, + math_in1_cb, + math_out_cb, + pass_through_cb, + fabric_buffer_size_pages, + // TODO: Move packet headers to side buffer and don't force it through + page_size); + + auto kernel_ids = + build_line_reduce_scatter_worker_ct(program, all_tensors, cb_handles, worker_cores.all_worker_cores, reduce_op); + + const size_t pages_per_cb_packet = fabric.get_edm_buffer_size_bytes() / cb_page_size; + auto builder_config = ReduceScatterBuilderConfig{ + program, + device, + forward_device.value_or(nullptr), + backward_device.value_or(nullptr), + fabric, + all_tensors, + kernel_ids, + all_cbs, + topology_config, + worker_cores, + page_size, + pages_per_cb_packet, + dim}; + bool is_end_of_line = topology_config.is_at_end_of_line(); + + log_trace(tt::LogOp, "Pages per CB packet: {}", pages_per_cb_packet); + WorkerCommandStreams command_streams; + std::unordered_map math_page_counts; + generate_worker_command_streams(builder_config, fabric_mode, command_streams, math_page_counts); + + log_worker_command_streams(command_streams, device); + + populate_worker_runtime_args(builder_config, fabric_mode, command_streams, math_page_counts); + + // Synchronous mode kernel invocation + auto override_runtime_arguments_callback = + [topology_config, from_remote_sems, to_remote_sem, kernel_ids, worker_cores]( + const void* operation, + Program& program, + const std::vector& input_tensors, + const std::vector>& optional_input_tensors, + const std::vector& output_tensors) { + const auto& input = input_tensors.at(0); + const auto& output = output_tensors.at(0); + auto& worker_reader_runtime_args_by_core = GetRuntimeArgs(program, kernel_ids.reader); + auto& worker_writer_runtime_args_by_core = GetRuntimeArgs(program, kernel_ids.writer); + }; + + log_trace(tt::LogOp, "Done program factory"); + + return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_arguments_callback}; +} + +operation::ProgramWithCallbacks build_reduce_scatter_async_program( + Tensor const& input_tensor, + Tensor& local_output_tensor, + Tensor& input_tensor_from_remote_forward_direction, + Tensor& input_tensor_from_remote_backward_direction, + Tensor& local_partial_output_tensor_from_forward_direction, + Tensor& local_partial_output_tensor_from_backward_direction, + std::optional& foreward_direction_remote_output_tensor, + std::optional& backward_direction_remote_output_tensor, + std::optional forward_device, + std::optional backward_device, + ttnn::operations::binary::BinaryOpType reduce_op, + const uint32_t dim, + const uint32_t line_size, + const uint32_t line_index, + ttnn::ccl::Topology topology, + std::optional num_links_preferred, + std::optional> const& from_remote_sem_opt, + std::optional> const& to_remote_sem_opt, + std::optional& fabric_handle_) { + auto program = tt::tt_metal::Program(); + + TT_FATAL(from_remote_sem_opt.has_value(), "Semaphore handle is required for compile time"); + TT_FATAL(to_remote_sem_opt.has_value(), "Semaphore handle is required for compile time"); + + auto from_remote_sem = from_remote_sem_opt.value(); + auto to_remote_sem = to_remote_sem_opt.value(); + + bool persistent_fabric = true; + Device* device = input_tensor.device(); + + std::optional fabric_handle = fabric_handle_; + fabric_lifetime_mode fabric_mode = fabric_lifetime_mode::PERSISTENT; + // fabric_handle.has_value() ? fabric_lifetime_mode::PERSISTENT : fabric_lifetime_mode::TRANSIENT; + // We only build the local chip's part of the fabric + if (!fabric_handle.has_value()) { + fabric_handle = ttnn::ccl::EdmLineFabricOpInterface( + device, + forward_device, + backward_device, + &program, + persistent_fabric, + num_links_preferred.value_or(line_size), + true); + } + + TT_FATAL(fabric_mode == fabric_lifetime_mode::PERSISTENT, "Reduce scatter doesn't support transient fabric mode"); + return reduce_scatter_async_on_instantiated_edm_fabric( + program, + fabric_handle.value(), + forward_device, + backward_device, + input_tensor, + local_output_tensor, + input_tensor_from_remote_forward_direction, + input_tensor_from_remote_backward_direction, + local_partial_output_tensor_from_forward_direction, + local_partial_output_tensor_from_backward_direction, + foreward_direction_remote_output_tensor, + backward_direction_remote_output_tensor, + reduce_op, + line_size, + line_index, + dim, + fabric_handle.value().get_num_links(), + ttnn::ccl::Topology::Linear, + fabric_mode, + from_remote_sem, + to_remote_sem); +} + +} // namespace ttnn::ccl::reduce_scatter_detail diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/reduce_scatter.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/reduce_scatter.cpp new file mode 100644 index 00000000000..39953e0d8a1 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/reduce_scatter.cpp @@ -0,0 +1,32 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "reduce_scatter.hpp" + +#include "ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/device/reduce_scatter_async_op.hpp" + +namespace ttnn::operations::experimental::ccl { + +ttnn::Tensor ExecuteReduceScatter::invoke( + const ttnn::Tensor& input_tensor, + const int32_t dim, + ttnn::operations::reduction::ReduceType math_op, + const std::optional& memory_config, + ttnn::ccl::Topology topology, + const std::optional num_preferred_links, + std::optional worker_subdevice_id_opt, + bool create_semaphore_handles) { + MemoryConfig out_memory_config = memory_config.value_or(input_tensor.memory_config()); + return ttnn::operations::experimental::ccl::reduce_scatter( + input_tensor, + dim, + math_op, + out_memory_config, + topology, + num_preferred_links, + worker_subdevice_id_opt, + create_semaphore_handles); +} + +} // namespace ttnn::operations::ccl diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/reduce_scatter.hpp b/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/reduce_scatter.hpp new file mode 100644 index 00000000000..8c42952c8b9 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/reduce_scatter.hpp @@ -0,0 +1,38 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include "ttnn/decorators.hpp" + +#include "ttnn/operations/reduction/generic/generic_reductions.hpp" + +#include "ttnn/cpp/ttnn/operations/ccl/ccl_host_types.hpp" + +namespace ttnn { +namespace operations { +namespace experimental { +namespace ccl { + +struct ExecuteReduceScatter { + static ttnn::Tensor invoke( + const ttnn::Tensor& input_tensor, + const int32_t dim, + ttnn::operations::reduction::ReduceType math_op, + const std::optional& memory_config = std::nullopt, + ttnn::ccl::Topology topology = ttnn::ccl::Topology::Linear, + const std::optional num_links = std::nullopt, + std::optional worker_subdevice_id_opt = std::nullopt, + bool create_semaphore_handles = true); +}; + +} // namespace ccl +} // namespace experimental +} // namespace operations + +constexpr auto reduce_scatter_async = + ttnn::register_operation<"ttnn::reduce_scatter_async", ttnn::operations::experimental::ccl::ExecuteReduceScatter>(); + +} // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/reduce_scatter_pybind.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/reduce_scatter_pybind.cpp new file mode 100644 index 00000000000..e87c6738f49 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/reduce_scatter_pybind.cpp @@ -0,0 +1,103 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "reduce_scatter_pybind.hpp" + +#include +#include + +#include "ttnn/cpp/pybind11/decorators.hpp" +#include "ttnn/operations/experimental/ccl/reduce_scatter_async/reduce_scatter.hpp" +#include "ttnn/types.hpp" + +#include "ttnn/operations/reduction/generic/generic_reductions.hpp" + +namespace ttnn::operations::experimental::ccl { + +namespace detail { + +template +void bind_reduce_scatter(pybind11::module& module, const ccl_operation_t& operation, const char* doc) { + bind_registered_operation( + module, + operation, + doc, + ttnn::pybind_overload_t{ + [](const ccl_operation_t& self, + const ttnn::Tensor& input_tensor, + const int32_t dim, + ttnn::operations::reduction::ReduceType math_op, + const ttnn::MemoryConfig& memory_config, + ttnn::ccl::Topology topology, + const std::optional num_links, + std::optional worker_subdevice_id_opt, + bool create_semaphore_handles) -> ttnn::Tensor { + return self( + input_tensor, + dim, + math_op, + memory_config, + topology, + num_links, + worker_subdevice_id_opt, + create_semaphore_handles); + }, + py::arg("input_tensor"), + py::arg("dim"), + py::arg("math_op"), + py::kw_only(), + py::arg("memory_config") = std::nullopt, + py::arg("topology") = ttnn::ccl::Topology::Linear, + py::arg("num_links") = std::nullopt, + py::arg("subdevice_id") = std::nullopt, + py::arg("create_semaphore_handles") = true}); +} + +} // namespace detail + +void py_bind_reduce_scatter_async(pybind11::module& module) { + detail::bind_reduce_scatter( + module, + ttnn::reduce_scatter_async, + R"doc( + + Performs an reduce_scatter operation on multi-device :attr:`input_tensor` across all devices. This operation requires a persistent + fabric to be enabled in order to function. + + Args: + input_tensor (ttnn.Tensor): multi-device tensor + dim (int): Dimension to perform operation + cluster_axis (int): Provided a MeshTensor, the axis corresponding to MeshDevice to perform the line-reduce-scatter operation on. + mesh_device (MeshDevice): Device mesh to perform the line-reduce-scatter operation on. + * cluster_axis and mesh_device parameters are applicable only for Linear Topology. + + Mesh Tensor Programming Guide : https://github.com/tenstorrent/tt-metal/blob/main/tech_reports/Programming%20Mesh%20of%20Devices/Programming%20Mesh%20of%20Devices%20with%20TT-NN.md + + Keyword Args: + memory_config (ttnn.MemoryConfig, optional): Memory configuration for the operation. Defaults to `input tensor memory config`. + num_links (int, optional): Number of links to use for the reduce_scatter operation. Defaults to `None`, which indicates to the operation that it should choose. Note that this value will be ignored if there are fewer links available than requested. + topology (ttnn.Topology, optional): The topology configuration to run the operation in. Valid options are Ring and Linear. Defaults to `ttnn.Topology.Ring`. + + Returns: + ttnn.Tensor: the output tensor. + + Example: + + >>> full_tensor = torch.randn([1, 1, 256, 256], dtype=torch.bfloat16) + >>> num_devices = 8 + >>> dim = 3 + >>> input_tensors = torch.chunk(full_tensor, num_devices, dim) + >>> physical_device_ids = ttnn.get_t3k_physical_device_ids_ring() + >>> mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(1, 8), physical_device_ids=physical_device_ids[:8]) + >>> tt_input_tensors = [] + >>> for i, t in enumerate(input_tensors): + tt_input_tensors.append(ttnn.Tensor(t, input_dtype).to(layout).to(mesh_device.get_devices()[i], mem_config)) + >>> input_tensor_mesh = ttnn.aggregate_as_tensor(tt_input_tensors) + + >>> output = ttnn.reduce_scatter(input_tensor_mesh, dim=0, topology=ttnn.Topology.Linear) + + )doc"); +} + +} // namespace ttnn::operations::experimental::ccl diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/reduce_scatter_pybind.hpp b/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/reduce_scatter_pybind.hpp new file mode 100644 index 00000000000..8a29e0c4e39 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/reduce_scatter_pybind.hpp @@ -0,0 +1,13 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "pybind11/pybind_fwd.hpp" + +namespace ttnn::operations::experimental::ccl { + +void py_bind_reduce_scatter_async(pybind11::module& module); + +} // namespace ttnn::operations::ccl diff --git a/ttnn/cpp/ttnn/operations/experimental/experimental_pybind.cpp b/ttnn/cpp/ttnn/operations/experimental/experimental_pybind.cpp index d6f9431947f..355c0f0c0f9 100644 --- a/ttnn/cpp/ttnn/operations/experimental/experimental_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/experimental_pybind.cpp @@ -33,8 +33,7 @@ #include "ttnn/cpp/ttnn/operations/experimental/copy/typecast/typecast_pybind.hpp" #include "ttnn/cpp/ttnn/operations/experimental/matmul/attn_matmul/attn_matmul_pybind.hpp" #include "ttnn/cpp/ttnn/operations/experimental/matmul/group_attn_matmul/group_attn_matmul_pybind.hpp" -#include "ttnn/operations/experimental/ccl/all_gather_matmul/all_gather_matmul_pybind.hpp" -#include "ttnn/operations/experimental/ccl/all_reduce/all_reduce_pybind.hpp" +#include "ttnn/operations/experimental/ccl/ccl_experimental_pybind.hpp" #include "ttnn/operations/experimental/plusone/plusone_pybind.hpp" namespace ttnn::operations::experimental { @@ -77,9 +76,9 @@ void py_module(py::module& module) { plusone::detail::bind_experimental_plusone_operation(module); // CCL ops - auto m_experimental_ccl = module.def_submodule("ccl", "experiemental collective communication operations"); - ccl::py_bind_all_gather_matmul(m_experimental_ccl); - ccl::py_bind_all_reduce(m_experimental_ccl); + auto m_experimental_ccl = + module.def_submodule("ccl_experimental", "experimental collective communication operations"); + ccl::py_module(m_experimental_ccl); } } // namespace ttnn::operations::experimental diff --git a/ttnn/ttnn/__init__.py b/ttnn/ttnn/__init__.py index 4f613ca11ef..c5aece10e0a 100644 --- a/ttnn/ttnn/__init__.py +++ b/ttnn/ttnn/__init__.py @@ -309,6 +309,8 @@ def auto_register_ttnn_cpp_operations(module): from ttnn.operations.ccl import ( Topology, + teardown_edm_fabric, + initialize_edm_fabric, ) from ttnn.operations.conv2d import ( diff --git a/ttnn/ttnn/operations/ccl.py b/ttnn/ttnn/operations/ccl.py index 0056f161327..5162a36cba8 100644 --- a/ttnn/ttnn/operations/ccl.py +++ b/ttnn/ttnn/operations/ccl.py @@ -7,6 +7,8 @@ __all__ = [] Topology = ttnn._ttnn.operations.ccl.Topology +initialize_edm_fabric = ttnn._ttnn.operations.ccl.initialize_edm_fabric +teardown_edm_fabric = ttnn._ttnn.operations.ccl.teardown_edm_fabric # TODO: Add golden functions (#12747)