Skip to content

Commit

Permalink
#0: Convert write_linear to use new dispatch NOC apis, which allows u…
Browse files Browse the repository at this point in the history
…s to switch to using NOC_DISPATCH_MULTICAST_WRITE_VC for all dispatch mcasts

Remove some minor riscv overhead since we don't need to create full noc addrs if we're not reprogramming the noc coords
  • Loading branch information
tt-aho committed Jul 31, 2024
1 parent 91ef8ee commit 96162aa
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 20 deletions.
2 changes: 1 addition & 1 deletion tt_metal/impl/dispatch/kernels/cq_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ void cq_noc_async_write_init_state(uint32_t src_addr, uint64_t dst_addr, uint32_
constexpr bool multicast_path_reserve = mcast;
constexpr bool posted = false;
constexpr bool linked = false;
constexpr uint32_t vc = mcast ? NOC_MULTICAST_WRITE_VC : NOC_UNICAST_WRITE_VC;
constexpr uint32_t vc = mcast ? NOC_DISPATCH_MULTICAST_WRITE_VC : NOC_UNICAST_WRITE_VC;

constexpr uint32_t noc_cmd_field =
NOC_CMD_CPY | NOC_CMD_WR |
Expand Down
48 changes: 29 additions & 19 deletions tt_metal/impl/dispatch/kernels/cq_dispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -381,27 +381,33 @@ void process_write_linear(uint32_t num_mcast_dests) {
uint32_t dst_addr = cmd->write_linear.addr + write_offset[write_offset_index];
uint32_t length = cmd->write_linear.length;
uint32_t data_ptr = cmd_ptr + sizeof(CQDispatchCmd);
cq_noc_async_write_init_state<CQ_NOC_sNdl, multicast>(0, get_noc_addr_helper(dst_noc, dst_addr));

uint32_t writes = 0;
uint32_t mcasts = noc_nonposted_writes_acked[noc_index];

while (length != 0) {
uint32_t xfer_size = (length > dispatch_cb_page_size) ? dispatch_cb_page_size : length;
uint64_t dst = get_noc_addr_helper(dst_noc, dst_addr);
// "Reserve" pages for the next write from this block
block_noc_writes_to_clear[rd_block_idx]++;
writes++;
mcasts += num_mcast_dests;
// Get a page if needed
if (data_ptr + xfer_size > cb_fence) {
// Check for block completion
if (cb_fence == block_next_start_addr[rd_block_idx]) {
uint32_t orphan_size = cb_fence - data_ptr;
// No more writes from this block. Decrement the number of writes
// since they were all accounted for.
block_noc_writes_to_clear[rd_block_idx] -= (orphan_size == 0);
writes -= (orphan_size == 0);
mcasts -= (orphan_size == 0) * num_mcast_dests;
// Check for dispatch_cb wrap
if (rd_block_idx == dispatch_cb_blocks - 1) {
if (orphan_size != 0) {
if constexpr (multicast) {
noc_async_write_multicast<dispatch_cb_page_size>(
data_ptr, dst, orphan_size, num_mcast_dests);
cq_noc_async_write_with_state<CQ_NOC_SnDL>(
data_ptr, dst_addr, orphan_size, num_mcast_dests);
} else {
noc_async_write<dispatch_cb_page_size>(data_ptr, dst, orphan_size);
cq_noc_async_write_with_state<CQ_NOC_SnDL>(data_ptr, dst_addr, orphan_size);
}
length -= orphan_size;
xfer_size -= orphan_size;
Expand All @@ -411,11 +417,14 @@ void process_write_linear(uint32_t num_mcast_dests) {
}
cb_fence = dispatch_cb_base;
data_ptr = dispatch_cb_base;
dst = get_noc_addr_helper(dst_noc, dst_addr);
}
block_noc_writes_to_clear[rd_block_idx] += writes;
noc_nonposted_writes_num_issued[noc_index] += writes;
writes = 0;
move_rd_to_next_block<dispatch_cb_blocks>(block_noc_writes_to_clear, rd_block_idx);
// Next write will be from next block. "Reserve" pages for it.
block_noc_writes_to_clear[rd_block_idx] += (orphan_size == 0);
writes += (orphan_size == 0);
mcasts += (orphan_size == 0) * num_mcast_dests;
}
// Wait for dispatcher to supply a page (this won't go beyond the buffer end)
uint32_t n_pages = cb_acquire_pages<my_noc_xy, my_dispatch_cb_sem_id, dispatch_cb_log_page_size>(
Expand All @@ -432,22 +441,26 @@ void process_write_linear(uint32_t num_mcast_dests) {
}

if constexpr (multicast) {
noc_async_write_multicast<dispatch_cb_page_size>(data_ptr, dst, xfer_size, num_mcast_dests);
cq_noc_async_write_with_state<CQ_NOC_SnDL>(data_ptr, dst_addr, xfer_size, num_mcast_dests);
} else {
noc_async_write<dispatch_cb_page_size>(data_ptr, dst, xfer_size);
cq_noc_async_write_with_state<CQ_NOC_SnDL>(data_ptr, dst_addr, xfer_size);
}
length -= xfer_size;
data_ptr += xfer_size;
dst_addr += xfer_size;
}
block_noc_writes_to_clear[rd_block_idx] += writes;
noc_nonposted_writes_num_issued[noc_index] += writes;
noc_nonposted_writes_acked[noc_index] = mcasts;

cmd_ptr = data_ptr;
}

void process_write() {
volatile tt_l1_ptr CQDispatchCmd *cmd = (volatile tt_l1_ptr CQDispatchCmd *)cmd_ptr;
uint32_t num_mcast_dests = cmd->write_linear.num_mcast_dests;
if (num_mcast_dests == 0) {
process_write_linear<false>(0);
process_write_linear<false>(1);
} else {
process_write_linear<true>(num_mcast_dests);
}
Expand Down Expand Up @@ -624,9 +637,9 @@ void process_write_packed(uint32_t flags) {
// This is done here so the common case doesn't have to restore the pointers
if (orphan_size != 0) {
uint32_t remainder_xfer_size = xfer_size - orphan_size;
// Creating full NOC addr not needed as we are not programming the noc coords
uint32_t remainder_dst_addr = dst_addr + orphan_size;
uint64_t remainder_dst = get_noc_addr_helper(dst_noc, remainder_dst_addr);
cq_noc_async_write_with_state<CQ_NOC_SnDL>(data_ptr, remainder_dst, remainder_xfer_size, num_dests);
cq_noc_async_write_with_state<CQ_NOC_SnDL>(data_ptr, remainder_dst_addr, remainder_xfer_size, num_dests);
// Reset values expected below
cq_noc_async_write_with_state<CQ_NOC_snDL, CQ_NOC_WAIT, CQ_NOC_send>(0, dst, xfer_size);
writes++;
Expand Down Expand Up @@ -704,9 +717,8 @@ void process_write_packed_large() {

sub_cmd_ptr++;

uint64_t dst = get_noc_addr_helper(dst_noc, dst_addr);
// Note: expect to only have 1 or a few pages, so this doesn't optimize writing length
cq_noc_async_write_with_state<CQ_NOC_sNdl, CQ_NOC_WAIT, CQ_NOC_send>(0, dst);
cq_noc_async_write_with_state<CQ_NOC_sNdl, CQ_NOC_WAIT, CQ_NOC_send>(0, get_noc_addr_helper(dst_noc, dst_addr));

while (length != 0) {
uint32_t xfer_size = (length > dispatch_cb_page_size) ? dispatch_cb_page_size : length;
Expand All @@ -726,7 +738,7 @@ void process_write_packed_large() {
if (rd_block_idx == dispatch_cb_blocks - 1) {
ASSERT(cb_fence == dispatch_cb_end);
if (orphan_size != 0) {
cq_noc_async_write_with_state<CQ_NOC_SnDL>(data_ptr, dst, orphan_size, num_dests);
cq_noc_async_write_with_state<CQ_NOC_SnDL>(data_ptr, dst_addr, orphan_size, num_dests);
length -= orphan_size;
xfer_size -= orphan_size;
dst_addr += orphan_size;
Expand All @@ -735,7 +747,6 @@ void process_write_packed_large() {
}
cb_fence = dispatch_cb_base;
data_ptr = dispatch_cb_base;
dst = get_noc_addr_helper(dst_noc, dst_addr);
}

block_noc_writes_to_clear[rd_block_idx] += writes;
Expand All @@ -753,12 +764,11 @@ void process_write_packed_large() {
cb_fence += n_pages * dispatch_cb_page_size;
}

cq_noc_async_write_with_state<CQ_NOC_SnDL>(data_ptr, dst, xfer_size, num_dests);
cq_noc_async_write_with_state<CQ_NOC_SnDL>(data_ptr, dst_addr, xfer_size, num_dests);

length -= xfer_size;
data_ptr += xfer_size;
dst_addr += xfer_size;
dst = get_noc_addr_helper(dst_noc, dst_addr);
}

// Release pages for prefetcher
Expand Down

0 comments on commit 96162aa

Please sign in to comment.