From e30f4503158993ac030865380671a950df35654e Mon Sep 17 00:00:00 2001 From: Nilaykumar Patel Date: Tue, 10 Sep 2024 19:36:29 +0530 Subject: [PATCH] Nkpatel/maxpool yolo v4 support (#12066) Add support for larger kernel sizes for max pool. Signed-off-by: Nilaykumar K Patel --- .../unit_tests/operations/test_maxpool2d.py | 172 ++++++++++----- .../max_pool_multi_core_large_kernel.cpp | 206 ++++++++++++++++++ ...core_sharded_with_halo_large_kernel_v2.cpp | 156 +++++++++++++ .../max_pool2d_multi_core_program_factory.cpp | 44 +++- .../sliding_window/sliding_window.cpp | 2 +- 5 files changed, 515 insertions(+), 65 deletions(-) create mode 100644 ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/compute/max_pool_multi_core_large_kernel.cpp create mode 100644 ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/dataflow/reader_max_pool_2d_multi_core_sharded_with_halo_large_kernel_v2.cpp diff --git a/tests/ttnn/unit_tests/operations/test_maxpool2d.py b/tests/ttnn/unit_tests/operations/test_maxpool2d.py index 69264b06ed4..d4328ed3dfd 100644 --- a/tests/ttnn/unit_tests/operations/test_maxpool2d.py +++ b/tests/ttnn/unit_tests/operations/test_maxpool2d.py @@ -15,66 +15,7 @@ from ttnn.operations.conv2d import determine_parallel_config, create_sharded_memory_config_from_parallel_config -@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576}], indirect=True) -@pytest.mark.parametrize( - "act_shape", ## NCHW - ( - ( ## resnet shapes - [1, 64, 112, 112], - [4, 64, 112, 112], - [8, 64, 112, 112], - [16, 64, 112, 112], - # [20, 64, 112, 112], ## oom - ## hpr shapes - [8, 32, 132, 20], - [16, 32, 132, 20], - [32, 32, 132, 20], - [64, 32, 132, 20], - [128, 32, 132, 20], - # [256, 32, 132, 20], ## oom - [8, 32, 264, 40], - [16, 32, 264, 40], - [32, 32, 264, 40], - # [64, 32, 264, 40], ## oom - # [128, 32, 264, 40], ## oom - # [256, 32, 264, 40], ## oom - [4, 16, 1056, 160], - # [8, 16, 1056, 160], ## oom - # [16, 16, 1056, 160], ## oom - # [32, 16, 1056, 160], ## oom - # [64, 16, 1056, 160], ## oom - # [128, 16, 1056, 160], ## oom - # [256, 16, 1056, 160], ## oom - [8, 16, 528, 80], - [16, 16, 528, 80], - # [32, 16, 528, 80], ## oom - # [64, 16, 528, 80], ## oom - # [128, 16, 528, 80], ## oom - # [256, 16, 528, 80], ## oom - ) - ), -) -@pytest.mark.parametrize( - "kernel_size", - ( - (2, 2), - (3, 3), - ), -) -@pytest.mark.parametrize( - "padding", - ( - (0, 0), - (1, 1), - ), -) -@pytest.mark.parametrize( - "stride", - ((2, 2),), -) -@pytest.mark.parametrize("dilation", ((1, 1),)) ## default -@pytest.mark.parametrize("dtype", [ttnn.bfloat16, ttnn.bfloat8_b]) -def test_run_max_pool( +def run_max_pool( act_shape, kernel_size, padding, @@ -210,6 +151,117 @@ def test_run_max_pool( assert isequal +@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576}], indirect=True) +@pytest.mark.parametrize( + "act_shape", ## NCHW + ( + ( ## resnet shapes + [1, 64, 112, 112], + [4, 64, 112, 112], + [8, 64, 112, 112], + [16, 64, 112, 112], + # [20, 64, 112, 112], ## oom + ## hpr shapes + [8, 32, 132, 20], + [16, 32, 132, 20], + [32, 32, 132, 20], + [64, 32, 132, 20], + [128, 32, 132, 20], + # [256, 32, 132, 20], ## oom + [8, 32, 264, 40], + [16, 32, 264, 40], + [32, 32, 264, 40], + # [64, 32, 264, 40], ## oom + # [128, 32, 264, 40], ## oom + # [256, 32, 264, 40], ## oom + [4, 16, 1056, 160], + # [8, 16, 1056, 160], ## oom + # [16, 16, 1056, 160], ## oom + # [32, 16, 1056, 160], ## oom + # [64, 16, 1056, 160], ## oom + # [128, 16, 1056, 160], ## oom + # [256, 16, 1056, 160], ## oom + [8, 16, 528, 80], + [16, 16, 528, 80], + # [32, 16, 528, 80], ## oom + # [64, 16, 528, 80], ## oom + # [128, 16, 528, 80], ## oom + # [256, 16, 528, 80], ## oom + ) + ), +) +@pytest.mark.parametrize( + "kernel_size", + ( + (2, 2), + (3, 3), + ), +) +@pytest.mark.parametrize( + "padding", + ( + (0, 0), + (1, 1), + ), +) +@pytest.mark.parametrize( + "stride", + ((2, 2),), +) +@pytest.mark.parametrize("dilation", ((1, 1),)) ## default +@pytest.mark.parametrize("dtype", [ttnn.bfloat16, ttnn.bfloat8_b]) +def test_run_max_pool( + act_shape, + kernel_size, + padding, + stride, + dilation, + device, + dtype, +): + run_max_pool(act_shape, kernel_size, padding, stride, dilation, device, dtype) + + +@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576}], indirect=True) +@pytest.mark.parametrize( + "act_shape", ## NCHW + (([1, 512, 10, 10],)), ## yolov4 shapes +) +@pytest.mark.parametrize( + "kernel_size", + ( + (5, 5), + (9, 9), + (13, 13), + # (3, 3), + ), +) +@pytest.mark.parametrize( + "padding", + ( + (2, 2), + (4, 4), + (6, 6), + ), +) +@pytest.mark.parametrize( + "stride", + ((1, 1),), +) +@pytest.mark.parametrize("dilation", ((1, 1),)) ## default +@pytest.mark.parametrize("dtype", [ttnn.bfloat16, ttnn.bfloat8_b]) +def test_run_max_pool_yolov4( + act_shape, + kernel_size, + padding, + stride, + dilation, + device, + dtype, +): + run_max_pool(act_shape, kernel_size, padding, stride, dilation, device, dtype) + + @pytest.mark.skip("See GH issue #12285") @pytest.mark.parametrize("device_params", [{"l1_small_size": 24576}], indirect=True) @pytest.mark.parametrize( diff --git a/ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/compute/max_pool_multi_core_large_kernel.cpp b/ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/compute/max_pool_multi_core_large_kernel.cpp new file mode 100644 index 00000000000..93b75a1f7c1 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/compute/max_pool_multi_core_large_kernel.cpp @@ -0,0 +1,206 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +// #include "compute_kernel_api.h" +#include "compute_kernel_api/pack_untilize.h" +#include "compute_kernel_api/reduce.h" +#include "compute_kernel_api/tilize.h" +// #include "tools/profiler/kernel_profiler.hpp" + +#define DEBUG_PRINT 0 + +#if DEBUG_PRINT == 1 +#include "debug/dprint.h" +// #include "debug_macros.h" + +// SliceRange srt = SliceRange{.h0 = 0, .h1 = 32, .hs = 8, .w0 = 0, .w1 = 32, .ws = 4}; +// SliceRange srr = SliceRange{.h0 = 0, .h1 = 1, .hs = 8, .w0 = 0, .w1 = 32, .ws = 1}; +// SliceRange srr1 = SliceRange{.h0 = 1, .h1 = 2, .hs = 8, .w0 = 0, .w1 = 32, .ws = 1}; +// SliceRange src = SliceRange{.h0 = 0, .h1 = 32, .hs = 1, .w0 = 0, .w1 = 1, .ws = 1}; + +inline void print_tile_rows(uint32_t cb_id, uint32_t rows = 32, uint32_t tile_id = 0, bool untilize = false) { + // UNPACK(( DPRINT << "======" << ENDL() )); + for (uint16_t r = 0; r < rows; ++r) { + SliceRange sr = SliceRange{.h0 = r, .h1 = (uint16_t)(r + 1), .hs = 1, .w0 = 0, .w1 = 32, .ws = 1}; + // UNPACK(( DPRINT << (uint)r << " :: " << TileSlice(cb_id, tile_id, sr, true, untilize) << ENDL() )); + UNPACK((DPRINT << (uint)r << " :: " << TileSlice(cb_id, tile_id, sr, true, untilize))); + } + // UNPACK(( DPRINT << "++++++" << ENDL() )); +} + +inline void print_full_tile(uint32_t cb_id, uint32_t tile_id = 0, bool untilize = false) { + UNPACK((DPRINT << "======" << ENDL())); + for (uint16_t r = 0; r < 32; ++r) { + SliceRange sr = SliceRange{.h0 = r, .h1 = (uint16_t)(r + 1), .hs = 1, .w0 = 0, .w1 = 32, .ws = 1}; + UNPACK((DPRINT << (uint)r << " : " << TileSlice(cb_id, tile_id, sr, true, untilize) << ENDL())); + } + UNPACK((DPRINT << "++++++" << ENDL())); +} + +// inline void print_cb_details(uint32_t cb_id) { +// DPRINT << "cb_id " << cb_id << ": { " +// << "size: " << cb_interface[cb_id].fifo_size << ", " +// << "limit: " << cb_interface[cb_id].fifo_limit << ", " +// << "page_size: " << cb_interface[cb_id].fifo_page_size << ", " +// << "num_pages: " << cb_interface[cb_id].fifo_num_pages << ", " +// << "rd_ptr: " << cb_interface[cb_id].fifo_rd_ptr << ", " +// << "wr_ptr: " << cb_interface[cb_id].fifo_wr_ptr << ", " +// << "wr_tile_ptr: " << cb_interface[cb_id].fifo_wr_tile_ptr << " }" << ENDL(); +// } +#endif + +template < + uint32_t in_ntiles_hw, + uint32_t in_ntiles_c, + uint32_t out_ntiles_c, + uint32_t nblocks, + bool is_partial_tile, + uint32_t split_reader> +inline void reduce_h_fused( + const uint32_t in_cb_id, + const uint32_t in_scalar_cb_id, + const uint32_t num_tiles_for_reduction, + const uint32_t in_stick_index, + const uint32_t out_cb_id, + const uint32_t unpA_face_r_dim) { + constexpr uint32_t num_output_tiles = out_ntiles_c * nblocks; + constexpr uint32_t num_faces_in_tile = is_partial_tile ? 1 : 2; + constexpr uint32_t num_out_rows = 1; + for (uint32_t out_elem_i = 0; out_elem_i < nblocks; ++out_elem_i) { + const uint32_t curr_in_cb_id = + split_reader ? (in_cb_id + (in_stick_index * nblocks + out_elem_i) & 0x1) : in_cb_id; + cb_wait_front(curr_in_cb_id, 1); + unpack_tilizeA_B_block( + curr_in_cb_id, + in_scalar_cb_id, + num_tiles_for_reduction, + 0 /*tile idx for Src b is 0 because only 1 tile of constants is loaded*/, + num_faces_in_tile /* unpack 1 or 2 faces ) */, + unpA_face_r_dim); + for (uint32_t c_i = 0; c_i < num_tiles_for_reduction; ++c_i) { + reduce_tile_math(in_ntiles_c * out_elem_i + c_i, num_faces_in_tile /* reduce 1 or 2 faces */); + } + cb_pop_front(curr_in_cb_id, 1); + } +} + +namespace NAMESPACE { + +void MAIN { + // NOTE: here it is assumed that in_ntiles_hw == 1. General cases not handled yet. + constexpr uint32_t in_ntiles_hw = get_compile_time_arg_val(0); + constexpr uint32_t in_ntiles_c = get_compile_time_arg_val(1); + constexpr uint32_t in_ntiles_hwc = get_compile_time_arg_val(2); + constexpr uint32_t window_size_hw = get_compile_time_arg_val(3); + constexpr uint32_t out_h = get_compile_time_arg_val(4); + constexpr uint32_t out_w = get_compile_time_arg_val(5); + constexpr uint32_t out_ntiles_c = get_compile_time_arg_val(7); + constexpr uint32_t nblocks = get_compile_time_arg_val(8); + + constexpr uint32_t split_reader = get_compile_time_arg_val(12); + + constexpr uint32_t nsticks_per_core_by_nblocks = get_compile_time_arg_val(13); + constexpr uint32_t in_c = get_compile_time_arg_val(14); + constexpr uint32_t num_output_tiles = out_ntiles_c * nblocks; + + constexpr uint32_t in_cb_id = tt::CB::c_in0; // and tt::CB::c_in1 for split reader + constexpr uint32_t in_scalar_cb_id = tt::CB::c_in4; + constexpr uint32_t in_tiled_cb_id = tt::CB::c_intermed0; + constexpr uint32_t out_cb_id = tt::CB::c_out0; + constexpr uint32_t interm_reduction_cb_id = tt::CB::c_intermed1; + + constexpr bool is_partial_tile = in_c < 32; + static_assert((!is_partial_tile || (in_c == 16)), "Partial tile must have c_dim 16"); + constexpr uint32_t num_faces_in_tile = is_partial_tile ? 1 : 2; + constexpr uint32_t num_out_rows = 1; + constexpr uint32_t MAX_ROWS_FOR_REDUCTION = 16; + constexpr uint32_t MAX_TILES_PER_REDUCTION = 8; + + constexpr uint32_t num_tiles_for_reduction = + in_ntiles_hwc > MAX_TILES_PER_REDUCTION ? MAX_TILES_PER_REDUCTION : in_ntiles_hwc; + uint32_t num_8_tiles_blocks = 1; + if (num_output_tiles > MAX_TILES_PER_REDUCTION) { + num_8_tiles_blocks = + num_output_tiles / MAX_TILES_PER_REDUCTION; // For now, only pow of 2 number of channels are supported. + } + + tilizeA_B_reduce_init( + in_cb_id, + in_scalar_cb_id, + num_tiles_for_reduction, + interm_reduction_cb_id, + num_faces_in_tile, + MAX_ROWS_FOR_REDUCTION); + + uint32_t interm_reduction_chunks = window_size_hw / MAX_ROWS_FOR_REDUCTION; + cb_wait_front(in_scalar_cb_id, 1); + cb_reserve_back(out_cb_id, 1); + for (uint32_t i = 0; i < nsticks_per_core_by_nblocks; ++i) { + for (uint32_t j = 0; j < num_8_tiles_blocks; j++) { + // NOTE: Assuming in_ntiles_hw < 8 for now. + // TODO: subblocking to support this. + uint32_t out_write_idx = i * num_8_tiles_blocks + j; + + pack_untilize_dst_init_short( + interm_reduction_cb_id, num_out_rows, num_faces_in_tile); + cb_reserve_back(interm_reduction_cb_id, 1); + for (uint32_t h = 0; h <= interm_reduction_chunks; h++) { + tile_regs_acquire(); + + reduce_h_fused( + in_cb_id, + in_scalar_cb_id, + num_tiles_for_reduction, + i, + interm_reduction_cb_id, + MAX_ROWS_FOR_REDUCTION); + tile_regs_commit(); + tile_regs_wait(); + pack_untilize_dst( + interm_reduction_cb_id, + 1 /*out_subblock_h*/, + h, + num_out_rows, + num_faces_in_tile); /* pack 1 row (1x16 or 1x32) */ + tile_regs_release(); + } + cb_push_back(interm_reduction_cb_id, 1); + pack_untilize_uninit(interm_reduction_cb_id); + cb_wait_front(interm_reduction_cb_id, 1); + pack_untilize_dst_init_short( + out_cb_id, num_out_rows, num_faces_in_tile); + + tile_regs_acquire(); + unpack_tilizeA_B_block( + interm_reduction_cb_id, + in_scalar_cb_id, + num_tiles_for_reduction, + 0 /*tile idx for Src b is 0 because only 1 tile of constants is loaded*/, + num_faces_in_tile /* unpack 1 or 2 faces ) */, + MAX_ROWS_FOR_REDUCTION); + for (uint32_t c_i = 0; c_i < num_tiles_for_reduction; ++c_i) { + reduce_tile_math(c_i, num_faces_in_tile /* reduce 1 or 2 faces */); + } + + tile_regs_commit(); + tile_regs_wait(); + pack_untilize_dst( + out_cb_id, + 1 /*out_subblock_h*/, + out_write_idx, + num_out_rows, + num_faces_in_tile); /* pack 1 row (1x16 or 1x32) */ + tile_regs_release(); + cb_pop_front(interm_reduction_cb_id, 1); + pack_untilize_uninit(out_cb_id); + } + } + // print_full_tile(out_cb_id); + cb_push_back(out_cb_id, 1); + cb_pop_front(in_scalar_cb_id, 1); +} + +} // namespace NAMESPACE diff --git a/ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/dataflow/reader_max_pool_2d_multi_core_sharded_with_halo_large_kernel_v2.cpp b/ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/dataflow/reader_max_pool_2d_multi_core_sharded_with_halo_large_kernel_v2.cpp new file mode 100644 index 00000000000..3a63cc8d04d --- /dev/null +++ b/ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/dataflow/reader_max_pool_2d_multi_core_sharded_with_halo_large_kernel_v2.cpp @@ -0,0 +1,156 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include +#include + +#include "dataflow_api.h" + +#define ENABLE_DEBUG_PRINT 0 + +#if ENABLE_DEBUG_PRINT == 1 +#include "debug/dprint.h" + +inline void print_pages(uint32_t l1_addr, uint32_t pagelen, uint32_t npages, uint32_t start = 0) { + volatile tt_l1_ptr uint16_t* ptr = reinterpret_cast(l1_addr) + start * pagelen; + for (uint32_t page = 0; page < npages; ++page) { + DPRINT << start + page << ": "; + for (uint32_t j = 0; j < pagelen; ++j, ++ptr) { + DPRINT << BF16(*ptr) << " "; + } + DPRINT << ENDL(); + } +} +#endif + +#define ALWI inline __attribute__((always_inline)) + +// Fill an L1 buffer with the given val +// WARNING: Use with caution as there's no memory protection. Make sure size is within limits +ALWI bool fill_with_val(uint32_t begin_addr, uint32_t n, uint16_t val) { + // simplest impl: + volatile tt_l1_ptr uint32_t* ptr = reinterpret_cast(begin_addr); + for (uint32_t i = 0; i < n / 2; ++i) { + ptr[i] = (val | (val << 16)); + } + return true; +} + +/** + * Max-pool 2D. + */ +void kernel_main() { + const uint32_t reader_nindices = get_compile_time_arg_val(0); + const uint32_t window_h = get_compile_time_arg_val(1); + const uint32_t window_w = get_compile_time_arg_val(2); + + const int32_t pad_w = get_compile_time_arg_val(3); + + // channel size in bytes, multiple of 32 + const uint32_t in_nbytes_c = get_compile_time_arg_val(4); + const uint32_t in_nbytes_c_log2 = get_compile_time_arg_val(5); + + // input tensor height / width / channels + const int32_t in_w = get_compile_time_arg_val(6); + const uint32_t in_cb_nsticks = get_compile_time_arg_val(7); + + const uint32_t in_c = get_compile_time_arg_val(8); + const uint32_t nblocks = get_compile_time_arg_val(9); + + const uint32_t split_reader = get_compile_time_arg_val(10); + const uint32_t reader_id = get_compile_time_arg_val(11); + + // compile time args + // value of 1 in bf16 in a uin32_t + constexpr uint32_t bf16_one_u32 = get_compile_time_arg_val(12); + + // static_assert(0 == reader_nindices%2, "reader_nindices must be multiple of 2"); + + constexpr uint32_t TILE_SIZE = 32 * 32; + constexpr uint32_t MAX_TILES_PER_REDUCTION = 8; + constexpr uint32_t MAX_ROWS_FOR_REDUCTION = 16; + constexpr uint32_t MAX_ELE_PER_REDUCTION = 512; + + constexpr uint32_t in_cb_id = (reader_id == 1) ? tt::CB::c_in1 : tt::CB::c_in0; + constexpr uint32_t in_shard_cb_id = tt::CB::c_in2; // local input shard + constexpr uint32_t in_reader_indices_cb_id = tt::CB::c_in3; + constexpr uint32_t in_scalar_cb_id = tt::CB::c_in4; + constexpr uint32_t interm_reduction_cb_id = tt::CB::c_intermed1; + + constexpr uint32_t ROW_HW = 64; + + // minus infinity for bfp16 + uint16_t minus_inf = 63487; + // Reduce scalar = 1 + if (reader_id == 0) { + cb_reserve_back(in_scalar_cb_id, 1); + + uint32_t bf16_one_u16 = bf16_one_u32 >> 16; + // fill 1 row w/ scalar + fill_with_val(get_write_ptr(in_scalar_cb_id), ROW_HW, bf16_one_u16); + // fill interm buffer with minus_inf + fill_with_val(get_write_ptr(interm_reduction_cb_id), TILE_SIZE * MAX_TILES_PER_REDUCTION, minus_inf); + cb_push_back(in_scalar_cb_id, 1); + } + + uint32_t in_l1_read_base_addr = get_read_ptr(in_shard_cb_id); + uint32_t reader_indices_l1_addr = get_read_ptr(in_reader_indices_cb_id); + volatile tt_l1_ptr uint16_t* reader_indices_ptr = + reinterpret_cast(reader_indices_l1_addr); + + uint32_t in_w_padded = in_w + 2 * pad_w; + + uint32_t npages_to_reserve = nblocks; + uint32_t num_8_tile_blocks = 1; + uint32_t read_bytes = in_nbytes_c; + if (in_nbytes_c > MAX_ELE_PER_REDUCTION) { + num_8_tile_blocks = in_nbytes_c / MAX_ELE_PER_REDUCTION; + read_bytes = MAX_ELE_PER_REDUCTION; // for now, pow of 2 channels are only supported. + } + uint32_t counter = reader_id; + uint32_t total_elems_to_reduce = window_h * window_w; + uint32_t remaining_elems = total_elems_to_reduce % MAX_ROWS_FOR_REDUCTION; + while (counter < reader_nindices) { + for (uint32_t j = 0; j < num_8_tile_blocks; j++) { + for (uint32_t i = 0; i < nblocks; ++i) { + uint16_t top_left_local_index = reader_indices_ptr[counter]; + uint32_t h_multiples = 0; + uint32_t processed_rows = 0; + uint32_t out_l1_write_addr_base = get_write_ptr(in_cb_id); + uint32_t out_l1_write_addr = out_l1_write_addr_base; + cb_reserve_back(in_cb_id, npages_to_reserve); + for (uint32_t h = 0; h < window_h; ++h, h_multiples += in_w_padded) { + uint32_t stick_offset = top_left_local_index + h_multiples; + uint32_t read_offset = + j * MAX_ELE_PER_REDUCTION + in_l1_read_base_addr + (stick_offset << in_nbytes_c_log2); + for (uint32_t w = 0; w < window_w; w++) { + noc_async_read_one_packet(get_noc_addr(read_offset), out_l1_write_addr, read_bytes); + out_l1_write_addr += read_bytes; + read_offset += in_nbytes_c; + processed_rows++; + if ((processed_rows % MAX_ROWS_FOR_REDUCTION) == 0) { + noc_async_read_barrier(); + cb_push_back(in_cb_id, npages_to_reserve); + out_l1_write_addr_base = get_write_ptr(in_cb_id); + out_l1_write_addr = out_l1_write_addr_base; + cb_reserve_back(in_cb_id, npages_to_reserve); + // If next is last chunk, fill whole buffer with -inf. + if ((total_elems_to_reduce - processed_rows) < MAX_ROWS_FOR_REDUCTION) + fill_with_val(out_l1_write_addr, TILE_SIZE * MAX_TILES_PER_REDUCTION, minus_inf); + } + } + } + if (remaining_elems) { + noc_async_read_barrier(); + cb_push_back(in_cb_id, npages_to_reserve); + } + } + } + counter++; + if (split_reader) + counter++; // interleave the indices + } +} // kernel_main() diff --git a/ttnn/cpp/ttnn/operations/pool/maxpool/device/max_pool2d_multi_core_program_factory.cpp b/ttnn/cpp/ttnn/operations/pool/maxpool/device/max_pool2d_multi_core_program_factory.cpp index aa35cf37db4..3c3e895f3f9 100644 --- a/ttnn/cpp/ttnn/operations/pool/maxpool/device/max_pool2d_multi_core_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/pool/maxpool/device/max_pool2d_multi_core_program_factory.cpp @@ -63,6 +63,10 @@ MaxPool2D::MultiCore::cached_program_t max_pool_2d_multi_core_sharded_with_halo_ uint32_t in_ntiles_hw = (uint32_t)std::ceil((float)kernel_size_hw_padded / tt::constants::TILE_HEIGHT); uint32_t in_ntiles_c = (uint32_t)std::ceil((float)input_shape[3] / tt::constants::TILE_WIDTH); uint32_t out_ntiles_c = (uint32_t)std::ceil((float)output_shape[3] / tt::constants::TILE_WIDTH); + uint32_t MAX_SMALL_KERNEL_SIZE_HW = 16; + // Hardware can do reduction of 8 tiles at a time. + // CB sizes can be restricted to this in case input channels are more than 256 to perform reduction iteratively. + uint32_t MAX_TILES_PER_REDUCTION = 8; TT_ASSERT(nblocks == 1, "Multiple blocks not yet supported"); @@ -137,11 +141,15 @@ MaxPool2D::MultiCore::cached_program_t max_pool_2d_multi_core_sharded_with_halo_ .set_globally_allocated_address(*reader_indices_buffer); auto in_reader_indices_cb = tt::tt_metal::CreateCircularBuffer(program, all_cores, in_reader_indices_cb_config); + auto in_cb_sz = + (input_shape[3] * kernel_size_hw_padded) > (tt::constants::TILE_HW * MAX_TILES_PER_REDUCTION) + ? (tt::constants::TILE_HW * MAX_TILES_PER_REDUCTION) + : input_shape[3] * kernel_size_hw_padded; // reader output == input to tilize uint32_t in_cb_id_0 = tt::CB::c_in0; // input rows for "multiple (out_nelems)" output pixels uint32_t in_cb_id_1 = tt::CB::c_in1; // input rows for "multiple (out_nelems)" output pixels uint32_t in_cb_page_padded = ceil_multiple_of( - input_shape[3] * kernel_size_hw_padded, + in_cb_sz, tt::constants::TILE_HW); // NOTE: ceil to tile size since triscs work with tilesize instead of pagesize uint32_t in_cb_pagesize = in_nbytes * in_cb_page_padded; uint32_t in_cb_npages = multi_buffering_factor * nblocks; @@ -168,6 +176,23 @@ MaxPool2D::MultiCore::cached_program_t max_pool_2d_multi_core_sharded_with_halo_ auto in_tiled_cb = tt::tt_metal::CreateCircularBuffer(program, all_cores, in_tiled_cb_config); log_debug(tt::LogOp, "CB {} :: PS = {}, NP = {}", in_tiled_cb_id, in_tiled_cb_pagesize, in_tiled_cb_npages); + if (kernel_size_hw > MAX_SMALL_KERNEL_SIZE_HW) { + uint32_t max_pool_partials_cb_id = tt::CB::c_intermed1; // max_pool partials + uint32_t max_pool_partials_cb_pagesize = in_cb_sz; + uint32_t max_pool_partials_cb_npages = nblocks; + CircularBufferConfig max_pool_partials_cb_config = + CircularBufferConfig( + max_pool_partials_cb_npages * max_pool_partials_cb_pagesize, {{max_pool_partials_cb_id, in_df}}) + .set_page_size(max_pool_partials_cb_id, max_pool_partials_cb_pagesize); + auto max_pool_partials_cb = tt::tt_metal::CreateCircularBuffer(program, all_cores, max_pool_partials_cb_config); + log_debug( + tt::LogOp, + "CB {} :: PS = {}, NP = {}", + max_pool_partials_cb_id, + max_pool_partials_cb_pagesize, + max_pool_partials_cb_npages); + } + // output of reduce == writer to write uint32_t out_cb_id = tt::CB::c_out0; // output rows in RM // uint32_t out_cb_pagesize = tile_size(out_df); @@ -263,8 +288,14 @@ MaxPool2D::MultiCore::cached_program_t max_pool_2d_multi_core_sharded_with_halo_ 1, // split reader id bf16_one_u32}; - std::string reader_kernel_fname( - "ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/dataflow/reader_max_pool_2d_multi_core_sharded_with_halo_v2.cpp"); + std::string reader_kernel_fname; + if(kernel_size_hw > MAX_SMALL_KERNEL_SIZE_HW) + reader_kernel_fname = + "ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/dataflow/reader_max_pool_2d_multi_core_sharded_with_halo_large_kernel_v2.cpp"; + else + reader_kernel_fname = + "ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/dataflow/reader_max_pool_2d_multi_core_sharded_with_halo_v2.cpp"; + auto reader0_config = DataMovementConfig{ .processor = DataMovementProcessor::RISCV_0, .noc = NOC::RISCV_0_default, .compile_args = reader0_ct_args}; @@ -305,7 +336,12 @@ MaxPool2D::MultiCore::cached_program_t max_pool_2d_multi_core_sharded_with_halo_ .math_approx_mode = false, .compile_args = compute_ct_args, .defines = reduce_op_utils::get_defines(reduce_op, reduce_dim)}; - std::string compute_kernel_fname("ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/compute/max_pool_multi_core.cpp"); + std::string compute_kernel_fname; + if(kernel_size_hw > MAX_SMALL_KERNEL_SIZE_HW) + compute_kernel_fname = "ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/compute/max_pool_multi_core_large_kernel.cpp"; + else + compute_kernel_fname = "ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/compute/max_pool_multi_core.cpp"; + auto compute_kernel = CreateKernel(program, compute_kernel_fname, core_range, compute_config); return { std::move(program), { diff --git a/ttnn/cpp/ttnn/operations/sliding_window/sliding_window.cpp b/ttnn/cpp/ttnn/operations/sliding_window/sliding_window.cpp index b0fd0abf61c..d4b26c6f1f9 100644 --- a/ttnn/cpp/ttnn/operations/sliding_window/sliding_window.cpp +++ b/ttnn/cpp/ttnn/operations/sliding_window/sliding_window.cpp @@ -89,7 +89,7 @@ std::vector> generate_shard_boundaries(c uint32_t output_index_start = 0; for (uint32_t core = 0; core < num_cores; ++ core) { uint32_t output_index_end = std::min(output_index_start + output_shard_h, max_index) - 1; - uint32_t input_index_start = op_trace_metadata[output_index_start]; + uint32_t input_index_start = op_trace_metadata[std::min(output_index_start, max_index - 1)]; uint32_t input_index_end = op_trace_metadata[output_index_end] + halo_with_pad_len; if (input_index_start == 0 and output_index_start != 0) { input_index_start = op_trace_metadata[output_index_end] + 1;