Skip to content

Commit

Permalink
Initial CCL V2 infra push - add cmd interpreter and reduce scatter, a…
Browse files Browse the repository at this point in the history
…ll-gather

Without going to deep into the weeds, there were numerous reasons why
CCLs needed to be fundamentally rewritten but to summarize some of the
reasons:
- Writing CCLs was not scalable from a development effort
- Even within a single op (e.g. all-gather) we need to be able to
support many topologies (ring, line, mesh, tree, tree of mesh, etc.) and
use cases (BW bound, latency bound, high reliability vs lower
reliability with potentially better perf)
- CCLs need to be able to be fused with just about any ops without it
being a Herculean effort
- New concepts like "async tensor" need to be supported to account for
performance artifacts like (dispatch) skew between chips and to
effectively hide latency of various operations
- (minor) support the new fabric projects with CCLs

  
### Initial test coverage  
- Gtests that provide basic coverage for the CCL Command interpreter
running on the transitionary EDM fabric (both in persistent and
non-persistent modes)
  - Gtests for reduce scatter and all-gather also added
- Basic all gather pytests

Future work will expand test coverage

### What's changed
Lots to discuss here:
- What's the command interpreter
- How's it work
- How do we build ops with it
- What's new with these CCLs?

The bulk of this information is or will be included in a much larger doc
that will be circulated more widely in the coming weeks so a summary is
provided below (if you want more details before the doc is provided, ask
and I will point you to what's in progress):

A new "command interpreter" kernel is provided which executes various
different command types. Some commands map nearly directly to the low
level noc API but others map to higher level operations.
High Level Operation Example:
- Stream Tensor Slice (from: CB/addrgen) (to:raw addr, CB, (fabric)
addrgen)

Low Level Command:
- Wait for semaphore value
- Send semaphore update
- Raw Read/Write

These commands are specifiable on host and there is a whole optimization
story for performance but to provide the general idea, here's the
primary functional code needed for all-gather as an example (code
reorganized for purpose of PR example - not 1:1 to
`all_gather_async_program.cpp`:

```
// Create a "reader kernel" command stream
std::vector<ttnn::ccl::cmd::CclHostLowLevelWorkerCommand> reader_cmd_stream;
reader_cmd_stream.push_back(ttnn::ccl::cmd::uops::read_tensor_slice_to_cb(input_worker_slice_v2, src0_cb_index));


// Create a "writer kernel" command stream
std::vector<ttnn::ccl::cmd::CclHostLowLevelWorkerCommand> writer_cmd_stream;
// 1, do mcast of the tensor slice to all the destinations
writer_cmd_stream.push_back(ttnn::ccl::cmd::uops::fabric_write_cb_to_tensor_slice(
        output_worker_slice_v2, src0_cb_index, mcast_dest_args));
        
// Really, for all-gather, that's basically it - the rest of the code is to choose core placement and get info - l
// ike which core(s) are fabric endpoints to connect to fabric, etc.)

// Now pass the commands to the kernel(s)
ttnn::ccl::worker_detail::generate_multi_input_command_stream_kernel_rt_args(
            program,
            worker_sender_reader_kernel_id,
            ...,
            reader_cmd_stream,
            std::nullopt,
            std::nullopt,
            std::nullopt);
ttnn::ccl::worker_detail::generate_multi_input_command_stream_kernel_rt_args(
            program,
            worker_sender_writer_kernel_id,
            ...,
            writer_cmd_stream,
            std::nullopt,
            {forward_fabric_connection},
            {backward_fabric_connection});

```

With the above, operations such as fusion become far simpler (in some
cases, trivial).

For example, in the case of fusing an all-reduce with split-qkv heads
operation for example (note that the output side of all-reduce is
basically all-gather in an optimized ring implementation), the basic
fusion operation is first identifying the split/slice boundaries of
split-qkv (this could potentially be obtained from the op directly) and
propagating those cut lines to all of the tensor slices of the producer
(like the tensor slices in the commands shown above) and simply
splitting those slices and setting the correct output tensors for each
accordingly.

Note that many commands can be added to each given command stream -
all-gather is just very simple. Reduce scatter is an example of one that
is more complicated.

### Expanding to other operations:
Here are some simple examples

#### Send/receive
- Take the all-gather as example, and rather than specifying an mcast on
the tensor write command:
```
writer_cmd_stream.push_back(ttnn::ccl::cmd::uops::fabric_write_cb_to_tensor_slice(
        output_worker_slice_v2, src0_cb_index, mcast_dest_args))
```
you would unicast it to the desired destination (replace
`mcast_dest_args`)

If running in synchronous tensor mode, add a command interpreter kernel
at the destination chip with a wait_val command to wait on a sem inc.
Append a seminc to the sender command stream

#### Broadcast
Invoke all-gather above but just from one chip. 

If running in synchronous tensor mode, add a command interpreter kernel
at all the destination chips with a wait_val command to wait on a sem
inc. Append a fabric multicast seminc to the sender command stream.

#### Reduce
- Build a tree on the cluster
- Each producer chip unicast sends to the next node towards the root of
the tree, send sync signal to downstream
- if not a leaf, perform partial reduction with your received data and
your local data and forward to the next node toward the root
    - Add a wait val before accepting your input data
- Root node can do any number of reductions to reduce the incoming data
streams (ensuring to first sync on any input stream before consuming

We do something similar to the above for reduce scatter


### Note on APIs
These APIs are expected to be refined over time. In the mean-time, I
have introduces the named "micro-ops" as commands to grant us some
flexibilitiy in changing underlying command encodings (both on host and
device). This will let us optimize and improve the "IR" over time
without requiring constant op implementation updates.

---------

Co-authored-by: Jack Cai <[email protected]>
  • Loading branch information
SeanNijjar and caixunshiren authored Dec 21, 2024
1 parent 4a4aa10 commit 4f5f417
Show file tree
Hide file tree
Showing 69 changed files with 15,490 additions and 1,287 deletions.
1 change: 1 addition & 0 deletions tests/ttnn/unit_tests/gtests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ set(TTNN_CCL_UNIT_TESTS_SRC
${CMAKE_CURRENT_SOURCE_DIR}/ccl/test_ccl_commands.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ccl/test_ccl_helpers.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ccl/test_ccl_tensor_slicers.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ccl/test_sharded_address_generators.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ccl/test_ccl_reduce_scatter_host_helpers.cpp
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "dataflow_api.h"
#include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header.hpp"
#include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/edm_fabric_worker_adapters.hpp"
#include "tests/ttnn/unit_tests/gtests/ccl/kernels/test_kernels.common.hpp"

struct unicast_mode {
uint8_t distance;
Expand Down Expand Up @@ -53,19 +54,20 @@ void kernel_main() {
const uint32_t eth_sender_noc_x = get_arg_val<uint32_t>(arg_idx++);
const uint32_t eth_sender_noc_y = get_arg_val<uint32_t>(arg_idx++);
const uint32_t num_buffers_per_edm_channel = get_arg_val<uint32_t>(arg_idx++);
size_t edm_connection_handshake_addr =
get_semaphore<ProgrammableCoreType::ACTIVE_ETH>(get_arg_val<uint32_t>(arg_idx++));
size_t edm_connection_handshake_id = get_arg_val<uint32_t>(arg_idx++);
size_t edm_worker_location_info_addr = get_arg_val<uint32_t>(arg_idx++);
size_t edm_buffer_size_bytes = get_arg_val<uint32_t>(arg_idx++);
size_t dest_addr = get_arg_val<uint32_t>(arg_idx++);
volatile uint32_t* const last_message_semaphore_address =
reinterpret_cast<volatile uint32_t* const>(get_semaphore(get_arg_val<uint32_t>(arg_idx++)));
*last_message_semaphore_address = 0;
auto worker_buffer_index_semaphore_addr = get_semaphore(get_arg_val<uint32_t>(arg_idx++));
bool connected_to_persistent_fabric = get_arg_val<uint32_t>(arg_idx++) != 0;

// TODO: move to semaphore
auto edm_buffer_index_sem_id = get_arg_val<uint32_t>(arg_idx++);
ASSERT(edm_buffer_index_sem_id < 8);
auto edm_buffer_index_address = get_semaphore<ProgrammableCoreType::ACTIVE_ETH>(edm_buffer_index_sem_id);
auto edm_buffer_index_id = edm_buffer_index_sem_id;
ASSERT(worker_buffer_index_semaphore_addr != reinterpret_cast<size_t>(writer_send_sem_addr));
ASSERT(worker_buffer_index_semaphore_addr != reinterpret_cast<size_t>(last_message_semaphore_address));

Expand All @@ -77,20 +79,22 @@ void kernel_main() {
config.unicast.distance = static_cast<uint8_t>(get_arg_val<uint32_t>(arg_idx++));
}

const InterleavedAddrGen<dest_is_dram> dest_addr_gen = {.bank_base_address = dest_addr, .page_size = page_size};
const InterleavedAddrGen<dest_is_dram> dest_addr_gen = {
.bank_base_address = dest_addr, .page_size = page_size};

ASSERT(num_buffers_per_channel > 0);
auto sender = tt::fabric::WorkerToFabricEdmSender(
connected_to_persistent_fabric,
eth_sender_noc_x,
eth_sender_noc_y,
eth_l1_base_addr,
num_buffers_per_channel,
eth_sender_l1_sem_id,

edm_connection_handshake_addr,
edm_connection_handshake_id,
edm_worker_location_info_addr,
edm_buffer_size_bytes,
edm_buffer_index_address,
edm_buffer_index_id,
writer_send_sem_addr,
worker_buffer_index_semaphore_addr);

Expand Down Expand Up @@ -154,10 +158,8 @@ void kernel_main() {

auto& packet_header = *reinterpret_cast<tt::fabric::PacketHeader*>(a_packet_header_addr);
ASSERT(*last_message_semaphore_address == 0);
packet_header.reserved = 0xE;
packet_header.reserved2 = 0xFFFF;
packet_header.to_atomic_inc();
packet_header.to_chip_unicast(tt::fabric::UnicastRoutingCommandHeader{1});
packet_header.to_chip_unicast(tt::fabric::UnicastRoutingCommandHeader{2});
packet_header.to_noc_unicast_atomic_inc(tt::fabric::NocUnicastAtomicIncCommandHeader(
reinterpret_cast<size_t>(last_message_semaphore_address), 1, 32, my_x[0], my_y[0]));

Expand All @@ -167,40 +169,9 @@ void kernel_main() {
noc_semaphore_wait(last_message_semaphore_address, 1);
}

bool closed = false;
size_t num_endpoints_to_terminate = get_arg_val<uint32_t>(arg_idx++);
for (size_t i = 0; i < num_endpoints_to_terminate; i++) {
size_t edm_noc_x = get_arg_val<uint32_t>(arg_idx++);
size_t edm_noc_y = get_arg_val<uint32_t>(arg_idx++);
size_t distance = get_arg_val<uint32_t>(arg_idx++);
size_t termination_addr = get_arg_val<uint32_t>(arg_idx++);

if (!closed && distance == 0) {
closed = true;
sender.close();
}
if (distance == 0) {
noc_inline_dw_write(
get_noc_addr(edm_noc_x, edm_noc_y, termination_addr),
tt::fabric::TerminationSignal::GRACEFULLY_TERMINATE);
} else {
auto& packet_header = *reinterpret_cast<tt::fabric::PacketHeader*>(a_packet_header_addr);
reinterpret_cast<volatile uint32_t*>(a_packet_header_addr)[sizeof(tt::fabric::PacketHeader) >> 2] =
tt::fabric::TerminationSignal::GRACEFULLY_TERMINATE;
sender.wait_for_empty_write_slot();
packet_header.to_write()
.to_chip_unicast(tt::fabric::UnicastRoutingCommandHeader{static_cast<uint8_t>(distance - 1)})
.to_noc_unicast(tt::fabric::NocUnicastCommandHeader{
termination_addr,
sizeof(tt::fabric::PacketHeader) + sizeof(uint32_t),
static_cast<uint8_t>(edm_noc_x),
static_cast<uint8_t>(edm_noc_y)});
sender.send_payload_blocking_from_address(
a_packet_header_addr, packet_header.get_payload_size_including_header());
noc_async_writes_flushed();
}
}
if (!closed) {
bool closed_fabric_connection = terminate_fabric_endpoints_farthest_to_nearest(sender, a_packet_header_addr, arg_idx);

if (!closed_fabric_connection) {
sender.close();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include <cstdint>

#include "dataflow_api.h"
#include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header.hpp"
#include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/edm_fabric_worker_adapters.hpp"
#include "tests/ttnn/unit_tests/gtests/ccl/kernels/test_kernels.common.hpp"
#include "ttnn/cpp/ttnn/operations/ccl/common/kernels/ccl_send_utils.hpp"

struct unicast_mode {
uint8_t distance;
};
struct mcast_mode {
uint8_t distance;
uint8_t range;
};

union transmit_config {
unicast_mode unicast;
mcast_mode mcast;
};

enum class ReadMode {
// Fully drain one input CB before advancing to the next
FULLY_ORDERED,
// Drain both inputs in at some specified ratio (e.g. if 5:3 then 3 pages from CB0 then 5 from CB1)
RATIOD_FORWARDING,
// Read pages from either CB as soon as they are available
ARBITRARILY_ORDERED
};

static constexpr size_t NORMALIZED_NOC_INDEX = 0;

template <typename AddrGen>
auto forward_to_fabric_from_cb(
size_t total_pages_to_send,
tt::fabric::WorkerToFabricEdmSender &sender,
uint32_t cb_id,
transmit_config const& config,
size_t page_size,
const AddrGen &dest_addr_gen,
size_t num_pages_per_send,
size_t current_page
) {

// for (uint32_t p = 0; p < total_pages_to_send; p += num_pages_per_send) {
uint32_t pages_to_send = std::min<uint32_t>(num_pages_per_send, total_pages_to_send - current_page);
sender.wait_for_empty_write_slot();

// bit of a hack to extract X/Y
const auto dest_noc_address = get_noc_addr(current_page, dest_addr_gen, 0, NORMALIZED_NOC_INDEX);
const auto [dest_worker_noc, dest_addr] = get_noc_address_components(dest_noc_address);
const size_t packet_size = page_size + sizeof(tt::fabric::PacketHeader);

auto packet_addr = get_read_ptr(cb_id);
auto &packet_header = *reinterpret_cast<tt::fabric::PacketHeader*>(packet_addr);
if constexpr (mcast_mode) {
packet_header.to_write()
.to_chip_multicast(tt::fabric::MulticastRoutingCommandHeader{config.mcast.distance, config.mcast.range})
.to_noc_unicast(tt::fabric::NocUnicastCommandHeader{
dest_addr,
(pages_to_send * page_size) + sizeof(tt::fabric::PacketHeader),
static_cast<uint8_t>(dest_worker_noc.x),
static_cast<uint8_t>(dest_worker_noc.y)});
} else {
packet_header.to_write()
.to_chip_unicast(tt::fabric::UnicastRoutingCommandHeader{config.unicast.distance})
.to_noc_unicast(tt::fabric::NocUnicastCommandHeader{
dest_addr,
(pages_to_send * page_size) + sizeof(tt::fabric::PacketHeader),
static_cast<uint8_t>(dest_worker_noc.x),
static_cast<uint8_t>(dest_worker_noc.y)});
}

uint64_t buffer_address = sender.edm_buffer_addr + (*sender.buffer_index_ptr * (sender.buffer_size_bytes + sizeof(eth_channel_sync_t)));
sender.send_payload_blocking_from_address(packet_addr, packet_size);
noc_async_writes_flushed();
// }
}

template <typename AddrGen>
void non_blocking_read_and_forward(size_t &current_page_in, uint32_t cb_id, const AddrGen &dest_addr_gen, tt::fabric::WorkerToFabricEdmSender &sender, transmit_config const& config, uint32_t page_size, uint32_t total_pages_to_send, uint32_t num_pages_per_send) {
uint32_t pages_to_send = std::min<uint32_t>(num_pages_per_send, total_pages_to_send - current_page_in);
if (!cb_pages_available_at_front(cb_id, pages_to_send)) {
return;
}

current_page_in += num_pages_per_send;
cb_wait_front(cb_id, pages_to_send);
forward_to_fabric_from_cb(
total_pages_to_send,
sender,
cb_id,
config,
page_size,
dest_addr_gen,
num_pages_per_send,
current_page_in
);
cb_pop_front(cb_id, pages_to_send);
}

// Worker core - Data Movement Writer -> Sends to Erisc Data Mover (sender side).
// -> takes input from local cb and pushes to erisc L1
void kernel_main() {

// Test doesn't support multiple pages per send yet since we are writing
// to interleaved which will never have subsequent pages on the same core
// (and hence, able to share a packet header)
constexpr uint32_t num_pages_per_send = 1;//get_compile_time_arg_val(0);
constexpr uint32_t total_pages_to_send = get_compile_time_arg_val(1);
constexpr uint32_t page_size = get_compile_time_arg_val(2);
constexpr uint32_t num_buffers_per_channel = get_compile_time_arg_val(3);
constexpr bool dest0_is_dram = get_compile_time_arg_val(4) != 0;
constexpr bool dest1_is_dram = get_compile_time_arg_val(5) != 0;
constexpr ReadMode read_mode = static_cast<ReadMode>(get_compile_time_arg_val(6));

transmit_config config;
size_t arg_idx = 0;
auto sender = tt::fabric::WorkerToFabricEdmSender::build_from_args<ProgrammableCoreType::TENSIX>(arg_idx);
volatile uint32_t* const last_message_semaphore_address = reinterpret_cast<volatile uint32_t* const >(get_semaphore(get_arg_val<uint32_t>(arg_idx++)));
size_t output_buffer0_addr = get_arg_val<uint32_t>(arg_idx++);
size_t output_buffer1_addr = get_arg_val<uint32_t>(arg_idx++);
config.unicast.distance = static_cast<uint8_t>(get_arg_val<uint32_t>(arg_idx++));

size_t read_ratio0 = (read_mode == ReadMode::ARBITRARILY_ORDERED) ? 0 :
(read_mode == ReadMode::FULLY_ORDERED) ? total_pages_to_send :
get_arg_val<uint32_t>(arg_idx++);
size_t read_ratio1 = (read_mode == ReadMode::ARBITRARILY_ORDERED) ? 0 :
(read_mode == ReadMode::FULLY_ORDERED) ? total_pages_to_send :
get_arg_val<uint32_t>(arg_idx++);


*last_message_semaphore_address = 0;
const InterleavedAddrGen<dest0_is_dram> dest_addr_gen0 = {
.bank_base_address = output_buffer0_addr, .page_size = page_size};
const InterleavedAddrGen<dest1_is_dram> dest_addr_gen1 = {
.bank_base_address = output_buffer1_addr, .page_size = page_size};

ASSERT(num_buffers_per_channel > 0);

sender.open();

constexpr uint32_t cb_id_in0 = tt::CB::c_in0;
constexpr uint32_t cb_id_in1 = tt::CB::c_in0;

// We need to normalize all noc addresses to be for a consistent noc ID
// so the remote sender core can correctly send the packet. In the future
// we can decide if it's better for the noc index to be embedded in the packet
// header (for now we don't do that)
constexpr size_t NORMALIZED_NOC_INDEX = 0;

cb_wait_front(cb_id_in0, 1);
auto a_packet_header_addr = get_read_ptr(cb_id_in0);

if constexpr (read_mode == ReadMode::FULLY_ORDERED || read_mode == ReadMode::RATIOD_FORWARDING) {

size_t current_page_in0 = 0;
size_t current_page_in1 = 0;
while (current_page_in0 < total_pages_to_send || current_page_in1 < total_pages_to_send) {
for (size_t read = 0; read < read_ratio0 && current_page_in0 < total_pages_to_send; current_page_in0 += num_pages_per_send, read++) {
uint32_t pages_to_send = std::min<uint32_t>(num_pages_per_send, total_pages_to_send - current_page_in0);
cb_wait_front(cb_id_in0, pages_to_send);
non_blocking_read_and_forward(current_page_in0, cb_id_in0, dest_addr_gen0, sender, config, page_size, total_pages_to_send, num_pages_per_send);
cb_pop_front(cb_id_in0, pages_to_send);
}

for (size_t read = 0; read < read_ratio1 && current_page_in1 < total_pages_to_send; current_page_in1 += num_pages_per_send, read++) {
uint32_t pages_to_send = std::min<uint32_t>(num_pages_per_send, total_pages_to_send - current_page_in1);
cb_wait_front(cb_id_in1, pages_to_send);
non_blocking_read_and_forward(current_page_in1, cb_id_in1, dest_addr_gen1, sender, config, page_size, total_pages_to_send, num_pages_per_send);
cb_pop_front(cb_id_in1, pages_to_send);
}
}

} else if constexpr (read_mode == ReadMode::ARBITRARILY_ORDERED) {
size_t current_page_in0 = 0;
size_t current_page_in1 = 0;
while (current_page_in0 < total_pages_to_send || current_page_in1 < total_pages_to_send) {
if (current_page_in0 < total_pages_to_send) {
non_blocking_read_and_forward(current_page_in0, cb_id_in0, dest_addr_gen0, sender, config, page_size, total_pages_to_send, num_pages_per_send);
}
if (current_page_in1 < total_pages_to_send) {
non_blocking_read_and_forward(current_page_in1, cb_id_in1, dest_addr_gen1, sender, config, page_size, total_pages_to_send, num_pages_per_send);
}
}
}

sender.wait_for_empty_write_slot();

constexpr size_t kLoopbackNumHopsToMyChip = 2;
auto &packet_header = *reinterpret_cast<tt::fabric::PacketHeader*>(a_packet_header_addr);
ASSERT(*last_message_semaphore_address == 0);
packet_header.reserved = 0xE;
packet_header.reserved2 = 0xFFFF;
packet_header.to_atomic_inc();
packet_header.to_chip_unicast(tt::fabric::UnicastRoutingCommandHeader{kLoopbackNumHopsToMyChip});
packet_header.to_noc_unicast_atomic_inc(tt::fabric::NocUnicastAtomicIncCommandHeader(
reinterpret_cast<size_t>(last_message_semaphore_address),
1,
32,
my_x[0],
my_y[0]
));

sender.send_payload_blocking_from_address(a_packet_header_addr, packet_header.get_payload_size_including_header());

noc_semaphore_wait(last_message_semaphore_address, 1);

bool closed_fabric_connection = terminate_fabric_endpoints_farthest_to_nearest(sender, a_packet_header_addr, arg_idx);

if (!closed_fabric_connection) {
sender.close();
}
}
46 changes: 46 additions & 0 deletions tests/ttnn/unit_tests/gtests/ccl/kernels/test_kernels.common.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/edm_fabric_worker_adapters.hpp"

bool terminate_fabric_endpoints_farthest_to_nearest (
tt::fabric::WorkerToFabricEdmSender &sender,
size_t a_packet_header_addr,
size_t arg_idx) {

bool closed = false;
size_t num_endpoints_to_terminate = get_arg_val<uint32_t>(arg_idx++);
for (size_t i = 0; i < num_endpoints_to_terminate; i++) {
size_t edm_noc_x = get_arg_val<uint32_t>(arg_idx++);
size_t edm_noc_y = get_arg_val<uint32_t>(arg_idx++);
size_t distance = get_arg_val<uint32_t>(arg_idx++);
size_t termination_addr = get_arg_val<uint32_t>(arg_idx++);

if (!closed && distance == 0) {
closed = true;
sender.close();
}
if (distance == 0) {
noc_inline_dw_write(get_noc_addr(edm_noc_x, edm_noc_y, termination_addr), tt::fabric::TerminationSignal::GRACEFULLY_TERMINATE);
} else {
auto &packet_header = *reinterpret_cast<tt::fabric::PacketHeader*>(a_packet_header_addr);
reinterpret_cast<volatile uint32_t*>(a_packet_header_addr)[sizeof(tt::fabric::PacketHeader) >> 2] = tt::fabric::TerminationSignal::GRACEFULLY_TERMINATE;
sender.wait_for_empty_write_slot();
packet_header.to_write()
.to_chip_unicast(tt::fabric::UnicastRoutingCommandHeader{static_cast<uint8_t>(distance)})
.to_noc_unicast(tt::fabric::NocUnicastCommandHeader{
termination_addr,
sizeof(tt::fabric::PacketHeader) + sizeof(uint32_t),
static_cast<uint8_t>(edm_noc_x),
static_cast<uint8_t>(edm_noc_y)
});
sender.send_payload_blocking_from_address(a_packet_header_addr, packet_header.get_payload_size_including_header());
noc_async_writes_flushed();
}
}

return closed;
}
Loading

0 comments on commit 4f5f417

Please sign in to comment.