Skip to content

Commit

Permalink
#15498: adapted read and compute kernels to support fp32 acc and dst …
Browse files Browse the repository at this point in the history
…full sync
  • Loading branch information
ipotkonjak-tt committed Nov 29, 2024
1 parent 4bb43d2 commit f90e8b5
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,30 +12,30 @@ void MAIN {
uint32_t Ht = get_compile_time_arg_val(0);
uint32_t Wt = get_compile_time_arg_val(1);
uint32_t NC = get_compile_time_arg_val(2);
uint32_t row_chunk = get_compile_time_arg_val(3);

reduce_init<true>(tt::CB::c_in0, tt::CB::c_in2);
cb_wait_front(tt::CB::c_in2, 1); // scaler tile from the reader
cb_wait_front(tt::CB::c_in2, 1); // scaler tile from the reader

constexpr int onetile = 1;
for (uint32_t nc = 0; nc < NC; ++nc) {
uint32_t row_chunk = 8;
for(uint32_t wt = 0; wt < Wt; wt += row_chunk) {
for (uint32_t wt = 0; wt < Wt; wt += row_chunk) {
uint32_t chunk_end = std::min(wt + row_chunk, Wt);
uint32_t tile_num = std::min(row_chunk, Wt - wt);
int reduce_dst_idx = 0;

//reduce a chunk of columns(max 8)
// reduce a chunk of columns(max 8)
acquire_dst();
for(uint32_t ht = 0; ht < Ht; ++ht) {
for (uint32_t ht = 0; ht < Ht; ++ht) {
reduce_dst_idx = 0;
for(uint32_t i = wt; i < chunk_end; ++i) {
for (uint32_t i = wt; i < chunk_end; ++i) {
cb_wait_front(tt::CB::c_in0, onetile);
reduce_tile(tt::CB::c_in0, tt::CB::c_in2, 0, 0, reduce_dst_idx);
cb_pop_front(tt::CB::c_in0, onetile);
++reduce_dst_idx;
}
}
for(uint32_t i = 0; i < tile_num; i++) {
for (uint32_t i = 0; i < tile_num; i++) {
cb_reserve_back(tt::CB::c_out0, onetile);
pack_tile(i, tt::CB::c_out0);
cb_push_back(tt::CB::c_out0, onetile);
Expand All @@ -44,4 +44,4 @@ void MAIN {
}
}
}
}
} // namespace NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ void kernel_main() {
constexpr uint32_t Ht = get_compile_time_arg_val(1);
constexpr uint32_t Wt = get_compile_time_arg_val(2);
constexpr uint32_t HtWt = get_compile_time_arg_val(3);
constexpr uint32_t row_chunk = get_compile_time_arg_val(4);

constexpr uint32_t cb_id_in0 = 0;

Expand All @@ -27,7 +28,7 @@ void kernel_main() {

#ifdef REDUCE_SCALER
constexpr uint32_t cb_id_in2 = 2;
constexpr uint32_t scalar = get_compile_time_arg_val(4);
constexpr uint32_t scalar = get_compile_time_arg_val(5);
generate_reduce_scaler(cb_id_in2, scalar);
#endif

Expand All @@ -36,7 +37,6 @@ void kernel_main() {

uint32_t w = curr_col_in_batch;

uint32_t row_chunk = 8;
for (uint32_t i = 0; i < num_cols; i += row_chunk) {
uint32_t chunk_end = std::min(i + row_chunk, num_cols);
uint32_t curr_id = col_start_tile_id;
Expand All @@ -60,16 +60,15 @@ void kernel_main() {
++w;

if (w == Wt) {
col_start_tile_id = curr_id + (Ht - j - 1) * Wt + 1;
curr_id = col_start_tile_id + j*Wt;
col_start_tile_id = curr_id + (Ht - j - 1) * Wt + 1;
curr_id = col_start_tile_id + j * Wt;
w = 0;
}
else {
} else {
++curr_id;
++col_start_tile_id;
}
}
curr_id = reset_curr_id + (j+1) * Wt; // stride in H
curr_id = reset_curr_id + (j + 1) * Wt; // stride in H
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ operation::ProgramWithCallbacks reduce_multi_core_h(
tt_metal::KernelHandle reader_kernel_id;
bfloat16 bfloat_scaler_value = bfloat16(scaler);
uint32_t packed_scaler_value = pack_two_bfloat16_into_uint32({bfloat_scaler_value, bfloat_scaler_value});

uint32_t chunk_size =
(fp32_dest_acc_en ? 4 : 8) * (dst_full_sync_en ? 2 : 1); // column chunk size used for interleaved input only

if (in_sharded) {
std::vector<uint32_t> reader_compile_time_args = {src0_cb_index, src1_cb_index, scaler_cb_index};
std::map<string, string> reader_defines;
Expand All @@ -132,7 +136,7 @@ operation::ProgramWithCallbacks reduce_multi_core_h(
} else {
bool src0_is_dram = src0_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0;
std::vector<uint32_t> reader_compile_time_args = {
(std::uint32_t)src0_is_dram, Ht, Wt, HtWt, packed_scaler_value};
(std::uint32_t)src0_is_dram, Ht, Wt, HtWt, chunk_size, packed_scaler_value};

std::map<string, string> reader_defines;
reader_defines["REDUCE_SCALER"] = "1";
Expand Down Expand Up @@ -168,15 +172,26 @@ operation::ProgramWithCallbacks reduce_multi_core_h(
tt_metal::WriterDataMovementConfig(writer_compile_time_args));
}
std::map<string, string> reduce_defines = reduce_op_utils::get_defines(reduce_op, ReduceOpDim::H);
std::vector<uint32_t> compute_kernel_args_group_1 = {
Ht, // Ht
num_cols_per_core_group_1, // Wt
1, // NC
};
std::vector<uint32_t> compute_kernel_args_group_1;

std::string compute_kernel_path;
if(out_sharded) compute_kernel_path = "ttnn/cpp/ttnn/operations/reduction/generic/device/kernels/compute/reduce_h.cpp";
else compute_kernel_path = "ttnn/cpp/ttnn/operations/reduction/generic/device/kernels/compute/reduce_h_interleaved.cpp";
if (out_sharded) {
compute_kernel_path = "ttnn/cpp/ttnn/operations/reduction/generic/device/kernels/compute/reduce_h.cpp";
compute_kernel_args_group_1 = {
Ht, // Ht
num_cols_per_core_group_1, // Wt
1, // NC
};
} else {
compute_kernel_path =
"ttnn/cpp/ttnn/operations/reduction/generic/device/kernels/compute/reduce_h_interleaved.cpp";
compute_kernel_args_group_1 = {
Ht, // Ht
num_cols_per_core_group_1, // Wt
1, // NC
chunk_size, // Column Chunk Size
};
}

auto reduce_compute_kernel_group_1_id = tt_metal::CreateKernel(
program,
Expand All @@ -189,11 +204,21 @@ operation::ProgramWithCallbacks reduce_multi_core_h(
.defines = reduce_defines});

if (!core_group_2.ranges().empty()) {
std::vector<uint32_t> compute_kernel_args_group_2 = {
Ht, // Ht
num_cols_per_core_group_2, // Wt
1, // NC
};
std::vector<uint32_t> compute_kernel_args_group_2;
if (out_sharded) {
compute_kernel_args_group_2 = {
Ht, // Ht
num_cols_per_core_group_2, // Wt
1, // NC
};
} else {
compute_kernel_args_group_2 = {
Ht, // Ht
num_cols_per_core_group_2, // Wt
1, // NC
chunk_size, // Column Chunk Size
};
}

auto reduce_compute_kernel_group_2_id = tt_metal::CreateKernel(
program,
Expand Down

0 comments on commit f90e8b5

Please sign in to comment.