From ce0fb7a2d7b0749eae43a3b8fec377dd73d2a50d Mon Sep 17 00:00:00 2001 From: Austin Ho Date: Thu, 23 May 2024 13:38:53 +0000 Subject: [PATCH] #8837: Optimize cache hit updating RTAs for some ops --- ...op_multi_core_reuse_mcast_2d_optimized.cpp | 101 +- .../op_library/downsample/downsample_op.cpp | 512 +++++--- .../multi_core/move_op_multi_core_sharded.cpp | 75 +- .../op_library/pool/max_pool_multi_core.cpp | 973 ++++++++------- .../multi_core_h/reduce_op_multi_core_h.cpp | 200 ++- .../multi_core_w/reduce_op_multi_core_w.cpp | 119 +- .../multi_core/sharded_op_multi_core.cpp | 3 +- .../tilize_op_multi_core.cpp | 96 +- .../multi_core/untilize_op_multi_core.cpp | 55 +- tt_metal/tt_metal.cpp | 1112 +++++++++-------- 10 files changed, 1754 insertions(+), 1492 deletions(-) diff --git a/tt_eager/tt_dnn/op_library/bmm/multi_core_reuse_mcast_2d_optimized/bmm_op_multi_core_reuse_mcast_2d_optimized.cpp b/tt_eager/tt_dnn/op_library/bmm/multi_core_reuse_mcast_2d_optimized/bmm_op_multi_core_reuse_mcast_2d_optimized.cpp index 4a7455be84e..06181313bbb 100644 --- a/tt_eager/tt_dnn/op_library/bmm/multi_core_reuse_mcast_2d_optimized/bmm_op_multi_core_reuse_mcast_2d_optimized.cpp +++ b/tt_eager/tt_dnn/op_library/bmm/multi_core_reuse_mcast_2d_optimized/bmm_op_multi_core_reuse_mcast_2d_optimized.cpp @@ -408,11 +408,13 @@ operation::ProgramWithCallbacks create_program_mcast_in0_in1( .compile_args = in1_sender_writer_compile_time_args, .defines = mm_kernel_in1_sender_writer_defines}); + auto in1_receiver = + (CoreRangeSet)(std::set){in0_sender_in1_receiver, in0_receiver_in1_receiver_left_half}; auto mm_kernel_in1_receiver_writer_id = tt_metal::CreateKernel( program, "tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_in1_receiver_writer_padding.cpp", /* in0_sender_in1_receiver, // If not using half-half noc setup */ - (CoreRangeSet)(std::set){in0_sender_in1_receiver, in0_receiver_in1_receiver_left_half}, + in1_receiver, tt_metal::DataMovementConfig{ .processor = tt_metal::DataMovementProcessor::RISCV_0, .noc = in1_noc, @@ -639,9 +641,6 @@ operation::ProgramWithCallbacks create_program_mcast_in0_in1( uint32_t last_block_padded_block_tiles_h_skip = (per_core_M / out_subblock_h - last_block_num_nonzero_subblocks_h) * (per_core_N * out_subblock_h); - std::vector reader_kernel_ids; - std::vector writer_kernel_ids; - uint32_t diff_start_coord; uint32_t diff_end_coord; std::vector in0_mcast_noc_x; @@ -668,6 +667,11 @@ operation::ProgramWithCallbacks create_program_mcast_in0_in1( } const auto& cores = grid_to_cores(all_cores.start, all_cores.end, true); + const auto& in0_sender_cores = grid_to_cores(in0_sender.start, in0_sender.end, true); + const auto& in1_sender_cores = grid_to_cores(in1_sender.start, in1_sender.end, true); + const auto& in1_receiver_cores = corerange_to_cores(in1_receiver, std::nullopt, true); + const auto& in1_receiver_other_cores = + grid_to_cores(in0_receiver_in1_receiver_right_half.start, in0_receiver_in1_receiver_right_half.end, true); for (const auto& core : cores) { CoreCoord left_core = {(std::size_t)start_core_x, (std::size_t)core.y}; CoreCoord left_core_plus_one = {(std::size_t)start_core_x + 1, (std::size_t)core.y}; @@ -733,7 +737,6 @@ operation::ProgramWithCallbacks create_program_mcast_in0_in1( mm_in0_sender_args.push_back(worker_shard_same_coord); } tt_metal::SetRuntimeArgs(program, mm_kernel_in0_sender_id, core, mm_in0_sender_args); // RISCV_0_default - reader_kernel_ids.push_back(mm_kernel_in0_sender_id); } else if (in1_idx == 0) { std::vector mm_in0_sender_args = { // in0 tensor args @@ -753,7 +756,6 @@ operation::ProgramWithCallbacks create_program_mcast_in0_in1( } tt_metal::SetRuntimeArgs(program, mm_kernel_in0_sender_id, core, mm_in0_sender_args); // RISCV_0_default - reader_kernel_ids.push_back(mm_kernel_in0_sender_id); // in0 receiver } else { @@ -765,13 +767,11 @@ operation::ProgramWithCallbacks create_program_mcast_in0_in1( // left half if (core.x <= half_core || (!transpose_mcast and core.y == start_core_y)) { tt_metal::SetRuntimeArgs(program, mm_kernel_in0_receiver_id, core, mm_in0_receiver_args); - reader_kernel_ids.push_back(mm_kernel_in0_receiver_id); } // right half else { tt_metal::SetRuntimeArgs( program, mm_kernel_in0_receiver_other_noc_setup_id, core, mm_in0_receiver_args); - reader_kernel_ids.push_back(mm_kernel_in0_receiver_other_noc_setup_id); } } @@ -826,7 +826,6 @@ operation::ProgramWithCallbacks create_program_mcast_in0_in1( } tt_metal::SetRuntimeArgs( program, mm_kernel_in1_sender_writer_id, core, mm_in1_sender_writer_args); // RISCV_1_default - writer_kernel_ids.push_back(mm_kernel_in1_sender_writer_id); // in1 receiver } else { @@ -883,20 +882,24 @@ operation::ProgramWithCallbacks create_program_mcast_in0_in1( // left half if (core.x <= half_core || (transpose_mcast and core.y == start_core_y)) { tt_metal::SetRuntimeArgs(program, mm_kernel_in1_receiver_writer_id, core, mm_in1_receiver_writer_args); - writer_kernel_ids.push_back(mm_kernel_in1_receiver_writer_id); } // right half else { tt_metal::SetRuntimeArgs( program, mm_kernel_in1_receiver_writer_other_noc_setup_id, core, mm_in1_receiver_writer_args); - writer_kernel_ids.push_back(mm_kernel_in1_receiver_writer_other_noc_setup_id); } } } auto override_runtime_arguments_callback = - [reader_kernel_ids, - writer_kernel_ids, + [mm_kernel_in0_sender_id, + in0_sender_cores, + mm_kernel_in1_sender_writer_id, + in1_sender_cores, + mm_kernel_in1_receiver_writer_id, + in1_receiver_cores, + mm_kernel_in1_receiver_writer_other_noc_setup_id, + in1_receiver_other_cores, cb_src2, cb_output, num_cores_r, @@ -910,8 +913,8 @@ operation::ProgramWithCallbacks create_program_mcast_in0_in1( const std::vector& input_tensors, const std::vector>& optional_input_tensors, const std::vector& output_tensors) { - TT_FATAL(input_tensors.size() + optional_input_tensors.size() == 3); - TT_FATAL(output_tensors.size() == 1); + TT_ASSERT(input_tensors.size() + optional_input_tensors.size() == 3); + TT_ASSERT(output_tensors.size() == 1); auto src_buffer_a = input_tensors.at(0).buffer(); auto src_buffer_b = input_tensors.at(1).buffer(); @@ -919,47 +922,49 @@ operation::ProgramWithCallbacks create_program_mcast_in0_in1( auto dst_buffer = output_tensors.at(0).buffer(); - bool src0_sharded = input_tensors.at(0).memory_config().is_sharded(); - bool out_sharded = output_tensors.at(0).memory_config().is_sharded(); - - for (uint32_t i = 0; i < cores.size(); ++i) { - const CoreCoord& core = cores[i]; - - auto reader_kernel_id = reader_kernel_ids.at(i); - auto& reader_runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); - - auto writer_kernel_id = writer_kernel_ids.at(i); - auto& writer_runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); - - uint32_t in0_idx = core.y - start_core_y; - uint32_t in1_idx = core.x - start_core_x; + bool src0_sharded = input_tensors[0].memory_config().is_sharded(); + bool out_sharded = output_tensors[0].memory_config().is_sharded(); - if (transpose_mcast) { - std::swap(in0_idx, in1_idx); - } + std::optional bias_buffer; + if (bias_tensor.has_value()) { + bias_buffer = bias_tensor.value().buffer(); + } - // in0 sender - if (!src0_sharded && in1_idx == 0) { + // in0 sender + if (src0_sharded) { + UpdateDynamicCircularBufferAddress(program, cb_src2, *src_buffer_a); + } else { + auto& reader_sender_runtime_args_by_core = GetRuntimeArgs(program, mm_kernel_in0_sender_id); + for (const auto& core : in0_sender_cores) { + auto& reader_runtime_args = reader_sender_runtime_args_by_core[core.x][core.y]; reader_runtime_args[0] = src_buffer_a->address(); - // in0 receiver - } else { } + } - // in1 sender - if (in0_idx == 0) { - writer_runtime_args[0] = src_buffer_b->address(); - writer_runtime_args[6] = dst_buffer->address(); - if (bias_tensor.has_value()) { - writer_runtime_args[16] = bias_tensor.value().buffer()->address(); - } - // in1 receiver - } else { - writer_runtime_args[2] = dst_buffer->address(); + // in1 sender + auto& sender_writer_runtime_args_by_core = GetRuntimeArgs(program, mm_kernel_in1_sender_writer_id); + for (const auto& core : in1_sender_cores) { + auto& writer_runtime_args = sender_writer_runtime_args_by_core[core.x][core.y]; + writer_runtime_args[0] = src_buffer_b->address(); + writer_runtime_args[6] = dst_buffer->address(); + if (bias_tensor.has_value()) { + writer_runtime_args[16] = (*bias_buffer)->address(); } } - if (src0_sharded) { - UpdateDynamicCircularBufferAddress(program, cb_src2, *src_buffer_a); + // in1 receiver + auto& receiver_writer_runtime_args_by_core = GetRuntimeArgs(program, mm_kernel_in1_receiver_writer_id); + for (const auto& core : in1_receiver_cores) { + auto& writer_runtime_args = receiver_writer_runtime_args_by_core[core.x][core.y]; + writer_runtime_args[2] = dst_buffer->address(); + } + if (mm_kernel_in1_receiver_writer_id != mm_kernel_in1_receiver_writer_other_noc_setup_id) { + auto& receiver_writer_runtime_args_by_core = + GetRuntimeArgs(program, mm_kernel_in1_receiver_writer_other_noc_setup_id); + for (const auto& core : in1_receiver_other_cores) { + auto& writer_runtime_args = receiver_writer_runtime_args_by_core[core.x][core.y]; + writer_runtime_args[2] = dst_buffer->address(); + } } if (out_sharded) { diff --git a/tt_eager/tt_dnn/op_library/downsample/downsample_op.cpp b/tt_eager/tt_dnn/op_library/downsample/downsample_op.cpp index 01279831d4e..f4dfc2ea2e4 100644 --- a/tt_eager/tt_dnn/op_library/downsample/downsample_op.cpp +++ b/tt_eager/tt_dnn/op_library/downsample/downsample_op.cpp @@ -2,15 +2,16 @@ // // SPDX-License-Identifier: Apache-2.0 +#include "tt_dnn/op_library/downsample/downsample_op.hpp" + #include +#include "tt_dnn/op_library/math.hpp" #include "tt_dnn/op_library/untilize/untilize_op.hpp" -#include "tt_dnn/op_library/downsample/downsample_op.hpp" #include "tt_dnn/op_library/work_split.hpp" -#include "tt_dnn/op_library/math.hpp" -#include "tt_metal/host_api.hpp" #include "tt_metal/common/constants.hpp" #include "tt_metal/detail/util.hpp" +#include "tt_metal/host_api.hpp" using namespace tt::constants; @@ -18,21 +19,24 @@ namespace tt { namespace tt_metal { - -void Downsample::validate(const std::vector &input_tensors) const { +void Downsample::validate(const std::vector& input_tensors) const { const auto& input_tensor_a = input_tensors.at(0); TT_FATAL(input_tensor_a.storage_type() == StorageType::DEVICE, "Operands to downsample need to be on device!"); - TT_FATAL(input_tensor_a.buffer() != nullptr , "Operands to downsample need to be allocated in buffers on device!"); + TT_FATAL(input_tensor_a.buffer() != nullptr, "Operands to downsample need to be allocated in buffers on device!"); TT_FATAL(input_tensor_a.get_layout() == Layout::TILE, "Can only downsample tile major data"); TT_FATAL(input_tensor_a.volume() % TILE_HW == 0); TT_FATAL(input_tensor_a.memory_config().is_sharded()); - TT_FATAL(input_tensor_a.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED || input_tensor_a.memory_config().memory_layout == TensorMemoryLayout::BLOCK_SHARDED); + TT_FATAL( + input_tensor_a.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED || + input_tensor_a.memory_config().memory_layout == TensorMemoryLayout::BLOCK_SHARDED); } -std::pair get_num_cores_height_width_sliced(CoreRangeSet all_cores, TensorMemoryLayout memory_layout, ShardOrientation shard_orientation) { - TT_ASSERT(memory_layout == TensorMemoryLayout::HEIGHT_SHARDED || memory_layout == TensorMemoryLayout::BLOCK_SHARDED); - if (memory_layout == TensorMemoryLayout::HEIGHT_SHARDED) { +std::pair get_num_cores_height_width_sliced( + CoreRangeSet all_cores, TensorMemoryLayout memory_layout, ShardOrientation shard_orientation) { + TT_ASSERT( + memory_layout == TensorMemoryLayout::HEIGHT_SHARDED || memory_layout == TensorMemoryLayout::BLOCK_SHARDED); + if (memory_layout == TensorMemoryLayout::HEIGHT_SHARDED) { TT_ASSERT(shard_orientation == ShardOrientation::ROW_MAJOR); } else { TT_ASSERT(shard_orientation == ShardOrientation::COL_MAJOR); @@ -40,51 +44,70 @@ std::pair get_num_cores_height_width_sliced(CoreRangeSet all } uint32_t num_cores = all_cores.num_cores(); auto first_core_range = *all_cores.ranges().begin(); - uint32_t num_cores_height_sliced = memory_layout == TensorMemoryLayout::HEIGHT_SHARDED ? num_cores : first_core_range.end.x+1; - uint32_t num_cores_width_sliced = memory_layout == TensorMemoryLayout::HEIGHT_SHARDED ? 1 : first_core_range.end.y+1; // width is not sliced when height sharded + uint32_t num_cores_height_sliced = + memory_layout == TensorMemoryLayout::HEIGHT_SHARDED ? num_cores : first_core_range.end.x + 1; + uint32_t num_cores_width_sliced = memory_layout == TensorMemoryLayout::HEIGHT_SHARDED + ? 1 + : first_core_range.end.y + 1; // width is not sliced when height sharded return {num_cores_height_sliced, num_cores_width_sliced}; } -std::vector Downsample::compute_output_shapes(const std::vector &input_tensors) const { +std::vector Downsample::compute_output_shapes(const std::vector& input_tensors) const { const auto& input_tensor_a = input_tensors.at(0); TT_ASSERT(input_tensor_a.get_legacy_shape()[0] == 1 && input_tensor_a.get_legacy_shape()[1] == 1); uint32_t input_height = input_tensor_a.get_legacy_shape()[2]; auto [img_batch_size, img_height, img_width, img_stride_h, img_stride_w] = this->downsample_params; TT_ASSERT(input_height >= img_batch_size * img_height * img_width); - uint32_t output_height_unpadded = img_batch_size * ceil( (double) img_height / (double) img_stride_h) * ceil( (double) img_width / (double) img_stride_w); + uint32_t output_height_unpadded = img_batch_size * ceil((double)img_height / (double)img_stride_h) * + ceil((double)img_width / (double)img_stride_w); uint32_t output_height = round_up(output_height_unpadded, TILE_HEIGHT); uint32_t output_width = input_tensor_a.get_legacy_shape()[3]; - auto output_padding = Padding({{0, 0}, {0, 0}, {0, (output_height - output_height_unpadded)}, {0, 0}}, Padding::PadValue::Any); + auto output_padding = + Padding({{0, 0}, {0, 0}, {0, (output_height - output_height_unpadded)}, {0, 0}}, Padding::PadValue::Any); auto output_tensor_shape = Shape({1, 1, output_height, output_width}, output_padding); log_debug(LogOp, "Downsample output shape: {}", output_tensor_shape); return {output_tensor_shape}; } -std::vector Downsample::create_output_tensors(const std::vector &input_tensors) const { +std::vector Downsample::create_output_tensors(const std::vector& input_tensors) const { const auto& input_tensor = input_tensors.at(0); auto output_shape = this->compute_output_shapes(input_tensors).at(0); - auto [num_cores_height_sliced, num_cores_width_sliced] = get_num_cores_height_width_sliced(input_tensor.shard_spec().value().grid, - input_tensor.memory_config().memory_layout, - input_tensor.shard_spec().value().orientation); - uint32_t output_shard_height = round_up(output_shape[2], num_cores_height_sliced * TILE_HEIGHT) / num_cores_height_sliced; - uint32_t output_shard_width = round_up(output_shape[3], num_cores_width_sliced * TILE_WIDTH) / num_cores_width_sliced; + auto [num_cores_height_sliced, num_cores_width_sliced] = get_num_cores_height_width_sliced( + input_tensor.shard_spec().value().grid, + input_tensor.memory_config().memory_layout, + input_tensor.shard_spec().value().orientation); + uint32_t output_shard_height = + round_up(output_shape[2], num_cores_height_sliced * TILE_HEIGHT) / num_cores_height_sliced; + uint32_t output_shard_width = + round_up(output_shape[3], num_cores_width_sliced * TILE_WIDTH) / num_cores_width_sliced; auto mem_config = input_tensor.memory_config(); - mem_config.shard_spec = ShardSpec {input_tensor.shard_spec().value().grid, std::array{{output_shard_height, output_shard_width}}, input_tensor.shard_spec().value().orientation}; + mem_config.shard_spec = ShardSpec{ + input_tensor.shard_spec().value().grid, + std::array{{output_shard_height, output_shard_width}}, + input_tensor.shard_spec().value().orientation}; return {create_device_tensor(output_shape, this->output_dtype, Layout::TILE, input_tensor.device(), mem_config)}; } -operation::ProgramWithCallbacks Downsample::create_program(const std::vector& input_tensors, std::vector &output_tensors) const { +operation::ProgramWithCallbacks Downsample::create_program( + const std::vector& input_tensors, std::vector& output_tensors) const { const auto& input_tensor_a = input_tensors.at(0); auto& output_tensor = output_tensors.at(0); return {downsample_single_core(input_tensor_a, downsample_params, output_tensor)}; } -Tensor downsample(const Tensor &input_tensor_a, std::array downsample_params, std::optional output_dtype) { +Tensor downsample( + const Tensor& input_tensor_a, std::array downsample_params, std::optional output_dtype) { std::vector output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor_a}))}; operation::launch_op( - [downsample_params, output_dtype] (const std::vector& input_tensors, const std::vector>& optional_input_tensors, const std::vector>& optional_output_tensors) mutable -> std::vector { - return operation::run_without_autoformat(Downsample{downsample_params, output_dtype.value_or(input_tensors.at(0).get_dtype())}, input_tensors); - }, {input_tensor_a}, output_tensors); + [downsample_params, output_dtype]( + const std::vector& input_tensors, + const std::vector>& optional_input_tensors, + const std::vector>& optional_output_tensors) mutable -> std::vector { + return operation::run_without_autoformat( + Downsample{downsample_params, output_dtype.value_or(input_tensors.at(0).get_dtype())}, input_tensors); + }, + {input_tensor_a}, + output_tensors); return output_tensors.at(0); } @@ -105,40 +128,50 @@ struct DownsampleReadPatternParams { struct ImgTrackingVars { uint32_t img_h = 0; uint32_t img_w = 0; - uint32_t next_img_h = 0; // img_h after stride + uint32_t next_img_h = 0; // img_h after stride uint32_t next_img_w = 0; - uint32_t input_flat_h = 0; // index within sharded input - uint32_t output_flat_h = 0; // index within sharded output + uint32_t input_flat_h = 0; // index within sharded input + uint32_t output_flat_h = 0; // index within sharded output }; -DownsampleReadPatternParams generate_downsample_read_pattern(ImgTrackingVars & v, uint32_t img_height, uint32_t img_width, uint32_t img_stride_h, uint32_t img_stride_w, uint32_t input_end_flat_h, uint32_t output_end_flat_h, bool current_region_is_halo_prev_core, bool current_region_is_halo_next_core) { +DownsampleReadPatternParams generate_downsample_read_pattern( + ImgTrackingVars& v, + uint32_t img_height, + uint32_t img_width, + uint32_t img_stride_h, + uint32_t img_stride_w, + uint32_t input_end_flat_h, + uint32_t output_end_flat_h, + bool current_region_is_halo_prev_core, + bool current_region_is_halo_next_core) { // Sanity checks at the start for local data TT_ASSERT(v.next_img_h >= v.img_h); - TT_ASSERT(v.next_img_w == v.img_w); // assumption that the start is picked and not skipped by stride + TT_ASSERT(v.next_img_w == v.img_w); // assumption that the start is picked and not skipped by stride TT_ASSERT(v.img_h < img_height); TT_ASSERT(v.next_img_w < img_width); if (current_region_is_halo_prev_core) { - //cout << "GENERATING READ PATTERN FOR HALO REGION FROM PREVIOUS CORE" << endl; + // cout << "GENERATING READ PATTERN FOR HALO REGION FROM PREVIOUS CORE" << endl; TT_ASSERT(!current_region_is_halo_next_core); TT_ASSERT(v.input_flat_h != 0); TT_ASSERT(v.output_flat_h == 0); } else if (current_region_is_halo_next_core) { - //cout << "GENERATING READ PATTERN FOR HALO REGION FROM NEXT CORE" << endl; + // cout << "GENERATING READ PATTERN FOR HALO REGION FROM NEXT CORE" << endl; TT_ASSERT(!current_region_is_halo_prev_core); TT_ASSERT(v.input_flat_h == 0); TT_ASSERT(v.output_flat_h != 0); } else { - //cout << "GENERATING READ PATTERN FOR LOCAL REGION" << endl; + // cout << "GENERATING READ PATTERN FOR LOCAL REGION" << endl; } - //cout << "img_h=" << v.img_h << ", img_w=" << v.img_w << ", next_img_h=" << v.next_img_h << ", next_img_w=" << v.img_w << endl; - //cout << "v.input_flat_h=" << v.input_flat_h << ", input_end_flat_h=" << input_end_flat_h << ", v.output_flat_h=" << v.output_flat_h << ", output_end_flat_h=" << output_end_flat_h << endl; + // cout << "img_h=" << v.img_h << ", img_w=" << v.img_w << ", next_img_h=" << v.next_img_h << ", next_img_w=" << + // v.img_w << endl; cout << "v.input_flat_h=" << v.input_flat_h << ", input_end_flat_h=" << input_end_flat_h << ", + // v.output_flat_h=" << v.output_flat_h << ", output_end_flat_h=" << output_end_flat_h << endl; TT_ASSERT(v.input_flat_h < input_end_flat_h); TT_ASSERT(v.output_flat_h < output_end_flat_h); - uint32_t output_img_height = std::ceil ( (double) img_height / (double) img_stride_h); - uint32_t output_img_width = std::ceil ( (double) img_width / (double) img_stride_w); + uint32_t output_img_height = std::ceil((double)img_height / (double)img_stride_h); + uint32_t output_img_width = std::ceil((double)img_width / (double)img_stride_w); bool found_halo_for_next_core = false; uint32_t top_partial_middle_aligned_row_width = 0; @@ -154,13 +187,13 @@ DownsampleReadPatternParams generate_downsample_read_pattern(ImgTrackingVars & v uint32_t skip_bottom_partial_left_aligned_row = 1; if (v.img_w != 0) { // Check if its right aligned or middle aligned (special corner case for halo) - if (v.input_flat_h + img_width - v.img_w <= input_end_flat_h+1) { + if (v.input_flat_h + img_width - v.img_w <= input_end_flat_h + 1) { // top partial right aligned top_partial_right_aligned_row_width = img_width - v.img_w; skip_top_partial_right_aligned_row = (v.next_img_h == v.img_h) ? 0 : 1; v.input_flat_h += top_partial_right_aligned_row_width; if (!skip_top_partial_right_aligned_row) { - v.output_flat_h += std::ceil((double) top_partial_right_aligned_row_width / (double) img_stride_w); + v.output_flat_h += std::ceil((double)top_partial_right_aligned_row_width / (double)img_stride_w); TT_ASSERT(v.output_flat_h <= output_end_flat_h); } v.img_w = 0; @@ -184,17 +217,17 @@ DownsampleReadPatternParams generate_downsample_read_pattern(ImgTrackingVars & v skip_top_partial_middle_aligned_row = (v.next_img_h == v.img_h) ? 0 : 1; v.input_flat_h += top_partial_middle_aligned_row_width; if (!skip_top_partial_middle_aligned_row) { - v.output_flat_h += std::ceil((double) top_partial_middle_aligned_row_width / (double) img_stride_w); + v.output_flat_h += std::ceil((double)top_partial_middle_aligned_row_width / (double)img_stride_w); TT_ASSERT(v.output_flat_h <= output_end_flat_h); } uint32_t img_w_start = v.img_w; - while(v.img_w < img_w_start + top_partial_middle_aligned_row_width) { + while (v.img_w < img_w_start + top_partial_middle_aligned_row_width) { v.img_w += 1; if (v.next_img_w < v.img_w) { v.next_img_w += img_stride_w; } } - TT_ASSERT(v.img_w < img_width-1); + TT_ASSERT(v.img_w < img_width - 1); } } TT_ASSERT(v.next_img_w == v.img_w); @@ -202,16 +235,20 @@ DownsampleReadPatternParams generate_downsample_read_pattern(ImgTrackingVars & v TT_ASSERT(v.next_img_h >= v.img_h); if (v.img_w != 0) { // special case for halo - TT_ASSERT(v.input_flat_h == input_end_flat_h+1); + TT_ASSERT(v.input_flat_h == input_end_flat_h + 1); } TT_ASSERT(v.img_h < img_height && v.img_w < img_width); uint32_t num_rows_remaining_of_current_image = (v.img_h == 0) ? 0 : img_height - v.img_h; if (num_rows_remaining_of_current_image > 0) { uint32_t num_rows_to_skip = v.next_img_h - v.img_h; - uint32_t output_h_from_remaining_rows_of_current_image = std::ceil( (double) (num_rows_remaining_of_current_image - num_rows_to_skip) / (double) img_stride_h ) * output_img_width; - bool output_for_partial_top_image = v.output_flat_h + output_h_from_remaining_rows_of_current_image <= output_end_flat_h+1; - bool input_for_partial_top_image = v.input_flat_h + (num_rows_remaining_of_current_image * img_width) <= input_end_flat_h+1; + uint32_t output_h_from_remaining_rows_of_current_image = + std::ceil((double)(num_rows_remaining_of_current_image - num_rows_to_skip) / (double)img_stride_h) * + output_img_width; + bool output_for_partial_top_image = + v.output_flat_h + output_h_from_remaining_rows_of_current_image <= output_end_flat_h + 1; + bool input_for_partial_top_image = + v.input_flat_h + (num_rows_remaining_of_current_image * img_width) <= input_end_flat_h + 1; if (output_for_partial_top_image && input_for_partial_top_image) { // Top partial image section num_rows_top_partial_image = img_height - v.img_h; @@ -222,17 +259,18 @@ DownsampleReadPatternParams generate_downsample_read_pattern(ImgTrackingVars & v v.next_img_h = 0; v.input_flat_h += (num_rows_top_partial_image * img_width); v.output_flat_h += output_h_from_remaining_rows_of_current_image; - TT_ASSERT(v.input_flat_h <= input_end_flat_h+1); + TT_ASSERT(v.input_flat_h <= input_end_flat_h + 1); } - TT_ASSERT(v.output_flat_h <= output_end_flat_h+1); + TT_ASSERT(v.output_flat_h <= output_end_flat_h + 1); } TT_ASSERT(v.img_h < img_height && v.img_w < img_width); if (v.img_h == 0 && v.img_w == 0) { // Check for full images - while(1) { - bool output_for_current_full_image = v.output_flat_h + (output_img_height * output_img_width) <= output_end_flat_h+1; - bool input_for_current_full_image = v.input_flat_h + (img_height * img_width) <= input_end_flat_h+1; + while (1) { + bool output_for_current_full_image = + v.output_flat_h + (output_img_height * output_img_width) <= output_end_flat_h + 1; + bool input_for_current_full_image = v.input_flat_h + (img_height * img_width) <= input_end_flat_h + 1; if (!output_for_current_full_image || !input_for_current_full_image) { break; } @@ -242,20 +280,22 @@ DownsampleReadPatternParams generate_downsample_read_pattern(ImgTrackingVars & v v.next_img_h = 0; v.next_img_w = 0; num_full_images += 1; - v.output_flat_h += (output_img_height * output_img_width); + v.output_flat_h += (output_img_height * output_img_width); } TT_ASSERT(v.img_h == 0 && v.img_w == 0 && v.next_img_h == 0 && v.next_img_w == 0); } // Sanity check - TT_ASSERT(v.input_flat_h <= input_end_flat_h+1); - TT_ASSERT(v.output_flat_h <= output_end_flat_h+1); + TT_ASSERT(v.input_flat_h <= input_end_flat_h + 1); + TT_ASSERT(v.output_flat_h <= output_end_flat_h + 1); bool found_first_unskipped_row_in_bottom_partial_imgage = false; // check for bottom partial image rows while (1) { - bool output_for_bottom_partial_image_row = (v.next_img_h == v.img_h) ? (v.output_flat_h + output_img_width <= output_end_flat_h+1) : true; // true for skipped row - bool input_for_bottom_partial_image_row = v.input_flat_h + img_width <= input_end_flat_h+1; + bool output_for_bottom_partial_image_row = (v.next_img_h == v.img_h) + ? (v.output_flat_h + output_img_width <= output_end_flat_h + 1) + : true; // true for skipped row + bool input_for_bottom_partial_image_row = v.input_flat_h + img_width <= input_end_flat_h + 1; if (!output_for_bottom_partial_image_row || !input_for_bottom_partial_image_row) { break; } @@ -273,11 +313,11 @@ DownsampleReadPatternParams generate_downsample_read_pattern(ImgTrackingVars & v } v.img_w = 0; v.next_img_w = 0; - TT_ASSERT(v.img_h < img_height - 1); // this is supposed to be a bottom partial image + TT_ASSERT(v.img_h < img_height - 1); // this is supposed to be a bottom partial image v.img_h += 1; if (v.next_img_h < v.img_h) { v.next_img_h += img_stride_h; - TT_ASSERT(v.next_img_h <= img_height); // odd heights and odd size sharding with stride > 1 not supported + TT_ASSERT(v.next_img_h <= img_height); // odd heights and odd size sharding with stride > 1 not supported if (v.next_img_h == img_height && v.img_h == img_height) { v.next_img_h = 0; v.img_h = 0; @@ -288,8 +328,8 @@ DownsampleReadPatternParams generate_downsample_read_pattern(ImgTrackingVars & v } // Sanity check - TT_ASSERT(v.input_flat_h <= input_end_flat_h+1); - TT_ASSERT(v.output_flat_h <= output_end_flat_h+1); + TT_ASSERT(v.input_flat_h <= input_end_flat_h + 1); + TT_ASSERT(v.output_flat_h <= output_end_flat_h + 1); TT_ASSERT(v.img_h < img_height && v.img_w < img_width); // check if there is a bottom partial left aligned row @@ -298,7 +338,9 @@ DownsampleReadPatternParams generate_downsample_read_pattern(ImgTrackingVars & v // bottom partial left aligned row width can be split between 2 cores uint32_t input_remaining = input_end_flat_h - v.input_flat_h + 1; uint32_t output_remaining = output_end_flat_h - v.output_flat_h + 1; - TT_ASSERT(output_remaining < output_img_width || input_remaining < img_width); // there must be a partial width either on input side or output side + TT_ASSERT( + output_remaining < output_img_width || + input_remaining < img_width); // there must be a partial width either on input side or output side bottom_partial_left_aligned_row_width = input_remaining; if (output_remaining < output_img_width) { bottom_partial_left_aligned_row_width = std::min(input_remaining, output_remaining * img_stride_w); @@ -307,17 +349,17 @@ DownsampleReadPatternParams generate_downsample_read_pattern(ImgTrackingVars & v TT_ASSERT(bottom_partial_left_aligned_row_width < img_width); TT_ASSERT(v.next_img_h >= v.img_h); skip_bottom_partial_left_aligned_row = (v.next_img_h == v.img_h) ? 0 : 1; - while(v.img_w < bottom_partial_left_aligned_row_width) { + while (v.img_w < bottom_partial_left_aligned_row_width) { v.img_w += 1; if (v.next_img_w < v.img_w) { v.next_img_w += img_stride_w; - TT_ASSERT(v.next_img_w < img_width); // odd widths and odd size sharding with stride > 1 not supported + TT_ASSERT(v.next_img_w < img_width); // odd widths and odd size sharding with stride > 1 not supported } } TT_ASSERT(v.img_w == bottom_partial_left_aligned_row_width && v.next_img_w >= v.img_w); v.input_flat_h += bottom_partial_left_aligned_row_width; if (!skip_bottom_partial_left_aligned_row) { - v.output_flat_h += std::ceil( (double) bottom_partial_left_aligned_row_width / (double) img_stride_w); + v.output_flat_h += std::ceil((double)bottom_partial_left_aligned_row_width / (double)img_stride_w); } } TT_ASSERT(v.img_h < img_height && v.img_w < img_width); @@ -338,12 +380,13 @@ DownsampleReadPatternParams generate_downsample_read_pattern(ImgTrackingVars & v log_debug(LogOp, " v.output_flat_h: {}", v.output_flat_h); log_debug(LogOp, " input_end_flat_h: {}", input_end_flat_h); log_debug(LogOp, " output_end_flat_h: {}", output_end_flat_h); - //cout << "img_h=" << v.img_h << ", img_w=" << v.img_w << ", next_img_h=" << v.next_img_h << ", next_img_w=" << v.img_w << endl; + // cout << "img_h=" << v.img_h << ", img_w=" << v.img_w << ", next_img_h=" << v.next_img_h << ", next_img_w=" << + // v.img_w << endl; } // Sanity check - TT_ASSERT(v.input_flat_h <= input_end_flat_h+1); - TT_ASSERT(v.output_flat_h <= output_end_flat_h+1); + TT_ASSERT(v.input_flat_h <= input_end_flat_h + 1); + TT_ASSERT(v.output_flat_h <= output_end_flat_h + 1); if (v.input_flat_h == input_end_flat_h + 1) { v.input_flat_h = 0; @@ -351,21 +394,22 @@ DownsampleReadPatternParams generate_downsample_read_pattern(ImgTrackingVars & v if (v.output_flat_h == output_end_flat_h + 1) { v.output_flat_h = 0; } - return DownsampleReadPatternParams{.top_partial_middle_aligned_row_width=top_partial_middle_aligned_row_width, - .skip_top_partial_middle_aligned_row=skip_top_partial_middle_aligned_row, - .top_partial_right_aligned_row_width=top_partial_right_aligned_row_width, - .skip_top_partial_right_aligned_row=skip_top_partial_right_aligned_row, - .num_rows_top_partial_image=num_rows_top_partial_image, - .num_skip_rows_top_partial_image=num_skip_rows_top_partial_image, - .num_full_images=num_full_images, - .num_rows_bottom_partial_image=num_rows_bottom_partial_image, - .num_skip_rows_bottom_partial_image=num_skip_rows_bottom_partial_image, - .bottom_partial_left_aligned_row_width=bottom_partial_left_aligned_row_width, - .skip_bottom_partial_left_aligned_row=skip_bottom_partial_left_aligned_row}; + return DownsampleReadPatternParams{ + .top_partial_middle_aligned_row_width = top_partial_middle_aligned_row_width, + .skip_top_partial_middle_aligned_row = skip_top_partial_middle_aligned_row, + .top_partial_right_aligned_row_width = top_partial_right_aligned_row_width, + .skip_top_partial_right_aligned_row = skip_top_partial_right_aligned_row, + .num_rows_top_partial_image = num_rows_top_partial_image, + .num_skip_rows_top_partial_image = num_skip_rows_top_partial_image, + .num_full_images = num_full_images, + .num_rows_bottom_partial_image = num_rows_bottom_partial_image, + .num_skip_rows_bottom_partial_image = num_skip_rows_bottom_partial_image, + .bottom_partial_left_aligned_row_width = bottom_partial_left_aligned_row_width, + .skip_bottom_partial_left_aligned_row = skip_bottom_partial_left_aligned_row}; } -operation::ProgramWithCallbacks downsample_single_core(const Tensor &a, std::array downsample_params, Tensor& output) { - +operation::ProgramWithCallbacks downsample_single_core( + const Tensor& a, std::array downsample_params, Tensor& output) { tt_metal::Program program = tt_metal::CreateProgram(); tt::DataFormat input_cb_data_format = tt_metal::datatype_to_dataformat_converter(a.get_dtype()); @@ -375,75 +419,76 @@ operation::ProgramWithCallbacks downsample_single_core(const Tensor &a, std::arr tt::DataFormat untilized_cb_data_format = DataFormat::Float16_b; uint32_t untilized_single_tile_size = tt_metal::detail::TileSize(untilized_cb_data_format); auto [img_batch_size, img_height, img_width, img_stride_h, img_stride_w] = downsample_params; - tt_metal::Buffer *src0_buffer = a.buffer(); + tt_metal::Buffer* src0_buffer = a.buffer(); TT_ASSERT(a.get_legacy_shape()[0] == 1 && a.get_legacy_shape()[1] == 1); TT_ASSERT(output.get_legacy_shape()[0] == 1 && output.get_legacy_shape()[1] == 1); - tt_metal::Device *device = a.device(); + tt_metal::Device* device = a.device(); - tt_metal::Buffer *dst_buffer = output.buffer(); + tt_metal::Buffer* dst_buffer = output.buffer(); TT_ASSERT(dst_buffer != nullptr, "Output buffer should be allocated on device!"); // Sanity check of output size TT_ASSERT(output.volume() % TILE_HW == 0); uint32_t unpadded_input_volume = img_batch_size * img_height * img_width; TT_ASSERT(a.volume() >= unpadded_input_volume); - uint32_t unpadded_output_volume = ceil((double) unpadded_input_volume / (double) (img_stride_h * img_stride_w)); + uint32_t unpadded_output_volume = ceil((double)unpadded_input_volume / (double)(img_stride_h * img_stride_w)); TT_ASSERT(output.volume() >= unpadded_output_volume); - uint32_t ncores_x_full_grid = device->compute_with_storage_grid_size().x; - auto [num_cores_height_sliced, num_cores_width_sliced] = get_num_cores_height_width_sliced(a.shard_spec().value().grid, - a.memory_config().memory_layout, - a.shard_spec().value().orientation); + auto [num_cores_height_sliced, num_cores_width_sliced] = get_num_cores_height_width_sliced( + a.shard_spec().value().grid, a.memory_config().memory_layout, a.shard_spec().value().orientation); uint32_t num_cores = num_cores_height_sliced * num_cores_width_sliced; auto all_cores = a.shard_spec().value().grid; auto memory_layout = a.memory_config().memory_layout; TT_ASSERT(all_cores == output.shard_spec().value().grid); TT_ASSERT(memory_layout == output.memory_config().memory_layout); - TT_ASSERT(memory_layout == TensorMemoryLayout::HEIGHT_SHARDED || memory_layout == TensorMemoryLayout::BLOCK_SHARDED); + TT_ASSERT( + memory_layout == TensorMemoryLayout::HEIGHT_SHARDED || memory_layout == TensorMemoryLayout::BLOCK_SHARDED); if (memory_layout == TensorMemoryLayout::BLOCK_SHARDED) { TT_ASSERT(all_cores.ranges().size() == 1); } else { TT_ASSERT(num_cores_width_sliced == 1); } - uint32_t num_cores_x = memory_layout == TensorMemoryLayout::HEIGHT_SHARDED ? ncores_x_full_grid : num_cores_height_sliced; + uint32_t num_cores_x = + memory_layout == TensorMemoryLayout::HEIGHT_SHARDED ? ncores_x_full_grid : num_cores_height_sliced; auto core_range = all_cores; - uint32_t input_height = a.get_legacy_shape()[2]; // input height == flattened face of input image, multiple images are stacked in H dim - uint32_t input_width = a.get_legacy_shape()[3]; // input width == input image # of channels - uint32_t output_height = output.get_legacy_shape()[2]; // output height == flattened face of output image, multiple images are stacked in H dim + uint32_t input_height = + a.get_legacy_shape()[2]; // input height == flattened face of input image, multiple images are stacked in H dim + uint32_t input_width = a.get_legacy_shape()[3]; // input width == input image # of channels + uint32_t output_height = output.get_legacy_shape()[2]; // output height == flattened face of output image, multiple + // images are stacked in H dim uint32_t output_width = output.get_legacy_shape()[3]; TT_ASSERT(input_width == output_width); uint32_t input_height_unpadded = img_batch_size * img_height * img_width; TT_ASSERT(input_height >= input_height_unpadded); - uint32_t output_height_unpadded = img_batch_size * std::ceil((double) (img_height * img_width) / (double) (img_stride_h * img_stride_w)); + uint32_t output_height_unpadded = + img_batch_size * std::ceil((double)(img_height * img_width) / (double)(img_stride_h * img_stride_w)); TT_ASSERT(output_height >= output_height_unpadded); uint32_t input_shard_height = a.shard_spec().value().shape[0]; - TT_ASSERT(input_shard_height * num_cores_height_sliced >= input_height); // last core shard size may be padded + TT_ASSERT(input_shard_height * num_cores_height_sliced >= input_height); // last core shard size may be padded uint32_t input_height_padded = input_shard_height * num_cores_height_sliced; uint32_t input_shard_width = a.shard_spec().value().shape[1]; TT_ASSERT(input_shard_width * num_cores_width_sliced == input_width); uint32_t input_height_padding = input_height_padded - input_height_unpadded; - TT_ASSERT(input_height_padding < input_shard_height); // last core has padding + TT_ASSERT(input_height_padding < input_shard_height); // last core has padding uint32_t last_core_input_shard_height_unpadded = input_shard_height - input_height_padding; - uint32_t output_shard_height = output.shard_spec().value().shape[0]; uint32_t output_shard_width = output.shard_spec().value().shape[1]; TT_ASSERT(output_shard_width == input_shard_width); - TT_ASSERT(output_shard_height * num_cores_height_sliced >= output_height); // last core shard size may be padded + TT_ASSERT(output_shard_height * num_cores_height_sliced >= output_height); // last core shard size may be padded uint32_t output_height_padded = output_shard_height * num_cores_height_sliced; uint32_t output_height_padding = output_height_padded - output_height_unpadded; TT_ASSERT(output_height_padding < output_shard_height); uint32_t last_core_output_shard_height_unpadded = output_shard_height - output_height_padding; - uint32_t input_shard_width_bytes = input_shard_width * datum_size(untilized_cb_data_format); TT_ASSERT(input_shard_width % TILE_WIDTH == 0); @@ -457,77 +502,135 @@ operation::ProgramWithCallbacks downsample_single_core(const Tensor &a, std::arr uint32_t input_cb_index = CB::c_in0; uint32_t num_input_tiles = num_input_tiles_in_row * num_rows_of_input_tiles; - tt_metal::CircularBufferConfig input_cb_config = tt_metal::CircularBufferConfig(num_input_tiles * input_single_tile_size, {{input_cb_index, input_cb_data_format}}) - .set_page_size(input_cb_index, input_single_tile_size); + tt_metal::CircularBufferConfig input_cb_config = + tt_metal::CircularBufferConfig( + num_input_tiles * input_single_tile_size, {{input_cb_index, input_cb_data_format}}) + .set_page_size(input_cb_index, input_single_tile_size); input_cb_config = input_cb_config.set_globally_allocated_address(*a.buffer()); auto input_cb = tt_metal::CreateCircularBuffer(program, core_range, input_cb_config); - log_debug(LogOp, "CB {}: PS = {} NP = {} :: TOTAL = {}", input_cb_index, input_single_tile_size, num_input_tiles, input_single_tile_size * num_input_tiles); + log_debug( + LogOp, + "CB {}: PS = {} NP = {} :: TOTAL = {}", + input_cb_index, + input_single_tile_size, + num_input_tiles, + input_single_tile_size * num_input_tiles); // CB to store halo data // hardcode to store 1 row of tiles uint32_t halo_prev_input_cb_index = CB::c_in1; uint32_t halo_prev_input_cb_max_rows_of_tiles = 4; - uint32_t num_halo_prev_cb_input_tiles = num_input_tiles_in_row * halo_prev_input_cb_max_rows_of_tiles; - tt_metal::CircularBufferConfig halo_prev_input_cb_config = tt_metal::CircularBufferConfig(num_halo_prev_cb_input_tiles * input_single_tile_size, {{halo_prev_input_cb_index, input_cb_data_format}}) - .set_page_size(halo_prev_input_cb_index, input_single_tile_size); + uint32_t num_halo_prev_cb_input_tiles = num_input_tiles_in_row * halo_prev_input_cb_max_rows_of_tiles; + tt_metal::CircularBufferConfig halo_prev_input_cb_config = + tt_metal::CircularBufferConfig( + num_halo_prev_cb_input_tiles * input_single_tile_size, {{halo_prev_input_cb_index, input_cb_data_format}}) + .set_page_size(halo_prev_input_cb_index, input_single_tile_size); auto halo_prev_input_cb = tt_metal::CreateCircularBuffer(program, core_range, halo_prev_input_cb_config); - log_debug(LogOp, "CB {}: PS = {} NP = {} :: TOTAL = {}", halo_prev_input_cb_index, input_single_tile_size, num_halo_prev_cb_input_tiles, input_single_tile_size * num_halo_prev_cb_input_tiles); + log_debug( + LogOp, + "CB {}: PS = {} NP = {} :: TOTAL = {}", + halo_prev_input_cb_index, + input_single_tile_size, + num_halo_prev_cb_input_tiles, + input_single_tile_size * num_halo_prev_cb_input_tiles); uint32_t halo_next_input_cb_index = CB::c_in2; - uint32_t halo_next_input_cb_max_rows_of_tiles = 33; // TODO: Remove hardcoding - uint32_t num_halo_next_cb_input_tiles = num_input_tiles_in_row * halo_next_input_cb_max_rows_of_tiles; - tt_metal::CircularBufferConfig halo_next_input_cb_config = tt_metal::CircularBufferConfig(num_halo_next_cb_input_tiles * input_single_tile_size, {{halo_next_input_cb_index, input_cb_data_format}}) - .set_page_size(halo_next_input_cb_index, input_single_tile_size); + uint32_t halo_next_input_cb_max_rows_of_tiles = 33; // TODO: Remove hardcoding + uint32_t num_halo_next_cb_input_tiles = num_input_tiles_in_row * halo_next_input_cb_max_rows_of_tiles; + tt_metal::CircularBufferConfig halo_next_input_cb_config = + tt_metal::CircularBufferConfig( + num_halo_next_cb_input_tiles * input_single_tile_size, {{halo_next_input_cb_index, input_cb_data_format}}) + .set_page_size(halo_next_input_cb_index, input_single_tile_size); auto halo_next_input_cb = tt_metal::CreateCircularBuffer(program, core_range, halo_next_input_cb_config); - log_debug(LogOp, "CB {}: PS = {} NP = {} :: TOTAL = {}", halo_next_input_cb_index, input_single_tile_size, num_halo_next_cb_input_tiles, input_single_tile_size * num_halo_next_cb_input_tiles); + log_debug( + LogOp, + "CB {}: PS = {} NP = {} :: TOTAL = {}", + halo_next_input_cb_index, + input_single_tile_size, + num_halo_next_cb_input_tiles, + input_single_tile_size * num_halo_next_cb_input_tiles); // CB to store reader pattern array // read pattern array size == output_height uint32_t reader_pattern_array_size = output_shard_height; uint32_t reader_pattern_array_cb_index = CB::c_intermed1; - tt_metal::CircularBufferConfig reader_pattern_array_cb_config = tt_metal::CircularBufferConfig(reader_pattern_array_size * 4, {{reader_pattern_array_cb_index, DataFormat::Float16_b}}) - .set_page_size(reader_pattern_array_cb_index, 4); + tt_metal::CircularBufferConfig reader_pattern_array_cb_config = + tt_metal::CircularBufferConfig( + reader_pattern_array_size * 4, {{reader_pattern_array_cb_index, DataFormat::Float16_b}}) + .set_page_size(reader_pattern_array_cb_index, 4); auto reader_pattern_array_cb = tt_metal::CreateCircularBuffer(program, core_range, reader_pattern_array_cb_config); - log_debug(LogOp, "CB {}: PS = {} NP = {} :: TOTAL = {}", reader_pattern_array_cb_index, 4, reader_pattern_array_size, 4 * reader_pattern_array_size); + log_debug( + LogOp, + "CB {}: PS = {} NP = {} :: TOTAL = {}", + reader_pattern_array_cb_index, + 4, + reader_pattern_array_size, + 4 * reader_pattern_array_size); // untilized CB has size - [32, full width] uint32_t untilize_cb_index = CB::c_intermed2; uint32_t num_tiles_untilize_cb = num_input_tiles_in_row; - tt_metal::CircularBufferConfig untilize_cb_config = tt_metal::CircularBufferConfig(num_tiles_untilize_cb * untilized_single_tile_size, {{untilize_cb_index, untilized_cb_data_format}}) - .set_page_size(untilize_cb_index, untilized_single_tile_size); + tt_metal::CircularBufferConfig untilize_cb_config = + tt_metal::CircularBufferConfig( + num_tiles_untilize_cb * untilized_single_tile_size, {{untilize_cb_index, untilized_cb_data_format}}) + .set_page_size(untilize_cb_index, untilized_single_tile_size); auto untilize_cb = tt_metal::CreateCircularBuffer(program, core_range, untilize_cb_config); - log_debug(LogOp, "CB {}: PS = {} NP = {} :: TOTAL = {}", untilize_cb_index, untilized_single_tile_size, num_tiles_untilize_cb, untilized_single_tile_size * num_tiles_untilize_cb); + log_debug( + LogOp, + "CB {}: PS = {} NP = {} :: TOTAL = {}", + untilize_cb_index, + untilized_single_tile_size, + num_tiles_untilize_cb, + untilized_single_tile_size * num_tiles_untilize_cb); uint32_t num_output_tiles = num_output_tiles_in_row * num_rows_of_output_tiles; uint32_t untilize_downsampled_cb_index = CB::c_intermed3; - uint32_t num_tiles_untilize_downsampled_cb = num_output_tiles; // untilize downsampled cb size == output size per core - tt_metal::CircularBufferConfig untilize_downsampled_cb_config = tt_metal::CircularBufferConfig(num_tiles_untilize_downsampled_cb * untilized_single_tile_size, {{untilize_downsampled_cb_index, untilized_cb_data_format}}) - .set_page_size(untilize_downsampled_cb_index, untilized_single_tile_size); + uint32_t num_tiles_untilize_downsampled_cb = + num_output_tiles; // untilize downsampled cb size == output size per core + tt_metal::CircularBufferConfig untilize_downsampled_cb_config = + tt_metal::CircularBufferConfig( + num_tiles_untilize_downsampled_cb * untilized_single_tile_size, + {{untilize_downsampled_cb_index, untilized_cb_data_format}}) + .set_page_size(untilize_downsampled_cb_index, untilized_single_tile_size); auto untilize_downsampled_cb = tt_metal::CreateCircularBuffer(program, core_range, untilize_downsampled_cb_config); - log_debug(LogOp, "CB {}: PS = {} NP = {} :: TOTAL = {}", untilize_downsampled_cb_index, untilized_single_tile_size, num_tiles_untilize_downsampled_cb, untilized_single_tile_size * num_tiles_untilize_downsampled_cb); + log_debug( + LogOp, + "CB {}: PS = {} NP = {} :: TOTAL = {}", + untilize_downsampled_cb_index, + untilized_single_tile_size, + num_tiles_untilize_downsampled_cb, + untilized_single_tile_size * num_tiles_untilize_downsampled_cb); uint32_t final_tilize_output_cb_index = CB::c_out0; - uint32_t num_tiles_final_tilize_output_cb = num_output_tiles; // final output cb size == output size per core - tt_metal::CircularBufferConfig final_tilize_output_cb_config = tt_metal::CircularBufferConfig(num_tiles_final_tilize_output_cb * output_single_tile_size, {{final_tilize_output_cb_index, output_cb_data_format}}) - .set_page_size(final_tilize_output_cb_index, output_single_tile_size); + uint32_t num_tiles_final_tilize_output_cb = num_output_tiles; // final output cb size == output size per core + tt_metal::CircularBufferConfig final_tilize_output_cb_config = + tt_metal::CircularBufferConfig( + num_tiles_final_tilize_output_cb * output_single_tile_size, + {{final_tilize_output_cb_index, output_cb_data_format}}) + .set_page_size(final_tilize_output_cb_index, output_single_tile_size); final_tilize_output_cb_config = final_tilize_output_cb_config.set_globally_allocated_address(*output.buffer()); auto final_tilize_output_cb = tt_metal::CreateCircularBuffer(program, core_range, final_tilize_output_cb_config); - log_debug(LogOp, "CB {}: PS = {} NP = {} :: TOTAL = {}", final_tilize_output_cb_index, output_single_tile_size, num_tiles_final_tilize_output_cb, output_single_tile_size * num_tiles_final_tilize_output_cb); + log_debug( + LogOp, + "CB {}: PS = {} NP = {} :: TOTAL = {}", + final_tilize_output_cb_index, + output_single_tile_size, + num_tiles_final_tilize_output_cb, + output_single_tile_size * num_tiles_final_tilize_output_cb); - uint32_t log_base_2_of_conv_act_size_c_bytes = (uint32_t) std::log2((float) input_shard_width_bytes); + uint32_t log_base_2_of_conv_act_size_c_bytes = (uint32_t)std::log2((float)input_shard_width_bytes); uint32_t stride_h_x_image_width = img_stride_h * img_width; std::vector writer_compile_time_args = { - (std::uint32_t) untilize_cb_index, - (std::uint32_t) untilize_downsampled_cb_index, - (std::uint32_t) final_tilize_output_cb_index, - (std::uint32_t) reader_pattern_array_cb_index, - (std::uint32_t) datum_size(untilized_cb_data_format), - (std::uint32_t) input_shard_width_bytes, - (std::uint32_t) halo_prev_input_cb_index, - (std::uint32_t) halo_next_input_cb_index, + (std::uint32_t)untilize_cb_index, + (std::uint32_t)untilize_downsampled_cb_index, + (std::uint32_t)final_tilize_output_cb_index, + (std::uint32_t)reader_pattern_array_cb_index, + (std::uint32_t)datum_size(untilized_cb_data_format), + (std::uint32_t)input_shard_width_bytes, + (std::uint32_t)halo_prev_input_cb_index, + (std::uint32_t)halo_next_input_cb_index, log_base_2_of_conv_act_size_c_bytes, - stride_h_x_image_width - }; + stride_h_x_image_width}; // Writer to downsample - drops rows from untilized cb tt_metal::KernelHandle downsample_writer_kernel_id = tt_metal::CreateKernel( @@ -549,22 +652,20 @@ operation::ProgramWithCallbacks downsample_single_core(const Tensor &a, std::arr }; string compute_kernel = "tt_eager/tt_dnn/op_library/downsample/kernels/downsample_compute_kernel.cpp"; if (num_input_tiles_in_row <= MAX_PACK_UNTILIZE_WIDTH) { - compute_kernel = "tt_eager/tt_dnn/op_library/downsample/kernels/downsample_fast_pack_untilize_compute_kernel.cpp"; + compute_kernel = + "tt_eager/tt_dnn/op_library/downsample/kernels/downsample_fast_pack_untilize_compute_kernel.cpp"; } auto downsample_compute_kernel_id = tt_metal::CreateKernel( - program, - compute_kernel, - core_range, - tt_metal::ComputeConfig{.compile_args = compute_args} - ); + program, compute_kernel, core_range, tt_metal::ComputeConfig{.compile_args = compute_args}); // track img h, img w, across cores ImgTrackingVars v; - CoreCoord prev_core = {0,0}; + CoreCoord prev_core = {0, 0}; bool input_flat_h_is_of_current_core = true; + const auto& cores = corerange_to_cores(all_cores, std::nullopt, true); for (uint32_t i = 0; i < num_cores; i++) { - CoreCoord core = {i % num_cores_x, i / num_cores_x}; - //cout << "i=" << i << endl; + const CoreCoord& core = cores[i]; + // cout << "i=" << i << endl; uint32_t input_end_flat_h = input_shard_height - 1; uint32_t output_end_flat_h = output_shard_height - 1; @@ -617,15 +718,20 @@ operation::ProgramWithCallbacks downsample_single_core(const Tensor &a, std::arr // halo region of previous core TT_ASSERT(i != 0); if (memory_layout == TensorMemoryLayout::BLOCK_SHARDED) { - TT_ASSERT(prev_core.y == core.y); // for block sharding case, prev core is left core + TT_ASSERT(prev_core.y == core.y); // for block sharding case, prev core is left core } halo_prev_read_enabled = true; TT_ASSERT(v.input_flat_h < input_shard_height); // get halo start tile address from height idx uint32_t halo_prev_start_tile_id_h = v.input_flat_h / TILE_HEIGHT; - TT_ASSERT(input_shard_height - v.input_flat_h <= TILE_HEIGHT * halo_prev_input_cb_max_rows_of_tiles); // halo input cb is hardcoded to store only 4 rows of tiles for now. TODO: allocate bigger CB or read in blocks + TT_ASSERT( + input_shard_height - v.input_flat_h <= + TILE_HEIGHT * + halo_prev_input_cb_max_rows_of_tiles); // halo input cb is hardcoded to store only 4 rows of tiles + // for now. TODO: allocate bigger CB or read in blocks // get halo size - halo_prev_size_bytes = (input_shard_height - (halo_prev_start_tile_id_h * TILE_HEIGHT)) * input_shard_width / TILE_HW * input_single_tile_size; + halo_prev_size_bytes = (input_shard_height - (halo_prev_start_tile_id_h * TILE_HEIGHT)) * + input_shard_width / TILE_HW * input_single_tile_size; TT_ASSERT(halo_prev_size_bytes % input_single_tile_size == 0); halo_prev_num_tiles = halo_prev_size_bytes / input_single_tile_size; TT_ASSERT(halo_prev_num_tiles <= num_halo_prev_cb_input_tiles); @@ -634,24 +740,35 @@ operation::ProgramWithCallbacks downsample_single_core(const Tensor &a, std::arr halo_prev_addr_offset = num_input_tiles_in_row * halo_prev_start_tile_id_h * input_single_tile_size; halo_prev_start_addr = GetCircularBufferConfig(program, input_cb).globally_allocated_address().value(); - TT_ASSERT((halo_prev_start_addr + halo_prev_addr_offset) % 32 == 0); // read address should be 32 byte aligned + TT_ASSERT( + (halo_prev_start_addr + halo_prev_addr_offset) % 32 == 0); // read address should be 32 byte aligned auto halo_noc_coords = device->worker_core_from_logical_core(prev_core); halo_prev_noc_x = halo_noc_coords.x; halo_prev_noc_y = halo_noc_coords.y; TT_ASSERT(v.input_flat_h >= halo_prev_start_tile_id_h * TILE_HEIGHT); halo_prev_read_pattern_offset = v.input_flat_h - (halo_prev_start_tile_id_h * TILE_HEIGHT); local_read_pattern_offset = halo_prev_input_num_rows_of_tiles * TILE_HEIGHT; - halo_prev_read_pattern_params = generate_downsample_read_pattern(v, img_height, img_width, img_stride_h, img_stride_w, input_end_flat_h, output_end_flat_h, true, false); + halo_prev_read_pattern_params = generate_downsample_read_pattern( + v, img_height, img_width, img_stride_h, img_stride_w, input_end_flat_h, output_end_flat_h, true, false); } // local core TT_ASSERT(v.output_flat_h < output_shard_height); uint32_t local_start_h = v.input_flat_h; - DownsampleReadPatternParams local_read_pattern_params = generate_downsample_read_pattern(v, img_height, img_width, img_stride_h, img_stride_w, current_core_input_end_flat_h, output_end_flat_h, false, false); + DownsampleReadPatternParams local_read_pattern_params = generate_downsample_read_pattern( + v, + img_height, + img_width, + img_stride_h, + img_stride_w, + current_core_input_end_flat_h, + output_end_flat_h, + false, + false); TT_ASSERT(v.output_flat_h <= output_shard_height); uint32_t local_end_h_exclusive = v.input_flat_h == 0 ? input_shard_height : v.input_flat_h; uint32_t local_num_rows = local_end_h_exclusive - local_start_h; TT_ASSERT(local_num_rows > 0); - uint32_t local_input_num_rows_of_tiles = std::ceil( (double) local_num_rows / (double) TILE_HEIGHT); + uint32_t local_input_num_rows_of_tiles = std::ceil((double)local_num_rows / (double)TILE_HEIGHT); uint32_t local_input_offset_rows_of_tiles = local_start_h / TILE_HEIGHT; if (local_start_h != 0) { TT_ASSERT(local_read_pattern_offset == 0); @@ -660,7 +777,7 @@ operation::ProgramWithCallbacks downsample_single_core(const Tensor &a, std::arr if (v.input_flat_h != 0) { input_flat_h_is_of_current_core = false; } else { - input_flat_h_is_of_current_core = true; // updating flag for next core + input_flat_h_is_of_current_core = true; // updating flag for next core } TT_ASSERT(local_input_num_rows_of_tiles <= num_rows_of_input_tiles); @@ -668,18 +785,33 @@ operation::ProgramWithCallbacks downsample_single_core(const Tensor &a, std::arr // need to read halo from next core TT_ASSERT(i != num_cores - 1); TT_ASSERT(v.input_flat_h == 0); - CoreCoord next_core = {(i+1)% num_cores_x, (i+1) / num_cores_x}; + const CoreCoord& next_core = cores[i + 1]; if (memory_layout == TensorMemoryLayout::BLOCK_SHARDED) { - TT_ASSERT(next_core.y == core.y); // for block sharding case, next core is right core + TT_ASSERT(next_core.y == core.y); // for block sharding case, next core is right core } halo_next_read_enabled = true; - halo_next_read_pattern_params = generate_downsample_read_pattern(v, img_height, img_width, img_stride_h, img_stride_w, next_core_input_end_flat_h, output_end_flat_h, false, true); + halo_next_read_pattern_params = generate_downsample_read_pattern( + v, + img_height, + img_width, + img_stride_h, + img_stride_w, + next_core_input_end_flat_h, + output_end_flat_h, + false, + true); TT_ASSERT(v.output_flat_h == 0); TT_ASSERT(v.input_flat_h != 0 && v.input_flat_h < input_shard_height); - TT_ASSERT(v.input_flat_h <= TILE_HEIGHT * halo_next_input_cb_max_rows_of_tiles, "v.input_flat_h ({}) should be <= TILE_HEIGHT * halo_next_input_cb_max_rows_of_tiles ({})", v.input_flat_h, halo_next_input_cb_max_rows_of_tiles); // halo next input cb is hardcoded to store only 5 rows of tiles for now. TODO: allocate bigger CB or read in blocks + TT_ASSERT( + v.input_flat_h <= TILE_HEIGHT * halo_next_input_cb_max_rows_of_tiles, + "v.input_flat_h ({}) should be <= TILE_HEIGHT * halo_next_input_cb_max_rows_of_tiles ({})", + v.input_flat_h, + halo_next_input_cb_max_rows_of_tiles); // halo next input cb is hardcoded to store only 5 rows of tiles + // for now. TODO: allocate bigger CB or read in blocks uint32_t halo_next_end_tile_id_h = v.input_flat_h / TILE_HEIGHT; // get halo size - halo_next_size_bytes = (halo_next_end_tile_id_h+1) * TILE_HEIGHT * input_shard_width / TILE_HW * input_single_tile_size; + halo_next_size_bytes = + (halo_next_end_tile_id_h + 1) * TILE_HEIGHT * input_shard_width / TILE_HW * input_single_tile_size; TT_ASSERT(halo_next_size_bytes % input_single_tile_size == 0); halo_next_num_tiles = halo_next_size_bytes / input_single_tile_size; TT_ASSERT(halo_next_num_tiles <= num_halo_next_cb_input_tiles); @@ -687,7 +819,8 @@ operation::ProgramWithCallbacks downsample_single_core(const Tensor &a, std::arr halo_next_input_num_rows_of_tiles = halo_next_num_tiles / num_input_tiles_in_row; halo_next_addr_offset = 0; halo_next_start_addr = GetCircularBufferConfig(program, input_cb).globally_allocated_address().value(); - TT_ASSERT((halo_next_start_addr + halo_next_addr_offset) % 32 == 0); // read address should be 32 byte aligned + TT_ASSERT( + (halo_next_start_addr + halo_next_addr_offset) % 32 == 0); // read address should be 32 byte aligned auto halo_noc_coords = device->worker_core_from_logical_core(next_core); halo_next_noc_x = halo_noc_coords.x; halo_next_noc_y = halo_noc_coords.y; @@ -706,19 +839,14 @@ operation::ProgramWithCallbacks downsample_single_core(const Tensor &a, std::arr halo_next_input_num_rows_of_tiles, }; - tt_metal::SetRuntimeArgs( - program, - downsample_compute_kernel_id, - core, - compile_rt_kernel_args - ); + tt_metal::SetRuntimeArgs(program, downsample_compute_kernel_id, core, compile_rt_kernel_args); // Writer runtime args vector writer_kernel_args = { - (uint32_t) img_height, - (uint32_t) img_width, - (uint32_t) img_stride_h, - (uint32_t) img_stride_w, + (uint32_t)img_height, + (uint32_t)img_width, + (uint32_t)img_stride_h, + (uint32_t)img_stride_w, // halo prev args halo_prev_read_enabled, @@ -784,46 +912,36 @@ operation::ProgramWithCallbacks downsample_single_core(const Tensor &a, std::arr num_input_tiles_in_row, num_output_tiles, - (uint32_t) false - }; + (uint32_t) false}; - tt_metal::SetRuntimeArgs( - program, - downsample_writer_kernel_id, - core, - writer_kernel_args - ); + tt_metal::SetRuntimeArgs(program, downsample_writer_kernel_id, core, writer_kernel_args); prev_core = core; } - auto override_runtime_args_callback = [ - input_cb=input_cb, - final_tilize_output_cb=final_tilize_output_cb, - downsample_writer_kernel_id=downsample_writer_kernel_id, - num_cores=num_cores, - num_cores_x=num_cores_x - ]( - const void* operation, - Program& program, - const std::vector& input_tensors, - const std::vector>& optional_input_tensors, - const std::vector& output_tensors - ) { - + auto override_runtime_args_callback = [input_cb = input_cb, + final_tilize_output_cb = final_tilize_output_cb, + downsample_writer_kernel_id = downsample_writer_kernel_id, + cores = cores]( + const void* operation, + Program& program, + const std::vector& input_tensors, + const std::vector>& optional_input_tensors, + const std::vector& output_tensors) { auto src_buffer = input_tensors.at(0).buffer(); auto dst_buffer = output_tensors.at(0).buffer(); UpdateDynamicCircularBufferAddress(program, input_cb, *src_buffer); UpdateDynamicCircularBufferAddress(program, final_tilize_output_cb, *dst_buffer); - for (uint32_t i = 0; i < num_cores; i++) { - CoreCoord core = {i % num_cores_x, i / num_cores_x}; - auto &runtime_args = GetRuntimeArgs(program, downsample_writer_kernel_id, core); + + auto& writer_runtime_args_by_core = GetRuntimeArgs(program, downsample_writer_kernel_id); + for (const auto& core : cores) { + auto& runtime_args = writer_runtime_args_by_core[core.x][core.y]; runtime_args[8] = src_buffer->address(); runtime_args[39] = src_buffer->address(); } }; - return {.program=std::move(program), .override_runtime_arguments_callback=override_runtime_args_callback}; + return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_args_callback}; } } // namespace tt_metal diff --git a/tt_eager/tt_dnn/op_library/move/multi_core/move_op_multi_core_sharded.cpp b/tt_eager/tt_dnn/op_library/move/multi_core/move_op_multi_core_sharded.cpp index 6affad184d9..cf44c6f136f 100644 --- a/tt_eager/tt_dnn/op_library/move/multi_core/move_op_multi_core_sharded.cpp +++ b/tt_eager/tt_dnn/op_library/move/multi_core/move_op_multi_core_sharded.cpp @@ -4,13 +4,12 @@ #include +#include "tt_dnn/op_library/math.hpp" #include "tt_dnn/op_library/move/move_op.hpp" #include "tt_dnn/op_library/work_split.hpp" -#include "tt_dnn/op_library/math.hpp" - -#include "tt_metal/host_api.hpp" -#include "tt_metal/detail/util.hpp" #include "tt_metal/common/constants.hpp" +#include "tt_metal/detail/util.hpp" +#include "tt_metal/host_api.hpp" using namespace tt::constants; @@ -19,7 +18,7 @@ namespace tt { namespace tt_metal { // Sharded buffers are mapped to CBs. Move from top of src CB to dst CB -operation::ProgramWithCallbacks move_multi_core_sharded(const Tensor &input, Tensor &output) { +operation::ProgramWithCallbacks move_multi_core_sharded(const Tensor& input, Tensor& output) { tt_metal::Program program{}; tt::DataFormat cb_data_format = datatype_to_dataformat_converter(input.get_dtype()); @@ -29,7 +28,9 @@ operation::ProgramWithCallbacks move_multi_core_sharded(const Tensor &input, Ten auto input_shape = input.get_legacy_shape(); auto input_dtype = input.get_dtype(); auto input_layout = input.get_layout(); - TT_FATAL(input_layout == output.get_layout() && input_dtype == output.get_dtype() && shard_shape == output.shard_spec().value().shape && input_shape == output.get_legacy_shape()); + TT_FATAL( + input_layout == output.get_layout() && input_dtype == output.get_dtype() && + shard_shape == output.shard_spec().value().shape && input_shape == output.get_legacy_shape()); const uint32_t src_cb_sharded = CB::c_in0; const uint32_t dst_cb_sharded = CB::c_in1; uint32_t tile_size_bytes = tile_size(cb_data_format); @@ -41,25 +42,32 @@ operation::ProgramWithCallbacks move_multi_core_sharded(const Tensor &input, Ten total_size_bytes = shard_shape_num_tiles * tile_size_bytes; page_size_bytes = tile_size_bytes; } else { - uint32_t datum_size_bytes = datum_size(cb_data_format); + uint32_t datum_size_bytes = datum_size(cb_data_format); total_size_bytes = shard_shape[0] * shard_shape[1] * datum_size_bytes; - page_size_bytes = shard_shape[1] * datum_size_bytes; + page_size_bytes = shard_shape[1] * datum_size_bytes; } - CircularBufferConfig src_cb_sharded_config = CircularBufferConfig(total_size_bytes, {{src_cb_sharded, cb_data_format}}) - .set_page_size(src_cb_sharded, page_size_bytes); + CircularBufferConfig src_cb_sharded_config = + CircularBufferConfig(total_size_bytes, {{src_cb_sharded, cb_data_format}}) + .set_page_size(src_cb_sharded, page_size_bytes); src_cb_sharded_config.set_globally_allocated_address(*input.buffer()); auto src_sharded_cb = tt_metal::CreateCircularBuffer(program, shard_grid, src_cb_sharded_config); - CircularBufferConfig dst_cb_sharded_config = CircularBufferConfig(total_size_bytes, {{dst_cb_sharded, cb_data_format}}) - .set_page_size(dst_cb_sharded, page_size_bytes); + CircularBufferConfig dst_cb_sharded_config = + CircularBufferConfig(total_size_bytes, {{dst_cb_sharded, cb_data_format}}) + .set_page_size(dst_cb_sharded, page_size_bytes); dst_cb_sharded_config.set_globally_allocated_address(*output.buffer()); auto dst_sharded_cb = tt_metal::CreateCircularBuffer(program, shard_grid, dst_cb_sharded_config); auto input_buffer_address = input.buffer()->address(); auto output_buffer_address = output.buffer()->address(); - TT_FATAL(output_buffer_address > input_buffer_address, "Expected output buffer to be allocated at a higher address than input buffer"); + TT_FATAL( + output_buffer_address > input_buffer_address, + "Expected output buffer to be allocated at a higher address than input buffer"); uint32_t move_chunk_size_bytes = output_buffer_address - input_buffer_address; - TT_FATAL(move_chunk_size_bytes % ADDRESS_ALIGNMENT == 0, "Expected chunk size bytes to move to be {} byte aligned.", ADDRESS_ALIGNMENT); + TT_FATAL( + move_chunk_size_bytes % ADDRESS_ALIGNMENT == 0, + "Expected chunk size bytes to move to be {} byte aligned.", + ADDRESS_ALIGNMENT); uint32_t num_chunks = total_size_bytes / move_chunk_size_bytes; uint32_t remainder_chunk_size_bytes = total_size_bytes % move_chunk_size_bytes; @@ -68,20 +76,24 @@ operation::ProgramWithCallbacks move_multi_core_sharded(const Tensor &input, Ten program, "tt_eager/tt_dnn/op_library/move/kernels/dataflow/reader_unary_local_l1_copy_backwards.cpp", shard_grid, - DataMovementConfig{.processor = DataMovementProcessor::RISCV_1, - .noc = NOC::NOC_1, - .compile_args = reader_compile_time_args} - ); - std::vector runtime_args = {total_size_bytes, num_chunks, move_chunk_size_bytes, remainder_chunk_size_bytes}; + DataMovementConfig{ + .processor = DataMovementProcessor::RISCV_1, .noc = NOC::NOC_1, .compile_args = reader_compile_time_args}); + std::vector runtime_args = { + total_size_bytes, num_chunks, move_chunk_size_bytes, remainder_chunk_size_bytes}; SetRuntimeArgs(program, kernel_id, shard_grid, runtime_args); - auto override_runtime_args_callback = [shard_grid=shard_grid, kernel_id=kernel_id, src_sharded_cb=src_sharded_cb, dst_sharded_cb=dst_sharded_cb, total_size_bytes=total_size_bytes]( - const void* operation, - Program& program, - const std::vector& input_tensors, - const std::vector>& optional_input_tensors, - const std::vector& output_tensors - ) { + const auto& cores = corerange_to_cores(shard_grid, std::nullopt, true); + auto override_runtime_args_callback = [shard_grid = shard_grid, + kernel_id = kernel_id, + src_sharded_cb = src_sharded_cb, + dst_sharded_cb = dst_sharded_cb, + total_size_bytes = total_size_bytes, + cores = cores]( + const void* operation, + Program& program, + const std::vector& input_tensors, + const std::vector>& optional_input_tensors, + const std::vector& output_tensors) { auto src_buffer = input_tensors.at(0).buffer(); auto dst_buffer = output_tensors.at(0).buffer(); UpdateDynamicCircularBufferAddress(program, src_sharded_cb, *src_buffer); @@ -91,11 +103,16 @@ operation::ProgramWithCallbacks move_multi_core_sharded(const Tensor &input, Ten uint32_t move_chunk_size_bytes = output_buffer_address - input_buffer_address; uint32_t num_chunks = total_size_bytes / move_chunk_size_bytes; uint32_t remainder_chunk_size_bytes = total_size_bytes % move_chunk_size_bytes; - std::vector runtime_args = {total_size_bytes, num_chunks, move_chunk_size_bytes, remainder_chunk_size_bytes}; - SetRuntimeArgs(program, kernel_id, shard_grid, runtime_args); + std::vector new_runtime_args = { + total_size_bytes, num_chunks, move_chunk_size_bytes, remainder_chunk_size_bytes}; + auto& runtime_args_by_core = GetRuntimeArgs(program, kernel_id); + for (const auto& core : cores) { + auto& runtime_args = runtime_args_by_core[core.x][core.y]; + std::copy(new_runtime_args.begin(), new_runtime_args.end(), runtime_args.data()); + } }; - return {.program=std::move(program), .override_runtime_arguments_callback=override_runtime_args_callback}; + return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_args_callback}; } } // namespace tt_metal diff --git a/tt_eager/tt_dnn/op_library/pool/max_pool_multi_core.cpp b/tt_eager/tt_dnn/op_library/pool/max_pool_multi_core.cpp index 289ffcbc19d..a034907e652 100644 --- a/tt_eager/tt_dnn/op_library/pool/max_pool_multi_core.cpp +++ b/tt_eager/tt_dnn/op_library/pool/max_pool_multi_core.cpp @@ -11,9 +11,8 @@ #include "tt_dnn/op_library/pool/max_pool.hpp" #include "tt_dnn/op_library/reduce/reduce_op.hpp" // for reduce_op_utils #include "tt_dnn/op_library/sharding_utilities.hpp" -#include "tt_dnn/op_library/sliding_window_op_infra/utils.hpp" -#include "tt_dnn/op_library/work_split.hpp" #include "tt_dnn/op_library/sliding_window_op_infra/sliding_window.hpp" +#include "tt_dnn/op_library/sliding_window_op_infra/utils.hpp" #include "tt_dnn/op_library/work_split.hpp" #include "tt_metal/host_api.hpp" @@ -41,7 +40,8 @@ CoreCoord get_ncores_hw(uint32_t h, uint32_t w, uint32_t avail_cores_h, uint32_t return cores_shape; } -std::tuple get_decomposition_h(uint32_t out_h, uint32_t ncores_h, uint32_t ncores_w) { +std::tuple get_decomposition_h( + uint32_t out_h, uint32_t ncores_h, uint32_t ncores_w) { uint32_t out_h_per_core = out_h / (ncores_h * ncores_w); uint32_t out_h_per_core_cliff = out_h % (ncores_h * ncores_w); std::set core_range, core_range_cliff; @@ -53,7 +53,8 @@ std::tuple get_decomp core_range.insert(CoreRange(CoreCoord(0, 0), CoreCoord(ncores_w - 2, ncores_h - 1))); // last row but last core, only the last core is cliff (1D, not 2D) core_range.insert(CoreRange(CoreCoord(0, ncores_h - 1), CoreCoord(ncores_w - 2, ncores_h - 1))); - core_range_cliff.insert(CoreRange(CoreCoord(ncores_w - 1, ncores_h - 1), CoreCoord(ncores_w - 1, ncores_h - 1))); + core_range_cliff.insert( + CoreRange(CoreCoord(ncores_w - 1, ncores_h - 1), CoreCoord(ncores_w - 1, ncores_h - 1))); } CoreRange all_cores(CoreCoord(0, 0), CoreCoord(ncores_w - 1, ncores_h - 1)); return std::make_tuple(all_cores, core_range, core_range_cliff, out_h_per_core, out_h_per_core_cliff); @@ -70,27 +71,27 @@ uint32_t get_num_cores(const Device* device, uint32_t out_nhw, uint32_t nbatch) case 1024: // test case ncores = 32; break; - case 2048: // test case - case 4096: // test case - case 8192: // test case + case 2048: // test case + case 4096: // test case + case 8192: // test case case 16384: // test case case 32768: // test case ncores = 64; break; - case 3136: // nbatch = 1 - case 6272: // nbatch = 2 - case 12544: // nbatch = 4 - case 25088: // nbatch = 8 - case 50176: // nbatch = 16 - case 62720: // nbatch = 20 + case 3136: // nbatch = 1 + case 6272: // nbatch = 2 + case 12544: // nbatch = 4 + case 25088: // nbatch = 8 + case 50176: // nbatch = 16 + case 62720: // nbatch = 20 ncores = 98; break; - case 784: // test case + case 784: // test case ncores = 49; break; default: // TT_ASSERT(false, "General case is not yet handled! Only RN50 shapes supported in multicore."); - uint32_t out_nhw_per_core = (uint32_t) ceil((float) out_nhw / avail_ncores); + uint32_t out_nhw_per_core = (uint32_t)ceil((float)out_nhw / avail_ncores); ncores = out_nhw / out_nhw_per_core; while (avail_ncores > 0) { if (out_nhw % avail_ncores == 0 && (out_nhw / avail_ncores) % TILE_HEIGHT == 0) { @@ -99,11 +100,11 @@ uint32_t get_num_cores(const Device* device, uint32_t out_nhw, uint32_t nbatch) } --avail_ncores; } - ncores = std::max(avail_ncores, (uint32_t) 1); + ncores = std::max(avail_ncores, (uint32_t)1); break; } } else if (device->arch() == ARCH::WORMHOLE_B0) { - uint32_t out_nhw_per_core = (uint32_t) ceil((float) out_nhw / avail_ncores); + uint32_t out_nhw_per_core = (uint32_t)ceil((float)out_nhw / avail_ncores); ncores = out_nhw / out_nhw_per_core; while (avail_ncores > 0) { if (out_nhw % avail_ncores == 0 && (out_nhw / avail_ncores) % TILE_HEIGHT == 0) { @@ -112,11 +113,12 @@ uint32_t get_num_cores(const Device* device, uint32_t out_nhw, uint32_t nbatch) } --avail_ncores; } - ncores = std::max(avail_ncores, (uint32_t) 1); + ncores = std::max(avail_ncores, (uint32_t)1); } else { TT_THROW("Unsupported device arch: {}", device->arch()); } - if (ncores == 0) TT_THROW("ncores = 0!"); + if (ncores == 0) + TT_THROW("ncores = 0!"); return ncores; } @@ -135,10 +137,11 @@ get_decomposition_nhw(const Device* device, uint32_t in_nhw, uint32_t out_nhw, u out_nhw_per_core = out_nhw / ncores; in_nhw_per_core = in_nhw / ncores; - uint32_t ncores_w = grid_size.x; // 12 + uint32_t ncores_w = grid_size.x; // 12 uint32_t ncores_h = ncores / ncores_w; uint32_t ncores_cliff_h = 0; - if (ncores % ncores_w != 0) ncores_cliff_h = 1; + if (ncores % ncores_w != 0) + ncores_cliff_h = 1; uint32_t ncores_cliff_w = ncores % ncores_w; // NOTE: Cliff core is not yet handled, assuming (out_nhw / ncores) is a whole number. uint32_t in_nhw_per_core_cliff = 0; @@ -153,27 +156,43 @@ get_decomposition_nhw(const Device* device, uint32_t in_nhw, uint32_t out_nhw, u all_cores.insert(CoreRange(CoreCoord(0, ncores_h), CoreCoord(ncores_cliff_w - 1, ncores_h))); } - return std::make_tuple(ncores, all_cores, core_range, core_range_cliff, in_nhw_per_core, in_nhw_per_core_cliff, out_nhw_per_core, out_nhw_per_core_cliff); + return std::make_tuple( + ncores, + all_cores, + core_range, + core_range_cliff, + in_nhw_per_core, + in_nhw_per_core_cliff, + out_nhw_per_core, + out_nhw_per_core_cliff); } -} // namespacce max_pool_helpers +} // namespace max_pool_helpers // this version uses distribution along height = N * H * W -operation::ProgramWithCallbacks max_pool_2d_multi_core_generic(const Tensor &input, Tensor& output, - uint32_t in_h, uint32_t in_w, - uint32_t out_h, uint32_t out_w, - uint32_t kernel_size_h, uint32_t kernel_size_w, - uint32_t stride_h, uint32_t stride_w, - uint32_t pad_h, uint32_t pad_w, - uint32_t dilation_h, uint32_t dilation_w, - const MemoryConfig& out_mem_config, - uint32_t nblocks) { +operation::ProgramWithCallbacks max_pool_2d_multi_core_generic( + const Tensor& input, + Tensor& output, + uint32_t in_h, + uint32_t in_w, + uint32_t out_h, + uint32_t out_w, + uint32_t kernel_size_h, + uint32_t kernel_size_w, + uint32_t stride_h, + uint32_t stride_w, + uint32_t pad_h, + uint32_t pad_w, + uint32_t dilation_h, + uint32_t dilation_w, + const MemoryConfig& out_mem_config, + uint32_t nblocks) { Program program = CreateProgram(); // This should allocate a DRAM buffer on the device - Device *device = input.device(); - Buffer *src_dram_buffer = input.buffer(); - Buffer *dst_dram_buffer = output.buffer(); + Device* device = input.device(); + Buffer* src_dram_buffer = input.buffer(); + Buffer* dst_dram_buffer = output.buffer(); Shape input_shape = input.get_legacy_shape(); Shape output_shape = output.get_legacy_shape(); @@ -185,24 +204,25 @@ operation::ProgramWithCallbacks max_pool_2d_multi_core_generic(const Tensor &inp DataFormat out_df = datatype_to_dataformat_converter(output.get_dtype()); uint32_t in_nbytes = datum_size(in_df); uint32_t out_nbytes = datum_size(out_df); - uint32_t in_nbytes_c = input_shape[3] * in_nbytes; // row of input (channels) - uint32_t out_nbytes_c = output_shape[3] * out_nbytes; // row of output (channels) - TT_ASSERT((in_nbytes_c & (in_nbytes_c - 1)) == 0, "in_nbytes_c should be power of 2"); // in_nbytes_c is power of 2 - TT_ASSERT((out_nbytes_c & (out_nbytes_c - 1)) == 0, "out_nbytes_c should be power of 2"); // out_nbytes_c is power of 2 + uint32_t in_nbytes_c = input_shape[3] * in_nbytes; // row of input (channels) + uint32_t out_nbytes_c = output_shape[3] * out_nbytes; // row of output (channels) + TT_ASSERT((in_nbytes_c & (in_nbytes_c - 1)) == 0, "in_nbytes_c should be power of 2"); // in_nbytes_c is power of 2 + TT_ASSERT( + (out_nbytes_c & (out_nbytes_c - 1)) == 0, "out_nbytes_c should be power of 2"); // out_nbytes_c is power of 2 uint32_t nbatch = input_shape[0]; TT_ASSERT(nbatch == output_shape[0], "Mismatch in N for input and output!!"); - uint32_t kernel_size_hw = kernel_size_w * kernel_size_h; // number of valid rows, to read + uint32_t kernel_size_hw = kernel_size_w * kernel_size_h; // number of valid rows, to read uint32_t kernel_size_hw_padded = ceil_multiple_of(kernel_size_hw, constants::TILE_HEIGHT); - uint32_t in_ntiles_hw = (uint32_t) ceil((float) kernel_size_hw_padded / constants::TILE_HEIGHT); - uint32_t in_ntiles_c = (uint32_t) ceil((float) input_shape[3] / constants::TILE_WIDTH); - uint32_t out_ntiles_hw = (uint32_t) ceil((float) output_shape[2] / constants::TILE_HEIGHT); - uint32_t out_ntiles_c = (uint32_t) ceil((float) output_shape[3] / constants::TILE_WIDTH); + uint32_t in_ntiles_hw = (uint32_t)ceil((float)kernel_size_hw_padded / constants::TILE_HEIGHT); + uint32_t in_ntiles_c = (uint32_t)ceil((float)input_shape[3] / constants::TILE_WIDTH); + uint32_t out_ntiles_hw = (uint32_t)ceil((float)output_shape[2] / constants::TILE_HEIGHT); + uint32_t out_ntiles_c = (uint32_t)ceil((float)output_shape[3] / constants::TILE_WIDTH); uint32_t out_nelems = nblocks; // TODO [AS]: Remove hard coding after identifying optimal param val // Also ensure the calculated ncores is good - uint32_t out_w_loop_count = ceil((float) out_w / out_nelems); + uint32_t out_w_loop_count = ceil((float)out_w / out_nelems); uint32_t in_hw = in_h * in_w; uint32_t in_nhw = in_hw * nbatch; @@ -211,7 +231,15 @@ operation::ProgramWithCallbacks max_pool_2d_multi_core_generic(const Tensor &inp // distributing out_hw across the grid auto grid_size = device->compute_with_storage_grid_size(); - auto [ncores, all_cores, core_range, core_range_cliff, in_nhw_per_core, in_nhw_per_core_cliff, out_nhw_per_core, out_nhw_per_core_cliff] = max_pool_helpers::get_decomposition_nhw(device, in_nhw, out_nhw, nbatch); + auto + [ncores, + all_cores, + core_range, + core_range_cliff, + in_nhw_per_core, + in_nhw_per_core_cliff, + out_nhw_per_core, + out_nhw_per_core_cliff] = max_pool_helpers::get_decomposition_nhw(device, in_nhw, out_nhw, nbatch); if (input.memory_config().is_sharded()) { all_cores = input.shard_spec().value().grid; uint32_t ncores = all_cores.num_cores(); @@ -225,11 +253,16 @@ operation::ProgramWithCallbacks max_pool_2d_multi_core_generic(const Tensor &inp uint32_t ncores_w = grid_size.x; // TODO: support generic nblocks - TT_ASSERT(out_nhw_per_core % nblocks == 0, "number of sticks per core ({}) should be divisible by nblocks ({})", out_nhw_per_core, nblocks); + TT_ASSERT( + out_nhw_per_core % nblocks == 0, + "number of sticks per core ({}) should be divisible by nblocks ({})", + out_nhw_per_core, + nblocks); // TODO: support generic values for in_nhw_per_core - TT_ASSERT((in_nhw_per_core & (in_nhw_per_core - 1)) == 0, "in_nhw_per_core {} needs to be power of 2!", in_nhw_per_core); + TT_ASSERT( + (in_nhw_per_core & (in_nhw_per_core - 1)) == 0, "in_nhw_per_core {} needs to be power of 2!", in_nhw_per_core); - uint32_t in_nhw_per_core_rem_mask = in_nhw_per_core - 1; // NOTE: assuming in_nhw_per_core is power of 2 + uint32_t in_nhw_per_core_rem_mask = in_nhw_per_core - 1; // NOTE: assuming in_nhw_per_core is power of 2 // CBs uint32_t multi_buffering_factor = 2; @@ -238,10 +271,9 @@ operation::ProgramWithCallbacks max_pool_2d_multi_core_generic(const Tensor &inp uint32_t in_scalar_cb_id = CB::c_in4; uint32_t in_scalar_cb_pagesize = tile_size(in_df); uint32_t in_scalar_cb_npages = 1; - CircularBufferConfig in_scalar_cb_config = CircularBufferConfig( - in_scalar_cb_npages * in_scalar_cb_pagesize, - {{in_scalar_cb_id, in_df}}) - .set_page_size(in_scalar_cb_id, in_scalar_cb_pagesize); + CircularBufferConfig in_scalar_cb_config = + CircularBufferConfig(in_scalar_cb_npages * in_scalar_cb_pagesize, {{in_scalar_cb_id, in_df}}) + .set_page_size(in_scalar_cb_id, in_scalar_cb_pagesize); auto in_scalar_cb = tt_metal::CreateCircularBuffer(program, all_cores, in_scalar_cb_config); CBHandle raw_in_cb = 0; @@ -250,64 +282,70 @@ operation::ProgramWithCallbacks max_pool_2d_multi_core_generic(const Tensor &inp auto raw_in_cb_id = CB::c_in2; uint32_t raw_in_cb_npages = in_nhw_per_core; uint32_t raw_in_cb_pagesize = in_nbytes_c; - CircularBufferConfig raw_in_cb_config = CircularBufferConfig( - raw_in_cb_npages * raw_in_cb_pagesize, - {{raw_in_cb_id, in_df}}) - .set_page_size(raw_in_cb_id, raw_in_cb_pagesize) - .set_globally_allocated_address(*input.buffer()); + CircularBufferConfig raw_in_cb_config = + CircularBufferConfig(raw_in_cb_npages * raw_in_cb_pagesize, {{raw_in_cb_id, in_df}}) + .set_page_size(raw_in_cb_id, raw_in_cb_pagesize) + .set_globally_allocated_address(*input.buffer()); raw_in_cb = CreateCircularBuffer(program, all_cores, raw_in_cb_config); } // reader output == input to tilize - uint32_t in_cb_id = CB::c_in0; // input rows for "multiple (out_nelems)" output pixels - uint32_t in_cb_page_nelems_padded = ceil_multiple_of(input_shape[3] * kernel_size_hw_padded, constants::TILE_HW); // NOTE: ceil to tile size since triscs work with tilesize instead of pagesize + uint32_t in_cb_id = CB::c_in0; // input rows for "multiple (out_nelems)" output pixels + uint32_t in_cb_page_nelems_padded = ceil_multiple_of( + input_shape[3] * kernel_size_hw_padded, + constants::TILE_HW); // NOTE: ceil to tile size since triscs work with tilesize instead of pagesize uint32_t in_cb_pagesize = in_nbytes * in_cb_page_nelems_padded; uint32_t in_cb_npages = multi_buffering_factor * out_nelems; CircularBufferConfig in_cb_config = CircularBufferConfig(in_cb_npages * in_cb_pagesize, {{in_cb_id, in_df}}) - .set_page_size(in_cb_id, in_cb_pagesize); + .set_page_size(in_cb_id, in_cb_pagesize); auto in_cb = tt_metal::CreateCircularBuffer(program, all_cores, in_cb_config); // output of tilize == input to reduce uint32_t in_tiled_cb_id = CB::c_intermed0; // tiled input uint32_t in_tiled_cb_pagesize = tile_size(in_df); uint32_t in_tiled_cb_npages = in_ntiles_c * in_ntiles_hw * out_nelems; - CircularBufferConfig in_tiled_cb_config = CircularBufferConfig(in_tiled_cb_npages * in_tiled_cb_pagesize, {{in_tiled_cb_id, in_df}}) - .set_page_size(in_tiled_cb_id, in_tiled_cb_pagesize); + CircularBufferConfig in_tiled_cb_config = + CircularBufferConfig(in_tiled_cb_npages * in_tiled_cb_pagesize, {{in_tiled_cb_id, in_df}}) + .set_page_size(in_tiled_cb_id, in_tiled_cb_pagesize); auto in_tiled_cb = tt_metal::CreateCircularBuffer(program, all_cores, in_tiled_cb_config); // output of reduce == writer to write - uint32_t out_cb_id = CB::c_out0; // output rows in RM - uint32_t out_cb_pagesize = tile_size(out_df) * out_ntiles_c * out_nelems; + uint32_t out_cb_id = CB::c_out0; // output rows in RM + uint32_t out_cb_pagesize = tile_size(out_df) * out_ntiles_c * out_nelems; uint32_t out_cb_npages = multi_buffering_factor; CircularBufferConfig cb_out_config = CircularBufferConfig(out_cb_npages * out_cb_pagesize, {{out_cb_id, out_df}}) - .set_page_size(out_cb_id, out_cb_pagesize); + .set_page_size(out_cb_id, out_cb_pagesize); auto cb_out = tt_metal::CreateCircularBuffer(program, all_cores, cb_out_config); CBHandle cb_sharded_out = 0; if (output.memory_config().is_sharded()) { - uint32_t sharded_out_cb_id = CB::c_out1; // output rows in RM + uint32_t sharded_out_cb_id = CB::c_out1; // output rows in RM uint32_t sharded_out_num_pages = output.shard_spec().value().shape[0]; - uint32_t sharded_out_cb_page_size = output.shard_spec().value().shape[1] * out_nbytes; // there is just one row of channels after reduction - CircularBufferConfig cb_sharded_out_config = CircularBufferConfig(sharded_out_num_pages * sharded_out_cb_page_size, {{sharded_out_cb_id, out_df}}) - .set_page_size(sharded_out_cb_id, sharded_out_cb_page_size).set_globally_allocated_address(*output.buffer()); + uint32_t sharded_out_cb_page_size = + output.shard_spec().value().shape[1] * out_nbytes; // there is just one row of channels after reduction + CircularBufferConfig cb_sharded_out_config = + CircularBufferConfig(sharded_out_num_pages * sharded_out_cb_page_size, {{sharded_out_cb_id, out_df}}) + .set_page_size(sharded_out_cb_id, sharded_out_cb_page_size) + .set_globally_allocated_address(*output.buffer()); cb_sharded_out = tt_metal::CreateCircularBuffer(program, all_cores, cb_sharded_out_config); } // Construct const buffer with -INF // uint32_t const_buffer_size = 32; - uint32_t const_buffer_size = input_shape[3]; // set it equal to 1 row + uint32_t const_buffer_size = input_shape[3]; // set it equal to 1 row auto minus_inf_const_buffer = owned_buffer::create(std::vector(const_buffer_size, bfloat16(0xf7ff))); - const Tensor minus_inf_const_tensor = Tensor(OwnedStorage{minus_inf_const_buffer}, - Shape({1, 1, 1, const_buffer_size}), - DataType::BFLOAT16, - Layout::ROW_MAJOR) - .to(device, MemoryConfig{.memory_layout = TensorMemoryLayout::INTERLEAVED, - .buffer_type = BufferType::L1}); + const Tensor minus_inf_const_tensor = + Tensor( + OwnedStorage{minus_inf_const_buffer}, + Shape({1, 1, 1, const_buffer_size}), + DataType::BFLOAT16, + Layout::ROW_MAJOR) + .to(device, MemoryConfig{.memory_layout = TensorMemoryLayout::INTERLEAVED, .buffer_type = BufferType::L1}); auto minus_inf_const_tensor_addr = minus_inf_const_tensor.buffer()->address(); - #if 0 +#if 0 { // debug log_debug(LogOp, "in_cb :: PS = {}, NP = {}", in_cb_pagesize, in_cb_npages); log_debug(LogOp, "in_scalar_cb :: PS = {}, NP = {}", in_scalar_cb_pagesize, in_scalar_cb_npages); @@ -350,14 +388,18 @@ operation::ProgramWithCallbacks max_pool_2d_multi_core_generic(const Tensor &inp log_debug(LogOp, "is_in_sharded: {}", input.memory_config().is_sharded()); log_debug(LogOp, "is_out_sharded: {}", output.memory_config().is_sharded()); } - #endif +#endif const uint32_t reader_noc = 0; const uint32_t writer_noc = 1; std::map left_neighbor_core, right_neighbor_core; if (input.memory_config().is_sharded()) { - utils::init_neighbor_core_xy_mapping(grid_size, left_neighbor_core, right_neighbor_core, input.memory_config().memory_layout == TensorMemoryLayout::BLOCK_SHARDED); + utils::init_neighbor_core_xy_mapping( + grid_size, + left_neighbor_core, + right_neighbor_core, + input.memory_config().memory_layout == TensorMemoryLayout::BLOCK_SHARDED); } /** @@ -365,81 +407,97 @@ operation::ProgramWithCallbacks max_pool_2d_multi_core_generic(const Tensor &inp */ float one = 1.; uint32_t bf16_one_u32 = *reinterpret_cast(&one); - std::vector reader_ct_args = {input.memory_config().buffer_type == BufferType::DRAM ? (uint) 1 : (uint) 0, - out_mem_config.buffer_type == BufferType::DRAM ? (uint) 1 : (uint) 0, - bf16_one_u32, - out_nelems, - static_cast(((in_nbytes_c & (in_nbytes_c - 1)) == 0) ? 1 : 0), // is in_nbytes_c power of 2 - stride_h, - stride_w, - reader_noc, - writer_noc}; - uint32_t in_log_base_2_of_page_size = (uint32_t) std::log2((float) in_nbytes_c); - std::vector reader_rt_args = {src_dram_buffer->address(), - dst_dram_buffer->address(), - kernel_size_h, kernel_size_w, kernel_size_hw, kernel_size_hw_padded, - stride_h, stride_w, - pad_h, pad_w, - out_h, out_w, output_shape[2], output_shape[3], - in_nbytes_c, out_nbytes_c, - in_h, in_w, input_shape[2], input_shape[3], - out_ntiles_hw, out_ntiles_c, - in_cb_pagesize, out_cb_pagesize, - in_cb_page_nelems_padded, out_w_loop_count, - in_log_base_2_of_page_size, - nbatch, - in_hw, - out_hw, - // these are set later in the following - 0, // start_out_h_i - 0, // end_out_h_i - 0, // base_start_h - 0, // start_out_row_id - minus_inf_const_tensor_addr, - const_buffer_size * in_nbytes, - (in_cb_page_nelems_padded * out_nelems * 2) >> 5, // TODO: generalize num rows to fill in in_cb - 0, // core_offset_in_row_id - 0, // core_out_w_i_start - 0, // core_out_h_i_start - out_nhw_per_core, // nsticks_per_core - 0, // core_offset_out_row_id - out_nhw_per_core / nblocks, // loop count with blocks - // the following are for sharded input - 0, // 43: local_out_stick_start - out_hw, // out_nsticks_per_batch - 0, // local_in_stick_start - 0, // local_in_stick_end - in_hw, // in_nsticks_per_batch - in_nhw_per_core, // in_nsticks_per_core - 0, // has_left - 0, // left_noc_x - 0, // left_noc_y - 0, // has_right - 0, // right_noc_x - 0, // right_noc_y - in_nhw_per_core_rem_mask, - 0, // 56: has_left_left, - 0, // left_left_noc_x, - 0, // left_left_noc_y, - 0, // has_right_right, - 0, // right_right_noc_x, - 0, // right_right_noc_y, - 0, // left_in_stick_start, - 0, // right_in_stick_end, - 0, // my_core - }; + std::vector reader_ct_args = { + input.memory_config().buffer_type == BufferType::DRAM ? (uint)1 : (uint)0, + out_mem_config.buffer_type == BufferType::DRAM ? (uint)1 : (uint)0, + bf16_one_u32, + out_nelems, + static_cast(((in_nbytes_c & (in_nbytes_c - 1)) == 0) ? 1 : 0), // is in_nbytes_c power of 2 + stride_h, + stride_w, + reader_noc, + writer_noc}; + uint32_t in_log_base_2_of_page_size = (uint32_t)std::log2((float)in_nbytes_c); + std::vector reader_rt_args = { + src_dram_buffer->address(), + dst_dram_buffer->address(), + kernel_size_h, + kernel_size_w, + kernel_size_hw, + kernel_size_hw_padded, + stride_h, + stride_w, + pad_h, + pad_w, + out_h, + out_w, + output_shape[2], + output_shape[3], + in_nbytes_c, + out_nbytes_c, + in_h, + in_w, + input_shape[2], + input_shape[3], + out_ntiles_hw, + out_ntiles_c, + in_cb_pagesize, + out_cb_pagesize, + in_cb_page_nelems_padded, + out_w_loop_count, + in_log_base_2_of_page_size, + nbatch, + in_hw, + out_hw, + // these are set later in the following + 0, // start_out_h_i + 0, // end_out_h_i + 0, // base_start_h + 0, // start_out_row_id + minus_inf_const_tensor_addr, + const_buffer_size * in_nbytes, + (in_cb_page_nelems_padded * out_nelems * 2) >> 5, // TODO: generalize num rows to fill in in_cb + 0, // core_offset_in_row_id + 0, // core_out_w_i_start + 0, // core_out_h_i_start + out_nhw_per_core, // nsticks_per_core + 0, // core_offset_out_row_id + out_nhw_per_core / nblocks, // loop count with blocks + // the following are for sharded input + 0, // 43: local_out_stick_start + out_hw, // out_nsticks_per_batch + 0, // local_in_stick_start + 0, // local_in_stick_end + in_hw, // in_nsticks_per_batch + in_nhw_per_core, // in_nsticks_per_core + 0, // has_left + 0, // left_noc_x + 0, // left_noc_y + 0, // has_right + 0, // right_noc_x + 0, // right_noc_y + in_nhw_per_core_rem_mask, + 0, // 56: has_left_left, + 0, // left_left_noc_x, + 0, // left_left_noc_y, + 0, // has_right_right, + 0, // right_right_noc_x, + 0, // right_right_noc_y, + 0, // left_in_stick_start, + 0, // right_in_stick_end, + 0, // my_core + }; auto reader_config = ReaderDataMovementConfig(reader_ct_args); std::string reader_kernel_fname; if (input.memory_config().is_sharded()) { // sharded, without halo - reader_kernel_fname = std::string("tt_eager/tt_dnn/op_library/pool/kernels/dataflow/reader_max_pool_2d_multi_core_sharded.cpp"); + reader_kernel_fname = + std::string("tt_eager/tt_dnn/op_library/pool/kernels/dataflow/reader_max_pool_2d_multi_core_sharded.cpp"); } else { - reader_kernel_fname = std::string("tt_eager/tt_dnn/op_library/pool/kernels/dataflow/reader_max_pool_2d_multi_core.cpp"); + reader_kernel_fname = + std::string("tt_eager/tt_dnn/op_library/pool/kernels/dataflow/reader_max_pool_2d_multi_core.cpp"); } - auto reader_kernel = CreateKernel(program, - reader_kernel_fname, - all_cores, - reader_config); + auto reader_kernel = CreateKernel(program, reader_kernel_fname, all_cores, reader_config); /** * Writer Kernel: output cb -> output rows @@ -450,76 +508,73 @@ operation::ProgramWithCallbacks max_pool_2d_multi_core_generic(const Tensor &inp } std::vector writer_ct_args = reader_ct_args; auto writer_config = WriterDataMovementConfig(writer_ct_args, writer_defines); - std::string writer_kernel_fname("tt_eager/tt_dnn/op_library/pool/kernels/dataflow/writer_max_pool_2d_multi_core.cpp"); - auto writer_kernel = CreateKernel(program, - writer_kernel_fname, - all_cores, - writer_config); + std::string writer_kernel_fname( + "tt_eager/tt_dnn/op_library/pool/kernels/dataflow/writer_max_pool_2d_multi_core.cpp"); + auto writer_kernel = CreateKernel(program, writer_kernel_fname, all_cores, writer_config); /** - * Compute Kernel: input cb -> tilize_block -> input tiles -> reduce_h max -> output tiles -> untilize_block -> output cb + * Compute Kernel: input cb -> tilize_block -> input tiles -> reduce_h max -> output tiles -> untilize_block -> + * output cb */ - std::vector compute_ct_args = {in_ntiles_hw, - in_ntiles_c, - in_ntiles_hw * in_ntiles_c, - kernel_size_hw, - out_h, - out_w, - (uint32_t) ceil((float) output_shape[2] / constants::TILE_HEIGHT), - (uint32_t) ceil((float) output_shape[3] / constants::TILE_WIDTH), - out_nelems, - out_w_loop_count, - nbatch, - out_nhw_per_core, - 0, // Split reader - out_nhw_per_core / nblocks, // loop count with blocks - input_shape[3], - }; + std::vector compute_ct_args = { + in_ntiles_hw, + in_ntiles_c, + in_ntiles_hw * in_ntiles_c, + kernel_size_hw, + out_h, + out_w, + (uint32_t)ceil((float)output_shape[2] / constants::TILE_HEIGHT), + (uint32_t)ceil((float)output_shape[3] / constants::TILE_WIDTH), + out_nelems, + out_w_loop_count, + nbatch, + out_nhw_per_core, + 0, // Split reader + out_nhw_per_core / nblocks, // loop count with blocks + input_shape[3], + }; auto compute_ct_args_cliff = compute_ct_args; auto reduce_op = ReduceOpMath::MAX; auto reduce_dim = ReduceOpDim::H; - auto compute_config = ComputeConfig{.math_fidelity = MathFidelity::HiFi4, - .fp32_dest_acc_en = false, - .math_approx_mode = false, - .compile_args = compute_ct_args, - .defines = reduce_op_utils::get_defines(reduce_op, reduce_dim)}; + auto compute_config = ComputeConfig{ + .math_fidelity = MathFidelity::HiFi4, + .fp32_dest_acc_en = false, + .math_approx_mode = false, + .compile_args = compute_ct_args, + .defines = reduce_op_utils::get_defines(reduce_op, reduce_dim)}; std::string compute_kernel_fname("tt_eager/tt_dnn/op_library/pool/kernels/compute/max_pool_multi_core.cpp"); - auto compute_kernel = CreateKernel(program, - compute_kernel_fname, - core_range, - compute_config); + auto compute_kernel = CreateKernel(program, compute_kernel_fname, core_range, compute_config); if (out_nhw_per_core_cliff > 0) { - TT_ASSERT(false, "The cliff core case is not yet handled"); // TODO + TT_ASSERT(false, "The cliff core case is not yet handled"); // TODO // there is a cliff core compute_ct_args_cliff[11] = out_nhw_per_core_cliff; - auto compute_config_cliff = ComputeConfig{.math_fidelity = MathFidelity::HiFi4, - .fp32_dest_acc_en = false, - .math_approx_mode = false, - .compile_args = compute_ct_args_cliff, - .defines = reduce_op_utils::get_defines(reduce_op, reduce_dim)}; - auto compute_kernel_cliff = CreateKernel(program, - compute_kernel_fname, - core_range_cliff, - compute_config); + auto compute_config_cliff = ComputeConfig{ + .math_fidelity = MathFidelity::HiFi4, + .fp32_dest_acc_en = false, + .math_approx_mode = false, + .compile_args = compute_ct_args_cliff, + .defines = reduce_op_utils::get_defines(reduce_op, reduce_dim)}; + auto compute_kernel_cliff = CreateKernel(program, compute_kernel_fname, core_range_cliff, compute_config); } // calculate and set the start/end h_i for each core // for all but last core (cliff) + const auto& cores = grid_to_cores(ncores, ncores_w, grid_size.y, true); uint32_t core_out_h_i = 0; uint32_t core_out_w_i = 0; - int32_t curr_start_h = - pad_h; + int32_t curr_start_h = -pad_h; if (out_nhw_per_core_cliff > 0) { // TODO? ... not yet handled - TT_ASSERT(false, "The cliff core case is not yet handled"); // TODO + TT_ASSERT(false, "The cliff core case is not yet handled"); // TODO } else { uint32_t core_batch_offset = 0; - uint32_t curr_out_stick_id = 0; // track output sticks with batch folded in - int32_t curr_in_stick_id = 0; // track input sticks with batch folded in + uint32_t curr_out_stick_id = 0; // track output sticks with batch folded in + int32_t curr_in_stick_id = 0; // track input sticks with batch folded in uint32_t core_out_w_i_start = 0; uint32_t core_out_h_i_start = 0; - for (int32_t i = 0; i < ncores; ++ i) { - CoreCoord core_coord(i % ncores_w, i / ncores_w); // logical + for (int32_t i = 0; i < ncores; ++i) { + const CoreCoord& core_coord = cores[i]; // logical reader_rt_args[37] = (curr_in_stick_id / in_hw) * in_hw; core_out_w_i_start = curr_out_stick_id % out_w; core_out_h_i_start = (curr_out_stick_id / out_w) % out_h; @@ -532,23 +587,23 @@ operation::ProgramWithCallbacks max_pool_2d_multi_core_generic(const Tensor &inp reader_rt_args[45] = curr_in_stick_id; reader_rt_args[46] = curr_in_stick_id + in_nhw_per_core; - reader_rt_args[64] = i; // my_core + reader_rt_args[64] = i; // my_core if (left_neighbor_core.count(core_coord) > 0) { CoreCoord left_core = left_neighbor_core.at(core_coord); CoreCoord left_noc = device->worker_core_from_logical_core(left_core); reader_rt_args[49] = 1; - reader_rt_args[50] = (uint32_t) left_noc.x; - reader_rt_args[51] = (uint32_t) left_noc.y; + reader_rt_args[50] = (uint32_t)left_noc.x; + reader_rt_args[51] = (uint32_t)left_noc.y; // left-left if (left_neighbor_core.count(left_core) > 0) { CoreCoord left_left_core = left_neighbor_core.at(left_core); CoreCoord left_left_noc = device->worker_core_from_logical_core(left_left_core); reader_rt_args[56] = 1; - reader_rt_args[57] = (uint32_t) left_left_noc.x; - reader_rt_args[58] = (uint32_t) left_left_noc.y; - reader_rt_args[62] = (uint32_t) (curr_in_stick_id - (int32_t) in_nhw_per_core); + reader_rt_args[57] = (uint32_t)left_left_noc.x; + reader_rt_args[58] = (uint32_t)left_left_noc.y; + reader_rt_args[62] = (uint32_t)(curr_in_stick_id - (int32_t)in_nhw_per_core); } else { reader_rt_args[56] = 0; } @@ -559,17 +614,17 @@ operation::ProgramWithCallbacks max_pool_2d_multi_core_generic(const Tensor &inp CoreCoord right_core = right_neighbor_core.at(core_coord); CoreCoord right_noc = device->worker_core_from_logical_core(right_core); reader_rt_args[52] = 1; - reader_rt_args[53] = (uint32_t) right_noc.x; - reader_rt_args[54] = (uint32_t) right_noc.y; + reader_rt_args[53] = (uint32_t)right_noc.x; + reader_rt_args[54] = (uint32_t)right_noc.y; // right-right if (right_neighbor_core.count(right_core) > 0) { CoreCoord right_right_core = right_neighbor_core.at(right_core); CoreCoord right_right_noc = device->worker_core_from_logical_core(right_right_core); reader_rt_args[59] = 1; - reader_rt_args[60] = (uint32_t) right_right_noc.x; - reader_rt_args[61] = (uint32_t) right_right_noc.y; - reader_rt_args[63] = (uint32_t) (curr_in_stick_id + 2 * in_nhw_per_core); + reader_rt_args[60] = (uint32_t)right_right_noc.x; + reader_rt_args[61] = (uint32_t)right_right_noc.y; + reader_rt_args[63] = (uint32_t)(curr_in_stick_id + 2 * in_nhw_per_core); } else { reader_rt_args[59] = 0; } @@ -587,62 +642,69 @@ operation::ProgramWithCallbacks max_pool_2d_multi_core_generic(const Tensor &inp } } - auto override_runtime_arguments_callback = [ - reader_kernel, writer_kernel, raw_in_cb, cb_sharded_out, ncores, ncores_w - ] - ( - const void* operation, - Program& program, - const std::vector& input_tensors, - const std::vector>& optional_input_tensors, - const std::vector& output_tensors - ) { - auto src_buffer = input_tensors.at(0).buffer(); - bool input_sharded = input_tensors.at(0).is_sharded(); - - auto dst_buffer = output_tensors.at(0).buffer(); - bool out_sharded = output_tensors.at(0).is_sharded(); - - for (uint32_t i = 0; i < ncores; ++ i) { - CoreCoord core{i % ncores_w, i / ncores_w }; - { - auto &runtime_args = GetRuntimeArgs(program, reader_kernel, core); - runtime_args[0] = src_buffer->address(); - runtime_args[1] = dst_buffer->address(); + auto override_runtime_arguments_callback = + [reader_kernel, writer_kernel, raw_in_cb, cb_sharded_out, cores]( + const void* operation, + Program& program, + const std::vector& input_tensors, + const std::vector>& optional_input_tensors, + const std::vector& output_tensors) { + auto src_buffer = input_tensors.at(0).buffer(); + bool input_sharded = input_tensors.at(0).is_sharded(); + + auto dst_buffer = output_tensors.at(0).buffer(); + bool out_sharded = output_tensors.at(0).is_sharded(); + + auto& reader_runtime_args_by_core = GetRuntimeArgs(program, reader_kernel); + auto& writer_runtime_args_by_core = GetRuntimeArgs(program, writer_kernel); + for (const auto& core : cores) { + { + auto& runtime_args = reader_runtime_args_by_core[core.x][core.y]; + runtime_args[0] = src_buffer->address(); + runtime_args[1] = dst_buffer->address(); + } + { + auto& runtime_args = writer_runtime_args_by_core[core.x][core.y]; + runtime_args[0] = src_buffer->address(); + runtime_args[1] = dst_buffer->address(); + } } - { - auto &runtime_args = GetRuntimeArgs(program, writer_kernel, core); - runtime_args[0] = src_buffer->address(); - runtime_args[1] = dst_buffer->address(); + if (input_sharded) { + UpdateDynamicCircularBufferAddress(program, raw_in_cb, *src_buffer); } - } - if (input_sharded) { - UpdateDynamicCircularBufferAddress(program, raw_in_cb, *src_buffer); - } - if (out_sharded) { - UpdateDynamicCircularBufferAddress(program, cb_sharded_out, *dst_buffer); - } - }; - return {.program=std::move(program), .override_runtime_arguments_callback=override_runtime_arguments_callback}; + if (out_sharded) { + UpdateDynamicCircularBufferAddress(program, cb_sharded_out, *dst_buffer); + } + }; + return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_arguments_callback}; } // this version uses distribution along height = N * H * W -operation::ProgramWithCallbacks max_pool_2d_multi_core_sharded_with_halo_v2_impl(Program& program, - const Tensor &input, const Tensor &reader_indices, - Tensor& output, - uint32_t in_n, uint32_t in_h, uint32_t in_w, - uint32_t out_h, uint32_t out_w, - uint32_t kernel_size_h, uint32_t kernel_size_w, - uint32_t stride_h, uint32_t stride_w, - uint32_t pad_h, uint32_t pad_w, - uint32_t dilation_h, uint32_t dilation_w, - const MemoryConfig& out_mem_config, - uint32_t nblocks) { +operation::ProgramWithCallbacks max_pool_2d_multi_core_sharded_with_halo_v2_impl( + Program& program, + const Tensor& input, + const Tensor& reader_indices, + Tensor& output, + uint32_t in_n, + uint32_t in_h, + uint32_t in_w, + uint32_t out_h, + uint32_t out_w, + uint32_t kernel_size_h, + uint32_t kernel_size_w, + uint32_t stride_h, + uint32_t stride_w, + uint32_t pad_h, + uint32_t pad_w, + uint32_t dilation_h, + uint32_t dilation_w, + const MemoryConfig& out_mem_config, + uint32_t nblocks) { // This should allocate a DRAM buffer on the device - Device *device = input.device(); - Buffer *src_dram_buffer = input.buffer(); - Buffer *reader_indices_buffer = reader_indices.buffer(); - Buffer *dst_dram_buffer = output.buffer(); + Device* device = input.device(); + Buffer* src_dram_buffer = input.buffer(); + Buffer* reader_indices_buffer = reader_indices.buffer(); + Buffer* dst_dram_buffer = output.buffer(); Shape input_shape = input.get_legacy_shape(); Shape output_shape = output.get_legacy_shape(); @@ -651,19 +713,20 @@ operation::ProgramWithCallbacks max_pool_2d_multi_core_sharded_with_halo_v2_impl DataFormat out_df = datatype_to_dataformat_converter(output.get_dtype()); uint32_t in_nbytes = datum_size(in_df); uint32_t out_nbytes = datum_size(out_df); - uint32_t in_nbytes_c = input_shape[3] * in_nbytes; // row of input (channels) - uint32_t out_nbytes_c = output_shape[3] * out_nbytes; // row of output (channels) - TT_ASSERT((in_nbytes_c & (in_nbytes_c - 1)) == 0, "in_nbytes_c should be power of 2"); // in_nbytes_c is power of 2 - TT_ASSERT((out_nbytes_c & (out_nbytes_c - 1)) == 0, "out_nbytes_c should be power of 2"); // out_nbytes_c is power of 2 + uint32_t in_nbytes_c = input_shape[3] * in_nbytes; // row of input (channels) + uint32_t out_nbytes_c = output_shape[3] * out_nbytes; // row of output (channels) + TT_ASSERT((in_nbytes_c & (in_nbytes_c - 1)) == 0, "in_nbytes_c should be power of 2"); // in_nbytes_c is power of 2 + TT_ASSERT( + (out_nbytes_c & (out_nbytes_c - 1)) == 0, "out_nbytes_c should be power of 2"); // out_nbytes_c is power of 2 - DataFormat indices_df = DataFormat::RawUInt16; //datatype_to_dataformat_converter(reader_indices.get_dtype()); + DataFormat indices_df = DataFormat::RawUInt16; // datatype_to_dataformat_converter(reader_indices.get_dtype()); uint32_t indices_nbytes = datum_size(indices_df); - uint32_t kernel_size_hw = kernel_size_w * kernel_size_h; // number of valid rows, to read + uint32_t kernel_size_hw = kernel_size_w * kernel_size_h; // number of valid rows, to read uint32_t kernel_size_hw_padded = ceil_multiple_of(kernel_size_hw, constants::TILE_HEIGHT); - uint32_t in_ntiles_hw = (uint32_t) ceil((float) kernel_size_hw_padded / constants::TILE_HEIGHT); - uint32_t in_ntiles_c = (uint32_t) ceil((float) input_shape[3] / constants::TILE_WIDTH); - uint32_t out_ntiles_c = (uint32_t) ceil((float) output_shape[3] / constants::TILE_WIDTH); + uint32_t in_ntiles_hw = (uint32_t)ceil((float)kernel_size_hw_padded / constants::TILE_HEIGHT); + uint32_t in_ntiles_c = (uint32_t)ceil((float)input_shape[3] / constants::TILE_WIDTH); + uint32_t out_ntiles_c = (uint32_t)ceil((float)output_shape[3] / constants::TILE_WIDTH); TT_ASSERT(nblocks == 1, "Multiple blocks not yet supported"); @@ -672,7 +735,7 @@ operation::ProgramWithCallbacks max_pool_2d_multi_core_sharded_with_halo_v2_impl TT_FATAL(input_shape[3] == 16); tile_w = constants::FACE_WIDTH; } - uint32_t out_w_loop_count = ceil((float) out_w / nblocks); + uint32_t out_w_loop_count = ceil((float)out_w / nblocks); // distributing out_hw across the grid auto grid_size = device->compute_with_storage_grid_size(); @@ -687,7 +750,11 @@ operation::ProgramWithCallbacks max_pool_2d_multi_core_sharded_with_halo_v2_impl uint32_t ncores_w = grid_size.x; // TODO: support generic nblocks - TT_ASSERT(out_nhw_per_core % nblocks == 0, "number of sticks per core ({}) should be divisible by nblocks ({})", out_nhw_per_core, nblocks); + TT_ASSERT( + out_nhw_per_core % nblocks == 0, + "number of sticks per core ({}) should be divisible by nblocks ({})", + out_nhw_per_core, + nblocks); // CBs uint32_t multi_buffering_factor = 2; @@ -698,10 +765,9 @@ operation::ProgramWithCallbacks max_pool_2d_multi_core_sharded_with_halo_v2_impl uint32_t in_scalar_cb_id = CB::c_in4; uint32_t in_scalar_cb_pagesize = tile_size(in_df); uint32_t in_scalar_cb_npages = 1; - CircularBufferConfig in_scalar_cb_config = CircularBufferConfig( - in_scalar_cb_npages * in_scalar_cb_pagesize, - {{in_scalar_cb_id, in_df}}) - .set_page_size(in_scalar_cb_id, in_scalar_cb_pagesize); + CircularBufferConfig in_scalar_cb_config = + CircularBufferConfig(in_scalar_cb_npages * in_scalar_cb_pagesize, {{in_scalar_cb_id, in_df}}) + .set_page_size(in_scalar_cb_id, in_scalar_cb_pagesize); auto in_scalar_cb = tt_metal::CreateCircularBuffer(program, all_cores, in_scalar_cb_config); log_debug(LogOp, "CB {} :: PS = {}, NP = {}", in_scalar_cb_id, in_scalar_cb_pagesize, in_scalar_cb_npages); @@ -710,41 +776,48 @@ operation::ProgramWithCallbacks max_pool_2d_multi_core_sharded_with_halo_v2_impl auto raw_in_cb_id = CB::c_in2; uint32_t raw_in_cb_npages = input.shard_spec().value().shape[0]; uint32_t raw_in_cb_pagesize = in_nbytes_c; - CircularBufferConfig raw_in_cb_config = CircularBufferConfig( - raw_in_cb_npages * raw_in_cb_pagesize, - {{raw_in_cb_id, in_df}}) - .set_page_size(raw_in_cb_id, raw_in_cb_pagesize) - .set_globally_allocated_address(*input.buffer()); + CircularBufferConfig raw_in_cb_config = + CircularBufferConfig(raw_in_cb_npages * raw_in_cb_pagesize, {{raw_in_cb_id, in_df}}) + .set_page_size(raw_in_cb_id, raw_in_cb_pagesize) + .set_globally_allocated_address(*input.buffer()); auto raw_in_cb = CreateCircularBuffer(program, all_cores, raw_in_cb_config); log_debug(LogOp, "CB {} :: PS = {}, NP = {}", raw_in_cb_id, raw_in_cb_pagesize, raw_in_cb_npages); // reader indices auto in_reader_indices_cb_id = CB::c_in3; - uint32_t in_reader_indices_cb_pagesize = round_up(out_nhw_per_core * indices_nbytes, 4); // pagesize needs to be multiple of 4 + uint32_t in_reader_indices_cb_pagesize = + round_up(out_nhw_per_core * indices_nbytes, 4); // pagesize needs to be multiple of 4 uint32_t in_reader_indices_cb_npages = 1; - log_debug(LogOp, "CB {} :: PS = {}, NP = {}", in_reader_indices_cb_id, in_reader_indices_cb_pagesize, in_reader_indices_cb_npages); - CircularBufferConfig in_reader_indices_cb_config = CircularBufferConfig( - in_reader_indices_cb_npages * in_reader_indices_cb_pagesize, - {{in_reader_indices_cb_id, indices_df}}) - .set_page_size(in_reader_indices_cb_id, in_reader_indices_cb_pagesize) - .set_globally_allocated_address(*reader_indices_buffer); + log_debug( + LogOp, + "CB {} :: PS = {}, NP = {}", + in_reader_indices_cb_id, + in_reader_indices_cb_pagesize, + in_reader_indices_cb_npages); + CircularBufferConfig in_reader_indices_cb_config = + CircularBufferConfig( + in_reader_indices_cb_npages * in_reader_indices_cb_pagesize, {{in_reader_indices_cb_id, indices_df}}) + .set_page_size(in_reader_indices_cb_id, in_reader_indices_cb_pagesize) + .set_globally_allocated_address(*reader_indices_buffer); auto in_reader_indices_cb = CreateCircularBuffer(program, all_cores, in_reader_indices_cb_config); // reader output == input to tilize - uint32_t in_cb_id_0 = CB::c_in0; // input rows for "multiple (out_nelems)" output pixels - uint32_t in_cb_id_1 = CB::c_in1; // input rows for "multiple (out_nelems)" output pixels - uint32_t in_cb_page_padded = ceil_multiple_of(input_shape[3] * kernel_size_hw_padded, constants::TILE_HW); // NOTE: ceil to tile size since triscs work with tilesize instead of pagesize + uint32_t in_cb_id_0 = CB::c_in0; // input rows for "multiple (out_nelems)" output pixels + uint32_t in_cb_id_1 = CB::c_in1; // input rows for "multiple (out_nelems)" output pixels + uint32_t in_cb_page_padded = ceil_multiple_of( + input_shape[3] * kernel_size_hw_padded, + constants::TILE_HW); // NOTE: ceil to tile size since triscs work with tilesize instead of pagesize uint32_t in_cb_pagesize = in_nbytes * in_cb_page_padded; uint32_t in_cb_npages = multi_buffering_factor * nblocks; CircularBufferConfig in_cb_config_0 = CircularBufferConfig(in_cb_npages * in_cb_pagesize, {{in_cb_id_0, in_df}}) - .set_page_size(in_cb_id_0, in_cb_pagesize); + .set_page_size(in_cb_id_0, in_cb_pagesize); auto in_cb_0 = tt_metal::CreateCircularBuffer(program, all_cores, in_cb_config_0); log_debug(LogOp, "CB {} :: PS = {}, NP = {}", in_cb_id_0, in_cb_pagesize, in_cb_npages); if (split_reader) { CircularBufferConfig in_cb_config_1 = CircularBufferConfig(in_cb_npages * in_cb_pagesize, {{in_cb_id_1, in_df}}) - .set_page_size(in_cb_id_1, in_cb_pagesize); + .set_page_size(in_cb_id_1, in_cb_pagesize); auto in_cb_1 = tt_metal::CreateCircularBuffer(program, all_cores, in_cb_config_1); log_debug(LogOp, "CB {} :: PS = {}, NP = {}", in_cb_id_1, in_cb_pagesize, in_cb_npages); } @@ -753,19 +826,24 @@ operation::ProgramWithCallbacks max_pool_2d_multi_core_sharded_with_halo_v2_impl uint32_t in_tiled_cb_id = CB::c_intermed0; // tiled input uint32_t in_tiled_cb_pagesize = tile_size(in_df); uint32_t in_tiled_cb_npages = in_ntiles_c * in_ntiles_hw * nblocks; - CircularBufferConfig in_tiled_cb_config = CircularBufferConfig(in_tiled_cb_npages * in_tiled_cb_pagesize, {{in_tiled_cb_id, in_df}}) - .set_page_size(in_tiled_cb_id, in_tiled_cb_pagesize); + CircularBufferConfig in_tiled_cb_config = + CircularBufferConfig(in_tiled_cb_npages * in_tiled_cb_pagesize, {{in_tiled_cb_id, in_df}}) + .set_page_size(in_tiled_cb_id, in_tiled_cb_pagesize); auto in_tiled_cb = tt_metal::CreateCircularBuffer(program, all_cores, in_tiled_cb_config); log_debug(LogOp, "CB {} :: PS = {}, NP = {}", in_tiled_cb_id, in_tiled_cb_pagesize, in_tiled_cb_npages); // output of reduce == writer to write - uint32_t out_cb_id = CB::c_out0; // output rows in RM - //uint32_t out_cb_pagesize = tile_size(out_df); - //uint32_t out_cb_npages = out_ntiles_c * nblocks * multi_buffering_factor; // there is just one row of channels after reduction - uint32_t out_cb_pagesize = output.shard_spec().value().shape[1] * out_nbytes; // there is just one row of channels after reduction + uint32_t out_cb_id = CB::c_out0; // output rows in RM + // uint32_t out_cb_pagesize = tile_size(out_df); + // uint32_t out_cb_npages = out_ntiles_c * nblocks * multi_buffering_factor; // there is just one row of channels + // after reduction + uint32_t out_cb_pagesize = + output.shard_spec().value().shape[1] * out_nbytes; // there is just one row of channels after reduction uint32_t out_cb_npages = output.shard_spec().value().shape[0]; CircularBufferConfig cb_out_config = CircularBufferConfig(out_cb_npages * out_cb_pagesize, {{out_cb_id, out_df}}) - .set_page_size(out_cb_id, out_cb_pagesize).set_globally_allocated_address(*output.buffer());; + .set_page_size(out_cb_id, out_cb_pagesize) + .set_globally_allocated_address(*output.buffer()); + ; auto cb_out = tt_metal::CreateCircularBuffer(program, all_cores, cb_out_config); log_debug(LogOp, "CB {} :: PS = {}, NP = {}", out_cb_id, out_cb_pagesize, out_cb_npages); @@ -774,24 +852,29 @@ operation::ProgramWithCallbacks max_pool_2d_multi_core_sharded_with_halo_v2_impl auto shard_shape = output.shard_spec().value().shape; uint32_t sharded_out_num_pages = output.shard_spec().value().shape[0]; uint32_t sharded_out_cb_id = CB::c_out1; // output rows in RM - uint32_t sharded_out_cb_page_size = output.shard_spec().value().shape[1] * out_nbytes; // there is just one row of channels after reduction - CircularBufferConfig cb_sharded_out_config = CircularBufferConfig(sharded_out_num_pages * sharded_out_cb_page_size, {{sharded_out_cb_id, out_df}}) - .set_page_size(sharded_out_cb_id, sharded_out_cb_page_size).set_globally_allocated_address(*output.buffer()); - auto cb_sharded_out = tt_metal::CreateCircularBuffer(program, all_cores, cb_sharded_out_config); - log_debug(LogOp, "CB {} :: PS = {}, NP = {}", sharded_out_cb_id, sharded_out_cb_page_size, sharded_out_num_pages); + uint32_t sharded_out_cb_page_size = output.shard_spec().value().shape[1] * out_nbytes; // there is just one row + of channels after reduction CircularBufferConfig cb_sharded_out_config = CircularBufferConfig(sharded_out_num_pages + * sharded_out_cb_page_size, {{sharded_out_cb_id, out_df}}) .set_page_size(sharded_out_cb_id, + sharded_out_cb_page_size).set_globally_allocated_address(*output.buffer()); auto cb_sharded_out = + tt_metal::CreateCircularBuffer(program, all_cores, cb_sharded_out_config); log_debug(LogOp, "CB {} :: PS = {}, NP = + {}", sharded_out_cb_id, sharded_out_cb_page_size, sharded_out_num_pages); */ - #if 1 - { // debug - //log_debug(LogOp, "OUTPUT SHARD: {} {}", shard_shape[0], shard_shape[1]); - //log_debug(LogOp, "OUTPUT CB: {} {}", sharded_out_cb_page_size, sharded_out_num_pages); +#if 1 + { // debug + // log_debug(LogOp, "OUTPUT SHARD: {} {}", shard_shape[0], shard_shape[1]); + // log_debug(LogOp, "OUTPUT CB: {} {}", sharded_out_cb_page_size, sharded_out_num_pages); log_debug(LogOp, "raw_in_cb :: PS = {}, NP = {}", raw_in_cb_pagesize, raw_in_cb_npages); log_debug(LogOp, "in_cb :: PS = {}, NP = {}", in_cb_pagesize, in_cb_npages); - log_debug(LogOp, "in_reader_indices_cb :: PS = {}, NP = {}", in_reader_indices_cb_pagesize, in_reader_indices_cb_npages); + log_debug( + LogOp, + "in_reader_indices_cb :: PS = {}, NP = {}", + in_reader_indices_cb_pagesize, + in_reader_indices_cb_npages); log_debug(LogOp, "in_scalar_cb :: PS = {}, NP = {}", in_scalar_cb_pagesize, in_scalar_cb_npages); log_debug(LogOp, "in_tiled_cb :: PS = {}, NP = {}", in_tiled_cb_pagesize, in_tiled_cb_npages); log_debug(LogOp, "out_cb :: PS = {}, NP = {}", out_cb_pagesize, out_cb_npages); - //log_debug(LogOp, "sharded_out_cb :: PS = {}, NP = {}", sharded_out_cb_page_size, sharded_out_num_pages); + // log_debug(LogOp, "sharded_out_cb :: PS = {}, NP = {}", sharded_out_cb_page_size, sharded_out_num_pages); log_debug(LogOp, "in_addr: {}", src_dram_buffer->address()); log_debug(LogOp, "in_reader_indices_addr: {}", reader_indices_buffer->address()); log_debug(LogOp, "out_addr: {}", dst_dram_buffer->address()); @@ -820,65 +903,55 @@ operation::ProgramWithCallbacks max_pool_2d_multi_core_sharded_with_halo_v2_impl log_debug(LogOp, "is_in_sharded: {}", input.memory_config().is_sharded()); log_debug(LogOp, "is_out_sharded: {}", output.memory_config().is_sharded()); } - #endif +#endif /** * Reader Kernel: input rows -> input cb */ float one = 1.; uint32_t bf16_one_u32 = *reinterpret_cast(&one); - uint32_t in_nbytes_c_log2 = (uint32_t) std::log2((float) in_nbytes_c); + uint32_t in_nbytes_c_log2 = (uint32_t)std::log2((float)in_nbytes_c); std::vector reader0_ct_args = { - out_nhw_per_core, - kernel_size_h, - kernel_size_w, - pad_w, - in_nbytes_c, - in_nbytes_c_log2, - in_w, - in_cb_page_padded * in_cb_npages / tile_w, - input_shape[3], - nblocks, - split_reader, // enable split reader - 0, // split reader id - bf16_one_u32 - }; + out_nhw_per_core, + kernel_size_h, + kernel_size_w, + pad_w, + in_nbytes_c, + in_nbytes_c_log2, + in_w, + in_cb_page_padded * in_cb_npages / tile_w, + input_shape[3], + nblocks, + split_reader, // enable split reader + 0, // split reader id + bf16_one_u32}; std::vector reader1_ct_args = { - out_nhw_per_core, - kernel_size_h, - kernel_size_w, - pad_w, - in_nbytes_c, - in_nbytes_c_log2, - in_w, - in_cb_page_padded * in_cb_npages / tile_w, - input_shape[3], - nblocks, - split_reader, // enable split reader - 1, // split reader id - bf16_one_u32 - }; - + out_nhw_per_core, + kernel_size_h, + kernel_size_w, + pad_w, + in_nbytes_c, + in_nbytes_c_log2, + in_w, + in_cb_page_padded * in_cb_npages / tile_w, + input_shape[3], + nblocks, + split_reader, // enable split reader + 1, // split reader id + bf16_one_u32}; - std::string reader_kernel_fname("tt_eager/tt_dnn/op_library/pool/kernels/dataflow/reader_max_pool_2d_multi_core_sharded_with_halo_v2.cpp"); + std::string reader_kernel_fname( + "tt_eager/tt_dnn/op_library/pool/kernels/dataflow/reader_max_pool_2d_multi_core_sharded_with_halo_v2.cpp"); - auto reader0_config = DataMovementConfig{.processor = DataMovementProcessor::RISCV_0, - .noc = NOC::RISCV_0_default, - .compile_args = reader0_ct_args}; + auto reader0_config = DataMovementConfig{ + .processor = DataMovementProcessor::RISCV_0, .noc = NOC::RISCV_0_default, .compile_args = reader0_ct_args}; - auto reader0_kernel = CreateKernel(program, - reader_kernel_fname, - all_cores, - reader0_config); + auto reader0_kernel = CreateKernel(program, reader_kernel_fname, all_cores, reader0_config); - auto reader1_config = DataMovementConfig{.processor = DataMovementProcessor::RISCV_1, - .noc = NOC::RISCV_1_default, - .compile_args = reader1_ct_args}; - auto reader1_kernel = split_reader ? CreateKernel(program, - reader_kernel_fname, - all_cores, - reader1_config) : 0; + auto reader1_config = DataMovementConfig{ + .processor = DataMovementProcessor::RISCV_1, .noc = NOC::RISCV_1_default, .compile_args = reader1_ct_args}; + auto reader1_kernel = split_reader ? CreateKernel(program, reader_kernel_fname, all_cores, reader1_config) : 0; /** * Writer Kernel: output cb -> output rows */ @@ -899,15 +972,14 @@ operation::ProgramWithCallbacks max_pool_2d_multi_core_sharded_with_halo_v2_impl .noc = NOC::RISCV_1_default, .compile_args = writer_ct_args, .defines = writer_defines}; - std::string writer_kernel_fname("tt_eager/tt_dnn/op_library/pool/kernels/dataflow/writer_max_pool_2d_multi_core_v2.cpp"); - auto writer_kernel = CreateKernel(program, - writer_kernel_fname, - all_cores, - writer_config); + std::string + writer_kernel_fname("tt_eager/tt_dnn/op_library/pool/kernels/dataflow/writer_max_pool_2d_multi_core_v2.cpp"); auto + writer_kernel = CreateKernel(program, writer_kernel_fname, all_cores, writer_config); */ /** - * Compute Kernel: input cb -> tilize_block -> input tiles -> reduce_h max -> output tiles -> untilize_block -> output cb + * Compute Kernel: input cb -> tilize_block -> input tiles -> reduce_h max -> output tiles -> untilize_block -> + * output cb */ std::vector compute_ct_args = { in_ntiles_hw, @@ -929,16 +1001,14 @@ operation::ProgramWithCallbacks max_pool_2d_multi_core_sharded_with_halo_v2_impl auto compute_ct_args_cliff = compute_ct_args; auto reduce_op = ReduceOpMath::MAX; auto reduce_dim = ReduceOpDim::H; - auto compute_config = ComputeConfig{.math_fidelity = MathFidelity::HiFi4, - .fp32_dest_acc_en = false, - .math_approx_mode = false, - .compile_args = compute_ct_args, - .defines = reduce_op_utils::get_defines(reduce_op, reduce_dim)}; + auto compute_config = ComputeConfig{ + .math_fidelity = MathFidelity::HiFi4, + .fp32_dest_acc_en = false, + .math_approx_mode = false, + .compile_args = compute_ct_args, + .defines = reduce_op_utils::get_defines(reduce_op, reduce_dim)}; std::string compute_kernel_fname("tt_eager/tt_dnn/op_library/pool/kernels/compute/max_pool_multi_core.cpp"); - auto compute_kernel = CreateKernel(program, - compute_kernel_fname, - core_range, - compute_config); + auto compute_kernel = CreateKernel(program, compute_kernel_fname, core_range, compute_config); /* uint32_t curr_out_stick_id = 0; // track output sticks with batch folded in @@ -950,62 +1020,95 @@ operation::ProgramWithCallbacks max_pool_2d_multi_core_sharded_with_halo_v2_impl } */ - auto override_runtime_arguments_callback = [ - //reader_kernel, writer_kernel, raw_in_cb, in_reader_indices_cb, cb_sharded_out, ncores, ncores_w - reader0_kernel, reader1_kernel, raw_in_cb, in_reader_indices_cb, cb_out, ncores, ncores_w - ] - ( - const void* operation, - Program& program, - const std::vector& input_tensors, - const std::vector>& optional_input_tensors, - const std::vector& output_tensors - ) { - auto src_buffer = input_tensors.at(0).buffer(); - bool input_sharded = input_tensors.at(0).is_sharded(); - auto reader_indices_buffer = input_tensors.at(1).buffer(); - - auto dst_buffer = output_tensors.at(0).buffer(); - bool out_sharded = output_tensors.at(0).is_sharded(); - - /* - for (uint32_t i = 0; i < ncores; ++ i) { - CoreCoord core{i % ncores_w, i / ncores_w }; - { - auto &runtime_args = GetRuntimeArgs(program, writer_kernel, core); - runtime_args[0] = dst_buffer->address(); + auto override_runtime_arguments_callback = + [ + // reader_kernel, writer_kernel, raw_in_cb, in_reader_indices_cb, cb_sharded_out, ncores, ncores_w + reader0_kernel, + reader1_kernel, + raw_in_cb, + in_reader_indices_cb, + cb_out, + ncores, + ncores_w]( + const void* operation, + Program& program, + const std::vector& input_tensors, + const std::vector>& optional_input_tensors, + const std::vector& output_tensors) { + auto src_buffer = input_tensors.at(0).buffer(); + bool input_sharded = input_tensors.at(0).is_sharded(); + auto reader_indices_buffer = input_tensors.at(1).buffer(); + + auto dst_buffer = output_tensors.at(0).buffer(); + bool out_sharded = output_tensors.at(0).is_sharded(); + + /* + for (uint32_t i = 0; i < ncores; ++ i) { + CoreCoord core{i % ncores_w, i / ncores_w }; + { + auto &runtime_args = GetRuntimeArgs(program, writer_kernel, core); + runtime_args[0] = dst_buffer->address(); + } } - } - */ - if (input_sharded) { - UpdateDynamicCircularBufferAddress(program, raw_in_cb, *src_buffer); - UpdateDynamicCircularBufferAddress(program, in_reader_indices_cb, *reader_indices_buffer); - } - if (out_sharded) { - UpdateDynamicCircularBufferAddress(program, cb_out, *dst_buffer); - } - }; - return {.program=std::move(program), .override_runtime_arguments_callback=override_runtime_arguments_callback}; + */ + if (input_sharded) { + UpdateDynamicCircularBufferAddress(program, raw_in_cb, *src_buffer); + UpdateDynamicCircularBufferAddress(program, in_reader_indices_cb, *reader_indices_buffer); + } + if (out_sharded) { + UpdateDynamicCircularBufferAddress(program, cb_out, *dst_buffer); + } + }; + return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_arguments_callback}; } -operation::ProgramWithCallbacks max_pool_2d_multi_core_sharded_with_halo_v2(const Tensor &input, const Tensor &reader_indices, - Tensor& output, - uint32_t in_n, uint32_t in_h, uint32_t in_w, - uint32_t out_h, uint32_t out_w, - uint32_t kernel_size_h, uint32_t kernel_size_w, - uint32_t stride_h, uint32_t stride_w, - uint32_t pad_h, uint32_t pad_w, - uint32_t dilation_h, uint32_t dilation_w, - const MemoryConfig& out_mem_config, - uint32_t nblocks) { +operation::ProgramWithCallbacks max_pool_2d_multi_core_sharded_with_halo_v2( + const Tensor& input, + const Tensor& reader_indices, + Tensor& output, + uint32_t in_n, + uint32_t in_h, + uint32_t in_w, + uint32_t out_h, + uint32_t out_w, + uint32_t kernel_size_h, + uint32_t kernel_size_w, + uint32_t stride_h, + uint32_t stride_w, + uint32_t pad_h, + uint32_t pad_w, + uint32_t dilation_h, + uint32_t dilation_w, + const MemoryConfig& out_mem_config, + uint32_t nblocks) { Program program = CreateProgram(); - return max_pool_2d_multi_core_sharded_with_halo_v2_impl(program, input, reader_indices, output, in_n, in_h, in_w, out_h, out_w, kernel_size_h, kernel_size_w, stride_h, stride_w, pad_h, pad_w, dilation_h, dilation_w, out_mem_config, nblocks); + return max_pool_2d_multi_core_sharded_with_halo_v2_impl( + program, + input, + reader_indices, + output, + in_n, + in_h, + in_w, + out_h, + out_w, + kernel_size_h, + kernel_size_w, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, + out_mem_config, + nblocks); } -operation::ProgramWithCallbacks max_pool_2d_multi_core_sharded_with_halo_v2_new(const Tensor &input, - Tensor& output, - const SlidingWindowConfig& sliding_window_config, - const MemoryConfig& out_mem_config) { +operation::ProgramWithCallbacks max_pool_2d_multi_core_sharded_with_halo_v2_new( + const Tensor& input, + Tensor& output, + const SlidingWindowConfig& sliding_window_config, + const MemoryConfig& out_mem_config) { Program program = CreateProgram(); ParallelConfig parallel_config = ParallelConfig{ @@ -1023,9 +1126,12 @@ operation::ProgramWithCallbacks max_pool_2d_multi_core_sharded_with_halo_v2_new( auto pad_metadata = sliding_window::generate_pad_metadata(sliding_window_config); auto op_trace_metadata = sliding_window::generate_op_trace_metadata(sliding_window_config); auto shard_boundaries = sliding_window::generate_shard_boundaries(sliding_window_config, op_trace_metadata); - auto top_left_indices = sliding_window::generate_sliding_window_op_config(op_trace_metadata, shard_boundaries, false, false); - auto reader_indices = sliding_window::construct_on_host_config_tensor(top_left_indices, sliding_window_config, parallel_config); - auto reader_indices_on_device = sliding_window::move_config_tensor_to_device(reader_indices, parallel_config, is_block_sharded, input.device()); + auto top_left_indices = + sliding_window::generate_sliding_window_op_config(op_trace_metadata, shard_boundaries, false, false); + auto reader_indices = + sliding_window::construct_on_host_config_tensor(top_left_indices, sliding_window_config, parallel_config); + auto reader_indices_on_device = + sliding_window::move_config_tensor_to_device(reader_indices, parallel_config, is_block_sharded, input.device()); detail::AddConfigBuffer(program, reader_indices_on_device.device_buffer()); @@ -1041,8 +1147,27 @@ operation::ProgramWithCallbacks max_pool_2d_multi_core_sharded_with_halo_v2_new( auto dilation_h = sliding_window_config.dilation_hw_.first; auto dilation_w = sliding_window_config.dilation_hw_.second; - return max_pool_2d_multi_core_sharded_with_halo_v2_impl(program, input, reader_indices_on_device, output, in_n, in_h, in_w, out_h, out_w, kernel_size_h, kernel_size_w, stride_h, stride_w, pad_h, pad_w, dilation_h, dilation_w, out_mem_config, 1); + return max_pool_2d_multi_core_sharded_with_halo_v2_impl( + program, + input, + reader_indices_on_device, + output, + in_n, + in_h, + in_w, + out_h, + out_w, + kernel_size_h, + kernel_size_w, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, + out_mem_config, + 1); } -} // namespace tt_metal -} // namespace tt +} // namespace tt_metal +} // namespace tt diff --git a/tt_eager/tt_dnn/op_library/reduce/multi_core_h/reduce_op_multi_core_h.cpp b/tt_eager/tt_dnn/op_library/reduce/multi_core_h/reduce_op_multi_core_h.cpp index 73f7c5deaea..4f295f37600 100644 --- a/tt_eager/tt_dnn/op_library/reduce/multi_core_h/reduce_op_multi_core_h.cpp +++ b/tt_eager/tt_dnn/op_library/reduce/multi_core_h/reduce_op_multi_core_h.cpp @@ -3,12 +3,12 @@ // SPDX-License-Identifier: Apache-2.0 #include + #include "tt_dnn/op_library/reduce/reduce_op.hpp" #include "tt_dnn/op_library/work_split.hpp" - -#include "tt_metal/host_api.hpp" #include "tt_metal/common/constants.hpp" #include "tt_metal/detail/util.hpp" +#include "tt_metal/host_api.hpp" using namespace tt::constants; using uint32_t = std::uint32_t; @@ -17,13 +17,13 @@ namespace tt { namespace tt_metal { -operation::ProgramWithCallbacks reduce_multi_core_h(const Tensor &a, Tensor& output, ReduceOpMath reduce_op, float scaler) { - +operation::ProgramWithCallbacks reduce_multi_core_h( + const Tensor &a, Tensor &output, ReduceOpMath reduce_op, float scaler) { const auto shape = a.get_legacy_shape(); - uint32_t W = shape[3], H = shape[2], NC = shape[1]*shape[0]; + uint32_t W = shape[3], H = shape[2], NC = shape[1] * shape[0]; - uint32_t Wt = W/TILE_WIDTH; - uint32_t Ht = H/TILE_HEIGHT; + uint32_t Wt = W / TILE_WIDTH; + uint32_t Ht = H / TILE_HEIGHT; uint32_t HtWt = Ht * Wt; tt_metal::Program program = tt_metal::CreateProgram(); @@ -35,7 +35,7 @@ operation::ProgramWithCallbacks reduce_multi_core_h(const Tensor &a, Tensor& out tt::DataFormat dst_cb_data_format = tt_metal::datatype_to_dataformat_converter(output.get_dtype()); uint32_t dst_single_tile_size = tt_metal::detail::TileSize(dst_cb_data_format); - uint32_t num_tiles = a.volume()/TILE_HW; + uint32_t num_tiles = a.volume() / TILE_HW; tt_metal::Device *device = a.device(); @@ -45,7 +45,8 @@ operation::ProgramWithCallbacks reduce_multi_core_h(const Tensor &a, Tensor& out uint32_t num_cores_x = compute_with_storage_grid_size.x; uint32_t num_cores_y = compute_with_storage_grid_size.y; auto num_cols = NC * Wt; - auto [num_cores, all_cores, core_group_1, core_group_2, num_cols_per_core_group_1, num_cols_per_core_group_2] = split_work_to_cores(compute_with_storage_grid_size, num_cols); + auto [num_cores, all_cores, core_group_1, core_group_2, num_cols_per_core_group_1, num_cols_per_core_group_2] = + split_work_to_cores(compute_with_storage_grid_size, num_cols); // Current sharding only supports width, and that input and output are sharded if (in_sharded) { @@ -64,36 +65,50 @@ operation::ProgramWithCallbacks reduce_multi_core_h(const Tensor &a, Tensor& out if (in_sharded) { uint32_t num_shard_tiles = a.shard_spec().value().numel() / TILE_HW; uint32_t num_input_tiles = 2; - tt_metal::CircularBufferConfig cb_src0_config = tt_metal::CircularBufferConfig(num_input_tiles * src0_single_tile_size, {{src0_cb_index, src0_cb_data_format}}) - .set_page_size(src0_cb_index, src0_single_tile_size); + tt_metal::CircularBufferConfig cb_src0_config = + tt_metal::CircularBufferConfig( + num_input_tiles * src0_single_tile_size, {{src0_cb_index, src0_cb_data_format}}) + .set_page_size(src0_cb_index, src0_single_tile_size); cb_src0 = tt_metal::CreateCircularBuffer(program, all_cores, cb_src0_config); - tt_metal::CircularBufferConfig cb_src1_config = tt_metal::CircularBufferConfig(num_shard_tiles * src0_single_tile_size, {{src1_cb_index, src0_cb_data_format}}) - .set_page_size(src1_cb_index, src0_single_tile_size).set_globally_allocated_address(*a.buffer()); + tt_metal::CircularBufferConfig cb_src1_config = + tt_metal::CircularBufferConfig( + num_shard_tiles * src0_single_tile_size, {{src1_cb_index, src0_cb_data_format}}) + .set_page_size(src1_cb_index, src0_single_tile_size) + .set_globally_allocated_address(*a.buffer()); cb_src1 = tt_metal::CreateCircularBuffer(program, all_cores, cb_src1_config); } else { uint32_t num_input_tiles = 2; - tt_metal::CircularBufferConfig cb_src0_config = tt_metal::CircularBufferConfig(num_input_tiles * src0_single_tile_size, {{src0_cb_index, src0_cb_data_format}}) - .set_page_size(src0_cb_index, src0_single_tile_size); + tt_metal::CircularBufferConfig cb_src0_config = + tt_metal::CircularBufferConfig( + num_input_tiles * src0_single_tile_size, {{src0_cb_index, src0_cb_data_format}}) + .set_page_size(src0_cb_index, src0_single_tile_size); cb_src0 = tt_metal::CreateCircularBuffer(program, all_cores, cb_src0_config); } uint32_t scaler_cb_index = CB::c_in2; - tt_metal::CircularBufferConfig cb_scaler_config = tt_metal::CircularBufferConfig(1 * scaler_single_tile_size, {{scaler_cb_index, scaler_cb_data_format}}) - .set_page_size(scaler_cb_index, scaler_single_tile_size); + tt_metal::CircularBufferConfig cb_scaler_config = + tt_metal::CircularBufferConfig(1 * scaler_single_tile_size, {{scaler_cb_index, scaler_cb_data_format}}) + .set_page_size(scaler_cb_index, scaler_single_tile_size); auto cb_scaler = tt_metal::CreateCircularBuffer(program, all_cores, cb_scaler_config); - uint32_t output_cb_index = CB::c_out0; // output operands start at index 16 + uint32_t output_cb_index = CB::c_out0; // output operands start at index 16 CBHandle cb_output; if (out_sharded) { uint32_t num_output_tiles = output.shard_spec().value().numel() / TILE_HW; - tt_metal::CircularBufferConfig cb_output_config = tt_metal::CircularBufferConfig(num_output_tiles * dst_single_tile_size, {{output_cb_index, dst_cb_data_format}}) - .set_page_size(output_cb_index, dst_single_tile_size).set_globally_allocated_address(*output.buffer());; + tt_metal::CircularBufferConfig cb_output_config = + tt_metal::CircularBufferConfig( + num_output_tiles * dst_single_tile_size, {{output_cb_index, dst_cb_data_format}}) + .set_page_size(output_cb_index, dst_single_tile_size) + .set_globally_allocated_address(*output.buffer()); + ; cb_output = tt_metal::CreateCircularBuffer(program, all_cores, cb_output_config); } else { uint32_t num_output_tiles = 2; - tt_metal::CircularBufferConfig cb_output_config = tt_metal::CircularBufferConfig(num_output_tiles * dst_single_tile_size, {{output_cb_index, dst_cb_data_format}}) - .set_page_size(output_cb_index, dst_single_tile_size); + tt_metal::CircularBufferConfig cb_output_config = + tt_metal::CircularBufferConfig( + num_output_tiles * dst_single_tile_size, {{output_cb_index, dst_cb_data_format}}) + .set_page_size(output_cb_index, dst_single_tile_size); cb_output = tt_metal::CreateCircularBuffer(program, all_cores, cb_output_config); } tt_metal::Buffer *src0_buffer = a.buffer(); @@ -101,33 +116,26 @@ operation::ProgramWithCallbacks reduce_multi_core_h(const Tensor &a, Tensor& out bfloat16 bfloat_scaler_value = bfloat16(scaler); uint32_t packed_scaler_value = pack_two_bfloat16_into_uint32({bfloat_scaler_value, bfloat_scaler_value}); if (in_sharded) { - std::vector reader_compile_time_args = { - src0_cb_index, - src1_cb_index, - scaler_cb_index - }; + std::vector reader_compile_time_args = {src0_cb_index, src1_cb_index, scaler_cb_index}; std::map reader_defines; reader_defines["REDUCE_SCALER"] = "1"; reader_kernel_id = tt_metal::CreateKernel( program, - "tt_eager/tt_dnn/op_library/reduce/kernels/dataflow/reader_unary_transpose_wh_interleaved_input_cols_partitioned_sharded.cpp", + "tt_eager/tt_dnn/op_library/reduce/kernels/dataflow/" + "reader_unary_transpose_wh_interleaved_input_cols_partitioned_sharded.cpp", all_cores, tt_metal::ReaderDataMovementConfig(reader_compile_time_args, reader_defines)); } else { bool src0_is_dram = src0_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; std::vector reader_compile_time_args = { - (std::uint32_t) src0_is_dram, - Ht, - Wt, - HtWt, - packed_scaler_value - }; + (std::uint32_t)src0_is_dram, Ht, Wt, HtWt, packed_scaler_value}; std::map reader_defines; reader_defines["REDUCE_SCALER"] = "1"; reader_kernel_id = tt_metal::CreateKernel( program, - "tt_eager/tt_dnn/op_library/reduce/kernels/dataflow/reader_unary_transpose_wh_interleaved_input_cols_partitioned.cpp", + "tt_eager/tt_dnn/op_library/reduce/kernels/dataflow/" + "reader_unary_transpose_wh_interleaved_input_cols_partitioned.cpp", all_cores, tt_metal::ReaderDataMovementConfig(reader_compile_time_args, reader_defines)); } @@ -146,10 +154,7 @@ operation::ProgramWithCallbacks reduce_multi_core_h(const Tensor &a, Tensor& out WriterDataMovementConfig(writer_ct_args)); } else { bool dst_is_dram = dst_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; - std::vector writer_compile_time_args = { - (std::uint32_t) output_cb_index, - (std::uint32_t) dst_is_dram - }; + std::vector writer_compile_time_args = {(std::uint32_t)output_cb_index, (std::uint32_t)dst_is_dram}; writer_kernel_id = tt_metal::CreateKernel( program, @@ -159,65 +164,46 @@ operation::ProgramWithCallbacks reduce_multi_core_h(const Tensor &a, Tensor& out } std::map reduce_defines = reduce_op_utils::get_defines(reduce_op, ReduceOpDim::H); vector compute_kernel_args_group_1 = { - Ht, // Ht - num_cols_per_core_group_1, // Wt - 1, // NC + Ht, // Ht + num_cols_per_core_group_1, // Wt + 1, // NC }; auto reduce_compute_kernel_group_1_id = tt_metal::CreateKernel( program, "tt_eager/tt_dnn/op_library/reduce/kernels/compute/reduce_h.cpp", core_group_1, - tt_metal::ComputeConfig{.compile_args = compute_kernel_args_group_1, .defines = reduce_defines} - ); + tt_metal::ComputeConfig{.compile_args = compute_kernel_args_group_1, .defines = reduce_defines}); - if(!core_group_2.ranges().empty()){ + if (!core_group_2.ranges().empty()) { vector compute_kernel_args_group_2 = { - Ht, // Ht - num_cols_per_core_group_2, // Wt - 1, // NC + Ht, // Ht + num_cols_per_core_group_2, // Wt + 1, // NC }; auto reduce_compute_kernel_group_2_id = tt_metal::CreateKernel( program, "tt_eager/tt_dnn/op_library/reduce/kernels/compute/reduce_h.cpp", core_group_2, - tt_metal::ComputeConfig{.compile_args = compute_kernel_args_group_2, .defines = reduce_defines} - ); + tt_metal::ComputeConfig{.compile_args = compute_kernel_args_group_2, .defines = reduce_defines}); } + const auto &cores = + grid_to_cores(num_cores, compute_with_storage_grid_size.x, compute_with_storage_grid_size.y, false); if (in_sharded && out_sharded) { uint32_t shard_Wt = num_cols_per_core_group_1 / NC; uint32_t shard_row_size = shard_Wt * src0_single_tile_size; uint32_t shard_batch_size = shard_row_size * Ht; vector reader_rt_args = { - num_cols_per_core_group_1 * Ht, - shard_Wt, - Ht, - NC, - shard_row_size, - shard_batch_size, - packed_scaler_value - }; - tt_metal::SetRuntimeArgs( - program, - reader_kernel_id, - all_cores, - reader_rt_args - ); + num_cols_per_core_group_1 * Ht, shard_Wt, Ht, NC, shard_row_size, shard_batch_size, packed_scaler_value}; + tt_metal::SetRuntimeArgs(program, reader_kernel_id, all_cores, reader_rt_args); - vector writer_rt_args = { - num_cols_per_core_group_1 - }; - tt_metal::SetRuntimeArgs( - program, - writer_kernel_id, - all_cores, - writer_rt_args - ); + vector writer_rt_args = {num_cols_per_core_group_1}; + tt_metal::SetRuntimeArgs(program, writer_kernel_id, all_cores, writer_rt_args); } else { - for (uint32_t i = 0, num_cols_read = 0; i < num_cores; i++){ - CoreCoord core = {i / num_cores_y, i % num_cores_y}; + for (uint32_t i = 0, num_cols_read = 0; i < num_cores; i++) { + const CoreCoord &core = cores[i]; uint32_t num_cols_per_core = 0; if (core_group_1.core_coord_in_core_ranges(core)) { num_cols_per_core = num_cols_per_core_group_1; @@ -227,43 +213,37 @@ operation::ProgramWithCallbacks reduce_multi_core_h(const Tensor &a, Tensor& out TT_ASSERT(false, "Core not in specified core ranges"); } tt_metal::SetRuntimeArgs( - program, reader_kernel_id, core, - { - a.buffer()->address(), - num_cols_read / Wt * HtWt + num_cols_read % Wt, - num_cols_read % Wt, - num_cols_per_core - } - ); + program, + reader_kernel_id, + core, + {a.buffer()->address(), + num_cols_read / Wt * HtWt + num_cols_read % Wt, + num_cols_read % Wt, + num_cols_per_core}); tt_metal::SetRuntimeArgs( - program, writer_kernel_id, core, + program, + writer_kernel_id, + core, { output.buffer()->address(), - num_cols_per_core, // number of tiles to write - num_cols_read // output tile start index - } - ); + num_cols_per_core, // number of tiles to write + num_cols_read // output tile start index + }); num_cols_read += num_cols_per_core; } } - auto override_runtime_arguments_callback = [ - reader_kernel_id=reader_kernel_id, - writer_kernel_id=writer_kernel_id, - cb_src1=cb_src1, - cb_output=cb_output, - num_cores=num_cores, - num_cores_y=num_cores_y - ] - ( - const void* operation, - Program& program, - const std::vector& input_tensors, - const std::vector>&, - const std::vector& output_tensors - ) { - + auto override_runtime_arguments_callback = [reader_kernel_id = reader_kernel_id, + writer_kernel_id = writer_kernel_id, + cb_src1 = cb_src1, + cb_output = cb_output, + cores = cores]( + const void *operation, + Program &program, + const std::vector &input_tensors, + const std::vector> &, + const std::vector &output_tensors) { auto src_buffer = input_tensors.at(0).buffer(); auto dst_buffer = output_tensors.at(0).buffer(); @@ -274,23 +254,23 @@ operation::ProgramWithCallbacks reduce_multi_core_h(const Tensor &a, Tensor& out UpdateDynamicCircularBufferAddress(program, cb_src1, *src_buffer); UpdateDynamicCircularBufferAddress(program, cb_output, *dst_buffer); } else { - for (uint32_t i = 0, num_tiles_read = 0; i < num_cores; i++){ - CoreCoord core = {i / num_cores_y, i % num_cores_y}; - + auto &reader_runtime_args_by_core = GetRuntimeArgs(program, reader_kernel_id); + auto &writer_runtime_args_by_core = GetRuntimeArgs(program, writer_kernel_id); + for (const auto &core : cores) { { - auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); + auto &runtime_args = reader_runtime_args_by_core[core.x][core.y]; runtime_args[0] = src_buffer->address(); } { - auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + auto &runtime_args = writer_runtime_args_by_core[core.x][core.y]; runtime_args[0] = dst_buffer->address(); } } } }; - return {.program=std::move(program), .override_runtime_arguments_callback=override_runtime_arguments_callback}; + return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_arguments_callback}; } } // namespace tt_metal diff --git a/tt_eager/tt_dnn/op_library/reduce/multi_core_w/reduce_op_multi_core_w.cpp b/tt_eager/tt_dnn/op_library/reduce/multi_core_w/reduce_op_multi_core_w.cpp index a1e20af72bd..5cd86b74768 100644 --- a/tt_eager/tt_dnn/op_library/reduce/multi_core_w/reduce_op_multi_core_w.cpp +++ b/tt_eager/tt_dnn/op_library/reduce/multi_core_w/reduce_op_multi_core_w.cpp @@ -3,12 +3,12 @@ // SPDX-License-Identifier: Apache-2.0 #include + #include "tt_dnn/op_library/reduce/reduce_op.hpp" #include "tt_dnn/op_library/work_split.hpp" - -#include "tt_metal/host_api.hpp" #include "tt_metal/common/constants.hpp" #include "tt_metal/detail/util.hpp" +#include "tt_metal/host_api.hpp" using namespace tt::constants; using uint32_t = std::uint32_t; @@ -17,14 +17,14 @@ namespace tt { namespace tt_metal { -operation::ProgramWithCallbacks reduce_multi_core_w(const Tensor &a, Tensor& output, ReduceOpMath reduce_op, float scaler) { - +operation::ProgramWithCallbacks reduce_multi_core_w( + const Tensor &a, Tensor &output, ReduceOpMath reduce_op, float scaler) { const auto shape = a.get_legacy_shape(); - uint32_t W = shape[3], H = shape[2], NC = shape[1]*shape[0]; - uint32_t HW = H*W; + uint32_t W = shape[3], H = shape[2], NC = shape[1] * shape[0]; + uint32_t HW = H * W; - uint32_t Wt = W/TILE_WIDTH; - uint32_t Ht = H/TILE_HEIGHT; + uint32_t Wt = W / TILE_WIDTH; + uint32_t Ht = H / TILE_HEIGHT; tt_metal::Program program = tt_metal::CreateProgram(); @@ -36,7 +36,7 @@ operation::ProgramWithCallbacks reduce_multi_core_w(const Tensor &a, Tensor& out tt::DataFormat dst_cb_data_format = tt_metal::datatype_to_dataformat_converter(output.get_dtype()); uint32_t dst_single_tile_size = tt_metal::detail::TileSize(dst_cb_data_format); - uint32_t num_tiles = a.volume()/TILE_HW; + uint32_t num_tiles = a.volume() / TILE_HW; tt_metal::Device *device = a.device(); @@ -44,22 +44,26 @@ operation::ProgramWithCallbacks reduce_multi_core_w(const Tensor &a, Tensor& out uint32_t num_cores_x = compute_with_storage_grid_size.x; uint32_t num_cores_y = compute_with_storage_grid_size.y; auto num_rows = NC * Ht; - auto [num_cores, all_cores, core_group_1, core_group_2, num_rows_per_core_group_1, num_rows_per_core_group_2] = split_work_to_cores(compute_with_storage_grid_size, num_rows); + auto [num_cores, all_cores, core_group_1, core_group_2, num_rows_per_core_group_1, num_rows_per_core_group_2] = + split_work_to_cores(compute_with_storage_grid_size, num_rows); uint32_t src0_cb_index = 0; uint32_t num_input_tiles = 2; - tt_metal::CircularBufferConfig cb_src0_config = tt_metal::CircularBufferConfig(num_input_tiles * src0_single_tile_size, {{src0_cb_index, src0_cb_data_format}}) - .set_page_size(src0_cb_index, src0_single_tile_size); + tt_metal::CircularBufferConfig cb_src0_config = + tt_metal::CircularBufferConfig(num_input_tiles * src0_single_tile_size, {{src0_cb_index, src0_cb_data_format}}) + .set_page_size(src0_cb_index, src0_single_tile_size); auto cb_src0 = tt_metal::CreateCircularBuffer(program, all_cores, cb_src0_config); - tt_metal::CircularBufferConfig cb_scaler_config = tt_metal::CircularBufferConfig(num_input_tiles * scaler_single_tile_size, {{CB::c_in2, scaler_cb_data_format}}) - .set_page_size(CB::c_in2, scaler_single_tile_size); + tt_metal::CircularBufferConfig cb_scaler_config = + tt_metal::CircularBufferConfig(num_input_tiles * scaler_single_tile_size, {{CB::c_in2, scaler_cb_data_format}}) + .set_page_size(CB::c_in2, scaler_single_tile_size); auto cb_scaler = tt_metal::CreateCircularBuffer(program, all_cores, cb_scaler_config); - uint32_t output_cb_index = 16; // output operands start at index 16 + uint32_t output_cb_index = 16; // output operands start at index 16 uint32_t num_output_tiles = 2; - tt_metal::CircularBufferConfig cb_output_config = tt_metal::CircularBufferConfig(num_output_tiles * dst_single_tile_size, {{output_cb_index, dst_cb_data_format}}) - .set_page_size(output_cb_index, dst_single_tile_size); + tt_metal::CircularBufferConfig cb_output_config = + tt_metal::CircularBufferConfig(num_output_tiles * dst_single_tile_size, {{output_cb_index, dst_cb_data_format}}) + .set_page_size(output_cb_index, dst_single_tile_size); auto cb_output = tt_metal::CreateCircularBuffer(program, all_cores, cb_output_config); bfloat16 bfloat_scaler_value = bfloat16(scaler); @@ -69,10 +73,7 @@ operation::ProgramWithCallbacks reduce_multi_core_w(const Tensor &a, Tensor& out std::vector reader_compile_time_args = {(uint32_t)src_is_dram, packed_scaler_value}; tt_metal::Buffer *dst_buffer = output.buffer(); bool dst_is_dram = dst_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; - std::vector writer_compile_time_args = { - (std::uint32_t) output_cb_index, - (std::uint32_t) dst_is_dram - }; + std::vector writer_compile_time_args = {(std::uint32_t)output_cb_index, (std::uint32_t)dst_is_dram}; tt_metal::KernelHandle reader_kernel_id = tt_metal::CreateKernel( program, @@ -88,36 +89,36 @@ operation::ProgramWithCallbacks reduce_multi_core_w(const Tensor &a, Tensor& out std::map reduce_defines = reduce_op_utils::get_defines(reduce_op, ReduceOpDim::W); vector compute_kernel_args_group_1 = { - num_rows_per_core_group_1, // Ht - Wt, // Wt - 1, // NC + num_rows_per_core_group_1, // Ht + Wt, // Wt + 1, // NC }; auto reduce_compute_kernel_group_1_id = tt_metal::CreateKernel( program, "tt_eager/tt_dnn/op_library/reduce/kernels/compute/reduce_w.cpp", core_group_1, - tt_metal::ComputeConfig{.compile_args = compute_kernel_args_group_1, .defines = reduce_defines} - ); + tt_metal::ComputeConfig{.compile_args = compute_kernel_args_group_1, .defines = reduce_defines}); - if(!core_group_2.ranges().empty()){ + if (!core_group_2.ranges().empty()) { vector compute_kernel_args_group_2 = { - num_rows_per_core_group_2, // Ht - Wt, // Wt - 1, // NC + num_rows_per_core_group_2, // Ht + Wt, // Wt + 1, // NC }; auto reduce_compute_kernel_group_2_id = tt_metal::CreateKernel( program, "tt_eager/tt_dnn/op_library/reduce/kernels/compute/reduce_w.cpp", core_group_2, - tt_metal::ComputeConfig{.compile_args = compute_kernel_args_group_2, .defines = reduce_defines} - ); + tt_metal::ComputeConfig{.compile_args = compute_kernel_args_group_2, .defines = reduce_defines}); } uint32_t out_dim_divider = Wt; - for (uint32_t i = 0, num_tiles_read = 0; i < num_cores; i++){ - CoreCoord core = {i / num_cores_y, i % num_cores_y}; + const auto &cores = + grid_to_cores(num_cores, compute_with_storage_grid_size.x, compute_with_storage_grid_size.y, false); + for (uint32_t i = 0, num_tiles_read = 0; i < num_cores; i++) { + const CoreCoord &core = cores[i]; uint32_t num_rows_per_core = 0; if (core_group_1.core_coord_in_core_ranges(core)) { num_rows_per_core = num_rows_per_core_group_1; @@ -126,53 +127,47 @@ operation::ProgramWithCallbacks reduce_multi_core_w(const Tensor &a, Tensor& out } else { TT_ASSERT(false, "Core not in specified core ranges"); } - uint32_t num_tensor_tiles_per_core = num_rows_per_core*Wt; + uint32_t num_tensor_tiles_per_core = num_rows_per_core * Wt; tt_metal::SetRuntimeArgs( - program, reader_kernel_id, core, + program, + reader_kernel_id, + core, { a.buffer()->address(), num_tensor_tiles_per_core, - num_tiles_read // tile index of row to start reading from - } - ); + num_tiles_read // tile index of row to start reading from + }); tt_metal::SetRuntimeArgs( - program, writer_kernel_id, core, + program, + writer_kernel_id, + core, { output.buffer()->address(), - num_tensor_tiles_per_core / out_dim_divider, // number of tiles to write - num_tiles_read / out_dim_divider // output tile start index - } - ); - num_tiles_read+=num_tensor_tiles_per_core; + num_tensor_tiles_per_core / out_dim_divider, // number of tiles to write + num_tiles_read / out_dim_divider // output tile start index + }); + num_tiles_read += num_tensor_tiles_per_core; } - auto override_runtime_args_callback = [ - reader_kernel_id, - writer_kernel_id, - num_cores, - num_cores_y - ] - ( - const Program &program, - const std::vector& input_buffers, - const std::vector& output_buffers - ) { - + auto override_runtime_args_callback = [reader_kernel_id, writer_kernel_id, cores]( + const Program &program, + const std::vector &input_buffers, + const std::vector &output_buffers) { auto src_dram_buffer = input_buffers.at(0); auto dst_dram_buffer = output_buffers.at(0); - for (uint32_t i = 0, num_tiles_read = 0; i < num_cores; i++){ - CoreCoord core = {i / num_cores_y, i % num_cores_y}; - + auto &reader_runtime_args_by_core = GetRuntimeArgs(program, reader_kernel_id); + auto &writer_runtime_args_by_core = GetRuntimeArgs(program, writer_kernel_id); + for (const auto &core : cores) { { - auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); + auto &runtime_args = reader_runtime_args_by_core[core.x][core.y]; runtime_args[0] = src_dram_buffer->address(); } { - auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + auto &runtime_args = writer_runtime_args_by_core[core.x][core.y]; runtime_args[0] = dst_dram_buffer->address(); } } diff --git a/tt_eager/tt_dnn/op_library/sharded/multi_core/sharded_op_multi_core.cpp b/tt_eager/tt_dnn/op_library/sharded/multi_core/sharded_op_multi_core.cpp index 19d6748b08e..3ff4fcdcb2c 100644 --- a/tt_eager/tt_dnn/op_library/sharded/multi_core/sharded_op_multi_core.cpp +++ b/tt_eager/tt_dnn/op_library/sharded/multi_core/sharded_op_multi_core.cpp @@ -308,8 +308,9 @@ operation::ProgramWithCallbacks interleaved_to_sharded_multi_core( starting_idx_h = calculate_starting_idx_h(input_tensors.at(0), num_slices, runtime_slice_index); } + auto& runtime_args_by_core = GetRuntimeArgs(program, unary_reader_kernel_id); for (const auto& core : cores) { - auto& runtime_args = GetRuntimeArgs(program, unary_reader_kernel_id, core); + auto& runtime_args = runtime_args_by_core[core.x][core.y]; runtime_args[0] = src_buffer->address(); if (partial_op) { runtime_args[7] = starting_idx_h; diff --git a/tt_eager/tt_dnn/op_library/tilize/tilize_multi_core/tilize_op_multi_core.cpp b/tt_eager/tt_dnn/op_library/tilize/tilize_multi_core/tilize_op_multi_core.cpp index 36cdfb9986c..33a6774e35b 100644 --- a/tt_eager/tt_dnn/op_library/tilize/tilize_multi_core/tilize_op_multi_core.cpp +++ b/tt_eager/tt_dnn/op_library/tilize/tilize_multi_core/tilize_op_multi_core.cpp @@ -104,8 +104,9 @@ operation::ProgramWithCallbacks tilize_multi_core_interleaved(const Tensor& a, T uint32_t ncores_x = grid_size.x; uint32_t tile_start_id = 0; uint32_t row_start_id = 0; + const auto& cores = grid_to_cores(ncores, grid_size.x, grid_size.y, true); for (uint32_t i = 0; i < ncores_full; ++i) { - CoreCoord core = {i % ncores_x, i / ncores_x}; + const CoreCoord& core = cores[i]; // reader runtime args vector reader_rt_args = { @@ -134,7 +135,7 @@ operation::ProgramWithCallbacks tilize_multi_core_interleaved(const Tensor& a, T } if (has_cliff) { // the last core is a cliff core with nblocks_per_core_cliff blocks - CoreCoord core = {ncores_full % ncores_x, ncores_full / ncores_x}; + const CoreCoord& core = cores.back(); // reader runtime args vector reader_rt_args = { @@ -159,28 +160,27 @@ operation::ProgramWithCallbacks tilize_multi_core_interleaved(const Tensor& a, T SetRuntimeArgs(program, unary_writer_kernel_id, core, writer_rt_args); } - auto override_runtime_args_callback = [reader_kernel_id = unary_reader_kernel_id, - writer_kernel_id = unary_writer_kernel_id, - ncores = ncores, - ncores_x = ncores_x]( - const Program& program, - const std::vector& input_buffers, - const std::vector& output_buffers) { - auto src_buffer = input_buffers.at(0); - auto dst_buffer = output_buffers.at(0); - - for (uint32_t i = 0; i < ncores; ++i) { - CoreCoord core = {i % ncores_x, i / ncores_x}; - { - auto& runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); - runtime_args[0] = src_buffer->address(); + auto override_runtime_args_callback = + [reader_kernel_id = unary_reader_kernel_id, writer_kernel_id = unary_writer_kernel_id, cores = cores]( + const Program& program, + const std::vector& input_buffers, + const std::vector& output_buffers) { + auto src_buffer = input_buffers.at(0); + auto dst_buffer = output_buffers.at(0); + + auto& reader_runtime_args_by_core = GetRuntimeArgs(program, reader_kernel_id); + auto& writer_runtime_args_by_core = GetRuntimeArgs(program, writer_kernel_id); + for (const auto& core : cores) { + { + auto& runtime_args = reader_runtime_args_by_core[core.x][core.y]; + runtime_args[0] = src_buffer->address(); + } + { + auto& runtime_args = writer_runtime_args_by_core[core.x][core.y]; + runtime_args[0] = dst_buffer->address(); + } } - { - auto& runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); - runtime_args[0] = dst_buffer->address(); - } - } - }; + }; return {std::move(program), override_runtime_args_callback}; } @@ -256,7 +256,7 @@ operation::ProgramWithCallbacks tilize_multi_core_sharded(const Tensor& input, T tt_metal::SetRuntimeArgs(program, unary_writer_kernel_id, all_cores, {num_tiles_per_shard}); auto override_runtime_arguments_callback = - [unary_reader_kernel_id, unary_writer_kernel_id, cb_src0, cb_output, num_cores, num_cores_x]( + [unary_reader_kernel_id, unary_writer_kernel_id, cb_src0, cb_output]( const void* operation, Program& program, const std::vector& input_tensors, @@ -383,7 +383,9 @@ operation::ProgramWithCallbacks tilize_with_val_padding_multi_core_interleaved( uint32_t row_start_id = 0; uint32_t ncores_x = grid_size.x; + const auto& cores = grid_to_cores(ncores, grid_size.x, grid_size.y, true); for (uint32_t i = 0; i < ncores; ++i) { + const auto& core = cores[i]; const std::vector& assignment = core_assignments.at(i); // reader runtime args @@ -412,36 +414,33 @@ operation::ProgramWithCallbacks tilize_with_val_padding_multi_core_interleaved( // writer runtime args vector writer_rt_args = {dst_buffer->address(), num_tiles_per_core, tile_start_id}; - CoreCoord core = {i % ncores_x, i / ncores_x}; - SetRuntimeArgs(program, unary_reader_kernel_id, core, reader_rt_args); SetRuntimeArgs(program, unary_writer_kernel_id, core, writer_rt_args); tile_start_id += num_tiles_per_core; } - auto override_runtime_args_callback = [reader_kernel_id = unary_reader_kernel_id, - writer_kernel_id = unary_writer_kernel_id, - ncores = ncores, - ncores_x = ncores_x]( - const Program& program, - const std::vector& input_buffers, - const std::vector& output_buffers) { - auto src_buffer = input_buffers.at(0); - auto dst_buffer = output_buffers.at(0); - - for (uint32_t i = 0; i < ncores; ++i) { - CoreCoord core = {i % ncores_x, i / ncores_x}; - { - auto& runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); - runtime_args[0] = src_buffer->address(); + auto override_runtime_args_callback = + [reader_kernel_id = unary_reader_kernel_id, writer_kernel_id = unary_writer_kernel_id, cores = cores]( + const Program& program, + const std::vector& input_buffers, + const std::vector& output_buffers) { + auto src_buffer = input_buffers.at(0); + auto dst_buffer = output_buffers.at(0); + + auto& reader_runtime_args_by_core = GetRuntimeArgs(program, reader_kernel_id); + auto& writer_runtime_args_by_core = GetRuntimeArgs(program, writer_kernel_id); + for (const auto& core : cores) { + { + auto& runtime_args = reader_runtime_args_by_core[core.x][core.y]; + runtime_args[0] = src_buffer->address(); + } + { + auto& runtime_args = writer_runtime_args_by_core[core.x][core.y]; + runtime_args[0] = dst_buffer->address(); + } } - { - auto& runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); - runtime_args[0] = dst_buffer->address(); - } - } - }; + }; return {std::move(program), override_runtime_args_callback}; } @@ -583,9 +582,6 @@ operation::ProgramWithCallbacks tilize_with_val_padding_multi_core_sharded( auto src_buffer = input_tensors.at(0).buffer(); auto dst_buffer = output_tensors.at(0).buffer(); - bool src_sharded = input_tensors.at(0).memory_config().is_sharded(); - bool out_sharded = output_tensors.at(0).memory_config().is_sharded(); - UpdateDynamicCircularBufferAddress(program, cb_src0, *src_buffer); UpdateDynamicCircularBufferAddress(program, cb_output, *dst_buffer); }; diff --git a/tt_eager/tt_dnn/op_library/untilize/multi_core/untilize_op_multi_core.cpp b/tt_eager/tt_dnn/op_library/untilize/multi_core/untilize_op_multi_core.cpp index 03b36698c97..c406e856638 100644 --- a/tt_eager/tt_dnn/op_library/untilize/multi_core/untilize_op_multi_core.cpp +++ b/tt_eager/tt_dnn/op_library/untilize/multi_core/untilize_op_multi_core.cpp @@ -412,8 +412,9 @@ operation::ProgramWithCallbacks untilize_multi_core( if (src_sharded) { UpdateDynamicCircularBufferAddress(program, cb_src0, *src_buffer); } else { + auto& runtime_args_by_core = GetRuntimeArgs(program, reader_kernel_id); for (const CoreCoord& core : cores_with_rtargs) { - auto& runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); + auto& runtime_args = runtime_args_by_core[core.x][core.y]; runtime_args[0] = src_buffer->address(); } } @@ -421,8 +422,9 @@ operation::ProgramWithCallbacks untilize_multi_core( if (out_sharded) { UpdateDynamicCircularBufferAddress(program, cb_output, *dst_buffer); } else { + auto& runtime_args_by_core = GetRuntimeArgs(program, writer_kernel_id); for (const CoreCoord& core : cores_with_rtargs) { - auto& runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + auto& runtime_args = runtime_args_by_core[core.x][core.y]; runtime_args[0] = dst_buffer->address(); } } @@ -546,7 +548,9 @@ operation::ProgramWithCallbacks untilize_with_unpadding_multi_core_interleaved( uint32_t row_start_id = 0; uint32_t ncores_x = grid_size.x; + const auto& cores = grid_to_cores(ncores, grid_size.x, grid_size.y, true); for (uint32_t i = 0; i < ncores; ++i) { + const auto& core = cores[i]; const std::vector& assignment = core_assignments.at(i); // writer runtime args @@ -574,36 +578,34 @@ operation::ProgramWithCallbacks untilize_with_unpadding_multi_core_interleaved( // reader runtime args vector reader_rt_args = {src0_buffer->address(), num_tiles_per_core, tile_start_id}; - CoreCoord core = {i % ncores_x, i / ncores_x}; - SetRuntimeArgs(program, unary_reader_kernel_id, core, reader_rt_args); SetRuntimeArgs(program, unary_writer_kernel_id, core, writer_rt_args); tile_start_id += num_tiles_per_core; } - auto override_runtime_args_callback = [reader_kernel_id = unary_reader_kernel_id, - writer_kernel_id = unary_writer_kernel_id, - ncores = ncores, - ncores_x = ncores_x]( - const Program& program, - const std::vector& input_buffers, - const std::vector& output_buffers) { - auto src_buffer = input_buffers.at(0); - auto dst_buffer = output_buffers.at(0); - - for (uint32_t i = 0; i < ncores; ++i) { - CoreCoord core = {i % ncores_x, i / ncores_x}; - { - auto& runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); - runtime_args[0] = src_buffer->address(); - } - { - auto& runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); - runtime_args[0] = dst_buffer->address(); + auto override_runtime_args_callback = + [reader_kernel_id = unary_reader_kernel_id, writer_kernel_id = unary_writer_kernel_id, cores = cores]( + const Program& program, + const std::vector& input_buffers, + const std::vector& output_buffers) { + auto src_buffer = input_buffers.at(0); + auto dst_buffer = output_buffers.at(0); + + auto& reader_runtime_args_by_core = GetRuntimeArgs(program, reader_kernel_id); + auto& writer_runtime_args_by_core = GetRuntimeArgs(program, writer_kernel_id); + + for (const auto& core : cores) { + { + auto& runtime_args = reader_runtime_args_by_core[core.x][core.y]; + runtime_args[0] = src_buffer->address(); + } + { + auto& runtime_args = writer_runtime_args_by_core[core.x][core.y]; + runtime_args[0] = dst_buffer->address(); + } } - } - }; + }; return {std::move(program), override_runtime_args_callback}; } @@ -862,8 +864,9 @@ operation::ProgramWithCallbacks untilize_with_unpadding_multi_core_sharded( if (out_sharded) { UpdateDynamicCircularBufferAddress(program, cb_sharded_output, *dst_buffer); } else { + auto& runtime_args_by_core = GetRuntimeArgs(program, writer_kernel_id); for (const CoreCoord& core : cores) { - auto& runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + auto& runtime_args = runtime_args_by_core[core.x][core.y]; runtime_args[0] = dst_buffer->address(); } } diff --git a/tt_metal/tt_metal.cpp b/tt_metal/tt_metal.cpp index 115b29412f6..1e4523f3dc8 100644 --- a/tt_metal/tt_metal.cpp +++ b/tt_metal/tt_metal.cpp @@ -2,24 +2,23 @@ // // SPDX-License-Identifier: Apache-2.0 +#include "tt_metal/detail/tt_metal.hpp" + #include #include #include #include -#include #include +#include +#include "dev_msgs.h" #include "impl/allocator/allocator.hpp" -#include "impl/dispatch/command_queue.hpp" -#include "tt_metal/host_api.hpp" #include "impl/debug/dprint_server.hpp" -#include "dev_msgs.h" - +#include "impl/dispatch/command_queue.hpp" #include "tools/profiler/profiler.hpp" -#include "tt_metal/detail/tt_metal.hpp" #include "tt_metal/detail/program.hpp" +#include "tt_metal/host_api.hpp" #include "tt_metal/impl/trace/trace.hpp" - #include "tt_metal/third_party/tracy/public/tracy/Tracy.hpp" namespace tt { @@ -28,7 +27,8 @@ namespace tt_metal { namespace { -void ConfigureKernelGroup(const Program &program, const KernelGroup *kernel_group, Device *device, const CoreCoord &logical_core) { +void ConfigureKernelGroup( + const Program &program, const KernelGroup *kernel_group, Device *device, const CoreCoord &logical_core) { if (kernel_group->compute_id.has_value()) { detail::GetKernel(program, kernel_group->compute_id.value())->configure(device, logical_core); } @@ -79,56 +79,82 @@ std::optional get_semaphore_address(const Program &program, const Core return address; } - -inline void SetRuntimeArgs(const Program &program, KernelHandle kernel_id, const CoreCoord &c, const std::vector &runtime_args) -{ +inline void SetRuntimeArgs( + const Program &program, KernelHandle kernel_id, const CoreCoord &c, const std::vector &runtime_args) { if (runtime_args.size() != 0) { detail::GetKernel(program, kernel_id)->set_runtime_args(c, runtime_args); } } +inline void SetRuntimeArgs( + const Program &program, + KernelHandle kernel_id, + const CoreRange &core_range, + const std::vector &runtime_args) { + if (runtime_args.size() != 0) { + auto kernel = detail::GetKernel(program, kernel_id); + for (auto x = core_range.start.x; x <= core_range.end.x; ++x) { + for (auto y = core_range.start.y; y <= core_range.end.y; ++y) { + kernel->set_runtime_args(CoreCoord(x, y), runtime_args); + } + } + } +} -inline void SetRuntimeArgs(const Program &program, KernelHandle kernel_id, const CoreRange &core_range, const std::vector &runtime_args) -{ +inline void SetRuntimeArgs( + const Program &program, + KernelHandle kernel_id, + const CoreRangeSet &core_range_set, + const std::vector &runtime_args) { if (runtime_args.size() != 0) { - for (auto x = core_range.start.x; x <= core_range.end.x; x++) { - for (auto y = core_range.start.y; y <= core_range.end.y; y++) { - // TODO: maybe directly update command queue - SetRuntimeArgs(program, kernel_id, CoreCoord(x,y), runtime_args); + auto kernel = detail::GetKernel(program, kernel_id); + for (const auto &core_range : core_range_set.ranges()) { + for (auto x = core_range.start.x; x <= core_range.end.x; ++x) { + for (auto y = core_range.start.y; y <= core_range.end.y; ++y) { + kernel->set_runtime_args(CoreCoord(x, y), runtime_args); + } } } } } -inline void SetRuntimeArgs(CommandQueue& cq, const std::shared_ptr kernel, const std::variant &core_spec, std::shared_ptr runtime_args, bool blocking) { +inline void SetRuntimeArgs( + CommandQueue &cq, + const std::shared_ptr kernel, + const std::variant &core_spec, + std::shared_ptr runtime_args, + bool blocking) { // SetRuntimeArgs API for Async CQ Mode - std::visit([&](auto&& core_spec) { + std::visit( + [&](auto &&core_spec) { using T = std::decay_t; if constexpr (std::is_same_v) { EnqueueSetRuntimeArgs(cq, kernel, core_spec, runtime_args, blocking); - } - else if constexpr (std::is_same_v) { + } else if constexpr (std::is_same_v) { for (auto x = core_spec.start.x; x <= core_spec.end.x; x++) { for (auto y = core_spec.start.y; y <= core_spec.end.y; y++) { EnqueueSetRuntimeArgs(cq, kernel, CoreCoord(x, y), runtime_args, blocking); } } - } - else if constexpr (std::is_same_v) { - for (const auto& core_range : core_spec.ranges()) { + } else if constexpr (std::is_same_v) { + for (const auto &core_range : core_spec.ranges()) { for (auto x = core_range.start.x; x <= core_range.end.x; x++) { for (auto y = core_range.start.y; y <= core_range.end.y; y++) { - EnqueueSetRuntimeArgs(cq, kernel, CoreCoord(x, y), runtime_args, blocking); + EnqueueSetRuntimeArgs(cq, kernel, CoreCoord(x, y), runtime_args, blocking); } } } } }, - core_spec - ); + core_spec); } -inline void SetRuntimeArgs(CommandQueue& cq, const std::shared_ptr kernel, const std::vector< CoreCoord > & core_spec, const std::vector> runtime_args, bool blocking) { +inline void SetRuntimeArgs( + CommandQueue &cq, + const std::shared_ptr kernel, + const std::vector &core_spec, + const std::vector> runtime_args, + bool blocking) { // SetRuntimeArgs API for Async CQ Mode (support vector of runtime args) for (size_t i = 0; i < core_spec.size(); i++) { EnqueueSetRuntimeArgs(cq, kernel, core_spec[i], runtime_args[i], blocking); @@ -137,13 +163,13 @@ inline void SetRuntimeArgs(CommandQueue& cq, const std::shared_ptr kerne } // namespace -//#define DEBUG_PRINT_SHARD +// #define DEBUG_PRINT_SHARD namespace device_pool { // Definition of the global device vector - std::vector devices; +std::vector devices; -} // device_pool +} // namespace device_pool namespace detail { @@ -159,9 +185,9 @@ std::map CreateDevices( if (active_devices.find(mmio_device_id) == active_devices.end()) { for (const auto &mmio_controlled_device_id : tt::Cluster::instance().get_devices_controlled_by_mmio_device(mmio_device_id)) { - //if (mmio_controlled_device_id != mmio_device_id) { - // continue; - //} + // if (mmio_controlled_device_id != mmio_device_id) { + // continue; + // } Device *dev = new Device(mmio_controlled_device_id, num_hw_cqs, l1_small_size, l1_bank_remap); active_devices.insert({mmio_controlled_device_id, dev}); detail::InitDeviceProfiler(dev); @@ -180,332 +206,299 @@ void CloseDevices(std::map devices) { } } - void print_page(uint32_t dev_page_id, CoreCoord core, uint32_t host_page_id, CoreCoord noc_coordinates, uint32_t l1_address, uint32_t bank_id, std::vector page){ - std::cout << "dev_page_index " << dev_page_id << " on core " << core.str() << std::endl; - std::cout << "host_page_index " << host_page_id << std::endl; - std::cout << "noc coordinates " << noc_coordinates.str() << std::endl; - std::cout << "l1_address " << l1_address << std::endl; - std::cout << "bank id " << bank_id << std::endl; - - std::cout << "0x"; - for(auto entry: page){ - std::cout << std::hex << entry << std::dec; - } - std::cout << std::dec << std::endl; - } - - void WriteToDeviceSharded(const Buffer &buffer, const std::vector &host_buffer) { - uint32_t host_buffer_size_bytes = host_buffer.size() * sizeof(uint32_t); - TT_FATAL( - host_buffer_size_bytes <= buffer.size(), - "Bounds-Error -- Attempting to write {} bytes to a {} byte buffer", host_buffer_size_bytes, buffer.size()); - - uint32_t page_size = buffer.page_size(); - TT_ASSERT(buffer.size() % page_size == 0); - - static constexpr uint32_t bytes_per_page_entry = sizeof(uint32_t); - TT_ASSERT(page_size % bytes_per_page_entry == 0); - uint32_t num_entries_per_page = page_size / bytes_per_page_entry; - - auto device = buffer.device(); - - TT_ASSERT(buffer.is_l1(), "Only L1 Buffers support sharding"); - - auto buffer_page_mapping = generate_buffer_page_mapping(buffer); - auto total_pages = buffer.num_pages(); - for(int host_page_id=0; host_page_idbank_ids_from_logical_core(buffer.buffer_type(), core)[0]; - auto absolute_address = buffer.sharded_page_address(bank_id, dev_page_id); - auto data_index = host_page_id * num_entries_per_page; - std::vector page; - page.insert( - page.end(), host_buffer.begin() + data_index, host_buffer.begin() + data_index + num_entries_per_page); - - auto noc_coordinates = buffer.noc_coordinates(bank_id); - llrt::write_hex_vec_to_core(device->id(), noc_coordinates, page, absolute_address); +void print_page( + uint32_t dev_page_id, + CoreCoord core, + uint32_t host_page_id, + CoreCoord noc_coordinates, + uint32_t l1_address, + uint32_t bank_id, + std::vector page) { + std::cout << "dev_page_index " << dev_page_id << " on core " << core.str() << std::endl; + std::cout << "host_page_index " << host_page_id << std::endl; + std::cout << "noc coordinates " << noc_coordinates.str() << std::endl; + std::cout << "l1_address " << l1_address << std::endl; + std::cout << "bank id " << bank_id << std::endl; + + std::cout << "0x"; + for (auto entry : page) { + std::cout << std::hex << entry << std::dec; + } + std::cout << std::dec << std::endl; +} + +void WriteToDeviceSharded(const Buffer &buffer, const std::vector &host_buffer) { + uint32_t host_buffer_size_bytes = host_buffer.size() * sizeof(uint32_t); + TT_FATAL( + host_buffer_size_bytes <= buffer.size(), + "Bounds-Error -- Attempting to write {} bytes to a {} byte buffer", + host_buffer_size_bytes, + buffer.size()); + + uint32_t page_size = buffer.page_size(); + TT_ASSERT(buffer.size() % page_size == 0); + + static constexpr uint32_t bytes_per_page_entry = sizeof(uint32_t); + TT_ASSERT(page_size % bytes_per_page_entry == 0); + uint32_t num_entries_per_page = page_size / bytes_per_page_entry; + + auto device = buffer.device(); + + TT_ASSERT(buffer.is_l1(), "Only L1 Buffers support sharding"); + + auto buffer_page_mapping = generate_buffer_page_mapping(buffer); + auto total_pages = buffer.num_pages(); + for (int host_page_id = 0; host_page_id < total_pages; host_page_id++) { + auto dev_page_id = buffer_page_mapping.host_page_to_dev_page_mapping_[host_page_id]; + auto core = buffer_page_mapping.all_cores_[buffer_page_mapping.dev_page_to_core_mapping_[dev_page_id]]; + auto bank_id = device->bank_ids_from_logical_core(buffer.buffer_type(), core)[0]; + auto absolute_address = buffer.sharded_page_address(bank_id, dev_page_id); + auto data_index = host_page_id * num_entries_per_page; + std::vector page; + page.insert( + page.end(), host_buffer.begin() + data_index, host_buffer.begin() + data_index + num_entries_per_page); + + auto noc_coordinates = buffer.noc_coordinates(bank_id); + llrt::write_hex_vec_to_core(device->id(), noc_coordinates, page, absolute_address); + } +} + +void WriteToDeviceInterleavedContiguous(const Buffer &buffer, const std::vector &host_buffer) { + uint32_t host_buffer_size_bytes = host_buffer.size() * sizeof(uint32_t); + TT_FATAL( + host_buffer_size_bytes <= buffer.size(), + "Bounds-Error -- Attempting to write {} bytes to a {} byte buffer", + host_buffer_size_bytes, + buffer.size()); + + uint32_t page_size = buffer.page_size(); + TT_FATAL(buffer.size() % page_size == 0); + uint32_t num_pages = buffer.size() / page_size; + + static constexpr uint32_t bytes_per_page_entry = sizeof(uint32_t); + TT_FATAL(page_size % bytes_per_page_entry == 0); + uint32_t num_entries_per_page = page_size / bytes_per_page_entry; + + auto device = buffer.device(); + auto num_banks = device->num_banks(buffer.buffer_type()); + uint32_t bank_index = 0; + int data_index = 0; + for (int page_index = 0; page_index < num_pages; page_index++) { + auto absolute_address = buffer.page_address(bank_index, page_index); + std::vector page; + page.insert( + page.end(), host_buffer.begin() + data_index, host_buffer.begin() + data_index + num_entries_per_page); + switch (buffer.buffer_type()) { + case BufferType::DRAM: { + auto dram_channel = buffer.dram_channel_from_bank_id(bank_index); + tt::Cluster::instance().write_dram_vec( + page, tt_target_dram{device->id(), dram_channel, 0}, absolute_address); + } break; + case BufferType::L1: // fallthrough + case BufferType::L1_SMALL: { + auto noc_coordinates = buffer.noc_coordinates(bank_index); + llrt::write_hex_vec_to_core(device->id(), noc_coordinates, page, absolute_address); + } break; + default: TT_FATAL(false && "Unsupported buffer type to write to device!"); } - + bank_index = (bank_index + 1) % num_banks; + data_index += num_entries_per_page; } +} +void WriteToDevice(const Buffer &buffer, const std::vector &host_buffer) { + ZoneScoped; + if (buffer.buffer_layout() == TensorMemoryLayout::INTERLEAVED || + buffer.buffer_layout() == TensorMemoryLayout::SINGLE_BANK) { + WriteToDeviceInterleavedContiguous(buffer, host_buffer); + } else if (is_sharded(buffer.buffer_layout())) { + WriteToDeviceSharded(buffer, host_buffer); + } else { + TT_ASSERT(false && "Unsupported buffer layout"); + } +} +void WriteToBuffer(std::shared_ptr buffer, const std::vector &host_buffer) { + WriteToBuffer(*buffer, host_buffer); +} - void WriteToDeviceInterleavedContiguous(const Buffer &buffer, const std::vector &host_buffer) { - - uint32_t host_buffer_size_bytes = host_buffer.size() * sizeof(uint32_t); - TT_FATAL( - host_buffer_size_bytes <= buffer.size(), - "Bounds-Error -- Attempting to write {} bytes to a {} byte buffer", host_buffer_size_bytes, buffer.size()); - - uint32_t page_size = buffer.page_size(); - TT_FATAL(buffer.size() % page_size == 0); - uint32_t num_pages = buffer.size() / page_size; - - static constexpr uint32_t bytes_per_page_entry = sizeof(uint32_t); - TT_FATAL(page_size % bytes_per_page_entry == 0); - uint32_t num_entries_per_page = page_size / bytes_per_page_entry; - - auto device = buffer.device(); - auto num_banks = device->num_banks(buffer.buffer_type()); - uint32_t bank_index = 0; - int data_index = 0; - for (int page_index = 0; page_index < num_pages; page_index++) { - auto absolute_address = buffer.page_address(bank_index, page_index); - std::vector page; - page.insert( - page.end(), host_buffer.begin() + data_index, host_buffer.begin() + data_index + num_entries_per_page); - switch (buffer.buffer_type()) { - case BufferType::DRAM: { - auto dram_channel = buffer.dram_channel_from_bank_id(bank_index); - tt::Cluster::instance().write_dram_vec(page, tt_target_dram{device->id(), dram_channel, 0}, absolute_address); - } break; - case BufferType::L1: // fallthrough - case BufferType::L1_SMALL: { - auto noc_coordinates = buffer.noc_coordinates(bank_index); - llrt::write_hex_vec_to_core(device->id(), noc_coordinates, page, absolute_address); - } break; - default: TT_FATAL(false && "Unsupported buffer type to write to device!"); - } - - bank_index = (bank_index + 1) % num_banks; - data_index += num_entries_per_page; - } +void WriteToBuffer(const Buffer &buffer, const std::vector &host_buffer) { + switch (buffer.buffer_type()) { + case BufferType::DRAM: // fallthrough + case BufferType::L1: // fallthrough + case BufferType::L1_SMALL: { + WriteToDevice(buffer, host_buffer); + } break; + case BufferType::SYSTEM_MEMORY: { + TT_FATAL(false && "Writing to host memory is unsupported!"); + } break; + default: TT_FATAL(false && "Unsupported buffer type!"); } +} - void WriteToDevice(const Buffer &buffer, const std::vector &host_buffer) { - ZoneScoped; - if(buffer.buffer_layout() == TensorMemoryLayout::INTERLEAVED || buffer.buffer_layout() == TensorMemoryLayout::SINGLE_BANK){ - WriteToDeviceInterleavedContiguous(buffer, host_buffer); - } - else if(is_sharded(buffer.buffer_layout())){ - WriteToDeviceSharded(buffer, host_buffer); - } - else{ - TT_ASSERT(false && "Unsupported buffer layout"); - } - } +void ReadFromDeviceInterleavedContiguous(const Buffer &buffer, std::vector &host_buffer) { + host_buffer.clear(); // overwrite the data + uint32_t page_size = buffer.page_size(); + TT_FATAL(buffer.size() % page_size == 0); + uint32_t num_pages = buffer.size() / page_size; - void WriteToBuffer( std::shared_ptr buffer, const std::vector &host_buffer){ - WriteToBuffer ( *buffer, host_buffer ); - } + auto device = buffer.device(); + auto num_banks = device->num_banks(buffer.buffer_type()); - void WriteToBuffer(const Buffer &buffer, const std::vector &host_buffer) { + uint32_t bank_index = 0; + for (int page_index = 0; page_index < num_pages; page_index++) { + auto absolute_address = buffer.page_address(bank_index, page_index); + std::vector page; switch (buffer.buffer_type()) { - case BufferType::DRAM: // fallthrough - case BufferType::L1: // fallthrough - case BufferType::L1_SMALL: { - WriteToDevice(buffer, host_buffer); + case BufferType::DRAM: { + auto dram_channel = buffer.dram_channel_from_bank_id(bank_index); + tt::Cluster::instance().read_dram_vec( + page, page_size, tt_target_dram{device->id(), dram_channel, 0}, absolute_address); } break; - case BufferType::SYSTEM_MEMORY: { - TT_FATAL(false && "Writing to host memory is unsupported!"); + case BufferType::L1: // fallthrough + case BufferType::L1_SMALL: { + auto noc_coordinates = buffer.noc_coordinates(bank_index); + page = llrt::read_hex_vec_from_core(device->id(), noc_coordinates, absolute_address, page_size); } break; - default: TT_FATAL(false && "Unsupported buffer type!"); + default: TT_FATAL(false && "Unsupported buffer type to write to device!"); } - } - void ReadFromDeviceInterleavedContiguous(const Buffer &buffer, std::vector &host_buffer) { - - host_buffer.clear(); // overwrite the data - uint32_t page_size = buffer.page_size(); - TT_FATAL(buffer.size() % page_size == 0); - uint32_t num_pages = buffer.size() / page_size; - - auto device = buffer.device(); - auto num_banks = device->num_banks(buffer.buffer_type()); - - uint32_t bank_index = 0; - for (int page_index = 0; page_index < num_pages; page_index++) { - auto absolute_address = buffer.page_address(bank_index, page_index); - std::vector page; - switch (buffer.buffer_type()) { - case BufferType::DRAM: { - auto dram_channel = buffer.dram_channel_from_bank_id(bank_index); - tt::Cluster::instance().read_dram_vec(page, page_size, tt_target_dram{device->id(), dram_channel, 0}, absolute_address); - } break; - case BufferType::L1: // fallthrough - case BufferType::L1_SMALL: { - auto noc_coordinates = buffer.noc_coordinates(bank_index); - page = llrt::read_hex_vec_from_core(device->id(), noc_coordinates, absolute_address, page_size); - } break; - default: TT_FATAL(false && "Unsupported buffer type to write to device!"); - } - - // Copy page into host buffer - for (uint32_t entry : page) { - host_buffer.push_back(entry); - } - - bank_index = (bank_index + 1) % num_banks; + // Copy page into host buffer + for (uint32_t entry : page) { + host_buffer.push_back(entry); } + bank_index = (bank_index + 1) % num_banks; } +} - void read_pages_to_host_helper( - Device * device, - const Buffer & dev_buffer, - std::vector &host_buffer, - const uint32_t & page_size, - const uint32_t & host_page_id, - const uint32_t & dev_page_id, - const uint32_t & bank_id - ){ - auto absolute_address = dev_buffer.sharded_page_address(bank_id, dev_page_id); - auto noc_coordinates = dev_buffer.noc_coordinates(bank_id); - - uint32_t num_entries_per_page = page_size/sizeof(uint32_t); - auto page = llrt::read_hex_vec_from_core(device->id(), noc_coordinates, absolute_address, page_size); - uint32_t host_buffer_start = host_page_id * num_entries_per_page; - uint32_t dev_page_index = 0; - for(uint32_t host_buffer_index = host_buffer_start; host_buffer_index < host_buffer_start + num_entries_per_page; host_buffer_index++){ - host_buffer[host_buffer_index] = page[dev_page_index]; - dev_page_index++; - } +void read_pages_to_host_helper( + Device *device, + const Buffer &dev_buffer, + std::vector &host_buffer, + const uint32_t &page_size, + const uint32_t &host_page_id, + const uint32_t &dev_page_id, + const uint32_t &bank_id) { + auto absolute_address = dev_buffer.sharded_page_address(bank_id, dev_page_id); + auto noc_coordinates = dev_buffer.noc_coordinates(bank_id); + uint32_t num_entries_per_page = page_size / sizeof(uint32_t); + auto page = llrt::read_hex_vec_from_core(device->id(), noc_coordinates, absolute_address, page_size); + uint32_t host_buffer_start = host_page_id * num_entries_per_page; + uint32_t dev_page_index = 0; + for (uint32_t host_buffer_index = host_buffer_start; host_buffer_index < host_buffer_start + num_entries_per_page; + host_buffer_index++) { + host_buffer[host_buffer_index] = page[dev_page_index]; + dev_page_index++; } +} - void ReadFromDeviceSharded(const Buffer &buffer, std::vector &host_buffer, bool shard_order){ - - TensorMemoryLayout buffer_layout = buffer.buffer_layout(); - - auto device = buffer.device(); - #ifdef DEBUG_PRINT_SHARD - std::cout << "Reading From Device Height Sharded " << std::endl; - #endif - - - int output_page_index = 0; - auto total_pages = buffer.num_dev_pages(); - uint32_t page_size = buffer.page_size(); - uint32_t bytes_per_page_entry = sizeof(uint32_t); - uint32_t num_entries_per_page = page_size / bytes_per_page_entry; +void ReadFromDeviceSharded(const Buffer &buffer, std::vector &host_buffer, bool shard_order) { + TensorMemoryLayout buffer_layout = buffer.buffer_layout(); - host_buffer = std::vector(total_pages * num_entries_per_page); + auto device = buffer.device(); +#ifdef DEBUG_PRINT_SHARD + std::cout << "Reading From Device Height Sharded " << std::endl; +#endif - auto buffer_page_mapping = generate_buffer_page_mapping(buffer); - for(int dev_page_id=0; dev_page_idbank_ids_from_logical_core(buffer.buffer_type(), core)[0]; - auto host_page_id = buffer_page_mapping.dev_page_to_host_page_mapping_[dev_page_id]; - if(host_page_id.has_value()) { - if(!shard_order){ - read_pages_to_host_helper( - device, - buffer, - host_buffer, - page_size, - host_page_id.value(), - dev_page_id, - bank_id - ); - } - else{ - read_pages_to_host_helper( - device, - buffer, - host_buffer, - page_size, - dev_page_id, - dev_page_id, - bank_id - ); - } + int output_page_index = 0; + auto total_pages = buffer.num_dev_pages(); + uint32_t page_size = buffer.page_size(); + uint32_t bytes_per_page_entry = sizeof(uint32_t); + uint32_t num_entries_per_page = page_size / bytes_per_page_entry; + + host_buffer = std::vector(total_pages * num_entries_per_page); + + auto buffer_page_mapping = generate_buffer_page_mapping(buffer); + for (int dev_page_id = 0; dev_page_id < total_pages; dev_page_id++) { + auto core = buffer_page_mapping.all_cores_[buffer_page_mapping.dev_page_to_core_mapping_[dev_page_id]]; + auto bank_id = device->bank_ids_from_logical_core(buffer.buffer_type(), core)[0]; + auto host_page_id = buffer_page_mapping.dev_page_to_host_page_mapping_[dev_page_id]; + if (host_page_id.has_value()) { + if (!shard_order) { + read_pages_to_host_helper( + device, buffer, host_buffer, page_size, host_page_id.value(), dev_page_id, bank_id); + } else { + read_pages_to_host_helper(device, buffer, host_buffer, page_size, dev_page_id, dev_page_id, bank_id); } } - } +} - - void ReadFromDevice(const Buffer &buffer, std::vector &host_buffer, bool shard_order) { - ZoneScoped; - host_buffer.clear(); // overwrite the data - if(buffer.buffer_layout() == TensorMemoryLayout::INTERLEAVED - || buffer.buffer_layout() == TensorMemoryLayout::SINGLE_BANK){ - ReadFromDeviceInterleavedContiguous(buffer, host_buffer); - } - else if(is_sharded(buffer.buffer_layout())){ - TT_ASSERT(buffer.is_l1(), "Only L1 Buffers support sharding"); - ReadFromDeviceSharded(buffer, host_buffer, shard_order); - } - else{ - TT_ASSERT(false && "Unsupported buffer layout"); - } +void ReadFromDevice(const Buffer &buffer, std::vector &host_buffer, bool shard_order) { + ZoneScoped; + host_buffer.clear(); // overwrite the data + if (buffer.buffer_layout() == TensorMemoryLayout::INTERLEAVED || + buffer.buffer_layout() == TensorMemoryLayout::SINGLE_BANK) { + ReadFromDeviceInterleavedContiguous(buffer, host_buffer); + } else if (is_sharded(buffer.buffer_layout())) { + TT_ASSERT(buffer.is_l1(), "Only L1 Buffers support sharding"); + ReadFromDeviceSharded(buffer, host_buffer, shard_order); + } else { + TT_ASSERT(false && "Unsupported buffer layout"); } +} - void ReadFromBuffer(std::shared_ptr buffer, std::vector &host_buffer, bool shard_order) - { - ReadFromBuffer(*buffer, host_buffer, shard_order); - } +void ReadFromBuffer(std::shared_ptr buffer, std::vector &host_buffer, bool shard_order) { + ReadFromBuffer(*buffer, host_buffer, shard_order); +} - void ReadFromBuffer(const Buffer &buffer, std::vector &host_buffer, bool shard_order) { - Device *device = buffer.device(); - switch (buffer.buffer_type()) { - case BufferType::DRAM: - case BufferType::L1: // fallthrough - case BufferType::L1_SMALL: { - if (buffer.buffer_type() == BufferType::DRAM) { - tt::Cluster::instance().dram_barrier(device->id()); - } else { - tt::Cluster::instance().l1_barrier(device->id()); - } - ReadFromDevice(buffer, host_buffer, shard_order); - } break; - case BufferType::SYSTEM_MEMORY: { - TT_FATAL(false && "Reading from host memory is unsupported!"); - } break; - default: TT_FATAL(false && "Unsupported buffer type!"); - } +void ReadFromBuffer(const Buffer &buffer, std::vector &host_buffer, bool shard_order) { + Device *device = buffer.device(); + switch (buffer.buffer_type()) { + case BufferType::DRAM: + case BufferType::L1: // fallthrough + case BufferType::L1_SMALL: { + if (buffer.buffer_type() == BufferType::DRAM) { + tt::Cluster::instance().dram_barrier(device->id()); + } else { + tt::Cluster::instance().l1_barrier(device->id()); + } + ReadFromDevice(buffer, host_buffer, shard_order); + } break; + case BufferType::SYSTEM_MEMORY: { + TT_FATAL(false && "Reading from host memory is unsupported!"); + } break; + default: TT_FATAL(false && "Unsupported buffer type!"); } +} - void ReadShard(const Buffer &buffer, std::vector &host_buffer, const uint32_t & core_id) { - - Device *device = buffer.device(); - TT_ASSERT(is_sharded(buffer.buffer_layout())); - host_buffer.clear(); // overwrite the data - - uint32_t num_entries_per_page = buffer.page_size() / sizeof(uint32_t); - uint32_t num_entries_per_shard = num_entries_per_page * buffer.shard_spec().size(); - host_buffer = std::vector(num_entries_per_shard); - - - std::vector page_ids; - auto buffer_page_mapping = generate_buffer_page_mapping(buffer); - for(uint32_t i=0; i &host_buffer, const uint32_t &core_id) { + Device *device = buffer.device(); + TT_ASSERT(is_sharded(buffer.buffer_layout())); + host_buffer.clear(); // overwrite the data - uint32_t host_page_id = 0; - for(auto dev_page_id: page_ids){ - auto core = buffer_page_mapping.all_cores_[buffer_page_mapping.dev_page_to_core_mapping_[dev_page_id]]; - auto bank_id = device->bank_ids_from_logical_core(buffer.buffer_type(), core)[0]; - read_pages_to_host_helper( - device, - buffer, - host_buffer, - buffer.page_size(), - host_page_id, - dev_page_id, - bank_id - ); - host_page_id++; + uint32_t num_entries_per_page = buffer.page_size() / sizeof(uint32_t); + uint32_t num_entries_per_shard = num_entries_per_page * buffer.shard_spec().size(); + host_buffer = std::vector(num_entries_per_shard); + std::vector page_ids; + auto buffer_page_mapping = generate_buffer_page_mapping(buffer); + for (uint32_t i = 0; i < buffer_page_mapping.dev_page_to_core_mapping_.size(); i++) { + if (buffer_page_mapping.dev_page_to_core_mapping_[i] == core_id) { + page_ids.push_back(i); } - - } - void LaunchProgram(Device *device, std::shared_ptr program, bool wait_until_cores_done){ - LaunchProgram(device, *program, wait_until_cores_done); + uint32_t host_page_id = 0; + for (auto dev_page_id : page_ids) { + auto core = buffer_page_mapping.all_cores_[buffer_page_mapping.dev_page_to_core_mapping_[dev_page_id]]; + auto bank_id = device->bank_ids_from_logical_core(buffer.buffer_type(), core)[0]; + read_pages_to_host_helper(device, buffer, host_buffer, buffer.page_size(), host_page_id, dev_page_id, bank_id); + host_page_id++; } +} + +void LaunchProgram(Device *device, std::shared_ptr program, bool wait_until_cores_done) { + LaunchProgram(device, *program, wait_until_cores_done); +} - void LaunchProgram(Device *device, Program &program, bool wait_until_cores_done) { - {//Profiler scope start +void LaunchProgram(Device *device, Program &program, bool wait_until_cores_done) { + { // Profiler scope start ZoneScoped; - detail::DispatchStateCheck( false ); + detail::DispatchStateCheck(false); detail::CompileProgram(device, program); detail::WriteRuntimeArgsToDevice(device, program); detail::ConfigureDeviceWithProgram(device, program); @@ -531,188 +524,197 @@ void CloseDevices(std::map devices) { // Wait for all cores to be done llrt::internal_::wait_until_cores_done(device_id, RUN_MSG_GO, not_done_cores); } - }//Profiler scope end - if (wait_until_cores_done) { - DumpDeviceProfileResults(device, program); - } + } // Profiler scope end + if (wait_until_cores_done) { + DumpDeviceProfileResults(device, program); } +} - void WaitProgramDone(Device *device, Program &program) { - auto device_id = device->id(); - std::unordered_map> logical_cores_used_in_program = program.logical_cores(); - std::unordered_set not_done_cores; - for (const auto &[core_type, logical_cores] : logical_cores_used_in_program) { - for (const auto &logical_core : logical_cores) { - auto physical_core = device->physical_core_from_logical_core(logical_core, core_type); - not_done_cores.insert(physical_core); - } +void WaitProgramDone(Device *device, Program &program) { + auto device_id = device->id(); + std::unordered_map> logical_cores_used_in_program = program.logical_cores(); + std::unordered_set not_done_cores; + for (const auto &[core_type, logical_cores] : logical_cores_used_in_program) { + for (const auto &logical_core : logical_cores) { + auto physical_core = device->physical_core_from_logical_core(logical_core, core_type); + not_done_cores.insert(physical_core); } - // Wait for all cores to be done - llrt::internal_::wait_until_cores_done(device_id, RUN_MSG_GO, not_done_cores); - DumpDeviceProfileResults(device, program); - } + // Wait for all cores to be done + llrt::internal_::wait_until_cores_done(device_id, RUN_MSG_GO, not_done_cores); + DumpDeviceProfileResults(device, program); +} - bool ConfigureDeviceWithProgram(Device *device, Program &program, bool fd_bootloader_mode) { - ZoneScoped; - bool pass = true; - // This is function is shared between FD and SD. - // We call this function when initializing HW Command Queues (tracked as fd_bootloader_mode) for Fast Dispatch. - // Used to Launch programs for Slow dispatch. - bool using_fast_dispatch = fd_bootloader_mode; - detail::DispatchStateCheck( using_fast_dispatch ); - - auto device_id = device->id(); - - program.allocate_circular_buffers(); - detail::ValidateCircularBufferRegion(program, device); +bool ConfigureDeviceWithProgram(Device *device, Program &program, bool fd_bootloader_mode) { + ZoneScoped; + bool pass = true; + // This is function is shared between FD and SD. + // We call this function when initializing HW Command Queues (tracked as fd_bootloader_mode) for Fast Dispatch. + // Used to Launch programs for Slow dispatch. + bool using_fast_dispatch = fd_bootloader_mode; + detail::DispatchStateCheck(using_fast_dispatch); - std::unordered_map> logical_cores_used_in_program = program.logical_cores(); - for (const auto &[core_type, logical_cores] : logical_cores_used_in_program) { - for (const auto &logical_core : logical_cores) { - KernelGroup *kernel_group = program.kernels_on_core(logical_core, core_type); - CoreCoord physical_core = device->physical_core_from_logical_core(logical_core, core_type); - - ConfigureKernelGroup( - program, kernel_group, device, logical_core); - // TODO: add support for CB for ethernet cores - if (core_type == CoreType::WORKER) { - // CircularBufferConfigVec -- common across all kernels, so written once to the core - llrt::CircularBufferConfigVec circular_buffer_config_vec = - llrt::create_circular_buffer_config_vector(); - - auto cbs_on_core = program.circular_buffers_on_core(logical_core); - for (auto circular_buffer : cbs_on_core) { - for (uint32_t buffer_index : circular_buffer->buffer_indices()) { - llrt::set_config_for_circular_buffer( - circular_buffer_config_vec, - buffer_index, - circular_buffer->address(), - circular_buffer->size(), - circular_buffer->num_pages(buffer_index)); - } - } // PROF_END("CBS") + auto device_id = device->id(); - if (cbs_on_core.size()) { - llrt::write_circular_buffer_config_vector_to_core( - device_id, - physical_core, - circular_buffer_config_vec); + program.allocate_circular_buffers(); + detail::ValidateCircularBufferRegion(program, device); + + std::unordered_map> logical_cores_used_in_program = program.logical_cores(); + for (const auto &[core_type, logical_cores] : logical_cores_used_in_program) { + for (const auto &logical_core : logical_cores) { + KernelGroup *kernel_group = program.kernels_on_core(logical_core, core_type); + CoreCoord physical_core = device->physical_core_from_logical_core(logical_core, core_type); + + ConfigureKernelGroup(program, kernel_group, device, logical_core); + // TODO: add support for CB for ethernet cores + if (core_type == CoreType::WORKER) { + // CircularBufferConfigVec -- common across all kernels, so written once to the core + llrt::CircularBufferConfigVec circular_buffer_config_vec = llrt::create_circular_buffer_config_vector(); + + auto cbs_on_core = program.circular_buffers_on_core(logical_core); + for (auto circular_buffer : cbs_on_core) { + for (uint32_t buffer_index : circular_buffer->buffer_indices()) { + llrt::set_config_for_circular_buffer( + circular_buffer_config_vec, + buffer_index, + circular_buffer->address(), + circular_buffer->size(), + circular_buffer->num_pages(buffer_index)); } + } // PROF_END("CBS") + if (cbs_on_core.size()) { + llrt::write_circular_buffer_config_vector_to_core( + device_id, physical_core, circular_buffer_config_vec); } - program.init_semaphores(*device, logical_core, core_type); } + program.init_semaphores(*device, logical_core, core_type); } - - return pass; } + return pass; +} - // Return base address in L1 for Runtime Args given processor type (and eth mode in case of ERISC). - uint32_t GetL1ArgBaseAddr(std::shared_ptr kernel) { - - const RISCV &riscv = kernel->processor(); - uint32_t l1_arg_base = 0; +// Return base address in L1 for Runtime Args given processor type (and eth mode in case of ERISC). +uint32_t GetL1ArgBaseAddr(std::shared_ptr kernel) { + const RISCV &riscv = kernel->processor(); + uint32_t l1_arg_base = 0; - switch (riscv) { - case RISCV::BRISC: { - l1_arg_base = BRISC_L1_ARG_BASE; - } break; - case RISCV::NCRISC: { - l1_arg_base = NCRISC_L1_ARG_BASE; - } break; - case RISCV::ERISC: { - auto config = std::get(kernel->config()); - if (config.eth_mode == Eth::IDLE) { - l1_arg_base = IDLE_ERISC_L1_ARG_BASE; - } else { - l1_arg_base = eth_l1_mem::address_map::ERISC_L1_ARG_BASE; - } - } break; - case RISCV::COMPUTE: { - l1_arg_base = TRISC_L1_ARG_BASE; + switch (riscv) { + case RISCV::BRISC: { + l1_arg_base = BRISC_L1_ARG_BASE; + } break; + case RISCV::NCRISC: { + l1_arg_base = NCRISC_L1_ARG_BASE; + } break; + case RISCV::ERISC: { + auto config = std::get(kernel->config()); + if (config.eth_mode == Eth::IDLE) { + l1_arg_base = IDLE_ERISC_L1_ARG_BASE; + } else { + l1_arg_base = eth_l1_mem::address_map::ERISC_L1_ARG_BASE; } - break; - default: TT_THROW("Unsupported {} processor does not support runtime args", riscv); - } - return l1_arg_base; + } break; + case RISCV::COMPUTE: { + l1_arg_base = TRISC_L1_ARG_BASE; + } break; + default: TT_THROW("Unsupported {} processor does not support runtime args", riscv); } + return l1_arg_base; +} - void WriteRuntimeArgsToDevice(Device *device, const Program &program) { - ZoneScoped; - auto device_id = device->id(); - detail::DispatchStateCheck( false ); - - for (size_t kernel_id = 0; kernel_id < program.num_kernels(); kernel_id++) { - const auto kernel = detail::GetKernel(program, kernel_id); - auto args_base_addr = detail::GetL1ArgBaseAddr(kernel); - - for (const auto &logical_core : kernel->cores_with_runtime_args()) { - auto physical_core = device->physical_core_from_logical_core(logical_core, kernel->get_kernel_core_type()); - const auto & rt_args = kernel->runtime_args(logical_core); - log_trace(tt::LogMetal, "{} - Writing {} unique rtargs to core {} (physical: {}) addr 0x{:x} => args: {}", - __FUNCTION__, rt_args.size(), logical_core.str(), physical_core.str(), args_base_addr, rt_args); - tt::llrt::write_hex_vec_to_core(device_id, physical_core, rt_args, args_base_addr); - } - - // Unicast common runtime args to all cores for kernel. Fast-Dispatch will multicast as perf opt. - const auto &common_rt_args = kernel->common_runtime_args(); - auto common_rt_args_offset = kernel->get_common_runtime_args_offset(); +void WriteRuntimeArgsToDevice(Device *device, const Program &program) { + ZoneScoped; + auto device_id = device->id(); + detail::DispatchStateCheck(false); + + for (size_t kernel_id = 0; kernel_id < program.num_kernels(); kernel_id++) { + const auto kernel = detail::GetKernel(program, kernel_id); + auto args_base_addr = detail::GetL1ArgBaseAddr(kernel); + + for (const auto &logical_core : kernel->cores_with_runtime_args()) { + auto physical_core = device->physical_core_from_logical_core(logical_core, kernel->get_kernel_core_type()); + const auto &rt_args = kernel->runtime_args(logical_core); + log_trace( + tt::LogMetal, + "{} - Writing {} unique rtargs to core {} (physical: {}) addr 0x{:x} => args: {}", + __FUNCTION__, + rt_args.size(), + logical_core.str(), + physical_core.str(), + args_base_addr, + rt_args); + tt::llrt::write_hex_vec_to_core(device_id, physical_core, rt_args, args_base_addr); + } - if (common_rt_args.size() > 0) { - for (auto &core_range : kernel->logical_coreranges()) { - for (auto x = core_range.start.x; x <= core_range.end.x; x++) { - for (auto y = core_range.start.y; y <= core_range.end.y; y++) { - CoreCoord logical_core({x, y}); - auto physical_core = device->physical_core_from_logical_core(logical_core, kernel->get_kernel_core_type()); - const auto common_args_addr = args_base_addr + common_rt_args_offset; // Common args are placed after unique args per core. - log_trace(tt::LogMetal, "{} - Writing {} common rtargs to core {} (physical: {}) addr 0x{:x} => args: {}", - __FUNCTION__, common_rt_args.size(), logical_core.str(), physical_core.str(), common_args_addr, common_rt_args); - tt::llrt::write_hex_vec_to_core(device_id, physical_core, common_rt_args, common_args_addr); - } + // Unicast common runtime args to all cores for kernel. Fast-Dispatch will multicast as perf opt. + const auto &common_rt_args = kernel->common_runtime_args(); + auto common_rt_args_offset = kernel->get_common_runtime_args_offset(); + + if (common_rt_args.size() > 0) { + for (auto &core_range : kernel->logical_coreranges()) { + for (auto x = core_range.start.x; x <= core_range.end.x; x++) { + for (auto y = core_range.start.y; y <= core_range.end.y; y++) { + CoreCoord logical_core({x, y}); + auto physical_core = + device->physical_core_from_logical_core(logical_core, kernel->get_kernel_core_type()); + const auto common_args_addr = + args_base_addr + + common_rt_args_offset; // Common args are placed after unique args per core. + log_trace( + tt::LogMetal, + "{} - Writing {} common rtargs to core {} (physical: {}) addr 0x{:x} => args: {}", + __FUNCTION__, + common_rt_args.size(), + logical_core.str(), + physical_core.str(), + common_args_addr, + common_rt_args); + tt::llrt::write_hex_vec_to_core(device_id, physical_core, common_rt_args, common_args_addr); } } } } } +} - void CompileProgram(Device *device, Program &program){ - ZoneScoped; - program.compile(device); - } +void CompileProgram(Device *device, Program &program) { + ZoneScoped; + program.compile(device); +} - void AllocateBuffer(Buffer* buffer, bool bottom_up) { - detail::DispatchStateCheck(not buffer->device()->using_slow_dispatch()); - EnqueueAllocateBuffer(buffer->device()->command_queue(), buffer, bottom_up, false); - } +void AllocateBuffer(Buffer *buffer, bool bottom_up) { + detail::DispatchStateCheck(not buffer->device()->using_slow_dispatch()); + EnqueueAllocateBuffer(buffer->device()->command_queue(), buffer, bottom_up, false); +} - void DeallocateBuffer(Buffer *buffer) { - detail::DispatchStateCheck(not buffer->device()->using_slow_dispatch()); - EnqueueDeallocateBuffer(buffer->device()->command_queue(), *(buffer->device()->allocator_), buffer->address(), buffer->buffer_type(), false); - } +void DeallocateBuffer(Buffer *buffer) { + detail::DispatchStateCheck(not buffer->device()->using_slow_dispatch()); + EnqueueDeallocateBuffer( + buffer->device()->command_queue(), + *(buffer->device()->allocator_), + buffer->address(), + buffer->buffer_type(), + false); +} - void GetBufferAddress(const Buffer* buffer, uint32_t* address_on_host) { - detail::DispatchStateCheck(not buffer->device()->using_slow_dispatch()); - EnqueueGetBufferAddr(buffer->device()->command_queue(), address_on_host, buffer, false); - } +void GetBufferAddress(const Buffer *buffer, uint32_t *address_on_host) { + detail::DispatchStateCheck(not buffer->device()->using_slow_dispatch()); + EnqueueGetBufferAddr(buffer->device()->command_queue(), address_on_host, buffer, false); +} - Device *GetDeviceHandle(chip_id_t device_id) { - ZoneScoped; - TT_ASSERT(device_id < device_pool::devices.size()); - TT_ASSERT(device_pool::devices[device_id] != nullptr); - return device_pool::devices[device_id]; - } +Device *GetDeviceHandle(chip_id_t device_id) { + ZoneScoped; + TT_ASSERT(device_id < device_pool::devices.size()); + TT_ASSERT(device_pool::devices[device_id] != nullptr); + return device_pool::devices[device_id]; +} - void DisableAllocs(Device *device) { - tt::tt_metal::allocator::disable_allocs(*(device->allocator_)); - } +void DisableAllocs(Device *device) { tt::tt_metal::allocator::disable_allocs(*(device->allocator_)); } - void EnableAllocs(Device *device) { - tt::tt_metal::allocator::enable_allocs(*(device->allocator_)); - } +void EnableAllocs(Device *device) { tt::tt_metal::allocator::enable_allocs(*(device->allocator_)); } -} // namespace detail +} // namespace detail size_t GetNumAvailableDevices() { #ifdef TT_METAL_VERSIM_DISABLED @@ -742,8 +744,7 @@ Device *CreateDevice( return dev; } -Device *CreateDeviceMinimal( - chip_id_t device_id) { +Device *CreateDeviceMinimal(chip_id_t device_id) { ZoneScoped; Device *dev = new Device(device_id, 1, DEFAULT_L1_SMALL_SIZE, {}, true); tt::Cluster::instance().set_internal_routing_info_for_ethernet_cores(true); @@ -761,38 +762,37 @@ bool CloseDevice(Device *device) { return device->close(); } -Program CreateProgram(){ - return Program(); -} +Program CreateProgram() { return Program(); } KernelHandle CreateKernel( Program &program, const std::string &file_name, const std::variant &core_spec, const std::variant &config) { - return std::visit( [&](auto&& cfg) -> KernelHandle - { - CoreRangeSet core_ranges = detail::GetCoreRangeSet(core_spec); - std::shared_ptr kernel; - using T = std::decay_t; - if constexpr (std::is_same_v) { - detail::CheckDataMovementConfig(program, file_name, core_ranges); - kernel = std::make_shared(file_name, core_ranges, cfg); - return detail::AddKernel(program, kernel, CoreType::WORKER); - } - else if constexpr (std::is_same_v) { - kernel = std::make_shared(file_name, core_ranges, cfg); - return detail::AddKernel(program, kernel, CoreType::WORKER); - } else if constexpr (std::is_same_v) { - kernel = std::make_shared(file_name, core_ranges, cfg); - return detail::AddKernel(program, kernel, CoreType::ETH); - } - }, - config - ); -} - -CBHandle CreateCircularBuffer(Program &program, const std::variant &core_spec, const CircularBufferConfig &config) { + return std::visit( + [&](auto &&cfg) -> KernelHandle { + CoreRangeSet core_ranges = detail::GetCoreRangeSet(core_spec); + std::shared_ptr kernel; + using T = std::decay_t; + if constexpr (std::is_same_v) { + detail::CheckDataMovementConfig(program, file_name, core_ranges); + kernel = std::make_shared(file_name, core_ranges, cfg); + return detail::AddKernel(program, kernel, CoreType::WORKER); + } else if constexpr (std::is_same_v) { + kernel = std::make_shared(file_name, core_ranges, cfg); + return detail::AddKernel(program, kernel, CoreType::WORKER); + } else if constexpr (std::is_same_v) { + kernel = std::make_shared(file_name, core_ranges, cfg); + return detail::AddKernel(program, kernel, CoreType::ETH); + } + }, + config); +} + +CBHandle CreateCircularBuffer( + Program &program, + const std::variant &core_spec, + const CircularBufferConfig &config) { CoreRangeSet core_ranges = detail::GetCoreRangeSet(core_spec); return program.add_circular_buffer(core_ranges, config); } @@ -822,20 +822,23 @@ void UpdateDynamicCircularBufferAddress(Program &program, CBHandle cb_handle, co circular_buffer->assign_global_address(); } -uint32_t CreateSemaphore(Program &program, const std::variant &core_spec, uint32_t initial_value, CoreType core_type) { +uint32_t CreateSemaphore( + Program &program, + const std::variant &core_spec, + uint32_t initial_value, + CoreType core_type) { return std::visit( - [&](auto&& c) -> uint32_t - { + [&](auto &&c) -> uint32_t { using T = std::decay_t; CoreRangeSet crs({}); if constexpr (std::is_same_v) { crs = CoreRangeSet({c}); - } else{ + } else { crs = c; } std::optional address; TT_FATAL(crs.ranges().size() > 0, "Expecting a non-empty CoreRangeSet!"); - for (const auto& core_range : crs.ranges()) { + for (const auto &core_range : crs.ranges()) { CoreCoord start_core = core_range.start; CoreCoord end_core = core_range.end; TT_FATAL(start_core == end_core or start_core < end_core && "Invalid core range!"); @@ -872,73 +875,96 @@ std::shared_ptr CreateBuffer(const ShardedBufferConfig &config) { void DeallocateBuffer(Buffer &buffer) { buffer.deallocate(); } -void AssignGlobalBufferToProgram(std::shared_ptr buffer, std::variant, std::shared_ptr> program) { +void AssignGlobalBufferToProgram( + std::shared_ptr buffer, std::variant, std::shared_ptr> program) { detail::DispatchStateCheck(not buffer->device()->using_slow_dispatch()); - EnqueueAddBufferToProgram(buffer-> device()->command_queue(), buffer, program, false); + EnqueueAddBufferToProgram(buffer->device()->command_queue(), buffer, program, false); } -void SetRuntimeArgs(const Program &program, KernelHandle kernel_id, const std::variant &core_spec, const std::vector &runtime_args) { +void SetRuntimeArgs( + const Program &program, + KernelHandle kernel_id, + const std::variant &core_spec, + const std::vector &runtime_args) { ZoneScoped; - TT_FATAL( not CommandQueue::async_mode_set(), "This variant of SetRuntimeArgs can only be called when Asynchronous SW Command Queues are disabled for Fast Dispatch."); - std::visit( - [&](auto&& core_spec) - { - using T = std::decay_t; - if constexpr (std::is_same_v || std::is_same_v ) { - SetRuntimeArgs(program, kernel_id, core_spec, runtime_args); - } - else if constexpr (std::is_same_v) { - for (const auto& core_range : core_spec.ranges()) { - SetRuntimeArgs(program, kernel_id, core_range, runtime_args); - } - } - }, - core_spec - ); -} - -void SetRuntimeArgs(const Program &program, KernelHandle kernel, const std::vector< CoreCoord > & core_spec, const std::vector< std::vector > &runtime_args) -{ + TT_FATAL( + not CommandQueue::async_mode_set(), + "This variant of SetRuntimeArgs can only be called when Asynchronous SW Command Queues are disabled for Fast " + "Dispatch."); + std::visit([&](auto &&core_spec) { SetRuntimeArgs(program, kernel_id, core_spec, runtime_args); }, core_spec); +} + +void SetRuntimeArgs( + const Program &program, + KernelHandle kernel, + const std::vector &core_spec, + const std::vector> &runtime_args) { ZoneScoped; - TT_FATAL( not CommandQueue::async_mode_set(), "This variant of SetRuntimeArgs can only be called when Asynchronous SW Command Queues are disabled for Fast Dispatch."); - TT_FATAL( core_spec.size() == runtime_args.size(), "Mistmatch between number of cores {} and number of runtime args {} getting updated", core_spec.size(), runtime_args.size()); + TT_FATAL( + not CommandQueue::async_mode_set(), + "This variant of SetRuntimeArgs can only be called when Asynchronous SW Command Queues are disabled for Fast " + "Dispatch."); + TT_FATAL( + core_spec.size() == runtime_args.size(), + "Mistmatch between number of cores {} and number of runtime args {} getting updated", + core_spec.size(), + runtime_args.size()); auto k = detail::GetKernel(program, kernel); - for (size_t i = 0; i < core_spec.size(); i++) - k->set_runtime_args(core_spec[i], runtime_args[i]); + for (size_t i = 0; i < core_spec.size(); i++) k->set_runtime_args(core_spec[i], runtime_args[i]); } -void SetRuntimeArgs(Device* device, const std::shared_ptr kernel, const std::variant &core_spec, std::shared_ptr runtime_args) { +void SetRuntimeArgs( + Device *device, + const std::shared_ptr kernel, + const std::variant &core_spec, + std::shared_ptr runtime_args) { detail::DispatchStateCheck(not device->using_slow_dispatch()); SetRuntimeArgs(device->command_queue(), kernel, core_spec, runtime_args, false); } -void SetRuntimeArgs(Device* device, const std::shared_ptr kernel, const std::vector< CoreCoord > & core_spec, const std::vector> runtime_args) { - TT_FATAL( core_spec.size() == runtime_args.size(), "Mismatch between number of cores {} and number of runtime args {} getting updated", core_spec.size(), runtime_args.size()); +void SetRuntimeArgs( + Device *device, + const std::shared_ptr kernel, + const std::vector &core_spec, + const std::vector> runtime_args) { + TT_FATAL( + core_spec.size() == runtime_args.size(), + "Mismatch between number of cores {} and number of runtime args {} getting updated", + core_spec.size(), + runtime_args.size()); detail::DispatchStateCheck(not device->using_slow_dispatch()); SetRuntimeArgs(device->command_queue(), kernel, core_spec, runtime_args, false); } - void SetCommonRuntimeArgs(const Program &program, KernelHandle kernel_id, const std::vector &runtime_args) { ZoneScoped; - TT_FATAL( not CommandQueue::async_mode_set(), "This variant of SetCommonRuntimeArgs can only be called when Asynchronous SW Command Queues are disabled for Fast Dispatch."); + TT_FATAL( + not CommandQueue::async_mode_set(), + "This variant of SetCommonRuntimeArgs can only be called when Asynchronous SW Command Queues are disabled for " + "Fast Dispatch."); if (runtime_args.size() != 0) { detail::GetKernel(program, kernel_id)->set_common_runtime_args(runtime_args); } } -RuntimeArgsData & GetRuntimeArgs(const Program &program, KernelHandle kernel_id, const CoreCoord &logical_core) { - TT_FATAL( not CommandQueue::async_mode_set(), "GetRuntimeArgs can only be called when Asynchronous SW Command Queues are disabled for Fast Dispatch."); +RuntimeArgsData &GetRuntimeArgs(const Program &program, KernelHandle kernel_id, const CoreCoord &logical_core) { + TT_FATAL( + not CommandQueue::async_mode_set(), + "GetRuntimeArgs can only be called when Asynchronous SW Command Queues are disabled for Fast Dispatch."); return detail::GetKernel(program, kernel_id)->runtime_args_data(logical_core); } -std::vector< std::vector< RuntimeArgsData> >& GetRuntimeArgs(const Program &program, KernelHandle kernel_id) { - TT_FATAL( not CommandQueue::async_mode_set(), "GetRuntimeArgs can only be called when Asynchronous SW Command Queues are disabled for Fast Dispatch."); +std::vector> &GetRuntimeArgs(const Program &program, KernelHandle kernel_id) { + TT_FATAL( + not CommandQueue::async_mode_set(), + "GetRuntimeArgs can only be called when Asynchronous SW Command Queues are disabled for Fast Dispatch."); return detail::GetKernel(program, kernel_id)->runtime_args_data(); } -RuntimeArgsData & GetCommonRuntimeArgs(const Program &program, KernelHandle kernel_id) { - TT_FATAL( not CommandQueue::async_mode_set(), "GetRuntimeArgs can only be called when Asynchronous SW Command Queues are disabled for Fast Dispatch."); +RuntimeArgsData &GetCommonRuntimeArgs(const Program &program, KernelHandle kernel_id) { + TT_FATAL( + not CommandQueue::async_mode_set(), + "GetRuntimeArgs can only be called when Asynchronous SW Command Queues are disabled for Fast Dispatch."); return detail::GetKernel(program, kernel_id)->common_runtime_args_data(); } @@ -948,17 +974,13 @@ uint32_t BeginTraceCapture(Device *device, const uint8_t cq_id, const uint32_t t return tid; } -void EndTraceCapture(Device *device, const uint8_t cq_id, const uint32_t tid) { - device->end_trace(cq_id, tid); -} +void EndTraceCapture(Device *device, const uint8_t cq_id, const uint32_t tid) { device->end_trace(cq_id, tid); } void ReplayTrace(Device *device, const uint8_t cq_id, const uint32_t tid, const bool blocking) { device->replay_trace(cq_id, tid, blocking); } -void ReleaseTrace(Device *device, const uint32_t tid) { - device->release_trace(tid); -} +void ReleaseTrace(Device *device, const uint32_t tid) { device->release_trace(tid); } } // namespace tt_metal