From f90e8b5a86f16a970f9937b12f3582c80e9d3cde Mon Sep 17 00:00:00 2001 From: Iva Potkonjak Date: Fri, 29 Nov 2024 14:17:42 +0000 Subject: [PATCH] #15498: adapted read and compute kernels to support fp32 acc and dst full sync --- .../kernels/compute/reduce_h_interleaved.cpp | 16 +++--- ..._wh_interleaved_input_cols_partitioned.cpp | 13 +++-- .../multi_core_h/reduce_op_multi_core_h.cpp | 51 ++++++++++++++----- 3 files changed, 52 insertions(+), 28 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/reduction/generic/device/kernels/compute/reduce_h_interleaved.cpp b/ttnn/cpp/ttnn/operations/reduction/generic/device/kernels/compute/reduce_h_interleaved.cpp index 3f8db86a76d..0a3d76e2562 100644 --- a/ttnn/cpp/ttnn/operations/reduction/generic/device/kernels/compute/reduce_h_interleaved.cpp +++ b/ttnn/cpp/ttnn/operations/reduction/generic/device/kernels/compute/reduce_h_interleaved.cpp @@ -12,30 +12,30 @@ void MAIN { uint32_t Ht = get_compile_time_arg_val(0); uint32_t Wt = get_compile_time_arg_val(1); uint32_t NC = get_compile_time_arg_val(2); + uint32_t row_chunk = get_compile_time_arg_val(3); reduce_init(tt::CB::c_in0, tt::CB::c_in2); - cb_wait_front(tt::CB::c_in2, 1); // scaler tile from the reader + cb_wait_front(tt::CB::c_in2, 1); // scaler tile from the reader constexpr int onetile = 1; for (uint32_t nc = 0; nc < NC; ++nc) { - uint32_t row_chunk = 8; - for(uint32_t wt = 0; wt < Wt; wt += row_chunk) { + for (uint32_t wt = 0; wt < Wt; wt += row_chunk) { uint32_t chunk_end = std::min(wt + row_chunk, Wt); uint32_t tile_num = std::min(row_chunk, Wt - wt); int reduce_dst_idx = 0; - //reduce a chunk of columns(max 8) + // reduce a chunk of columns(max 8) acquire_dst(); - for(uint32_t ht = 0; ht < Ht; ++ht) { + for (uint32_t ht = 0; ht < Ht; ++ht) { reduce_dst_idx = 0; - for(uint32_t i = wt; i < chunk_end; ++i) { + for (uint32_t i = wt; i < chunk_end; ++i) { cb_wait_front(tt::CB::c_in0, onetile); reduce_tile(tt::CB::c_in0, tt::CB::c_in2, 0, 0, reduce_dst_idx); cb_pop_front(tt::CB::c_in0, onetile); ++reduce_dst_idx; } } - for(uint32_t i = 0; i < tile_num; i++) { + for (uint32_t i = 0; i < tile_num; i++) { cb_reserve_back(tt::CB::c_out0, onetile); pack_tile(i, tt::CB::c_out0); cb_push_back(tt::CB::c_out0, onetile); @@ -44,4 +44,4 @@ void MAIN { } } } -} +} // namespace NAMESPACE diff --git a/ttnn/cpp/ttnn/operations/reduction/generic/device/kernels/dataflow/reader_unary_transpose_wh_interleaved_input_cols_partitioned.cpp b/ttnn/cpp/ttnn/operations/reduction/generic/device/kernels/dataflow/reader_unary_transpose_wh_interleaved_input_cols_partitioned.cpp index 7a7529f1c9c..f19dd086643 100644 --- a/ttnn/cpp/ttnn/operations/reduction/generic/device/kernels/dataflow/reader_unary_transpose_wh_interleaved_input_cols_partitioned.cpp +++ b/ttnn/cpp/ttnn/operations/reduction/generic/device/kernels/dataflow/reader_unary_transpose_wh_interleaved_input_cols_partitioned.cpp @@ -17,6 +17,7 @@ void kernel_main() { constexpr uint32_t Ht = get_compile_time_arg_val(1); constexpr uint32_t Wt = get_compile_time_arg_val(2); constexpr uint32_t HtWt = get_compile_time_arg_val(3); + constexpr uint32_t row_chunk = get_compile_time_arg_val(4); constexpr uint32_t cb_id_in0 = 0; @@ -27,7 +28,7 @@ void kernel_main() { #ifdef REDUCE_SCALER constexpr uint32_t cb_id_in2 = 2; - constexpr uint32_t scalar = get_compile_time_arg_val(4); + constexpr uint32_t scalar = get_compile_time_arg_val(5); generate_reduce_scaler(cb_id_in2, scalar); #endif @@ -36,7 +37,6 @@ void kernel_main() { uint32_t w = curr_col_in_batch; - uint32_t row_chunk = 8; for (uint32_t i = 0; i < num_cols; i += row_chunk) { uint32_t chunk_end = std::min(i + row_chunk, num_cols); uint32_t curr_id = col_start_tile_id; @@ -60,16 +60,15 @@ void kernel_main() { ++w; if (w == Wt) { - col_start_tile_id = curr_id + (Ht - j - 1) * Wt + 1; - curr_id = col_start_tile_id + j*Wt; + col_start_tile_id = curr_id + (Ht - j - 1) * Wt + 1; + curr_id = col_start_tile_id + j * Wt; w = 0; - } - else { + } else { ++curr_id; ++col_start_tile_id; } } - curr_id = reset_curr_id + (j+1) * Wt; // stride in H + curr_id = reset_curr_id + (j + 1) * Wt; // stride in H } } } diff --git a/ttnn/cpp/ttnn/operations/reduction/generic/device/multi_core_h/reduce_op_multi_core_h.cpp b/ttnn/cpp/ttnn/operations/reduction/generic/device/multi_core_h/reduce_op_multi_core_h.cpp index 0e3d832008c..a995e0819f5 100644 --- a/ttnn/cpp/ttnn/operations/reduction/generic/device/multi_core_h/reduce_op_multi_core_h.cpp +++ b/ttnn/cpp/ttnn/operations/reduction/generic/device/multi_core_h/reduce_op_multi_core_h.cpp @@ -119,6 +119,10 @@ operation::ProgramWithCallbacks reduce_multi_core_h( tt_metal::KernelHandle reader_kernel_id; bfloat16 bfloat_scaler_value = bfloat16(scaler); uint32_t packed_scaler_value = pack_two_bfloat16_into_uint32({bfloat_scaler_value, bfloat_scaler_value}); + + uint32_t chunk_size = + (fp32_dest_acc_en ? 4 : 8) * (dst_full_sync_en ? 2 : 1); // column chunk size used for interleaved input only + if (in_sharded) { std::vector reader_compile_time_args = {src0_cb_index, src1_cb_index, scaler_cb_index}; std::map reader_defines; @@ -132,7 +136,7 @@ operation::ProgramWithCallbacks reduce_multi_core_h( } else { bool src0_is_dram = src0_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; std::vector reader_compile_time_args = { - (std::uint32_t)src0_is_dram, Ht, Wt, HtWt, packed_scaler_value}; + (std::uint32_t)src0_is_dram, Ht, Wt, HtWt, chunk_size, packed_scaler_value}; std::map reader_defines; reader_defines["REDUCE_SCALER"] = "1"; @@ -168,15 +172,26 @@ operation::ProgramWithCallbacks reduce_multi_core_h( tt_metal::WriterDataMovementConfig(writer_compile_time_args)); } std::map reduce_defines = reduce_op_utils::get_defines(reduce_op, ReduceOpDim::H); - std::vector compute_kernel_args_group_1 = { - Ht, // Ht - num_cols_per_core_group_1, // Wt - 1, // NC - }; + std::vector compute_kernel_args_group_1; std::string compute_kernel_path; - if(out_sharded) compute_kernel_path = "ttnn/cpp/ttnn/operations/reduction/generic/device/kernels/compute/reduce_h.cpp"; - else compute_kernel_path = "ttnn/cpp/ttnn/operations/reduction/generic/device/kernels/compute/reduce_h_interleaved.cpp"; + if (out_sharded) { + compute_kernel_path = "ttnn/cpp/ttnn/operations/reduction/generic/device/kernels/compute/reduce_h.cpp"; + compute_kernel_args_group_1 = { + Ht, // Ht + num_cols_per_core_group_1, // Wt + 1, // NC + }; + } else { + compute_kernel_path = + "ttnn/cpp/ttnn/operations/reduction/generic/device/kernels/compute/reduce_h_interleaved.cpp"; + compute_kernel_args_group_1 = { + Ht, // Ht + num_cols_per_core_group_1, // Wt + 1, // NC + chunk_size, // Column Chunk Size + }; + } auto reduce_compute_kernel_group_1_id = tt_metal::CreateKernel( program, @@ -189,11 +204,21 @@ operation::ProgramWithCallbacks reduce_multi_core_h( .defines = reduce_defines}); if (!core_group_2.ranges().empty()) { - std::vector compute_kernel_args_group_2 = { - Ht, // Ht - num_cols_per_core_group_2, // Wt - 1, // NC - }; + std::vector compute_kernel_args_group_2; + if (out_sharded) { + compute_kernel_args_group_2 = { + Ht, // Ht + num_cols_per_core_group_2, // Wt + 1, // NC + }; + } else { + compute_kernel_args_group_2 = { + Ht, // Ht + num_cols_per_core_group_2, // Wt + 1, // NC + chunk_size, // Column Chunk Size + }; + } auto reduce_compute_kernel_group_2_id = tt_metal::CreateKernel( program,