diff --git a/tt_eager/tt_dnn/op_library/layernorm/kernels/compute/layernorm_sharded.cpp b/tt_eager/tt_dnn/op_library/layernorm/kernels/compute/layernorm_sharded.cpp index 6016d237c188..2aa6505680ae 100644 --- a/tt_eager/tt_dnn/op_library/layernorm/kernels/compute/layernorm_sharded.cpp +++ b/tt_eager/tt_dnn/op_library/layernorm/kernels/compute/layernorm_sharded.cpp @@ -24,7 +24,7 @@ void MAIN { constexpr uint32_t is_top_row = get_compile_time_arg_val(0); constexpr uint32_t do_gamma = get_compile_time_arg_val(1); constexpr uint32_t do_beta = get_compile_time_arg_val(2); - constexpr uint32_t num_blocks = get_compile_time_arg_val(3); + constexpr uint32_t num_blocks_first_stage = get_compile_time_arg_val(3); constexpr uint32_t block_w = get_compile_time_arg_val(5); constexpr uint32_t block_h_const = get_compile_time_arg_val(4); volatile uint32_t block_h_volatile = get_compile_time_arg_val(4); @@ -33,9 +33,26 @@ void MAIN { constexpr uint32_t num_subblocks_w = get_compile_time_arg_val(7); const bool is_allgather_worker = get_compile_time_arg_val(8) == 1; constexpr uint32_t num_tiles_per_block = get_compile_time_arg_val(9); - constexpr bool FLOAT32_DTYPE = get_compile_time_arg_val(10) == 1; + constexpr bool FLOAT32_DTYPE = get_compile_time_arg_val(10) == 1; + constexpr uint32_t num_blocks_second_stage = get_compile_time_arg_val(11); - const uint32_t num_tiles_per_allgather_worker = is_allgather_worker ? get_arg_val(0) : 0; + const uint32_t num_tiles_per_allgather_worker = is_allgather_worker ? get_arg_val(0) : 0; + const bool use_two_stage_reduce = is_allgather_worker ? get_arg_val(1) == 1 : false; + const bool is_second_stage_reader = is_allgather_worker ? get_arg_val(2) == 1 : false; + + uint32_t num_blocks_reduce; + if (is_second_stage_reader) { + num_blocks_reduce = num_blocks_first_stage + num_blocks_second_stage - 1; + } else { + num_blocks_reduce = num_blocks_first_stage; + } + + bool enable_sqrt; + if (use_two_stage_reduce and not is_second_stage_reader) { + enable_sqrt = false; + } else { + enable_sqrt = true; + } constexpr uint32_t dst0 = 0; constexpr uint32_t scaler0 = 0; @@ -153,7 +170,7 @@ void MAIN { for (uint32_t i = 0; i < num_tiles_per_allgather_worker; i++) { cb_wait_front(cb_scaler_global, 1); tile_regs_acquire(); - for (uint32_t w = 0; w < num_blocks; w++) { + for (uint32_t w = 0; w < num_blocks_reduce; w++) { cb_wait_front(cb_ex_external, 1); reduce_tile(cb_ex_external, cb_scaler_global, 0, scaler0, dst0); cb_pop_front(cb_ex_external, 1); @@ -238,6 +255,9 @@ void MAIN { cb_wait_front(cb_xmm2, num_tiles_per_block); // Var(x) + #ifdef RMSNORM + cb_wait_front(cb_scaler, 1); + #endif cb_reserve_back(cb_ex_partial2, block_h); reduce_init_delta(REDUCE_OP, REDUCE_DIM); index_h_offset = 0; @@ -265,7 +285,7 @@ void MAIN { cb_wait_front(cb_scaler_global, 1); tile_regs_acquire(); - for (uint32_t w = 0; w < num_blocks; w++) { + for (uint32_t w = 0; w < num_blocks_reduce; w++) { cb_wait_front(cb_ex_external2, 1); reduce_tile(cb_ex_external2, cb_scaler_global, 0, scaler0, dst0); cb_pop_front(cb_ex_external2, 1); @@ -278,28 +298,29 @@ void MAIN { reduce_revert_delta(); cb_push_back(cb_ex2, num_tiles_per_allgather_worker); - for (uint32_t i = 0; i < num_tiles_per_allgather_worker; i++) { - // 1/[sqrt(Var + eps)], - cb_wait_front(cb_ex2, 1); - cb_reserve_back(cb_ex2pe, 1); - tile_regs_acquire(); - add_tiles_init(); - add_tiles(cb_ex2, cb_eps, i, 0, dst0); - tile_regs_wait(); - // sqrt(Var + eps) - sqrt_tile_init(); - sqrt_tile(dst0); - tile_regs_wait(); - // 1/[sqrt(Var + eps)] - recip_tile_init(); - recip_tile(dst0); - tile_regs_commit(); - tile_regs_wait(); - pack_tile(dst0, cb_ex2pe); - cb_push_back(cb_ex2pe, 1); - tile_regs_release(); + if (enable_sqrt) { + for (uint32_t i = 0; i < num_tiles_per_allgather_worker; i++) { + // 1/[sqrt(Var + eps)], + cb_wait_front(cb_ex2, 1); + cb_reserve_back(cb_ex2pe, 1); + tile_regs_acquire(); + add_tiles_init(); + add_tiles(cb_ex2, cb_eps, i, 0, dst0); + tile_regs_wait(); + // sqrt(Var + eps) + sqrt_tile_init(); + sqrt_tile(dst0); + tile_regs_wait(); + // 1/[sqrt(Var + eps)] + recip_tile_init(); + recip_tile(dst0); + tile_regs_commit(); + tile_regs_wait(); + pack_tile(dst0, cb_ex2pe); + cb_push_back(cb_ex2pe, 1); + tile_regs_release(); + } } - } diff --git a/tt_eager/tt_dnn/op_library/layernorm/kernels/dataflow/reader_mcast_receiver_unary_sharded_ln.cpp b/tt_eager/tt_dnn/op_library/layernorm/kernels/dataflow/reader_mcast_receiver_unary_sharded_ln.cpp index 6a3205345849..a649640a5b76 100644 --- a/tt_eager/tt_dnn/op_library/layernorm/kernels/dataflow/reader_mcast_receiver_unary_sharded_ln.cpp +++ b/tt_eager/tt_dnn/op_library/layernorm/kernels/dataflow/reader_mcast_receiver_unary_sharded_ln.cpp @@ -6,6 +6,7 @@ #include "dataflow_api.h" #include "hostdevcommon/common_values.hpp" + // split REDUCE across cores void kernel_main() { @@ -20,13 +21,18 @@ void kernel_main() { constexpr bool row_major = (bool) get_compile_time_arg_val(8); constexpr uint32_t num_x = get_compile_time_arg_val(9); constexpr uint32_t num_y = get_compile_time_arg_val(10); + constexpr bool use_two_stage_reduce = (bool) get_compile_time_arg_val(11); + constexpr uint32_t num_blocks_first_stage = get_compile_time_arg_val(12); + constexpr uint32_t num_blocks_second_stage = get_compile_time_arg_val(13); + constexpr uint32_t reduce_second_stage_semaphore_addr = get_compile_time_arg_val(14); const bool is_last_all_to_all_worker = get_arg_val(0); const uint32_t all_to_all_tile_offset_bytes = get_arg_val(1); - const uint32_t start_x = get_arg_val(2); - const uint32_t start_y = get_arg_val(3); - tt_l1_ptr uint32_t * in0_remote_noc_x = (tt_l1_ptr uint32_t*)(get_arg_addr(4)); - tt_l1_ptr uint32_t * in0_remote_noc_y = (tt_l1_ptr uint32_t*)(get_arg_addr(4 + num_x)); + const bool is_second_stage_reader = get_arg_val(2); + const uint32_t start_x = get_arg_val(3); + const uint32_t start_y = get_arg_val(4); + volatile tt_l1_ptr uint32_t * in0_remote_noc_x = (volatile tt_l1_ptr uint32_t*)(get_arg_addr(5)); + volatile tt_l1_ptr uint32_t * in0_remote_noc_y = (volatile tt_l1_ptr uint32_t*)(get_arg_addr(5 + num_x)); const uint32_t num_tiles_to_read = is_last_all_to_all_worker ? num_tiles_per_worker_last : num_tiles_per_worker; @@ -42,62 +48,97 @@ void kernel_main() { const uint32_t single_tile_size_bytes = get_tile_size(cb_ex_partial2); // tile size const DataFormat data_format = get_dataformat(cb_ex_partial2); // data format - uint64_t remote_noc_addrs[is_all_to_all_worker ? num_blocks : 1]; + uint64_t remote_noc_addrs_first_stage[is_all_to_all_worker ? num_blocks_first_stage : 1]; + uint64_t remote_noc_addrs_second_stage[is_all_to_all_worker ? num_blocks_second_stage : 1]; if constexpr (is_all_to_all_worker) { - uint32_t x = start_x, y = start_y; - for (uint32_t i = 0; i < num_blocks; ++i) { - remote_noc_addrs[i] = get_noc_addr(in0_remote_noc_x[x], in0_remote_noc_y[y], 0); - if constexpr(row_major) { - ++x; - if (x == num_x) { - x = 0; + if constexpr (use_two_stage_reduce) { + uint32_t x = start_x, y = start_y; + for (uint32_t i = 0; i < num_blocks_first_stage; ++i) { + remote_noc_addrs_first_stage[i] = get_noc_addr(in0_remote_noc_x[x], in0_remote_noc_y[y], 0); + if constexpr(row_major) { + ++x; + if (x == num_x) { + x = 0; + } + } else { ++y; if (y == num_y) { y = 0; } } + } + if constexpr(row_major) { + x = start_x; + y = 0; } else { - ++y; - if (y == num_y) { - y = 0; + x = 0; + y = start_y; + } + for (uint32_t i = 0; i < num_blocks_second_stage; ++i) { + remote_noc_addrs_second_stage[i] = get_noc_addr(in0_remote_noc_x[x], in0_remote_noc_y[y], 0); + if constexpr(row_major) { + ++y; + } else { + ++x; + } + } + } else { + uint32_t x = start_x, y = start_y; + for (uint32_t i = 0; i < num_blocks; ++i) { + remote_noc_addrs_first_stage[i] = get_noc_addr(in0_remote_noc_x[x], in0_remote_noc_y[y], 0); + if constexpr(row_major) { ++x; if (x == num_x) { x = 0; + ++y; + if (y == num_y) { + y = 0; + } + } + } else { + ++y; + if (y == num_y) { + y = 0; + ++x; + if (x == num_x) { + x = 0; + } } } } } } else { - remote_noc_addrs[0] = get_noc_addr(in0_remote_noc_x[0], in0_remote_noc_y[0], 0); + remote_noc_addrs_first_stage[0] = get_noc_addr(in0_remote_noc_x[0], in0_remote_noc_y[0], 0); } volatile tt_l1_ptr uint32_t* reduce_receiver_semaphore_addr_ptr = reinterpret_cast(reduce_receiver_semaphore_addr); volatile tt_l1_ptr uint32_t* reduce_sender_semaphore_addr_ptr = reinterpret_cast(reduce_sender_semaphore_addr); + volatile tt_l1_ptr uint32_t* reduce_second_stage_semaphore_addr_ptr = reinterpret_cast(reduce_second_stage_semaphore_addr); const uint64_t reduce_receiver_semaphore_noc_addr = get_noc_addr(in0_remote_noc_x[0], in0_remote_noc_y[0], reduce_receiver_semaphore_addr); + const uint64_t reduce_second_stage_receiver_semaphore_noc_addr = remote_noc_addrs_second_stage[0] | reduce_second_stage_semaphore_addr; - const auto& global_reduce_receiver = [&](const uint32_t cb_partial, const uint32_t cb_external, const uint32_t cb_ex, const uint32_t cb_ex_global) __attribute__((always_inline)) + const auto& global_reduce_receiver = [&](const uint32_t cb_partial, const uint32_t cb_external, const uint32_t cb_ex, const uint32_t cb_ex_global, const uint32_t cb_reduce_first_stage) __attribute__((always_inline)) { // global reduce // wait for local data ready cb_wait_front(cb_partial, block_h); - // inc top core + // inc mcast sender noc_semaphore_set(reduce_sender_semaphore_addr_ptr, INVALID); noc_semaphore_inc(reduce_receiver_semaphore_noc_addr, 1); noc_semaphore_wait(reduce_sender_semaphore_addr_ptr, VALID); if constexpr (is_all_to_all_worker) { - // read data from other cores + // read data from other cores - reduce first stage uint32_t l1_read_addr_ex_par = get_read_ptr(cb_partial); l1_read_addr_ex_par += all_to_all_tile_offset_bytes; for (uint32_t i = 0; i < num_tiles_to_read; i++) { - uint32_t l1_write_addr_external = get_write_ptr(cb_external); - for(uint32_t block = 0; block < num_blocks; block++) { + for(uint32_t block = 0; block < num_blocks_first_stage; block++) { cb_reserve_back(cb_external, 1); - uint64_t noc_addr_ex_par = remote_noc_addrs[block] | l1_read_addr_ex_par; + uint32_t l1_write_addr_external = get_write_ptr(cb_external); + uint64_t noc_addr_ex_par = remote_noc_addrs_first_stage[block] | l1_read_addr_ex_par; noc_async_read_one_packet(noc_addr_ex_par, l1_write_addr_external, single_tile_size_bytes); - l1_write_addr_external += single_tile_size_bytes; noc_async_read_barrier(); cb_push_back(cb_external, 1); @@ -105,26 +146,60 @@ void kernel_main() { l1_read_addr_ex_par += single_tile_size_bytes; } - // send result to other cores - cb_wait_front(cb_ex, num_tiles_to_read); - } + // read data from other cores - reduce first stage + if constexpr(use_two_stage_reduce) { + if (is_second_stage_reader) { // gather data from a column of cores (if row major) - // sync with other workers - noc_semaphore_set(reduce_sender_semaphore_addr_ptr, INVALID); - noc_semaphore_inc(reduce_receiver_semaphore_noc_addr, 1); - noc_semaphore_wait(reduce_sender_semaphore_addr_ptr, VALID); - noc_semaphore_set(reduce_sender_semaphore_addr_ptr, INVALID); + noc_semaphore_wait(reduce_second_stage_semaphore_addr_ptr, num_blocks_second_stage-1); + noc_semaphore_set(reduce_second_stage_semaphore_addr_ptr, 0); + + uint32_t block_index_stride; + if constexpr(row_major) { + block_index_stride = num_x; + } else { + block_index_stride = num_y; + } + + // read data from other cores - second stage reduce + uint32_t l1_read_addr_ex = get_read_ptr(cb_reduce_first_stage); + for (uint32_t i = 0; i < num_tiles_per_worker; ++i) { + for(uint32_t block = 0; block < num_blocks_second_stage - 1; ++block) { + cb_reserve_back(cb_external, 1); + uint32_t l1_write_addr_external = get_write_ptr(cb_external); + uint64_t noc_addr_ex = remote_noc_addrs_second_stage[block + 1] | l1_read_addr_ex; + noc_async_read_one_packet(noc_addr_ex, l1_write_addr_external, single_tile_size_bytes); + noc_async_read_barrier(); + + cb_push_back(cb_external, 1); + } + l1_read_addr_ex += single_tile_size_bytes; + } + + // sync with the mcast sender + cb_wait_front(cb_ex, num_tiles_to_read); + noc_semaphore_inc(reduce_receiver_semaphore_noc_addr, 1); + } else { + // sync with the gather worker + cb_wait_front(cb_reduce_first_stage, num_tiles_to_read); + noc_semaphore_inc(reduce_second_stage_receiver_semaphore_noc_addr, 1); + } + } else { + // send result to other cores + cb_wait_front(cb_ex, num_tiles_to_read); + noc_semaphore_inc(reduce_receiver_semaphore_noc_addr, 1); + } + } for (uint32_t block = 0; block < num_all_to_all_workers; ++block) { uint32_t num_tiles = block == num_all_to_all_workers - 1 ? num_tiles_per_worker_last : num_tiles_per_worker; cb_reserve_back(cb_ex_global, num_tiles); - noc_semaphore_wait(reduce_sender_semaphore_addr_ptr, block+1); + noc_semaphore_wait(reduce_sender_semaphore_addr_ptr, block+2); cb_push_back(cb_ex_global, num_tiles); } }; #ifndef RMSNORM - global_reduce_receiver(cb_ex_partial, cb_ex_external, cb_ex, cb_ex_global); + global_reduce_receiver(cb_ex_partial, cb_ex_external, cb_ex, cb_ex_global, cb_ex); #endif - global_reduce_receiver(cb_ex_partial2, cb_ex_external2, cb_ex2pe, cb_ex_global); + global_reduce_receiver(cb_ex_partial2, cb_ex_external2, cb_ex2pe, cb_ex_global, cb_ex2); } diff --git a/tt_eager/tt_dnn/op_library/layernorm/kernels/dataflow/reader_mcast_sender_unary_sharded_ln.cpp b/tt_eager/tt_dnn/op_library/layernorm/kernels/dataflow/reader_mcast_sender_unary_sharded_ln.cpp index 978f7effbbbc..3e0ad4950d10 100644 --- a/tt_eager/tt_dnn/op_library/layernorm/kernels/dataflow/reader_mcast_sender_unary_sharded_ln.cpp +++ b/tt_eager/tt_dnn/op_library/layernorm/kernels/dataflow/reader_mcast_sender_unary_sharded_ln.cpp @@ -6,22 +6,27 @@ #include "dataflow_api.h" #include "hostdevcommon/common_values.hpp" + // split REDUCE across cores void kernel_main() { - constexpr uint32_t reduce_receiver_semaphore_addr = get_compile_time_arg_val(0); - constexpr uint32_t reduce_sender_semaphore_addr = get_compile_time_arg_val(1); - constexpr uint32_t num_blocks = get_compile_time_arg_val(2); - constexpr uint32_t block_h = get_compile_time_arg_val(3); - constexpr uint32_t block_h_size_bytes = get_compile_time_arg_val(4); - constexpr uint32_t num_all_to_all_workers = get_compile_time_arg_val(5); - constexpr uint32_t num_tiles_per_worker = get_compile_time_arg_val(6); - constexpr uint32_t num_tiles_per_worker_bytes = get_compile_time_arg_val(7); + constexpr uint32_t reduce_receiver_semaphore_addr = get_compile_time_arg_val(0); + constexpr uint32_t reduce_sender_semaphore_addr = get_compile_time_arg_val(1); + constexpr uint32_t num_blocks = get_compile_time_arg_val(2); + constexpr uint32_t block_h = get_compile_time_arg_val(3); + constexpr uint32_t block_h_size_bytes = get_compile_time_arg_val(4); + constexpr uint32_t num_all_to_all_workers_first_stage = get_compile_time_arg_val(5); + constexpr uint32_t num_tiles_per_worker = get_compile_time_arg_val(6); + constexpr uint32_t num_tiles_per_worker_bytes = get_compile_time_arg_val(7); constexpr uint32_t num_tiles_per_worker_last = get_compile_time_arg_val(8); constexpr uint32_t num_tiles_per_worker_last_bytes = get_compile_time_arg_val(9); constexpr bool row_major = (bool) get_compile_time_arg_val(10); constexpr uint32_t num_x = get_compile_time_arg_val(11); constexpr uint32_t num_y = get_compile_time_arg_val(12); + constexpr bool use_two_stage_reduce = (bool) get_compile_time_arg_val(13); + constexpr uint32_t num_blocks_first_stage = get_compile_time_arg_val(14); + constexpr uint32_t num_blocks_second_stage = get_compile_time_arg_val(15); + constexpr uint32_t reduce_second_stage_semaphore_addr = get_compile_time_arg_val(16); const uint32_t mcast_dest_noc_start_x = get_arg_val(0); const uint32_t mcast_dest_noc_start_y = get_arg_val(1); @@ -82,13 +87,14 @@ void kernel_main() { volatile tt_l1_ptr uint32_t* reduce_sender_semaphore_addr_ptr = reinterpret_cast(reduce_sender_semaphore_addr); volatile tt_l1_ptr uint32_t* reduce_receiver_semaphore_addr_ptr = reinterpret_cast(reduce_receiver_semaphore_addr); + volatile tt_l1_ptr uint32_t* reduce_second_stage_semaphore_addr_ptr = reinterpret_cast(reduce_second_stage_semaphore_addr); - const auto& global_reduce_sender = [&](const uint32_t cb_partial, const uint32_t cb_external, const uint32_t cb_ex, const uint32_t cb_ex_global) __attribute__((always_inline)) + const auto& global_reduce_sender = [&](const uint32_t cb_partial, const uint32_t cb_external, const uint32_t cb_ex, const uint32_t cb_ex_global, const uint32_t cb_reduce_first_stage) __attribute__((always_inline)) { // global reduce // wait for local data ready cb_wait_front(cb_partial, block_h); - // inc semaphore of other cores + // inc semaphore of other cores, tell other all-to-all workers to start if constexpr(num_blocks > 1) { *reduce_sender_semaphore_addr_ptr = VALID; noc_semaphore_wait(reduce_receiver_semaphore_addr_ptr, num_blocks-1); @@ -96,70 +102,92 @@ void kernel_main() { noc_semaphore_set_multicast(reduce_sender_semaphore_addr, reduce_sender_semaphore_noc_addr, num_blocks-1); } - // read data from other cores + // read data from other cores - first stage reduce uint32_t l1_read_addr_ex_par = get_read_ptr(cb_partial); for (uint32_t i = 0; i < num_tiles_per_worker; ++i) { - uint32_t l1_write_addr_external = get_write_ptr(cb_external); - for(uint32_t block = 0; block < num_blocks; ++block) { + for(uint32_t block = 0; block < num_blocks_first_stage; ++block) { cb_reserve_back(cb_external, 1); + uint32_t l1_write_addr_external = get_write_ptr(cb_external); uint64_t noc_addr_ex_par = remote_noc_addrs[block] | l1_read_addr_ex_par; noc_async_read_one_packet(noc_addr_ex_par, l1_write_addr_external, single_tile_size_bytes); - l1_write_addr_external += single_tile_size_bytes; noc_async_read_barrier(); - cb_push_back(cb_external, 1); } l1_read_addr_ex_par += single_tile_size_bytes; } + // sync with second-stage all-to-all workers + if constexpr(use_two_stage_reduce) { + + noc_semaphore_wait(reduce_second_stage_semaphore_addr_ptr, num_blocks_second_stage-1); + noc_semaphore_set(reduce_second_stage_semaphore_addr_ptr, 0); + + uint32_t block_index_stride; + if constexpr(row_major) { + block_index_stride = num_x; + } else { + block_index_stride = num_y; + } + + // read data from other cores - second stage reduce + uint32_t l1_read_addr_ex = get_read_ptr(cb_reduce_first_stage); + for (uint32_t i = 0; i < num_tiles_per_worker; ++i) { + uint32_t curr_block_index = block_index_stride; + for(uint32_t block = 0; block < num_blocks_second_stage - 1; ++block) { + cb_reserve_back(cb_external, 1); + uint32_t l1_write_addr_external = get_write_ptr(cb_external); + uint64_t noc_addr_ex = remote_noc_addrs[curr_block_index] | l1_read_addr_ex; + noc_async_read_one_packet(noc_addr_ex, l1_write_addr_external, single_tile_size_bytes); + curr_block_index += block_index_stride; + noc_async_read_barrier(); + cb_push_back(cb_external, 1); + } + l1_read_addr_ex += single_tile_size_bytes; + } + } + uint32_t l1_read_addr_ex = get_read_ptr(cb_ex); uint32_t l1_write_addr_ex_global = get_write_ptr(cb_ex_global); cb_wait_front(cb_ex, num_tiles_per_worker); - // sync with other workers - if constexpr(num_blocks > 1) { - noc_semaphore_wait(reduce_receiver_semaphore_addr_ptr, num_blocks-1); + // sync with other all-to-all workers, on the same row + if constexpr(num_all_to_all_workers_first_stage > 1) { + noc_semaphore_wait(reduce_receiver_semaphore_addr_ptr, num_all_to_all_workers_first_stage-1); noc_semaphore_set(reduce_receiver_semaphore_addr_ptr, 0); - noc_semaphore_set_multicast(reduce_sender_semaphore_addr, reduce_sender_semaphore_noc_addr, num_blocks-1); } // gather data to top row cb_reserve_back(cb_ex_global, block_h); - for (uint32_t block = 0; block < num_all_to_all_workers; ++block) { + for (uint32_t block = 0; block < num_all_to_all_workers_first_stage; ++block) { uint64_t noc_addr_ex = remote_noc_addrs[block] | l1_read_addr_ex; - uint32_t num_tiles = block == num_all_to_all_workers - 1 ? num_tiles_per_worker_last_bytes : num_tiles_per_worker_bytes; + uint32_t num_tiles_bytes = block == num_all_to_all_workers_first_stage - 1 ? num_tiles_per_worker_last_bytes : num_tiles_per_worker_bytes; if constexpr (num_tiles_per_worker_bytes <= NOC_MAX_BURST_SIZE) - noc_async_read_one_packet(noc_addr_ex, l1_write_addr_ex_global, num_tiles); + noc_async_read_one_packet(noc_addr_ex, l1_write_addr_ex_global, num_tiles_bytes); else - noc_async_read(noc_addr_ex, l1_write_addr_ex_global, num_tiles); - l1_write_addr_ex_global += num_tiles; + noc_async_read(noc_addr_ex, l1_write_addr_ex_global, num_tiles_bytes); + l1_write_addr_ex_global += num_tiles_bytes; } noc_async_read_barrier(); // mcast uint32_t l1_read_addr_ex_global = get_read_ptr(cb_ex_global); - if constexpr(num_blocks > 1) { - for (uint32_t block = 0; block < num_all_to_all_workers; ++block) { - *reduce_sender_semaphore_addr_ptr = block + 1; - - uint32_t num_tiles = block == num_all_to_all_workers - 1 ? num_tiles_per_worker_last_bytes : num_tiles_per_worker_bytes; + for (uint32_t block = 0; block < num_all_to_all_workers_first_stage; ++block) { + *reduce_sender_semaphore_addr_ptr = block + 2; - noc_async_write_multicast(l1_read_addr_ex_global, multicast_data_noc | l1_read_addr_ex_global, num_tiles, num_blocks-1, true); + uint32_t num_tiles_bytes = block == num_all_to_all_workers_first_stage - 1 ? num_tiles_per_worker_last_bytes : num_tiles_per_worker_bytes; - if (block == num_all_to_all_workers-1) - noc_semaphore_set_multicast(reduce_sender_semaphore_addr, reduce_sender_semaphore_noc_addr, num_blocks-1, false); - else - noc_semaphore_set_multicast(reduce_sender_semaphore_addr, reduce_sender_semaphore_noc_addr, num_blocks-1, true); + noc_async_write_multicast(l1_read_addr_ex_global, multicast_data_noc | l1_read_addr_ex_global, num_tiles_bytes, num_blocks-1, false, false); + noc_semaphore_set_multicast(reduce_sender_semaphore_addr, reduce_sender_semaphore_noc_addr, num_blocks-1, false, false); - l1_read_addr_ex_global += num_tiles; + l1_read_addr_ex_global += num_tiles_bytes; noc_async_write_barrier(); } } cb_push_back(cb_ex_global, block_h); }; #ifndef RMSNORM - global_reduce_sender(cb_ex_partial, cb_ex_external, cb_ex, cb_ex_global); + global_reduce_sender(cb_ex_partial, cb_ex_external, cb_ex, cb_ex_global, cb_ex); #endif - global_reduce_sender(cb_ex_partial2, cb_ex_external2, cb_ex2pe, cb_ex_global); + global_reduce_sender(cb_ex_partial2, cb_ex_external2, cb_ex2pe, cb_ex_global, cb_ex2); } diff --git a/tt_eager/tt_dnn/op_library/layernorm/kernels/dataflow/writer_unary_sharded_ln_rm_gb.cpp b/tt_eager/tt_dnn/op_library/layernorm/kernels/dataflow/writer_unary_sharded_ln_rm_gb.cpp index 1f91ac559ec8..a5572bde1908 100644 --- a/tt_eager/tt_dnn/op_library/layernorm/kernels/dataflow/writer_unary_sharded_ln_rm_gb.cpp +++ b/tt_eager/tt_dnn/op_library/layernorm/kernels/dataflow/writer_unary_sharded_ln_rm_gb.cpp @@ -8,7 +8,6 @@ #include "tt_eager/tt_dnn/kernels/dataflow/generate_reduce_scaler.hpp" #include "tt_eager/tt_dnn/kernels/dataflow/generate_bcast_scalar.hpp" -// #include "debug/dprint.h" void kernel_main() { constexpr bool is_all_to_all_worker = get_compile_time_arg_val(0) == 1; diff --git a/tt_eager/tt_dnn/op_library/layernorm/multi_core/layernorm_op_multi_core.cpp b/tt_eager/tt_dnn/op_library/layernorm/multi_core/layernorm_op_multi_core.cpp index 7666984fcad5..53a256c14224 100644 --- a/tt_eager/tt_dnn/op_library/layernorm/multi_core/layernorm_op_multi_core.cpp +++ b/tt_eager/tt_dnn/op_library/layernorm/multi_core/layernorm_op_multi_core.cpp @@ -485,6 +485,16 @@ operation::ProgramWithCallbacks layernorm_multi_core_sharded( num_blocks = grid_size.y; } + // two-stage reduce + bool use_two_stage_reduce = false; + if (mcast_1d) { + // only do this for row/col dim are full length + if (row_wise && grid_size.x == device->compute_with_storage_grid_size().x && grid_size.y > 1) { // row major and multiple rows + use_two_stage_reduce = true; + } else if (!row_wise && grid_size.x > 1 && grid_size.y == device->compute_with_storage_grid_size().y) { // col major and multiple cols + use_two_stage_reduce = true; + } + } uint32_t num_subblocks_w = block_wt / subblock_wt; // get sharded addr @@ -511,6 +521,13 @@ operation::ProgramWithCallbacks layernorm_multi_core_sharded( //////////////////////////////////////////////////////////////////////////// // block size for in0 (tensor a) uint32_t num_rows_per_all_to_all_worker = div_up(block_ht, num_blocks); + if (use_two_stage_reduce) { + if (row_wise) { + num_rows_per_all_to_all_worker = div_up(block_ht, grid_size.x); + } else { + num_rows_per_all_to_all_worker = div_up(block_ht, grid_size.y); + } + } uint32_t num_rows_per_all_to_all_worker_last = block_ht - (block_ht / num_rows_per_all_to_all_worker) * num_rows_per_all_to_all_worker; uint32_t in0_block_tiles = block_wt * block_ht; uint32_t in0_CB_tiles = in0_block_tiles; @@ -548,6 +565,22 @@ operation::ProgramWithCallbacks layernorm_multi_core_sharded( uint32_t num_cores_y = grid_size.y; uint32_t num_cores = num_cores_x * num_cores_y; uint32_t num_cores_all_to_all = div_up(block_ht, num_rows_per_all_to_all_worker); + uint32_t num_cores_all_to_all_first_stage = num_cores_all_to_all; + uint32_t num_cores_all_to_all_second_stage = 0; + uint32_t num_blocks_first_stage = num_blocks; + uint32_t num_blocks_second_stage = 0; + if (use_two_stage_reduce) { + if (row_wise) { + num_blocks_first_stage = num_cores_x; + num_cores_all_to_all_second_stage = num_cores_y; + num_cores_all_to_all *= num_cores_y; + } else { + num_blocks_first_stage = num_cores_y; + num_cores_all_to_all_second_stage = num_cores_x; + num_cores_all_to_all *= num_cores_x; + } + num_blocks_second_stage = num_cores_all_to_all_second_stage; + } uint32_t num_none_all_to_all_workers = num_blocks - num_cores_all_to_all; if (num_rows_per_all_to_all_worker_last == 0) num_rows_per_all_to_all_worker_last = num_rows_per_all_to_all_worker; @@ -561,48 +594,94 @@ operation::ProgramWithCallbacks layernorm_multi_core_sharded( uint32_t num_cores_x_mcast, num_cores_y_mcast; if (mcast_1d) { sender_cores = {start_core, start_core}; - all_to_all_cores = num_cores_to_corerange_set(start_core, num_cores_all_to_all, grid_size, row_wise); + CoreCoord all_core_grid_size; + CoreCoord none_core_grid_size; + if (use_two_stage_reduce) { + if (row_wise) { + all_core_grid_size = {num_cores_all_to_all_first_stage, num_cores_y}; + none_core_grid_size = {num_cores_x - num_cores_all_to_all_first_stage, num_cores_y}; + } else { + all_core_grid_size = {num_cores_x, num_cores_all_to_all_first_stage}; + none_core_grid_size = {num_cores_x, num_cores_y - num_cores_all_to_all_first_stage}; + } + } else { + all_core_grid_size = grid_size; + none_core_grid_size = grid_size; + } + all_to_all_cores = num_cores_to_corerange_set(start_core, num_cores_all_to_all, all_core_grid_size, row_wise); if (row_wise) { if (use_mcast) { CoreCoord all_start_core; CoreCoord end_core = sender_cores.end; - if (end_core.x == bbox.end.x) { - all_start_core = {0, end_core.y + 1}; + if (use_two_stage_reduce) { + if (end_core.x == all_core_grid_size.x - 1) { + all_start_core = {0, end_core.y + 1}; + } else { + all_start_core = {end_core.x + 1, end_core.y}; + } } else { - all_start_core = {end_core.x + 1, end_core.y}; + if (end_core.x == bbox.end.x) { + all_start_core = {0, end_core.y + 1}; + } else { + all_start_core = {end_core.x + 1, end_core.y}; + } } - all_to_all_workers_except_sender = num_cores_to_corerange_set(all_start_core, num_cores_all_to_all - 1, grid_size, row_wise); + all_to_all_workers_except_sender = num_cores_to_corerange_set(all_start_core, num_cores_all_to_all - 1, all_core_grid_size, row_wise); } if (num_none_all_to_all_workers > 0) { - CoreCoord none_start_core; - CoreCoord end_core = (*all_to_all_cores.ranges().rbegin()).end; - if (end_core.x == bbox.end.x) { - none_start_core = {0, end_core.y + 1}; + if (use_two_stage_reduce) { + CoreCoord none_start_core = {all_core_grid_size.x, sender_cores.end.y}; + CoreCoord none_end_core = {num_cores_x - 1, num_cores_y - 1}; + CoreRange none_core_range = CoreRange(none_start_core, none_end_core); + std::set none_core_set; none_core_set.insert(none_core_range); + not_all_to_all_workers = CoreRangeSet(none_core_set); } else { - none_start_core = {end_core.x + 1, end_core.y}; + CoreCoord none_start_core; + CoreCoord end_core = (*all_to_all_cores.ranges().rbegin()).end; + if (end_core.x == bbox.end.x) { + none_start_core = {0, end_core.y + 1}; + } else { + none_start_core = {end_core.x + 1, end_core.y}; + } + not_all_to_all_workers = num_cores_to_corerange_set(none_start_core, num_none_all_to_all_workers, none_core_grid_size, row_wise); } - not_all_to_all_workers = num_cores_to_corerange_set(none_start_core, num_none_all_to_all_workers, grid_size, row_wise); } } else { if (use_mcast) { CoreCoord all_start_core; CoreCoord end_core = sender_cores.end; - if (end_core.y == bbox.end.y) { - all_start_core = {end_core.x + 1, 0}; + if (use_two_stage_reduce) { + if (end_core.y == all_core_grid_size.y - 1) { + all_start_core = {end_core.x + 1, 0}; + } else { + all_start_core = {end_core.x, end_core.y + 1}; + } } else { - all_start_core = {end_core.x, end_core.y + 1}; + if (end_core.y == bbox.end.y) { + all_start_core = {end_core.x + 1, 0}; + } else { + all_start_core = {end_core.x, end_core.y + 1}; + } } - all_to_all_workers_except_sender = num_cores_to_corerange_set(CoreCoord{start_core.x, start_core.y + 1}, num_cores_all_to_all - 1, grid_size, row_wise); + all_to_all_workers_except_sender = num_cores_to_corerange_set(CoreCoord{start_core.x, start_core.y + 1}, num_cores_all_to_all - 1, all_core_grid_size, row_wise); } if (num_none_all_to_all_workers > 0) { - CoreCoord none_start_core; - CoreCoord end_core = (*all_to_all_cores.ranges().rbegin()).end; - if (end_core.y == bbox.end.y) { - none_start_core = {end_core.x + 1, 0}; + if (use_two_stage_reduce) { + CoreCoord none_start_core = {sender_cores.end.x, all_core_grid_size.y}; + CoreCoord none_end_core = {num_cores_x - 1, num_cores_y - 1}; + CoreRange none_core_range = CoreRange(none_start_core, none_end_core); + std::set none_core_set; none_core_set.insert(none_core_range); + not_all_to_all_workers = CoreRangeSet(none_core_set); } else { - none_start_core = {end_core.x, end_core.y + 1}; + CoreCoord none_start_core; + CoreCoord end_core = (*all_to_all_cores.ranges().rbegin()).end; + if (end_core.y == bbox.end.y) { + none_start_core = {end_core.x + 1, 0}; + } else { + none_start_core = {end_core.x, end_core.y + 1}; + } + not_all_to_all_workers = num_cores_to_corerange_set(none_start_core, num_none_all_to_all_workers, none_core_grid_size, row_wise); } - not_all_to_all_workers = num_cores_to_corerange_set(none_start_core, num_none_all_to_all_workers, grid_size, row_wise); } } num_cores_x_mcast = num_cores_x; @@ -651,6 +730,7 @@ operation::ProgramWithCallbacks layernorm_multi_core_sharded( // Mcast args auto reduce_sender_semaphore = tt_metal::CreateSemaphore(program, all_cores, INVALID); auto reduce_receiver_semaphore = tt_metal::CreateSemaphore(program, all_cores, INVALID); + auto reduce_second_stage_semaphore = tt_metal::CreateSemaphore(program, all_cores, INVALID); // reader defines std::map reader_mcast_sender_defines; std::map reader_mcast_receiver_defines; @@ -677,7 +757,7 @@ operation::ProgramWithCallbacks layernorm_multi_core_sharded( (std::uint32_t) num_blocks, (std::uint32_t) block_ht, (std::uint32_t) block_ht * single_tile_size, - (std::uint32_t) num_cores_all_to_all, + (std::uint32_t) num_cores_all_to_all_first_stage, (std::uint32_t) num_rows_per_all_to_all_worker, (std::uint32_t) num_rows_per_all_to_all_worker * single_tile_size, (std::uint32_t) num_rows_per_all_to_all_worker_last, @@ -685,6 +765,10 @@ operation::ProgramWithCallbacks layernorm_multi_core_sharded( (std::uint32_t) row_wise, (std::uint32_t) num_cores_x_mcast, (std::uint32_t) num_cores_y_mcast, + (std::uint32_t) use_two_stage_reduce, + (std::uint32_t) num_blocks_first_stage, + (std::uint32_t) num_blocks_second_stage, + (std::uint32_t) reduce_second_stage_semaphore }; std::vector reader_mcast_receiver_all_to_all_compile_time_args = { (std::uint32_t) reduce_receiver_semaphore, @@ -692,12 +776,16 @@ operation::ProgramWithCallbacks layernorm_multi_core_sharded( (std::uint32_t) num_blocks, (std::uint32_t) block_ht, (std::uint32_t) 1, - (std::uint32_t) num_cores_all_to_all, + (std::uint32_t) num_cores_all_to_all_first_stage, (std::uint32_t) num_rows_per_all_to_all_worker, (std::uint32_t) num_rows_per_all_to_all_worker_last, (std::uint32_t) row_wise, (std::uint32_t) num_cores_x_mcast, (std::uint32_t) num_cores_y_mcast, + (std::uint32_t) use_two_stage_reduce, + (std::uint32_t) num_blocks_first_stage, + (std::uint32_t) num_blocks_second_stage, + (std::uint32_t) reduce_second_stage_semaphore }; std::vector reader_mcast_receiver_compile_time_args = { (std::uint32_t) reduce_receiver_semaphore, @@ -705,12 +793,16 @@ operation::ProgramWithCallbacks layernorm_multi_core_sharded( (std::uint32_t) num_blocks, (std::uint32_t) block_ht, (std::uint32_t) 0, - (std::uint32_t) num_cores_all_to_all, + (std::uint32_t) num_cores_all_to_all_first_stage, (std::uint32_t) num_rows_per_all_to_all_worker, (std::uint32_t) num_rows_per_all_to_all_worker_last, (std::uint32_t) row_wise, (std::uint32_t) 1, (std::uint32_t) 1, + (std::uint32_t) 0, + (std::uint32_t) 0, + (std::uint32_t) 0, + (std::uint32_t) reduce_second_stage_semaphore }; tt_metal::NOC reader_noc = detail::GetPreferredNOCForDRAMRead(device->arch()); @@ -831,44 +923,33 @@ operation::ProgramWithCallbacks layernorm_multi_core_sharded( compute_defines["RMSNORM"] = "1"; } // compute kernel compile time args - std::vector top_row_compute_compile_time_args = { - 1, - gamma.has_value(), - beta.has_value(), - num_blocks, - block_ht, - block_wt, - subblock_wt, - num_subblocks_w, - 1, - block_ht * block_wt, - fp32_dest_acc_en - }; std::vector all_to_all_except_top_compute_compile_time_args = { 0, gamma.has_value(), beta.has_value(), - num_blocks, + num_blocks_first_stage, block_ht, block_wt, subblock_wt, num_subblocks_w, 1, block_ht * block_wt, - fp32_dest_acc_en + fp32_dest_acc_en, + num_blocks_second_stage }; std::vector not_all_to_all_compute_compile_time_args = { 0, gamma.has_value(), beta.has_value(), - num_blocks, + num_blocks_first_stage, block_ht, block_wt, subblock_wt, num_subblocks_w, 0, block_ht * block_wt, - fp32_dest_acc_en + fp32_dest_acc_en, + num_blocks_second_stage }; // compute kernel KernelHandle compute_kernels_id = -1; @@ -997,8 +1078,11 @@ operation::ProgramWithCallbacks layernorm_multi_core_sharded( writer_kernel_ids.reserve(cores.size()); float winv = 1.0f / block_w; // bcast-w scaler float cinv = 1.0f / num_blocks; // bcast-cores scaler + float cinv_one = 1.0f; // bcast-cores scaler for all-to-all cores not on first row/col bfloat16 bfloat_cinv_value = bfloat16(cinv); uint32_t packed_cinv_value = pack_two_bfloat16_into_uint32({bfloat_cinv_value, bfloat_cinv_value}); + bfloat16 bfloat_cinv_value_one = bfloat16(cinv_one); + uint32_t packed_cinv_value_one = pack_two_bfloat16_into_uint32({bfloat_cinv_value_one, bfloat_cinv_value_one}); bfloat16 bfloat_winv_value = bfloat16(winv); uint32_t packed_winv_value = pack_two_bfloat16_into_uint32({bfloat_winv_value, bfloat_winv_value}); union { float f; uint32_t u; } e; e.f = eps; @@ -1029,15 +1113,37 @@ operation::ProgramWithCallbacks layernorm_multi_core_sharded( width_index = core.y; } } - uint32_t all_to_all_worker_tile_offset_size_bytes = (width_index * num_rows_per_all_to_all_worker) * single_tile_size; - uint32_t in1_tile_start_id = (height_index * block_ht * Kt) + (width_index * block_wt); + + uint32_t width_index_two_stage = width_index % num_blocks_first_stage; + + uint32_t all_to_all_worker_tile_offset_size_bytes; + if (use_two_stage_reduce) { + all_to_all_worker_tile_offset_size_bytes = (width_index_two_stage * num_rows_per_all_to_all_worker) * single_tile_size; + } else { + all_to_all_worker_tile_offset_size_bytes = (width_index * num_rows_per_all_to_all_worker) * single_tile_size; + } uint32_t gamma_tile_start_id = width_index * block_wt; uint32_t beta_tile_start_id = width_index * block_wt; - if (width_index < num_cores_all_to_all) { + if ((not use_two_stage_reduce and width_index < num_cores_all_to_all) or + (use_two_stage_reduce and width_index_two_stage < num_cores_all_to_all_first_stage)) + { std::vector compute_args; - uint32_t num_rows = width_index == num_cores_all_to_all - 1 ? num_rows_per_all_to_all_worker_last : num_rows_per_all_to_all_worker; + uint32_t num_rows; + if (use_two_stage_reduce) { + num_rows = width_index_two_stage == num_cores_all_to_all_first_stage - 1 ? num_rows_per_all_to_all_worker_last : num_rows_per_all_to_all_worker; + } else { + num_rows = width_index == num_cores_all_to_all - 1 ? num_rows_per_all_to_all_worker_last : num_rows_per_all_to_all_worker; + } compute_args.push_back(num_rows); + compute_args.push_back((uint32_t)use_two_stage_reduce); + bool is_second_stage_reader; + if (use_two_stage_reduce) { + is_second_stage_reader = width_index < num_cores_all_to_all_first_stage; + } else { + is_second_stage_reader = false; + } + compute_args.push_back((uint32_t)is_second_stage_reader); tt_metal::SetRuntimeArgs(program, compute_kernels_id_all_to_all, core, compute_args); } @@ -1094,11 +1200,26 @@ operation::ProgramWithCallbacks layernorm_multi_core_sharded( } } tt_metal::SetRuntimeArgs(program, reader_mcast_sender_kernels_id, core, mcast_sender_args); - } else if (width_index < num_cores_all_to_all) { + } else if ((not use_two_stage_reduce and width_index < num_cores_all_to_all) or + (use_two_stage_reduce and width_index_two_stage < num_cores_all_to_all_first_stage)) + { std::vector mcast_receiver_args; - bool is_last_all_to_all_worker = width_index == num_cores_all_to_all - 1 ? true : false; + bool is_last_all_to_all_worker; + if (use_two_stage_reduce) { + is_last_all_to_all_worker = width_index_two_stage == num_cores_all_to_all_first_stage - 1 ? true : false; + } else { + is_last_all_to_all_worker = width_index == num_cores_all_to_all - 1 ? true : false; + } mcast_receiver_args.push_back(is_last_all_to_all_worker); mcast_receiver_args.push_back(all_to_all_worker_tile_offset_size_bytes); + bool is_second_stage_reader; + if (use_two_stage_reduce and width_index < num_cores_all_to_all_first_stage) { + is_second_stage_reader = true; + mcast_receiver_args.push_back((uint32_t)is_second_stage_reader); + } else { + is_second_stage_reader = false; + mcast_receiver_args.push_back((uint32_t)is_second_stage_reader); + } if (mcast_1d) { mcast_receiver_args.push_back(core.x - start_core.x); mcast_receiver_args.push_back(core.y - start_core.y); @@ -1125,6 +1246,7 @@ operation::ProgramWithCallbacks layernorm_multi_core_sharded( mcast_receiver_args.push_back(all_to_all_worker_tile_offset_size_bytes); mcast_receiver_args.push_back(0); mcast_receiver_args.push_back(0); + mcast_receiver_args.push_back(0); if (mcast_1d) { mcast_receiver_args.push_back(in0_mcast_noc_x[0]); mcast_receiver_args.push_back(in0_mcast_noc_y[0]); @@ -1140,10 +1262,22 @@ operation::ProgramWithCallbacks layernorm_multi_core_sharded( tt_metal::SetRuntimeArgs(program, reader_mcast_receiver_kernels_id, core, mcast_receiver_args); } - if (width_index < num_cores_all_to_all) { // all to all workers + if ((not use_two_stage_reduce and width_index < num_cores_all_to_all) or + (use_two_stage_reduce and width_index_two_stage < num_cores_all_to_all_first_stage)) + { std::vector writer_mcast_sender_args; - writer_mcast_sender_args.push_back(packed_cinv_value); - writer_mcast_sender_args.push_back(packed_winv_value); + if (use_two_stage_reduce) { + if (width_index < num_cores_all_to_all_first_stage) { + writer_mcast_sender_args.push_back(packed_cinv_value); + writer_mcast_sender_args.push_back(packed_winv_value); + } else { + writer_mcast_sender_args.push_back(packed_cinv_value_one); + writer_mcast_sender_args.push_back(packed_winv_value); + } + } else { + writer_mcast_sender_args.push_back(packed_cinv_value); + writer_mcast_sender_args.push_back(packed_winv_value); + } writer_mcast_sender_args.push_back(e.u); writer_mcast_sender_args.push_back(gamma_dram_addr); writer_mcast_sender_args.push_back(beta_dram_addr);