diff --git a/tt_eager/tt_dnn/op_library/layernorm/kernels/dataflow/reader_mcast_receiver_unary_sharded_ln.cpp b/tt_eager/tt_dnn/op_library/layernorm/kernels/dataflow/reader_mcast_receiver_unary_sharded_ln.cpp index a649640a5b76..657a6e4c14fa 100644 --- a/tt_eager/tt_dnn/op_library/layernorm/kernels/dataflow/reader_mcast_receiver_unary_sharded_ln.cpp +++ b/tt_eager/tt_dnn/op_library/layernorm/kernels/dataflow/reader_mcast_receiver_unary_sharded_ln.cpp @@ -134,16 +134,16 @@ void kernel_main() { uint32_t l1_read_addr_ex_par = get_read_ptr(cb_partial); l1_read_addr_ex_par += all_to_all_tile_offset_bytes; for (uint32_t i = 0; i < num_tiles_to_read; i++) { + cb_reserve_back(cb_external, num_blocks_first_stage); + uint32_t l1_write_addr_external = get_write_ptr(cb_external); for(uint32_t block = 0; block < num_blocks_first_stage; block++) { - cb_reserve_back(cb_external, 1); - uint32_t l1_write_addr_external = get_write_ptr(cb_external); uint64_t noc_addr_ex_par = remote_noc_addrs_first_stage[block] | l1_read_addr_ex_par; noc_async_read_one_packet(noc_addr_ex_par, l1_write_addr_external, single_tile_size_bytes); - - noc_async_read_barrier(); - cb_push_back(cb_external, 1); + l1_write_addr_external+=single_tile_size_bytes; } l1_read_addr_ex_par += single_tile_size_bytes; + noc_async_read_barrier(); + cb_push_back(cb_external, num_blocks_first_stage); } // read data from other cores - reduce first stage @@ -163,16 +163,16 @@ void kernel_main() { // read data from other cores - second stage reduce uint32_t l1_read_addr_ex = get_read_ptr(cb_reduce_first_stage); for (uint32_t i = 0; i < num_tiles_per_worker; ++i) { + cb_reserve_back(cb_external, num_blocks_second_stage - 1); + uint32_t l1_write_addr_external = get_write_ptr(cb_external); for(uint32_t block = 0; block < num_blocks_second_stage - 1; ++block) { - cb_reserve_back(cb_external, 1); - uint32_t l1_write_addr_external = get_write_ptr(cb_external); uint64_t noc_addr_ex = remote_noc_addrs_second_stage[block + 1] | l1_read_addr_ex; noc_async_read_one_packet(noc_addr_ex, l1_write_addr_external, single_tile_size_bytes); - noc_async_read_barrier(); - - cb_push_back(cb_external, 1); + l1_write_addr_external += single_tile_size_bytes; } l1_read_addr_ex += single_tile_size_bytes; + noc_async_read_barrier(); + cb_push_back(cb_external, num_blocks_second_stage - 1); } // sync with the mcast sender diff --git a/tt_eager/tt_dnn/op_library/layernorm/kernels/dataflow/reader_mcast_sender_unary_sharded_ln.cpp b/tt_eager/tt_dnn/op_library/layernorm/kernels/dataflow/reader_mcast_sender_unary_sharded_ln.cpp index 3e0ad4950d10..6e821734f345 100644 --- a/tt_eager/tt_dnn/op_library/layernorm/kernels/dataflow/reader_mcast_sender_unary_sharded_ln.cpp +++ b/tt_eager/tt_dnn/op_library/layernorm/kernels/dataflow/reader_mcast_sender_unary_sharded_ln.cpp @@ -105,15 +105,16 @@ void kernel_main() { // read data from other cores - first stage reduce uint32_t l1_read_addr_ex_par = get_read_ptr(cb_partial); for (uint32_t i = 0; i < num_tiles_per_worker; ++i) { + cb_reserve_back(cb_external, num_blocks_first_stage); + uint32_t l1_write_addr_external = get_write_ptr(cb_external); for(uint32_t block = 0; block < num_blocks_first_stage; ++block) { - cb_reserve_back(cb_external, 1); - uint32_t l1_write_addr_external = get_write_ptr(cb_external); uint64_t noc_addr_ex_par = remote_noc_addrs[block] | l1_read_addr_ex_par; noc_async_read_one_packet(noc_addr_ex_par, l1_write_addr_external, single_tile_size_bytes); - noc_async_read_barrier(); - cb_push_back(cb_external, 1); + l1_write_addr_external += single_tile_size_bytes; } l1_read_addr_ex_par += single_tile_size_bytes; + noc_async_read_barrier(); + cb_push_back(cb_external, num_blocks_first_stage); } // sync with second-stage all-to-all workers @@ -133,16 +134,17 @@ void kernel_main() { uint32_t l1_read_addr_ex = get_read_ptr(cb_reduce_first_stage); for (uint32_t i = 0; i < num_tiles_per_worker; ++i) { uint32_t curr_block_index = block_index_stride; + cb_reserve_back(cb_external, num_blocks_second_stage - 1); + uint32_t l1_write_addr_external = get_write_ptr(cb_external); for(uint32_t block = 0; block < num_blocks_second_stage - 1; ++block) { - cb_reserve_back(cb_external, 1); - uint32_t l1_write_addr_external = get_write_ptr(cb_external); uint64_t noc_addr_ex = remote_noc_addrs[curr_block_index] | l1_read_addr_ex; noc_async_read_one_packet(noc_addr_ex, l1_write_addr_external, single_tile_size_bytes); curr_block_index += block_index_stride; - noc_async_read_barrier(); - cb_push_back(cb_external, 1); + l1_write_addr_external += single_tile_size_bytes; } l1_read_addr_ex += single_tile_size_bytes; + noc_async_read_barrier(); + cb_push_back(cb_external, num_blocks_second_stage - 1); } } @@ -171,6 +173,7 @@ void kernel_main() { // mcast uint32_t l1_read_addr_ex_global = get_read_ptr(cb_ex_global); + cb_push_back(cb_ex_global, block_h); if constexpr(num_blocks > 1) { for (uint32_t block = 0; block < num_all_to_all_workers_first_stage; ++block) { *reduce_sender_semaphore_addr_ptr = block + 2; @@ -184,7 +187,6 @@ void kernel_main() { noc_async_write_barrier(); } } - cb_push_back(cb_ex_global, block_h); }; #ifndef RMSNORM global_reduce_sender(cb_ex_partial, cb_ex_external, cb_ex, cb_ex_global, cb_ex);