Skip to content

Commit

Permalink
Add missing BH cache invalidation in FW and dispatch kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
abhullar-tt committed Dec 21, 2024
1 parent 09a7751 commit 380df9a
Show file tree
Hide file tree
Showing 8 changed files with 28 additions and 9 deletions.
5 changes: 4 additions & 1 deletion tt_metal/hw/firmware/src/brisc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,9 @@ inline void finish_ncrisc_copy_and_run(dispatch_core_processor_masks enables) {

inline void wait_ncrisc_trisc() {
WAYPOINT("NTW");
while (mailboxes->slave_sync.all != RUN_SYNC_MSG_ALL_SLAVES_DONE);
while (mailboxes->slave_sync.all != RUN_SYNC_MSG_ALL_SLAVES_DONE) {
invalidate_l1_cache();
}
WAYPOINT("NTD");
}

Expand Down Expand Up @@ -384,6 +386,7 @@ int main() {
WAYPOINT("GW");
uint8_t go_message_signal = RUN_MSG_DONE;
while ((go_message_signal = mailboxes->go_message.signal) != RUN_MSG_GO) {
invalidate_l1_cache();
// While the go signal for kernel execution is not sent, check if the worker was signalled
// to reset its launch message read pointer.
if (go_message_signal == RUN_MSG_RESET_READ_PTR) {
Expand Down
5 changes: 3 additions & 2 deletions tt_metal/hw/firmware/src/idle_erisc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ inline void run_slave_eriscs(dispatch_core_processor_masks enables) {
inline void wait_slave_eriscs(uint32_t &heartbeat) {
WAYPOINT("SEW");
while (mailboxes->slave_sync.all != RUN_SYNC_MSG_ALL_SLAVES_DONE) {
invalidate_l1_cache();
RISC_POST_HEARTBEAT(heartbeat);
}
WAYPOINT("SED");
Expand Down Expand Up @@ -128,8 +129,8 @@ int main() {
init_sync_registers();
// Wait...
WAYPOINT("GW");
while (mailboxes->go_message.signal != RUN_MSG_GO)
{
while (mailboxes->go_message.signal != RUN_MSG_GO) {
invalidate_l1_cache();
RISC_POST_HEARTBEAT(heartbeat);
};
WAYPOINT("GD");
Expand Down
4 changes: 3 additions & 1 deletion tt_metal/hw/firmware/src/ncrisc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ inline __attribute__((always_inline)) void notify_brisc_and_wait() {
#ifdef NCRISC_HAS_IRAM
notify_brisc_and_halt(RUN_SYNC_MSG_DONE);
#else
while (*ncrisc_run != RUN_SYNC_MSG_GO);
while (*ncrisc_run != RUN_SYNC_MSG_GO) {
invalidate_l1_cache();
}
#endif
}

Expand Down
4 changes: 3 additions & 1 deletion tt_metal/hw/firmware/src/slave_idle_erisc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ int main(int argc, char *argv[]) {
// Cleanup profiler buffer incase we never get the go message
while (1) {
WAYPOINT("W");
while (*slave_idle_erisc_run != RUN_SYNC_MSG_GO);
while (*slave_idle_erisc_run != RUN_SYNC_MSG_GO) {
invalidate_l1_cache();
}
DeviceZoneScopedMainN("SLAVE-IDLE-ERISC-FW");

flush_erisc_icache();
Expand Down
4 changes: 3 additions & 1 deletion tt_metal/hw/firmware/src/trisc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ int main(int argc, char *argv[]) {
// Cleanup profiler buffer incase we never get the go message
while (1) {
WAYPOINT("W");
while (*trisc_run != RUN_SYNC_MSG_GO);
while (*trisc_run != RUN_SYNC_MSG_GO) {
invalidate_l1_cache();
}
DeviceZoneScopedMainN("TRISC-FW");

uint32_t launch_msg_rd_ptr = mailboxes->launch_msg_rd_ptr;
Expand Down
4 changes: 3 additions & 1 deletion tt_metal/impl/dispatch/kernels/cq_dispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -915,7 +915,9 @@ void process_go_signal_mcast_cmd() {
volatile uint32_t tt_l1_ptr* aligned_go_signal_storage = (volatile uint32_t tt_l1_ptr*)cmd_ptr;
*aligned_go_signal_storage = cmd->mcast.go_signal;

while (*worker_sem_addr < cmd->mcast.wait_count);
while (*worker_sem_addr < cmd->mcast.wait_count) {
invalidate_l1_cache();
}
uint8_t go_signal_noc_data_idx = cmd->mcast.noc_data_start_index;
// send go signal update here
for (uint32_t i = 0, num_mcasts = cmd->mcast.num_mcast_txns; i < num_mcasts; ++i) {
Expand Down
10 changes: 8 additions & 2 deletions tt_metal/impl/dispatch/kernels/cq_dispatch_slave.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,9 @@ void wait_for_workers(volatile CQDispatchCmd tt_l1_ptr* cmd) {
uint8_t dispatch_message_offset = *((uint8_t*)&cmd->mcast.go_signal + offsetof(go_msg_t, dispatch_message_offset));
volatile tt_l1_ptr uint32_t* worker_sem =
reinterpret_cast<volatile tt_l1_ptr uint32_t*>(worker_sem_base_addr + dispatch_message_offset);
while (wrap_gt(cmd->mcast.wait_count, *worker_sem));
while (wrap_gt(cmd->mcast.wait_count, *worker_sem)) {
invalidate_l1_cache();
}
}

template <bool flush_write = false>
Expand Down Expand Up @@ -160,6 +162,7 @@ FORCE_INLINE void cb_acquire_pages_dispatch_s(uint32_t n) {
// Stall until the number of pages already acquired + the number that need to be acquired is greater
// than the number available
while (wrap_gt(num_pages_acquired + n, *sem_addr)) {
invalidate_l1_cache();
update_worker_completion_count_on_dispatch_d();
IDLE_ERISC_HEARTBEAT_AND_RETURN(heartbeat);
}
Expand All @@ -182,6 +185,7 @@ void process_go_signal_mcast_cmd() {
// Wait for notification from dispatch_d, signalling that it's safe to send the go signal
uint32_t& mcasts_sent = num_mcasts_sent[(cmd->mcast.wait_addr - worker_sem_base_addr) / L1_ALIGNMENT];
while (wrap_ge(mcasts_sent, *sync_sem_addr)) {
invalidate_l1_cache();
// Update dispatch_d with the latest num_workers
update_worker_completion_count_on_dispatch_d();
}
Expand Down Expand Up @@ -226,7 +230,9 @@ void process_dispatch_s_wait_cmd() {
uint32_t index = (worker_sem_addr - worker_sem_base_addr) / L1_ALIGNMENT;
volatile tt_l1_ptr uint32_t* worker_sem = reinterpret_cast<volatile tt_l1_ptr uint32_t*>(worker_sem_addr);
// Wait for workers to complete
while (wrap_gt(cmd->wait.count, *worker_sem));
while (wrap_gt(cmd->wait.count, *worker_sem)) {
invalidate_l1_cache();
}
// Send updated worker count to dispatch_d and wait for updated count to get picked up by NOC before clearing the
// counter. dispatch_d will clear it's own counter
update_worker_completion_count_on_dispatch_d<true>();
Expand Down
1 change: 1 addition & 0 deletions tt_metal/impl/dispatch/kernels/cq_prefetch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -928,6 +928,7 @@ void paged_read_into_cmddat_q(uint32_t read_ptr, PrefetchExecBufState& exec_buf_
InterleavedAddrGen<true> addr_gen{.bank_base_address = base_addr, .page_size = page_size};

while (pages_at_once != 0) {
invalidate_l1_cache();
uint64_t noc_addr = addr_gen.get_noc_addr(page_id);
noc_async_read(noc_addr, read_ptr, page_size);
read_ptr += page_size;
Expand Down

0 comments on commit 380df9a

Please sign in to comment.