diff --git a/tests/ttnn/unit_tests/operations/test_maxpool2d.py b/tests/ttnn/unit_tests/operations/test_maxpool2d.py index 43fa209acb0..6dab6291762 100644 --- a/tests/ttnn/unit_tests/operations/test_maxpool2d.py +++ b/tests/ttnn/unit_tests/operations/test_maxpool2d.py @@ -269,6 +269,8 @@ def run_max_pool( [1, 512, 14, 14], # wide yolo kernel [1, 512, 10, 10], + [1, 96, 112, 112], + [1, 192, 132, 20], ) ), ) @@ -335,6 +337,7 @@ def test_run_max_pool( [8, 4096, 10, 16], # wide yolo kernel [1, 32768, 10, 10], + [1, 6144, 6, 6], ) ), ) @@ -429,6 +432,8 @@ def test_run_max_pool_width_shard( [16, 16, 528, 80], # wide yolo kernel [1, 4096, 10, 10], + [1, 768, 56, 56], + [1, 1280, 8, 6], ) ), ) diff --git a/ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/dataflow/reader_max_pool_2d_multi_core_sharded_with_halo_large_kernel_v2.cpp b/ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/dataflow/reader_max_pool_2d_multi_core_sharded_with_halo_large_kernel_v2.cpp index dae7348ab1d..be922e6da3c 100644 --- a/ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/dataflow/reader_max_pool_2d_multi_core_sharded_with_halo_large_kernel_v2.cpp +++ b/ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/dataflow/reader_max_pool_2d_multi_core_sharded_with_halo_large_kernel_v2.cpp @@ -49,24 +49,24 @@ void kernel_main() { const int32_t pad_w = get_compile_time_arg_val(3); - // channel size in bytes, multiple of 32 + // channel size in bytes const uint32_t in_nbytes_c = get_compile_time_arg_val(4); // input tensor height / width / channels - const int32_t in_w = get_compile_time_arg_val(6); - const uint32_t in_cb_nsticks = get_compile_time_arg_val(7); + const int32_t in_w = get_compile_time_arg_val(5); + const uint32_t in_cb_nsticks = get_compile_time_arg_val(6); - const uint32_t in_c = get_compile_time_arg_val(8); + const uint32_t in_c = get_compile_time_arg_val(7); - const uint32_t split_reader = get_compile_time_arg_val(10); - const uint32_t reader_id = get_compile_time_arg_val(11); + const uint32_t split_reader = get_compile_time_arg_val(9); + const uint32_t reader_id = get_compile_time_arg_val(10); // compile time args // value of 1 in bf16 in a uin32_t - constexpr uint32_t bf16_one_u32 = get_compile_time_arg_val(12); - constexpr uint32_t in_nblocks_c = get_compile_time_arg_val(13); - constexpr uint32_t in_cb_sz = get_compile_time_arg_val(14); - constexpr uint32_t max_rows_for_reduction = get_compile_time_arg_val(15); + constexpr uint32_t bf16_one_u32 = get_compile_time_arg_val(11); + constexpr uint32_t in_nblocks_c = get_compile_time_arg_val(12); + constexpr uint32_t in_cb_sz = get_compile_time_arg_val(13); + constexpr uint32_t max_rows_for_reduction = get_compile_time_arg_val(14); constexpr uint32_t TILE_SIZE = 32 * 32; constexpr uint32_t MAX_TILES_PER_REDUCTION = 8; diff --git a/ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/dataflow/reader_max_pool_2d_multi_core_sharded_with_halo_v2.cpp b/ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/dataflow/reader_max_pool_2d_multi_core_sharded_with_halo_v2.cpp index 9f692303666..a313d7cf73d 100644 --- a/ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/dataflow/reader_max_pool_2d_multi_core_sharded_with_halo_v2.cpp +++ b/ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/dataflow/reader_max_pool_2d_multi_core_sharded_with_halo_v2.cpp @@ -46,23 +46,22 @@ void kernel_main() { const int32_t pad_w = get_compile_time_arg_val(3); - // channel size in bytes, multiple of 32 + // channel size in bytes const uint32_t in_nbytes_c = get_compile_time_arg_val(4); - const uint32_t in_nbytes_c_log2 = get_compile_time_arg_val(5); // input tensor height / width / channels - const int32_t in_w = get_compile_time_arg_val(6); - const uint32_t in_cb_nsticks = get_compile_time_arg_val(7); + const int32_t in_w = get_compile_time_arg_val(5); + const uint32_t in_cb_nsticks = get_compile_time_arg_val(6); - const uint32_t in_c = get_compile_time_arg_val(8); + const uint32_t in_c = get_compile_time_arg_val(7); - const uint32_t split_reader = get_compile_time_arg_val(10); - const uint32_t reader_id = get_compile_time_arg_val(11); + const uint32_t split_reader = get_compile_time_arg_val(9); + const uint32_t reader_id = get_compile_time_arg_val(10); // value of 1 in bf16 in a uin32_t - constexpr uint32_t bf16_one_u32 = get_compile_time_arg_val(12); + constexpr uint32_t bf16_one_u32 = get_compile_time_arg_val(11); - constexpr uint32_t in_nblocks_c = get_compile_time_arg_val(13); + constexpr uint32_t in_nblocks_c = get_compile_time_arg_val(12); constexpr uint32_t TILE_WIDTH = 32; @@ -99,7 +98,7 @@ void kernel_main() { uint32_t h_multiples = 0; for (uint32_t h = 0; h < window_h; ++ h, h_multiples += in_w_padded) { uint32_t stick_offset = top_left_local_index + h_multiples; - uint32_t read_offset = in_l1_read_base_addr + (stick_offset << in_nbytes_c_log2); + uint32_t read_offset = in_l1_read_base_addr + (stick_offset * in_nbytes_c); noc_async_read_one_packet(get_noc_addr(read_offset), out_l1_write_addr, in_nbytes_c * window_w); out_l1_write_addr += in_nbytes_c * window_w; } diff --git a/ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/dataflow/reader_max_pool_2d_multi_core_sharded_with_halo_wide.cpp b/ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/dataflow/reader_max_pool_2d_multi_core_sharded_with_halo_wide.cpp index 2556fc53fcb..c7bb703e645 100644 --- a/ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/dataflow/reader_max_pool_2d_multi_core_sharded_with_halo_wide.cpp +++ b/ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/dataflow/reader_max_pool_2d_multi_core_sharded_with_halo_wide.cpp @@ -48,21 +48,20 @@ void kernel_main() { // channel size in bytes, multiple of 32 const uint32_t in_nbytes_c = get_compile_time_arg_val(4); - const uint32_t in_nbytes_c_log2 = get_compile_time_arg_val(5); // input tensor height / width / channels - const int32_t in_w = get_compile_time_arg_val(6); - const uint32_t in_cb_nsticks = get_compile_time_arg_val(7); + const int32_t in_w = get_compile_time_arg_val(5); + const uint32_t in_cb_nsticks = get_compile_time_arg_val(6); - const uint32_t in_c = get_compile_time_arg_val(8); + const uint32_t in_c = get_compile_time_arg_val(7); - const uint32_t split_reader = get_compile_time_arg_val(10); - const uint32_t reader_id = get_compile_time_arg_val(11); + const uint32_t split_reader = get_compile_time_arg_val(9); + const uint32_t reader_id = get_compile_time_arg_val(10); // value of 1 in bf16 in a uin32_t - constexpr uint32_t bf16_one_u32 = get_compile_time_arg_val(12); + constexpr uint32_t bf16_one_u32 = get_compile_time_arg_val(11); - constexpr uint32_t in_nblocks_c = get_compile_time_arg_val(13); + constexpr uint32_t in_nblocks_c = get_compile_time_arg_val(12); // static_assert(0 == reader_nindices%2, "reader_nindices must be multiple of 2"); diff --git a/ttnn/cpp/ttnn/operations/pool/maxpool/device/max_pool2d_device_op.cpp b/ttnn/cpp/ttnn/operations/pool/maxpool/device/max_pool2d_device_op.cpp index b57e0de77b4..ed9527415dc 100644 --- a/ttnn/cpp/ttnn/operations/pool/maxpool/device/max_pool2d_device_op.cpp +++ b/ttnn/cpp/ttnn/operations/pool/maxpool/device/max_pool2d_device_op.cpp @@ -20,11 +20,6 @@ void validate_maxpool(const Tensor& input, const sliding_window::SlidingWindowCo TT_FATAL(input.get_dtype() == DataType::BFLOAT16, "Only BFLOAT16 supported for now"); TT_FATAL(input.get_layout() == Layout::ROW_MAJOR, "Only ROW_MAJOR supported for now"); - // NOTE: This is not a hard requirement. If need to support non-power-of-2, simply change the address generator in reader to generic one. - uint32_t in_nbytes_c = (input.get_legacy_shape()[3]) * (input.get_dtype() == DataType::BFLOAT16 ? 2 : 1); - bool is_pow2 = (in_nbytes_c & (in_nbytes_c - 1)) == 0; - TT_FATAL(is_pow2, "Row size (nchannels * bytes = {}) should be power of 2 ({}).", in_nbytes_c, is_pow2); - TT_FATAL(input.memory_config().is_sharded(), "Input needs to be sharded"); TT_FATAL(out_mem_config.is_sharded(), "Output memory config needs to be sharded"); diff --git a/ttnn/cpp/ttnn/operations/pool/maxpool/device/max_pool2d_multi_core_program_factory.cpp b/ttnn/cpp/ttnn/operations/pool/maxpool/device/max_pool2d_multi_core_program_factory.cpp index 577632a9d92..afc92ff0316 100644 --- a/ttnn/cpp/ttnn/operations/pool/maxpool/device/max_pool2d_multi_core_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/pool/maxpool/device/max_pool2d_multi_core_program_factory.cpp @@ -49,9 +49,6 @@ MaxPool2D::MultiCore::cached_program_t max_pool_2d_multi_core_sharded_with_halo_ uint32_t in_nbytes_c = input_shape[3] / num_shards_c * in_nbytes; // row of input (channels) uint32_t out_nbytes_c = output_shape[3] / num_shards_c * out_nbytes; // row of output (channels) - TT_ASSERT((in_nbytes_c & (in_nbytes_c - 1)) == 0, "in_nbytes_c should be power of 2"); // in_nbytes_c is power of 2 - TT_ASSERT( - (out_nbytes_c & (out_nbytes_c - 1)) == 0, "out_nbytes_c should be power of 2"); // out_nbytes_c is power of 2 tt::DataFormat indices_df = tt::DataFormat::RawUInt16; // datatype_to_dataformat_converter(reader_indices.get_dtype()); uint32_t indices_nbytes = datum_size(indices_df); @@ -282,14 +279,12 @@ MaxPool2D::MultiCore::cached_program_t max_pool_2d_multi_core_sharded_with_halo_ */ float one = 1.; uint32_t bf16_one_u32 = *reinterpret_cast(&one); - uint32_t in_nbytes_c_log2 = (uint32_t)std::log2((float)in_nbytes_c); std::vector reader0_ct_args = { out_nhw_per_core, kernel_size_h, kernel_size_w, pad_w, in_nbytes_c, - in_nbytes_c_log2, in_w, in_cb_page_padded * in_cb_npages / tile_w, input_shape[3] / num_shards_c, @@ -307,7 +302,6 @@ MaxPool2D::MultiCore::cached_program_t max_pool_2d_multi_core_sharded_with_halo_ kernel_size_w, pad_w, in_nbytes_c, - in_nbytes_c_log2, in_w, in_cb_page_padded * in_cb_npages / tile_w, input_shape[3] / num_shards_c,