Skip to content

Commit

Permalink
#8060: opt ln data movement
Browse files Browse the repository at this point in the history
  • Loading branch information
yugaoTT committed Jun 21, 2024
1 parent d23bdc3 commit 6e88ba4
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -134,16 +134,16 @@ void kernel_main() {
uint32_t l1_read_addr_ex_par = get_read_ptr(cb_partial);
l1_read_addr_ex_par += all_to_all_tile_offset_bytes;
for (uint32_t i = 0; i < num_tiles_to_read; i++) {
cb_reserve_back(cb_external, num_blocks_first_stage);
uint32_t l1_write_addr_external = get_write_ptr(cb_external);
for(uint32_t block = 0; block < num_blocks_first_stage; block++) {
cb_reserve_back(cb_external, 1);
uint32_t l1_write_addr_external = get_write_ptr(cb_external);
uint64_t noc_addr_ex_par = remote_noc_addrs_first_stage[block] | l1_read_addr_ex_par;
noc_async_read_one_packet(noc_addr_ex_par, l1_write_addr_external, single_tile_size_bytes);

noc_async_read_barrier();
cb_push_back(cb_external, 1);
l1_write_addr_external+=single_tile_size_bytes;
}
l1_read_addr_ex_par += single_tile_size_bytes;
noc_async_read_barrier();
cb_push_back(cb_external, num_blocks_first_stage);
}

// read data from other cores - reduce first stage
Expand All @@ -163,16 +163,16 @@ void kernel_main() {
// read data from other cores - second stage reduce
uint32_t l1_read_addr_ex = get_read_ptr(cb_reduce_first_stage);
for (uint32_t i = 0; i < num_tiles_per_worker; ++i) {
cb_reserve_back(cb_external, num_blocks_second_stage - 1);
uint32_t l1_write_addr_external = get_write_ptr(cb_external);
for(uint32_t block = 0; block < num_blocks_second_stage - 1; ++block) {
cb_reserve_back(cb_external, 1);
uint32_t l1_write_addr_external = get_write_ptr(cb_external);
uint64_t noc_addr_ex = remote_noc_addrs_second_stage[block + 1] | l1_read_addr_ex;
noc_async_read_one_packet(noc_addr_ex, l1_write_addr_external, single_tile_size_bytes);
noc_async_read_barrier();

cb_push_back(cb_external, 1);
l1_write_addr_external += single_tile_size_bytes;
}
l1_read_addr_ex += single_tile_size_bytes;
noc_async_read_barrier();
cb_push_back(cb_external, num_blocks_second_stage - 1);
}

// sync with the mcast sender
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,16 @@ void kernel_main() {
// read data from other cores - first stage reduce
uint32_t l1_read_addr_ex_par = get_read_ptr(cb_partial);
for (uint32_t i = 0; i < num_tiles_per_worker; ++i) {
cb_reserve_back(cb_external, num_blocks_first_stage);
uint32_t l1_write_addr_external = get_write_ptr(cb_external);
for(uint32_t block = 0; block < num_blocks_first_stage; ++block) {
cb_reserve_back(cb_external, 1);
uint32_t l1_write_addr_external = get_write_ptr(cb_external);
uint64_t noc_addr_ex_par = remote_noc_addrs[block] | l1_read_addr_ex_par;
noc_async_read_one_packet(noc_addr_ex_par, l1_write_addr_external, single_tile_size_bytes);
noc_async_read_barrier();
cb_push_back(cb_external, 1);
l1_write_addr_external += single_tile_size_bytes;
}
l1_read_addr_ex_par += single_tile_size_bytes;
noc_async_read_barrier();
cb_push_back(cb_external, num_blocks_first_stage);
}

// sync with second-stage all-to-all workers
Expand All @@ -133,16 +134,17 @@ void kernel_main() {
uint32_t l1_read_addr_ex = get_read_ptr(cb_reduce_first_stage);
for (uint32_t i = 0; i < num_tiles_per_worker; ++i) {
uint32_t curr_block_index = block_index_stride;
cb_reserve_back(cb_external, num_blocks_second_stage - 1);
uint32_t l1_write_addr_external = get_write_ptr(cb_external);
for(uint32_t block = 0; block < num_blocks_second_stage - 1; ++block) {
cb_reserve_back(cb_external, 1);
uint32_t l1_write_addr_external = get_write_ptr(cb_external);
uint64_t noc_addr_ex = remote_noc_addrs[curr_block_index] | l1_read_addr_ex;
noc_async_read_one_packet(noc_addr_ex, l1_write_addr_external, single_tile_size_bytes);
curr_block_index += block_index_stride;
noc_async_read_barrier();
cb_push_back(cb_external, 1);
l1_write_addr_external += single_tile_size_bytes;
}
l1_read_addr_ex += single_tile_size_bytes;
noc_async_read_barrier();
cb_push_back(cb_external, num_blocks_second_stage - 1);
}
}

Expand Down Expand Up @@ -171,6 +173,7 @@ void kernel_main() {

// mcast
uint32_t l1_read_addr_ex_global = get_read_ptr(cb_ex_global);
cb_push_back(cb_ex_global, block_h);
if constexpr(num_blocks > 1) {
for (uint32_t block = 0; block < num_all_to_all_workers_first_stage; ++block) {
*reduce_sender_semaphore_addr_ptr = block + 2;
Expand All @@ -184,7 +187,6 @@ void kernel_main() {
noc_async_write_barrier();
}
}
cb_push_back(cb_ex_global, block_h);
};
#ifndef RMSNORM
global_reduce_sender(cb_ex_partial, cb_ex_external, cb_ex, cb_ex_global, cb_ex);
Expand Down

0 comments on commit 6e88ba4

Please sign in to comment.