Skip to content

Commit

Permalink
#8060: add 2d reduce for ln sharded
Browse files Browse the repository at this point in the history
  • Loading branch information
yugaoTT committed Jun 21, 2024
1 parent 4657af0 commit d23bdc3
Show file tree
Hide file tree
Showing 5 changed files with 405 additions and 148 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ void MAIN {
constexpr uint32_t is_top_row = get_compile_time_arg_val(0);
constexpr uint32_t do_gamma = get_compile_time_arg_val(1);
constexpr uint32_t do_beta = get_compile_time_arg_val(2);
constexpr uint32_t num_blocks = get_compile_time_arg_val(3);
constexpr uint32_t num_blocks_first_stage = get_compile_time_arg_val(3);
constexpr uint32_t block_w = get_compile_time_arg_val(5);
constexpr uint32_t block_h_const = get_compile_time_arg_val(4);
volatile uint32_t block_h_volatile = get_compile_time_arg_val(4);
Expand All @@ -33,9 +33,26 @@ void MAIN {
constexpr uint32_t num_subblocks_w = get_compile_time_arg_val(7);
const bool is_allgather_worker = get_compile_time_arg_val(8) == 1;
constexpr uint32_t num_tiles_per_block = get_compile_time_arg_val(9);
constexpr bool FLOAT32_DTYPE = get_compile_time_arg_val(10) == 1;
constexpr bool FLOAT32_DTYPE = get_compile_time_arg_val(10) == 1;
constexpr uint32_t num_blocks_second_stage = get_compile_time_arg_val(11);

const uint32_t num_tiles_per_allgather_worker = is_allgather_worker ? get_arg_val<uint32_t>(0) : 0;
const uint32_t num_tiles_per_allgather_worker = is_allgather_worker ? get_arg_val<uint32_t>(0) : 0;
const bool use_two_stage_reduce = is_allgather_worker ? get_arg_val<uint32_t>(1) == 1 : false;
const bool is_second_stage_reader = is_allgather_worker ? get_arg_val<uint32_t>(2) == 1 : false;

uint32_t num_blocks_reduce;
if (is_second_stage_reader) {
num_blocks_reduce = num_blocks_first_stage + num_blocks_second_stage - 1;
} else {
num_blocks_reduce = num_blocks_first_stage;
}

bool enable_sqrt;
if (use_two_stage_reduce and not is_second_stage_reader) {
enable_sqrt = false;
} else {
enable_sqrt = true;
}

constexpr uint32_t dst0 = 0;
constexpr uint32_t scaler0 = 0;
Expand Down Expand Up @@ -153,7 +170,7 @@ void MAIN {
for (uint32_t i = 0; i < num_tiles_per_allgather_worker; i++) {
cb_wait_front(cb_scaler_global, 1);
tile_regs_acquire();
for (uint32_t w = 0; w < num_blocks; w++) {
for (uint32_t w = 0; w < num_blocks_reduce; w++) {
cb_wait_front(cb_ex_external, 1);
reduce_tile(cb_ex_external, cb_scaler_global, 0, scaler0, dst0);
cb_pop_front(cb_ex_external, 1);
Expand Down Expand Up @@ -238,6 +255,9 @@ void MAIN {
cb_wait_front(cb_xmm2, num_tiles_per_block);

// Var(x)
#ifdef RMSNORM
cb_wait_front(cb_scaler, 1);
#endif
cb_reserve_back(cb_ex_partial2, block_h);
reduce_init_delta<false>(REDUCE_OP, REDUCE_DIM);
index_h_offset = 0;
Expand Down Expand Up @@ -265,7 +285,7 @@ void MAIN {
cb_wait_front(cb_scaler_global, 1);

tile_regs_acquire();
for (uint32_t w = 0; w < num_blocks; w++) {
for (uint32_t w = 0; w < num_blocks_reduce; w++) {
cb_wait_front(cb_ex_external2, 1);
reduce_tile(cb_ex_external2, cb_scaler_global, 0, scaler0, dst0);
cb_pop_front(cb_ex_external2, 1);
Expand All @@ -278,28 +298,29 @@ void MAIN {
reduce_revert_delta();
cb_push_back(cb_ex2, num_tiles_per_allgather_worker);

for (uint32_t i = 0; i < num_tiles_per_allgather_worker; i++) {
// 1/[sqrt(Var + eps)],
cb_wait_front(cb_ex2, 1);
cb_reserve_back(cb_ex2pe, 1);
tile_regs_acquire();
add_tiles_init();
add_tiles(cb_ex2, cb_eps, i, 0, dst0);
tile_regs_wait();
// sqrt(Var + eps)
sqrt_tile_init();
sqrt_tile(dst0);
tile_regs_wait();
// 1/[sqrt(Var + eps)]
recip_tile_init();
recip_tile(dst0);
tile_regs_commit();
tile_regs_wait();
pack_tile(dst0, cb_ex2pe);
cb_push_back(cb_ex2pe, 1);
tile_regs_release();
if (enable_sqrt) {
for (uint32_t i = 0; i < num_tiles_per_allgather_worker; i++) {
// 1/[sqrt(Var + eps)],
cb_wait_front(cb_ex2, 1);
cb_reserve_back(cb_ex2pe, 1);
tile_regs_acquire();
add_tiles_init();
add_tiles(cb_ex2, cb_eps, i, 0, dst0);
tile_regs_wait();
// sqrt(Var + eps)
sqrt_tile_init();
sqrt_tile(dst0);
tile_regs_wait();
// 1/[sqrt(Var + eps)]
recip_tile_init();
recip_tile(dst0);
tile_regs_commit();
tile_regs_wait();
pack_tile(dst0, cb_ex2pe);
cb_push_back(cb_ex2pe, 1);
tile_regs_release();
}
}

}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "dataflow_api.h"
#include "hostdevcommon/common_values.hpp"


// split REDUCE across cores
void kernel_main() {

Expand All @@ -20,13 +21,18 @@ void kernel_main() {
constexpr bool row_major = (bool) get_compile_time_arg_val(8);
constexpr uint32_t num_x = get_compile_time_arg_val(9);
constexpr uint32_t num_y = get_compile_time_arg_val(10);
constexpr bool use_two_stage_reduce = (bool) get_compile_time_arg_val(11);
constexpr uint32_t num_blocks_first_stage = get_compile_time_arg_val(12);
constexpr uint32_t num_blocks_second_stage = get_compile_time_arg_val(13);
constexpr uint32_t reduce_second_stage_semaphore_addr = get_compile_time_arg_val(14);

const bool is_last_all_to_all_worker = get_arg_val<uint32_t>(0);
const uint32_t all_to_all_tile_offset_bytes = get_arg_val<uint32_t>(1);
const uint32_t start_x = get_arg_val<uint32_t>(2);
const uint32_t start_y = get_arg_val<uint32_t>(3);
tt_l1_ptr uint32_t * in0_remote_noc_x = (tt_l1_ptr uint32_t*)(get_arg_addr(4));
tt_l1_ptr uint32_t * in0_remote_noc_y = (tt_l1_ptr uint32_t*)(get_arg_addr(4 + num_x));
const bool is_second_stage_reader = get_arg_val<uint32_t>(2);
const uint32_t start_x = get_arg_val<uint32_t>(3);
const uint32_t start_y = get_arg_val<uint32_t>(4);
volatile tt_l1_ptr uint32_t * in0_remote_noc_x = (volatile tt_l1_ptr uint32_t*)(get_arg_addr(5));
volatile tt_l1_ptr uint32_t * in0_remote_noc_y = (volatile tt_l1_ptr uint32_t*)(get_arg_addr(5 + num_x));

const uint32_t num_tiles_to_read = is_last_all_to_all_worker ? num_tiles_per_worker_last : num_tiles_per_worker;

Expand All @@ -42,89 +48,158 @@ void kernel_main() {
const uint32_t single_tile_size_bytes = get_tile_size(cb_ex_partial2); // tile size
const DataFormat data_format = get_dataformat(cb_ex_partial2); // data format

uint64_t remote_noc_addrs[is_all_to_all_worker ? num_blocks : 1];
uint64_t remote_noc_addrs_first_stage[is_all_to_all_worker ? num_blocks_first_stage : 1];
uint64_t remote_noc_addrs_second_stage[is_all_to_all_worker ? num_blocks_second_stage : 1];
if constexpr (is_all_to_all_worker) {
uint32_t x = start_x, y = start_y;
for (uint32_t i = 0; i < num_blocks; ++i) {
remote_noc_addrs[i] = get_noc_addr(in0_remote_noc_x[x], in0_remote_noc_y[y], 0);
if constexpr(row_major) {
++x;
if (x == num_x) {
x = 0;
if constexpr (use_two_stage_reduce) {
uint32_t x = start_x, y = start_y;
for (uint32_t i = 0; i < num_blocks_first_stage; ++i) {
remote_noc_addrs_first_stage[i] = get_noc_addr(in0_remote_noc_x[x], in0_remote_noc_y[y], 0);
if constexpr(row_major) {
++x;
if (x == num_x) {
x = 0;
}
} else {
++y;
if (y == num_y) {
y = 0;
}
}
}
if constexpr(row_major) {
x = start_x;
y = 0;
} else {
++y;
if (y == num_y) {
y = 0;
x = 0;
y = start_y;
}
for (uint32_t i = 0; i < num_blocks_second_stage; ++i) {
remote_noc_addrs_second_stage[i] = get_noc_addr(in0_remote_noc_x[x], in0_remote_noc_y[y], 0);
if constexpr(row_major) {
++y;
} else {
++x;
}
}
} else {
uint32_t x = start_x, y = start_y;
for (uint32_t i = 0; i < num_blocks; ++i) {
remote_noc_addrs_first_stage[i] = get_noc_addr(in0_remote_noc_x[x], in0_remote_noc_y[y], 0);
if constexpr(row_major) {
++x;
if (x == num_x) {
x = 0;
++y;
if (y == num_y) {
y = 0;
}
}
} else {
++y;
if (y == num_y) {
y = 0;
++x;
if (x == num_x) {
x = 0;
}
}
}
}
}
} else {
remote_noc_addrs[0] = get_noc_addr(in0_remote_noc_x[0], in0_remote_noc_y[0], 0);
remote_noc_addrs_first_stage[0] = get_noc_addr(in0_remote_noc_x[0], in0_remote_noc_y[0], 0);
}

volatile tt_l1_ptr uint32_t* reduce_receiver_semaphore_addr_ptr = reinterpret_cast<volatile tt_l1_ptr uint32_t*>(reduce_receiver_semaphore_addr);
volatile tt_l1_ptr uint32_t* reduce_sender_semaphore_addr_ptr = reinterpret_cast<volatile tt_l1_ptr uint32_t*>(reduce_sender_semaphore_addr);
volatile tt_l1_ptr uint32_t* reduce_second_stage_semaphore_addr_ptr = reinterpret_cast<volatile tt_l1_ptr uint32_t*>(reduce_second_stage_semaphore_addr);

const uint64_t reduce_receiver_semaphore_noc_addr = get_noc_addr(in0_remote_noc_x[0], in0_remote_noc_y[0], reduce_receiver_semaphore_addr);
const uint64_t reduce_second_stage_receiver_semaphore_noc_addr = remote_noc_addrs_second_stage[0] | reduce_second_stage_semaphore_addr;

const auto& global_reduce_receiver = [&](const uint32_t cb_partial, const uint32_t cb_external, const uint32_t cb_ex, const uint32_t cb_ex_global) __attribute__((always_inline))
const auto& global_reduce_receiver = [&](const uint32_t cb_partial, const uint32_t cb_external, const uint32_t cb_ex, const uint32_t cb_ex_global, const uint32_t cb_reduce_first_stage) __attribute__((always_inline))
{
// global reduce
// wait for local data ready
cb_wait_front(cb_partial, block_h);

// inc top core
// inc mcast sender
noc_semaphore_set(reduce_sender_semaphore_addr_ptr, INVALID);
noc_semaphore_inc(reduce_receiver_semaphore_noc_addr, 1);
noc_semaphore_wait(reduce_sender_semaphore_addr_ptr, VALID);

if constexpr (is_all_to_all_worker) {
// read data from other cores
// read data from other cores - reduce first stage
uint32_t l1_read_addr_ex_par = get_read_ptr(cb_partial);
l1_read_addr_ex_par += all_to_all_tile_offset_bytes;
for (uint32_t i = 0; i < num_tiles_to_read; i++) {
uint32_t l1_write_addr_external = get_write_ptr(cb_external);
for(uint32_t block = 0; block < num_blocks; block++) {
for(uint32_t block = 0; block < num_blocks_first_stage; block++) {
cb_reserve_back(cb_external, 1);
uint64_t noc_addr_ex_par = remote_noc_addrs[block] | l1_read_addr_ex_par;
uint32_t l1_write_addr_external = get_write_ptr(cb_external);
uint64_t noc_addr_ex_par = remote_noc_addrs_first_stage[block] | l1_read_addr_ex_par;
noc_async_read_one_packet(noc_addr_ex_par, l1_write_addr_external, single_tile_size_bytes);
l1_write_addr_external += single_tile_size_bytes;

noc_async_read_barrier();
cb_push_back(cb_external, 1);
}
l1_read_addr_ex_par += single_tile_size_bytes;
}

// send result to other cores
cb_wait_front(cb_ex, num_tiles_to_read);
}
// read data from other cores - reduce first stage
if constexpr(use_two_stage_reduce) {
if (is_second_stage_reader) { // gather data from a column of cores (if row major)

// sync with other workers
noc_semaphore_set(reduce_sender_semaphore_addr_ptr, INVALID);
noc_semaphore_inc(reduce_receiver_semaphore_noc_addr, 1);
noc_semaphore_wait(reduce_sender_semaphore_addr_ptr, VALID);
noc_semaphore_set(reduce_sender_semaphore_addr_ptr, INVALID);
noc_semaphore_wait(reduce_second_stage_semaphore_addr_ptr, num_blocks_second_stage-1);
noc_semaphore_set(reduce_second_stage_semaphore_addr_ptr, 0);

uint32_t block_index_stride;
if constexpr(row_major) {
block_index_stride = num_x;
} else {
block_index_stride = num_y;
}

// read data from other cores - second stage reduce
uint32_t l1_read_addr_ex = get_read_ptr(cb_reduce_first_stage);
for (uint32_t i = 0; i < num_tiles_per_worker; ++i) {
for(uint32_t block = 0; block < num_blocks_second_stage - 1; ++block) {
cb_reserve_back(cb_external, 1);
uint32_t l1_write_addr_external = get_write_ptr(cb_external);
uint64_t noc_addr_ex = remote_noc_addrs_second_stage[block + 1] | l1_read_addr_ex;
noc_async_read_one_packet(noc_addr_ex, l1_write_addr_external, single_tile_size_bytes);
noc_async_read_barrier();

cb_push_back(cb_external, 1);
}
l1_read_addr_ex += single_tile_size_bytes;
}

// sync with the mcast sender
cb_wait_front(cb_ex, num_tiles_to_read);
noc_semaphore_inc(reduce_receiver_semaphore_noc_addr, 1);
} else {
// sync with the gather worker
cb_wait_front(cb_reduce_first_stage, num_tiles_to_read);
noc_semaphore_inc(reduce_second_stage_receiver_semaphore_noc_addr, 1);
}
} else {
// send result to other cores
cb_wait_front(cb_ex, num_tiles_to_read);
noc_semaphore_inc(reduce_receiver_semaphore_noc_addr, 1);
}
}

for (uint32_t block = 0; block < num_all_to_all_workers; ++block) {
uint32_t num_tiles = block == num_all_to_all_workers - 1 ? num_tiles_per_worker_last : num_tiles_per_worker;
cb_reserve_back(cb_ex_global, num_tiles);
noc_semaphore_wait(reduce_sender_semaphore_addr_ptr, block+1);
noc_semaphore_wait(reduce_sender_semaphore_addr_ptr, block+2);
cb_push_back(cb_ex_global, num_tiles);
}
};
#ifndef RMSNORM
global_reduce_receiver(cb_ex_partial, cb_ex_external, cb_ex, cb_ex_global);
global_reduce_receiver(cb_ex_partial, cb_ex_external, cb_ex, cb_ex_global, cb_ex);
#endif
global_reduce_receiver(cb_ex_partial2, cb_ex_external2, cb_ex2pe, cb_ex_global);
global_reduce_receiver(cb_ex_partial2, cb_ex_external2, cb_ex2pe, cb_ex_global, cb_ex2);

}
Loading

0 comments on commit d23bdc3

Please sign in to comment.