diff --git a/ttnn/CMakeLists.txt b/ttnn/CMakeLists.txt index 6886f1b8a61..3de0178fac3 100644 --- a/ttnn/CMakeLists.txt +++ b/ttnn/CMakeLists.txt @@ -63,6 +63,7 @@ set(TTNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/reduction/topk/device/topk_op.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/embedding/device/embedding_device_operation.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/slice/device/slice_op.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/slice/device/slice_program_factory.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/concat/device/concat_device_operation.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/concat/device/concat_program_factory.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/concat/concat.cpp @@ -71,11 +72,15 @@ set(TTNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/normalization/softmax/device/softmax_op.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/normalization/softmax/device/multi_core/softmax_op_multi_core.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/pad/device/pad_op.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/pad/device/pad_program_factory.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/tilize/device/tilize_op.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/tilize/device/tilize_program_factory.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/tilize_with_val_padding/device/tilize_with_val_padding_op.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/tilize_with_val_padding/device/tilize_with_val_padding_program_factory.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/normalization/layernorm/device/layernorm_op.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/normalization/layernorm/device/multi_core/layernorm_op_multi_core.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/downsample/device/downsample_op.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/downsample/device/downsample_program_factory.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/normalization/groupnorm/device/groupnorm_op.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/normalization/groupnorm/device/multi_core/groupnorm_op_multi_core.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/experimental/transformer/device/transformer_device_operation.cpp diff --git a/ttnn/cpp/ttnn/operations/data_movement/downsample/device/downsample_op.cpp b/ttnn/cpp/ttnn/operations/data_movement/downsample/device/downsample_op.cpp index 1582000ab46..50b22c15332 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/downsample/device/downsample_op.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/downsample/device/downsample_op.cpp @@ -3,6 +3,7 @@ // SPDX-License-Identifier: Apache-2.0 #include "downsample_op.hpp" +#include "downsample_program_factory.hpp" #include @@ -15,11 +16,7 @@ using namespace tt::constants; -namespace ttnn { - -namespace operations { - -namespace data_movement { +namespace ttnn::operations::data_movement { void Downsample::validate(const std::vector& input_tensors) const { const auto& input_tensor_a = input_tensors.at(0); @@ -34,25 +31,7 @@ void Downsample::validate(const std::vector& input_tensors) const { input_tensor_a.memory_config().memory_layout == TensorMemoryLayout::BLOCK_SHARDED); } -std::pair get_num_cores_height_width_sliced( - CoreRangeSet all_cores, TensorMemoryLayout memory_layout, ShardOrientation shard_orientation) { - TT_ASSERT( - memory_layout == TensorMemoryLayout::HEIGHT_SHARDED || memory_layout == TensorMemoryLayout::BLOCK_SHARDED); - if (memory_layout == TensorMemoryLayout::HEIGHT_SHARDED) { - TT_ASSERT(shard_orientation == ShardOrientation::ROW_MAJOR); - } else { - TT_ASSERT(shard_orientation == ShardOrientation::COL_MAJOR); - TT_ASSERT(all_cores.ranges().size() == 1); - } - uint32_t num_cores = all_cores.num_cores(); - auto first_core_range = *all_cores.ranges().begin(); - uint32_t num_cores_height_sliced = - memory_layout == TensorMemoryLayout::HEIGHT_SHARDED ? num_cores : first_core_range.end_coord.x + 1; - uint32_t num_cores_width_sliced = memory_layout == TensorMemoryLayout::HEIGHT_SHARDED - ? 1 - : first_core_range.end_coord.y + 1; // width is not sliced when height sharded - return {num_cores_height_sliced, num_cores_width_sliced}; -} + std::vector Downsample::compute_output_shapes(const std::vector& input_tensors) const { const auto& input_tensor_a = input_tensors.at(0); @@ -74,7 +53,7 @@ std::vector Downsample::compute_output_shapes(const std::ve std::vector Downsample::create_output_tensors(const std::vector& input_tensors) const { const auto& input_tensor = input_tensors.at(0); auto output_shape = this->compute_output_shapes(input_tensors).at(0); - auto [num_cores_height_sliced, num_cores_width_sliced] = get_num_cores_height_width_sliced( + auto [num_cores_height_sliced, num_cores_width_sliced] = detail::get_num_cores_height_width_sliced( input_tensor.shard_spec().value().grid, input_tensor.memory_config().memory_layout, input_tensor.shard_spec().value().orientation); @@ -94,7 +73,7 @@ operation::ProgramWithCallbacks Downsample::create_program( const std::vector& input_tensors, std::vector& output_tensors) const { const auto& input_tensor_a = input_tensors.at(0); auto& output_tensor = output_tensors.at(0); - return {downsample_single_core(input_tensor_a, downsample_params, output_tensor)}; + return {detail::downsample_single_core(input_tensor_a, downsample_params, output_tensor)}; } Tensor downsample( @@ -104,840 +83,11 @@ Tensor downsample( return output_tensors.at(0); } -struct DownsampleReadPatternParams { - uint32_t top_partial_middle_aligned_row_width; - uint32_t skip_top_partial_middle_aligned_row; - uint32_t top_partial_right_aligned_row_width; - uint32_t skip_top_partial_right_aligned_row; - uint32_t num_rows_top_partial_image; - uint32_t num_skip_rows_top_partial_image; - uint32_t num_full_images; - uint32_t num_rows_bottom_partial_image; - uint32_t num_skip_rows_bottom_partial_image; - uint32_t bottom_partial_left_aligned_row_width; - uint32_t skip_bottom_partial_left_aligned_row; -}; - -struct ImgTrackingVars { - uint32_t img_h = 0; - uint32_t img_w = 0; - uint32_t next_img_h = 0; // img_h after stride - uint32_t next_img_w = 0; - uint32_t input_flat_h = 0; // index within sharded input - uint32_t output_flat_h = 0; // index within sharded output -}; - -DownsampleReadPatternParams generate_downsample_read_pattern( - ImgTrackingVars& v, - uint32_t img_height, - uint32_t img_width, - uint32_t img_stride_h, - uint32_t img_stride_w, - uint32_t input_end_flat_h, - uint32_t output_end_flat_h, - bool current_region_is_halo_prev_core, - bool current_region_is_halo_next_core) { - // Sanity checks at the start for local data - TT_ASSERT(v.next_img_h >= v.img_h); - TT_ASSERT(v.next_img_w == v.img_w); // assumption that the start is picked and not skipped by stride - TT_ASSERT(v.img_h < img_height); - TT_ASSERT(v.next_img_w < img_width); - if (current_region_is_halo_prev_core) { - // cout << "GENERATING READ PATTERN FOR HALO REGION FROM PREVIOUS CORE" << endl; - TT_ASSERT(!current_region_is_halo_next_core); - TT_ASSERT(v.input_flat_h != 0); - TT_ASSERT(v.output_flat_h == 0); - } else if (current_region_is_halo_next_core) { - // cout << "GENERATING READ PATTERN FOR HALO REGION FROM NEXT CORE" << endl; - TT_ASSERT(!current_region_is_halo_prev_core); - TT_ASSERT(v.input_flat_h == 0); - TT_ASSERT(v.output_flat_h != 0); - } else { - // cout << "GENERATING READ PATTERN FOR LOCAL REGION" << endl; - } - - // cout << "img_h=" << v.img_h << ", img_w=" << v.img_w << ", next_img_h=" << v.next_img_h << ", next_img_w=" << - // v.img_w << endl; cout << "v.input_flat_h=" << v.input_flat_h << ", input_end_flat_h=" << input_end_flat_h << ", - // v.output_flat_h=" << v.output_flat_h << ", output_end_flat_h=" << output_end_flat_h << endl; - - TT_ASSERT(v.input_flat_h < input_end_flat_h); - TT_ASSERT(v.output_flat_h < output_end_flat_h); - - uint32_t output_img_height = std::ceil((double)img_height / (double)img_stride_h); - uint32_t output_img_width = std::ceil((double)img_width / (double)img_stride_w); - bool found_halo_for_next_core = false; - - uint32_t top_partial_middle_aligned_row_width = 0; - uint32_t skip_top_partial_middle_aligned_row = 1; - uint32_t top_partial_right_aligned_row_width = 0; - uint32_t skip_top_partial_right_aligned_row = 1; - uint32_t num_rows_top_partial_image = 0; - uint32_t num_skip_rows_top_partial_image = 0; - uint32_t num_full_images = 0; - uint32_t num_rows_bottom_partial_image = 0; - uint32_t num_skip_rows_bottom_partial_image = 0; - uint32_t bottom_partial_left_aligned_row_width = 0; - uint32_t skip_bottom_partial_left_aligned_row = 1; - if (v.img_w != 0) { - // Check if its right aligned or middle aligned (special corner case for halo) - if (v.input_flat_h + img_width - v.img_w <= input_end_flat_h + 1) { - // top partial right aligned - top_partial_right_aligned_row_width = img_width - v.img_w; - skip_top_partial_right_aligned_row = (v.next_img_h == v.img_h) ? 0 : 1; - v.input_flat_h += top_partial_right_aligned_row_width; - if (!skip_top_partial_right_aligned_row) { - v.output_flat_h += std::ceil((double)top_partial_right_aligned_row_width / (double)img_stride_w); - TT_ASSERT(v.output_flat_h <= output_end_flat_h); - } - v.img_w = 0; - v.next_img_w = 0; - if (v.img_h == img_height - 1) { - v.img_h = 0; - v.next_img_h = 0; - } else { - v.img_h += 1; - if (v.next_img_h < v.img_h) { - v.next_img_h += img_stride_h; - } - } - } else { - // special corner case for halo region - // middle aligned - TT_ASSERT(input_end_flat_h - v.input_flat_h + 1 < img_width); - TT_ASSERT(current_region_is_halo_prev_core || current_region_is_halo_next_core); - // top partial middle aligned - top_partial_middle_aligned_row_width = input_end_flat_h - v.input_flat_h + 1; - skip_top_partial_middle_aligned_row = (v.next_img_h == v.img_h) ? 0 : 1; - v.input_flat_h += top_partial_middle_aligned_row_width; - if (!skip_top_partial_middle_aligned_row) { - v.output_flat_h += std::ceil((double)top_partial_middle_aligned_row_width / (double)img_stride_w); - TT_ASSERT(v.output_flat_h <= output_end_flat_h); - } - uint32_t img_w_start = v.img_w; - while (v.img_w < img_w_start + top_partial_middle_aligned_row_width) { - v.img_w += 1; - if (v.next_img_w < v.img_w) { - v.next_img_w += img_stride_w; - } - } - TT_ASSERT(v.img_w < img_width - 1); - } - } - TT_ASSERT(v.next_img_w == v.img_w); - TT_ASSERT(v.output_flat_h <= output_end_flat_h); - TT_ASSERT(v.next_img_h >= v.img_h); - if (v.img_w != 0) { - // special case for halo - TT_ASSERT(v.input_flat_h == input_end_flat_h + 1); - } - TT_ASSERT(v.img_h < img_height && v.img_w < img_width); - - uint32_t num_rows_remaining_of_current_image = (v.img_h == 0) ? 0 : img_height - v.img_h; - if (num_rows_remaining_of_current_image > 0) { - uint32_t num_rows_to_skip = v.next_img_h - v.img_h; - uint32_t output_h_from_remaining_rows_of_current_image = - std::ceil((double)(num_rows_remaining_of_current_image - num_rows_to_skip) / (double)img_stride_h) * - output_img_width; - bool output_for_partial_top_image = - v.output_flat_h + output_h_from_remaining_rows_of_current_image <= output_end_flat_h + 1; - bool input_for_partial_top_image = - v.input_flat_h + (num_rows_remaining_of_current_image * img_width) <= input_end_flat_h + 1; - if (output_for_partial_top_image && input_for_partial_top_image) { - // Top partial image section - num_rows_top_partial_image = img_height - v.img_h; - num_skip_rows_top_partial_image = v.next_img_h - v.img_h; - // Sanity check - TT_ASSERT((v.img_h + num_rows_top_partial_image == img_height)); - v.img_h = 0; - v.next_img_h = 0; - v.input_flat_h += (num_rows_top_partial_image * img_width); - v.output_flat_h += output_h_from_remaining_rows_of_current_image; - TT_ASSERT(v.input_flat_h <= input_end_flat_h + 1); - } - TT_ASSERT(v.output_flat_h <= output_end_flat_h + 1); - } - TT_ASSERT(v.img_h < img_height && v.img_w < img_width); - - if (v.img_h == 0 && v.img_w == 0) { - // Check for full images - while (1) { - bool output_for_current_full_image = - v.output_flat_h + (output_img_height * output_img_width) <= output_end_flat_h + 1; - bool input_for_current_full_image = v.input_flat_h + (img_height * img_width) <= input_end_flat_h + 1; - if (!output_for_current_full_image || !input_for_current_full_image) { - break; - } - v.input_flat_h += (img_height * img_width); - v.img_h = 0; - v.img_w = 0; - v.next_img_h = 0; - v.next_img_w = 0; - num_full_images += 1; - v.output_flat_h += (output_img_height * output_img_width); - } - TT_ASSERT(v.img_h == 0 && v.img_w == 0 && v.next_img_h == 0 && v.next_img_w == 0); - } - - // Sanity check - TT_ASSERT(v.input_flat_h <= input_end_flat_h + 1); - TT_ASSERT(v.output_flat_h <= output_end_flat_h + 1); - - bool found_first_unskipped_row_in_bottom_partial_imgage = false; - // check for bottom partial image rows - while (1) { - bool output_for_bottom_partial_image_row = (v.next_img_h == v.img_h) - ? (v.output_flat_h + output_img_width <= output_end_flat_h + 1) - : true; // true for skipped row - bool input_for_bottom_partial_image_row = v.input_flat_h + img_width <= input_end_flat_h + 1; - if (!output_for_bottom_partial_image_row || !input_for_bottom_partial_image_row) { - break; - } - if (!found_first_unskipped_row_in_bottom_partial_imgage) { - if (v.next_img_h == v.img_h) { - found_first_unskipped_row_in_bottom_partial_imgage = true; - } else { - TT_ASSERT(v.next_img_h > v.img_h); - num_skip_rows_bottom_partial_image += 1; - } - } - v.input_flat_h += img_width; - if (v.next_img_h == v.img_h) { - v.output_flat_h += output_img_width; - } - v.img_w = 0; - v.next_img_w = 0; - TT_ASSERT(v.img_h < img_height - 1); // this is supposed to be a bottom partial image - v.img_h += 1; - if (v.next_img_h < v.img_h) { - v.next_img_h += img_stride_h; - TT_ASSERT(v.next_img_h <= img_height); // odd heights and odd size sharding with stride > 1 not supported - if (v.next_img_h == img_height && v.img_h == img_height) { - v.next_img_h = 0; - v.img_h = 0; - break; - } - } - num_rows_bottom_partial_image += 1; - } - - // Sanity check - TT_ASSERT(v.input_flat_h <= input_end_flat_h + 1); - TT_ASSERT(v.output_flat_h <= output_end_flat_h + 1); - TT_ASSERT(v.img_h < img_height && v.img_w < img_width); - - // check if there is a bottom partial left aligned row - if (v.input_flat_h <= input_end_flat_h && v.output_flat_h <= output_end_flat_h) { - TT_ASSERT(v.img_w == 0 && v.next_img_w == 0); - // bottom partial left aligned row width can be split between 2 cores - uint32_t input_remaining = input_end_flat_h - v.input_flat_h + 1; - uint32_t output_remaining = output_end_flat_h - v.output_flat_h + 1; - TT_ASSERT( - output_remaining < output_img_width || - input_remaining < img_width); // there must be a partial width either on input side or output side - bottom_partial_left_aligned_row_width = input_remaining; - if (output_remaining < output_img_width) { - bottom_partial_left_aligned_row_width = std::min(input_remaining, output_remaining * img_stride_w); - } - // sanity - TT_ASSERT(bottom_partial_left_aligned_row_width < img_width); - TT_ASSERT(v.next_img_h >= v.img_h); - skip_bottom_partial_left_aligned_row = (v.next_img_h == v.img_h) ? 0 : 1; - while (v.img_w < bottom_partial_left_aligned_row_width) { - v.img_w += 1; - if (v.next_img_w < v.img_w) { - v.next_img_w += img_stride_w; - TT_ASSERT(v.next_img_w < img_width); // odd widths and odd size sharding with stride > 1 not supported - } - } - TT_ASSERT(v.img_w == bottom_partial_left_aligned_row_width && v.next_img_w >= v.img_w); - v.input_flat_h += bottom_partial_left_aligned_row_width; - if (!skip_bottom_partial_left_aligned_row) { - v.output_flat_h += std::ceil((double)bottom_partial_left_aligned_row_width / (double)img_stride_w); - } - } - TT_ASSERT(v.img_h < img_height && v.img_w < img_width); - - if (0) { - log_debug(tt::LogOp, " top_partial_middle_aligned_row_width: {}", top_partial_middle_aligned_row_width); - log_debug(tt::LogOp, " skip_top_partial_middle_aligned_row: {}", skip_top_partial_middle_aligned_row); - log_debug(tt::LogOp, " top_partial_right_aligned_row_width: {}", top_partial_right_aligned_row_width); - log_debug(tt::LogOp, " skip_top_partial_right_aligned_row: {}", skip_top_partial_right_aligned_row); - log_debug(tt::LogOp, " num_rows_top_partial_image: {}", num_rows_top_partial_image); - log_debug(tt::LogOp, " num_skip_rows_top_partial_image: {}", num_skip_rows_top_partial_image); - log_debug(tt::LogOp, " num_full_images: {}", num_full_images); - log_debug(tt::LogOp, " num_rows_bottom_partial_image: {}", num_rows_bottom_partial_image); - log_debug(tt::LogOp, " num_skip_rows_bottom_partial_image: {}", num_skip_rows_bottom_partial_image); - log_debug(tt::LogOp, " bottom_partial_left_aligned_row_width: {}", bottom_partial_left_aligned_row_width); - log_debug(tt::LogOp, " skip_bottom_partial_left_aligned_row: {}", skip_bottom_partial_left_aligned_row); - log_debug(tt::LogOp, " v.input_flat_h: {}", v.input_flat_h); - log_debug(tt::LogOp, " v.output_flat_h: {}", v.output_flat_h); - log_debug(tt::LogOp, " input_end_flat_h: {}", input_end_flat_h); - log_debug(tt::LogOp, " output_end_flat_h: {}", output_end_flat_h); - // cout << "img_h=" << v.img_h << ", img_w=" << v.img_w << ", next_img_h=" << v.next_img_h << ", next_img_w=" << - // v.img_w << endl; - } - - // Sanity check - TT_ASSERT(v.input_flat_h <= input_end_flat_h + 1); - TT_ASSERT(v.output_flat_h <= output_end_flat_h + 1); - - if (v.input_flat_h == input_end_flat_h + 1) { - v.input_flat_h = 0; - } - if (v.output_flat_h == output_end_flat_h + 1) { - v.output_flat_h = 0; - } - return DownsampleReadPatternParams{ - .top_partial_middle_aligned_row_width = top_partial_middle_aligned_row_width, - .skip_top_partial_middle_aligned_row = skip_top_partial_middle_aligned_row, - .top_partial_right_aligned_row_width = top_partial_right_aligned_row_width, - .skip_top_partial_right_aligned_row = skip_top_partial_right_aligned_row, - .num_rows_top_partial_image = num_rows_top_partial_image, - .num_skip_rows_top_partial_image = num_skip_rows_top_partial_image, - .num_full_images = num_full_images, - .num_rows_bottom_partial_image = num_rows_bottom_partial_image, - .num_skip_rows_bottom_partial_image = num_skip_rows_bottom_partial_image, - .bottom_partial_left_aligned_row_width = bottom_partial_left_aligned_row_width, - .skip_bottom_partial_left_aligned_row = skip_bottom_partial_left_aligned_row}; -} - -operation::ProgramWithCallbacks downsample_single_core( - const Tensor& a, std::array downsample_params, Tensor& output) { - tt::tt_metal::Program program = tt::tt_metal::CreateProgram(); - - tt::DataFormat input_cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(a.get_dtype()); - uint32_t input_single_tile_size = tt::tt_metal::detail::TileSize(input_cb_data_format); - tt::DataFormat output_cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(output.get_dtype()); - uint32_t output_single_tile_size = tt::tt_metal::detail::TileSize(output_cb_data_format); - tt::DataFormat untilized_cb_data_format = tt::DataFormat::Float16_b; - uint32_t untilized_single_tile_size = tt::tt_metal::detail::TileSize(untilized_cb_data_format); - auto [img_batch_size, img_height, img_width, img_stride_h, img_stride_w] = downsample_params; - tt::tt_metal::Buffer* src0_buffer = a.buffer(); - - TT_ASSERT(a.get_legacy_shape()[0] == 1 && a.get_legacy_shape()[1] == 1); - TT_ASSERT(output.get_legacy_shape()[0] == 1 && output.get_legacy_shape()[1] == 1); - - tt::tt_metal::Device* device = a.device(); - - tt::tt_metal::Buffer* dst_buffer = output.buffer(); - TT_ASSERT(dst_buffer != nullptr, "Output buffer should be allocated on device!"); - // Sanity check of output size - TT_ASSERT(output.volume() % TILE_HW == 0); - uint32_t unpadded_input_volume = img_batch_size * img_height * img_width; - TT_ASSERT(a.volume() >= unpadded_input_volume); - uint32_t unpadded_output_volume = ceil((double)unpadded_input_volume / (double)(img_stride_h * img_stride_w)); - TT_ASSERT(output.volume() >= unpadded_output_volume); - - uint32_t ncores_x_full_grid = device->compute_with_storage_grid_size().x; - auto [num_cores_height_sliced, num_cores_width_sliced] = get_num_cores_height_width_sliced( - a.shard_spec().value().grid, a.memory_config().memory_layout, a.shard_spec().value().orientation); - uint32_t num_cores = num_cores_height_sliced * num_cores_width_sliced; - auto all_cores = a.shard_spec().value().grid; - auto memory_layout = a.memory_config().memory_layout; - TT_ASSERT(all_cores == output.shard_spec().value().grid); - TT_ASSERT(memory_layout == output.memory_config().memory_layout); - TT_ASSERT( - memory_layout == TensorMemoryLayout::HEIGHT_SHARDED || memory_layout == TensorMemoryLayout::BLOCK_SHARDED); - if (memory_layout == TensorMemoryLayout::BLOCK_SHARDED) { - TT_ASSERT(all_cores.ranges().size() == 1); - } else { - TT_ASSERT(num_cores_width_sliced == 1); - } - uint32_t num_cores_x = - memory_layout == TensorMemoryLayout::HEIGHT_SHARDED ? ncores_x_full_grid : num_cores_height_sliced; - - auto core_range = all_cores; - - uint32_t input_height = - a.get_legacy_shape()[2]; // input height == flattened face of input image, multiple images are stacked in H dim - uint32_t input_width = a.get_legacy_shape()[3]; // input width == input image # of channels - uint32_t output_height = output.get_legacy_shape()[2]; // output height == flattened face of output image, multiple - // images are stacked in H dim - uint32_t output_width = output.get_legacy_shape()[3]; - TT_ASSERT(input_width == output_width); - - uint32_t input_height_unpadded = img_batch_size * img_height * img_width; - TT_ASSERT(input_height >= input_height_unpadded); - uint32_t output_height_unpadded = - img_batch_size * std::ceil((double)(img_height * img_width) / (double)(img_stride_h * img_stride_w)); - TT_ASSERT(output_height >= output_height_unpadded); - - uint32_t input_shard_height = a.shard_spec().value().shape[0]; - TT_ASSERT(input_shard_height * num_cores_height_sliced >= input_height); // last core shard size may be padded - uint32_t input_height_padded = input_shard_height * num_cores_height_sliced; - uint32_t input_shard_width = a.shard_spec().value().shape[1]; - TT_ASSERT(input_shard_width * num_cores_width_sliced == input_width); - - uint32_t input_height_padding = input_height_padded - input_height_unpadded; - TT_ASSERT(input_height_padding < input_shard_height); // last core has padding - uint32_t last_core_input_shard_height_unpadded = input_shard_height - input_height_padding; - - uint32_t output_shard_height = output.shard_spec().value().shape[0]; - uint32_t output_shard_width = output.shard_spec().value().shape[1]; - TT_ASSERT(output_shard_width == input_shard_width); - TT_ASSERT(output_shard_height * num_cores_height_sliced >= output_height); // last core shard size may be padded - uint32_t output_height_padded = output_shard_height * num_cores_height_sliced; - uint32_t output_height_padding = output_height_padded - output_height_unpadded; - TT_ASSERT(output_height_padding < output_shard_height); - uint32_t last_core_output_shard_height_unpadded = output_shard_height - output_height_padding; - uint32_t input_shard_width_bytes = input_shard_width * datum_size(untilized_cb_data_format); - TT_ASSERT(input_shard_width % TILE_WIDTH == 0); - uint32_t num_input_tiles_in_row = input_shard_width / TILE_WIDTH; - TT_ASSERT(input_shard_height % TILE_HEIGHT == 0); - uint32_t num_rows_of_input_tiles = input_shard_height / TILE_HEIGHT; - uint32_t num_output_tiles_in_row = num_input_tiles_in_row; - TT_ASSERT(output_shard_height % TILE_HEIGHT == 0); - uint32_t num_rows_of_output_tiles = output_shard_height / TILE_HEIGHT; - uint32_t input_cb_index = tt::CB::c_in0; - uint32_t num_input_tiles = num_input_tiles_in_row * num_rows_of_input_tiles; - tt::tt_metal::CircularBufferConfig input_cb_config = - tt::tt_metal::CircularBufferConfig( - num_input_tiles * input_single_tile_size, {{input_cb_index, input_cb_data_format}}) - .set_page_size(input_cb_index, input_single_tile_size); - input_cb_config = input_cb_config.set_globally_allocated_address(*a.buffer()); - auto input_cb = tt::tt_metal::CreateCircularBuffer(program, core_range, input_cb_config); - log_debug( - tt::LogOp, - "CB {}: PS = {} NP = {} :: TOTAL = {}", - input_cb_index, - input_single_tile_size, - num_input_tiles, - input_single_tile_size * num_input_tiles); - - // CB to store halo data - // hardcode to store 1 row of tiles - uint32_t halo_prev_input_cb_index = tt::CB::c_in1; - uint32_t halo_prev_input_cb_max_rows_of_tiles = 4; - uint32_t num_halo_prev_cb_input_tiles = num_input_tiles_in_row * halo_prev_input_cb_max_rows_of_tiles; - tt::tt_metal::CircularBufferConfig halo_prev_input_cb_config = - tt::tt_metal::CircularBufferConfig( - num_halo_prev_cb_input_tiles * input_single_tile_size, {{halo_prev_input_cb_index, input_cb_data_format}}) - .set_page_size(halo_prev_input_cb_index, input_single_tile_size); - auto halo_prev_input_cb = tt::tt_metal::CreateCircularBuffer(program, core_range, halo_prev_input_cb_config); - log_debug( - tt::LogOp, - "CB {}: PS = {} NP = {} :: TOTAL = {}", - halo_prev_input_cb_index, - input_single_tile_size, - num_halo_prev_cb_input_tiles, - input_single_tile_size * num_halo_prev_cb_input_tiles); - - uint32_t halo_next_input_cb_index = tt::CB::c_in2; - uint32_t halo_next_input_cb_max_rows_of_tiles = 33; // TODO: Remove hardcoding - uint32_t num_halo_next_cb_input_tiles = num_input_tiles_in_row * halo_next_input_cb_max_rows_of_tiles; - tt::tt_metal::CircularBufferConfig halo_next_input_cb_config = - tt::tt_metal::CircularBufferConfig( - num_halo_next_cb_input_tiles * input_single_tile_size, {{halo_next_input_cb_index, input_cb_data_format}}) - .set_page_size(halo_next_input_cb_index, input_single_tile_size); - auto halo_next_input_cb = tt::tt_metal::CreateCircularBuffer(program, core_range, halo_next_input_cb_config); - log_debug( - tt::LogOp, - "CB {}: PS = {} NP = {} :: TOTAL = {}", - halo_next_input_cb_index, - input_single_tile_size, - num_halo_next_cb_input_tiles, - input_single_tile_size * num_halo_next_cb_input_tiles); - - // CB to store reader pattern array - // read pattern array size == output_height - uint32_t reader_pattern_array_size = output_shard_height; - uint32_t reader_pattern_array_cb_index = tt::CB::c_intermed1; - tt::tt_metal::CircularBufferConfig reader_pattern_array_cb_config = - tt::tt_metal::CircularBufferConfig( - reader_pattern_array_size * 4, {{reader_pattern_array_cb_index, tt::DataFormat::Float16_b}}) - .set_page_size(reader_pattern_array_cb_index, 4); - auto reader_pattern_array_cb = tt::tt_metal::CreateCircularBuffer(program, core_range, reader_pattern_array_cb_config); - log_debug( - tt::LogOp, - "CB {}: PS = {} NP = {} :: TOTAL = {}", - reader_pattern_array_cb_index, - 4, - reader_pattern_array_size, - 4 * reader_pattern_array_size); - - // untilized CB has size - [32, full width] - uint32_t untilize_cb_index = tt::CB::c_intermed2; - uint32_t num_tiles_untilize_cb = num_input_tiles_in_row; - tt::tt_metal::CircularBufferConfig untilize_cb_config = - tt::tt_metal::CircularBufferConfig( - num_tiles_untilize_cb * untilized_single_tile_size, {{untilize_cb_index, untilized_cb_data_format}}) - .set_page_size(untilize_cb_index, untilized_single_tile_size); - auto untilize_cb = tt::tt_metal::CreateCircularBuffer(program, core_range, untilize_cb_config); - log_debug( - tt::LogOp, - "CB {}: PS = {} NP = {} :: TOTAL = {}", - untilize_cb_index, - untilized_single_tile_size, - num_tiles_untilize_cb, - untilized_single_tile_size * num_tiles_untilize_cb); - - uint32_t num_output_tiles = num_output_tiles_in_row * num_rows_of_output_tiles; - uint32_t untilize_downsampled_cb_index = tt::CB::c_intermed3; - uint32_t num_tiles_untilize_downsampled_cb = - num_output_tiles; // untilize downsampled cb size == output size per core - tt::tt_metal::CircularBufferConfig untilize_downsampled_cb_config = - tt::tt_metal::CircularBufferConfig( - num_tiles_untilize_downsampled_cb * untilized_single_tile_size, - {{untilize_downsampled_cb_index, untilized_cb_data_format}}) - .set_page_size(untilize_downsampled_cb_index, untilized_single_tile_size); - auto untilize_downsampled_cb = tt::tt_metal::CreateCircularBuffer(program, core_range, untilize_downsampled_cb_config); - log_debug( - tt::LogOp, - "CB {}: PS = {} NP = {} :: TOTAL = {}", - untilize_downsampled_cb_index, - untilized_single_tile_size, - num_tiles_untilize_downsampled_cb, - untilized_single_tile_size * num_tiles_untilize_downsampled_cb); - - uint32_t final_tilize_output_cb_index = tt::CB::c_out0; - uint32_t num_tiles_final_tilize_output_cb = num_output_tiles; // final output cb size == output size per core - tt::tt_metal::CircularBufferConfig final_tilize_output_cb_config = - tt::tt_metal::CircularBufferConfig( - num_tiles_final_tilize_output_cb * output_single_tile_size, - {{final_tilize_output_cb_index, output_cb_data_format}}) - .set_page_size(final_tilize_output_cb_index, output_single_tile_size); - final_tilize_output_cb_config = final_tilize_output_cb_config.set_globally_allocated_address(*output.buffer()); - auto final_tilize_output_cb = tt::tt_metal::CreateCircularBuffer(program, core_range, final_tilize_output_cb_config); - log_debug( - tt::LogOp, - "CB {}: PS = {} NP = {} :: TOTAL = {}", - final_tilize_output_cb_index, - output_single_tile_size, - num_tiles_final_tilize_output_cb, - output_single_tile_size * num_tiles_final_tilize_output_cb); - - uint32_t log_base_2_of_conv_act_size_c_bytes = (uint32_t)std::log2((float)input_shard_width_bytes); - uint32_t stride_h_x_image_width = img_stride_h * img_width; - std::vector writer_compile_time_args = { - (std::uint32_t)untilize_cb_index, - (std::uint32_t)untilize_downsampled_cb_index, - (std::uint32_t)final_tilize_output_cb_index, - (std::uint32_t)reader_pattern_array_cb_index, - (std::uint32_t)datum_size(untilized_cb_data_format), - (std::uint32_t)input_shard_width_bytes, - (std::uint32_t)halo_prev_input_cb_index, - (std::uint32_t)halo_next_input_cb_index, - log_base_2_of_conv_act_size_c_bytes, - stride_h_x_image_width}; - - // Writer to downsample - drops rows from untilized cb - tt::tt_metal::KernelHandle downsample_writer_kernel_id = tt::tt_metal::CreateKernel( - program, - "ttnn/cpp/ttnn/operations/data_movement/downsample/device/kernels/downsample_writer_kernel.cpp", - core_range, - tt::tt_metal::WriterDataMovementConfig(writer_compile_time_args)); - - vector compute_args = { - input_cb_index, - halo_prev_input_cb_index, - halo_next_input_cb_index, - untilize_cb_index, - untilize_downsampled_cb_index, - final_tilize_output_cb_index, - num_input_tiles_in_row, - num_rows_of_output_tiles, - num_output_tiles_in_row, - }; - string compute_kernel = "ttnn/cpp/ttnn/operations/data_movement/downsample/device/kernels/downsample_compute_kernel.cpp"; - if (num_input_tiles_in_row <= MAX_PACK_UNTILIZE_WIDTH) { - compute_kernel = - "ttnn/cpp/ttnn/operations/data_movement/downsample/device/kernels/downsample_fast_pack_untilize_compute_kernel.cpp"; - } - auto downsample_compute_kernel_id = tt::tt_metal::CreateKernel( - program, compute_kernel, core_range, tt::tt_metal::ComputeConfig{.compile_args = compute_args}); - - // track img h, img w, across cores - ImgTrackingVars v; - CoreCoord prev_core = {0, 0}; - bool input_flat_h_is_of_current_core = true; - const auto& cores = corerange_to_cores(all_cores, std::nullopt, true); - for (uint32_t i = 0; i < num_cores; i++) { - const CoreCoord& core = cores[i]; - // cout << "i=" << i << endl; - - uint32_t input_end_flat_h = input_shard_height - 1; - uint32_t output_end_flat_h = output_shard_height - 1; - - bool halo_prev_read_enabled = false; - DownsampleReadPatternParams halo_prev_read_pattern_params; - uint32_t halo_prev_noc_x = 0; - uint32_t halo_prev_noc_y = 0; - uint32_t halo_prev_start_addr = 0; - uint32_t halo_prev_addr_offset = 0; - uint32_t halo_prev_num_tiles = 0; - uint32_t halo_prev_size_bytes = 0; - uint32_t halo_prev_input_num_rows_of_tiles = 0; - uint32_t halo_prev_read_pattern_offset = 0; - - bool halo_next_read_enabled = false; - DownsampleReadPatternParams halo_next_read_pattern_params; - uint32_t halo_next_noc_x = 0; - uint32_t halo_next_noc_y = 0; - uint32_t halo_next_start_addr = 0; - uint32_t halo_next_addr_offset = 0; - uint32_t halo_next_num_tiles = 0; - uint32_t halo_next_size_bytes = 0; - uint32_t halo_next_input_num_rows_of_tiles = 0; - uint32_t halo_next_read_pattern_offset = 0; - uint32_t local_read_pattern_offset = 0; - uint32_t current_core_input_end_flat_h = input_end_flat_h; - uint32_t next_core_input_end_flat_h = input_end_flat_h; - if (memory_layout == TensorMemoryLayout::BLOCK_SHARDED) { - if (i % num_cores_x == 0) { - // first core in row - // reset - v.input_flat_h = 0; - v.output_flat_h = 0; - } else if (i % num_cores_x == num_cores_x - 1) { - // set unpadded height as end index for last core in row - current_core_input_end_flat_h = last_core_input_shard_height_unpadded - 1; - output_end_flat_h = last_core_output_shard_height_unpadded - 1; - } else if (i % num_cores_x == num_cores_x - 2) { - next_core_input_end_flat_h = last_core_input_shard_height_unpadded - 1; - } - } else if (i == num_cores - 1) { - // for height sharding, set unpadded height as end index for last core - current_core_input_end_flat_h = last_core_input_shard_height_unpadded - 1; - output_end_flat_h = last_core_output_shard_height_unpadded - 1; - } else if (i == num_cores - 2) { - next_core_input_end_flat_h = last_core_input_shard_height_unpadded - 1; - } - if (v.input_flat_h != 0 && !input_flat_h_is_of_current_core) { - // halo region of previous core - TT_ASSERT(i != 0); - if (memory_layout == TensorMemoryLayout::BLOCK_SHARDED) { - TT_ASSERT(prev_core.y == core.y); // for block sharding case, prev core is left core - } - halo_prev_read_enabled = true; - TT_ASSERT(v.input_flat_h < input_shard_height); - // get halo start tile address from height idx - uint32_t halo_prev_start_tile_id_h = v.input_flat_h / TILE_HEIGHT; - TT_ASSERT( - input_shard_height - v.input_flat_h <= - TILE_HEIGHT * - halo_prev_input_cb_max_rows_of_tiles); // halo input cb is hardcoded to store only 4 rows of tiles - // for now. TODO: allocate bigger CB or read in blocks - // get halo size - halo_prev_size_bytes = (input_shard_height - (halo_prev_start_tile_id_h * TILE_HEIGHT)) * - input_shard_width / TILE_HW * input_single_tile_size; - TT_ASSERT(halo_prev_size_bytes % input_single_tile_size == 0); - halo_prev_num_tiles = halo_prev_size_bytes / input_single_tile_size; - TT_ASSERT(halo_prev_num_tiles <= num_halo_prev_cb_input_tiles); - TT_ASSERT(halo_prev_num_tiles % num_input_tiles_in_row == 0); - halo_prev_input_num_rows_of_tiles = halo_prev_num_tiles / num_input_tiles_in_row; - halo_prev_addr_offset = num_input_tiles_in_row * halo_prev_start_tile_id_h * input_single_tile_size; - halo_prev_start_addr = GetCircularBufferConfig(program, input_cb).globally_allocated_address().value(); - - TT_ASSERT( - (halo_prev_start_addr + halo_prev_addr_offset) % 32 == 0); // read address should be 32 byte aligned - auto halo_noc_coords = device->worker_core_from_logical_core(prev_core); - halo_prev_noc_x = halo_noc_coords.x; - halo_prev_noc_y = halo_noc_coords.y; - TT_ASSERT(v.input_flat_h >= halo_prev_start_tile_id_h * TILE_HEIGHT); - halo_prev_read_pattern_offset = v.input_flat_h - (halo_prev_start_tile_id_h * TILE_HEIGHT); - local_read_pattern_offset = halo_prev_input_num_rows_of_tiles * TILE_HEIGHT; - halo_prev_read_pattern_params = generate_downsample_read_pattern( - v, img_height, img_width, img_stride_h, img_stride_w, input_end_flat_h, output_end_flat_h, true, false); - } - // local core - TT_ASSERT(v.output_flat_h < output_shard_height); - uint32_t local_start_h = v.input_flat_h; - DownsampleReadPatternParams local_read_pattern_params = generate_downsample_read_pattern( - v, - img_height, - img_width, - img_stride_h, - img_stride_w, - current_core_input_end_flat_h, - output_end_flat_h, - false, - false); - TT_ASSERT(v.output_flat_h <= output_shard_height); - uint32_t local_end_h_exclusive = v.input_flat_h == 0 ? input_shard_height : v.input_flat_h; - uint32_t local_num_rows = local_end_h_exclusive - local_start_h; - TT_ASSERT(local_num_rows > 0); - uint32_t local_input_num_rows_of_tiles = std::ceil((double)local_num_rows / (double)TILE_HEIGHT); - uint32_t local_input_offset_rows_of_tiles = local_start_h / TILE_HEIGHT; - if (local_start_h != 0) { - TT_ASSERT(local_read_pattern_offset == 0); - local_read_pattern_offset = local_start_h % TILE_HEIGHT; - } - if (v.input_flat_h != 0) { - input_flat_h_is_of_current_core = false; - } else { - input_flat_h_is_of_current_core = true; // updating flag for next core - } - TT_ASSERT(local_input_num_rows_of_tiles <= num_rows_of_input_tiles); - - if (v.output_flat_h != 0) { - // need to read halo from next core - TT_ASSERT(i != num_cores - 1); - TT_ASSERT(v.input_flat_h == 0); - const CoreCoord& next_core = cores[i + 1]; - if (memory_layout == TensorMemoryLayout::BLOCK_SHARDED) { - TT_ASSERT(next_core.y == core.y); // for block sharding case, next core is right core - } - halo_next_read_enabled = true; - halo_next_read_pattern_params = generate_downsample_read_pattern( - v, - img_height, - img_width, - img_stride_h, - img_stride_w, - next_core_input_end_flat_h, - output_end_flat_h, - false, - true); - TT_ASSERT(v.output_flat_h == 0); - TT_ASSERT(v.input_flat_h != 0 && v.input_flat_h < input_shard_height); - TT_ASSERT( - v.input_flat_h <= TILE_HEIGHT * halo_next_input_cb_max_rows_of_tiles, - "v.input_flat_h ({}) should be <= TILE_HEIGHT * halo_next_input_cb_max_rows_of_tiles ({})", - v.input_flat_h, - halo_next_input_cb_max_rows_of_tiles); // halo next input cb is hardcoded to store only 5 rows of tiles - // for now. TODO: allocate bigger CB or read in blocks - uint32_t halo_next_end_tile_id_h = v.input_flat_h / TILE_HEIGHT; - // get halo size - halo_next_size_bytes = - (halo_next_end_tile_id_h + 1) * TILE_HEIGHT * input_shard_width / TILE_HW * input_single_tile_size; - TT_ASSERT(halo_next_size_bytes % input_single_tile_size == 0); - halo_next_num_tiles = halo_next_size_bytes / input_single_tile_size; - TT_ASSERT(halo_next_num_tiles <= num_halo_next_cb_input_tiles); - TT_ASSERT(halo_next_num_tiles % num_input_tiles_in_row == 0); - halo_next_input_num_rows_of_tiles = halo_next_num_tiles / num_input_tiles_in_row; - halo_next_addr_offset = 0; - halo_next_start_addr = GetCircularBufferConfig(program, input_cb).globally_allocated_address().value(); - TT_ASSERT( - (halo_next_start_addr + halo_next_addr_offset) % 32 == 0); // read address should be 32 byte aligned - auto halo_noc_coords = device->worker_core_from_logical_core(next_core); - halo_next_noc_x = halo_noc_coords.x; - halo_next_noc_y = halo_noc_coords.y; - TT_ASSERT(halo_prev_input_num_rows_of_tiles == 0); - halo_next_read_pattern_offset = local_input_num_rows_of_tiles * TILE_HEIGHT; - } - TT_ASSERT(v.output_flat_h == 0); - - // Compile runtime args - vector compile_rt_kernel_args = { - local_input_num_rows_of_tiles, - local_input_offset_rows_of_tiles, - halo_prev_read_enabled, - halo_prev_input_num_rows_of_tiles, - halo_next_read_enabled, - halo_next_input_num_rows_of_tiles, - }; - - tt::tt_metal::SetRuntimeArgs(program, downsample_compute_kernel_id, core, compile_rt_kernel_args); - - // Writer runtime args - vector writer_kernel_args = { - (uint32_t)img_height, - (uint32_t)img_width, - (uint32_t)img_stride_h, - (uint32_t)img_stride_w, - - // halo prev args - halo_prev_read_enabled, - halo_prev_noc_x, - halo_prev_noc_y, - halo_prev_num_tiles, - halo_prev_start_addr, - halo_prev_addr_offset, - halo_prev_size_bytes, - - // halo prev read pattern args - halo_prev_read_pattern_offset, - halo_prev_read_pattern_params.top_partial_middle_aligned_row_width, - halo_prev_read_pattern_params.skip_top_partial_middle_aligned_row, - halo_prev_read_pattern_params.top_partial_right_aligned_row_width, - halo_prev_read_pattern_params.skip_top_partial_right_aligned_row, - halo_prev_read_pattern_params.num_rows_top_partial_image, - halo_prev_read_pattern_params.num_skip_rows_top_partial_image, - halo_prev_read_pattern_params.num_full_images, - halo_prev_read_pattern_params.num_rows_bottom_partial_image, - halo_prev_read_pattern_params.num_skip_rows_bottom_partial_image, - halo_prev_read_pattern_params.bottom_partial_left_aligned_row_width, - halo_prev_read_pattern_params.skip_bottom_partial_left_aligned_row, - - // local read pattern args - local_read_pattern_offset, - local_read_pattern_params.top_partial_middle_aligned_row_width, - local_read_pattern_params.skip_top_partial_middle_aligned_row, - local_read_pattern_params.top_partial_right_aligned_row_width, - local_read_pattern_params.skip_top_partial_right_aligned_row, - local_read_pattern_params.num_rows_top_partial_image, - local_read_pattern_params.num_skip_rows_top_partial_image, - local_read_pattern_params.num_full_images, - local_read_pattern_params.num_rows_bottom_partial_image, - local_read_pattern_params.num_skip_rows_bottom_partial_image, - local_read_pattern_params.bottom_partial_left_aligned_row_width, - local_read_pattern_params.skip_bottom_partial_left_aligned_row, - - // halo next core args - halo_next_read_enabled, - halo_next_noc_x, - halo_next_noc_y, - halo_next_num_tiles, - halo_next_start_addr, - halo_next_addr_offset, - halo_next_size_bytes, - - // halo next read pattern args - halo_next_read_pattern_offset, - halo_next_read_pattern_params.top_partial_middle_aligned_row_width, - halo_next_read_pattern_params.skip_top_partial_middle_aligned_row, - halo_next_read_pattern_params.top_partial_right_aligned_row_width, - halo_next_read_pattern_params.skip_top_partial_right_aligned_row, - halo_next_read_pattern_params.num_rows_top_partial_image, - halo_next_read_pattern_params.num_skip_rows_top_partial_image, - halo_next_read_pattern_params.num_full_images, - halo_next_read_pattern_params.num_rows_bottom_partial_image, - halo_next_read_pattern_params.num_skip_rows_bottom_partial_image, - halo_next_read_pattern_params.bottom_partial_left_aligned_row_width, - halo_next_read_pattern_params.skip_bottom_partial_left_aligned_row, - - halo_prev_input_num_rows_of_tiles + local_input_num_rows_of_tiles + halo_next_input_num_rows_of_tiles, - num_input_tiles_in_row, - num_output_tiles, - - (uint32_t) false}; - - tt::tt_metal::SetRuntimeArgs(program, downsample_writer_kernel_id, core, writer_kernel_args); - prev_core = core; - } - - auto override_runtime_args_callback = [input_cb = input_cb, - final_tilize_output_cb = final_tilize_output_cb, - downsample_writer_kernel_id = downsample_writer_kernel_id, - cores = cores]( - const void* operation, - Program& program, - const std::vector& input_tensors, - const std::vector>& optional_input_tensors, - const std::vector& output_tensors) { - auto src_buffer = input_tensors.at(0).buffer(); - auto dst_buffer = output_tensors.at(0).buffer(); - - UpdateDynamicCircularBufferAddress(program, input_cb, *src_buffer); - UpdateDynamicCircularBufferAddress(program, final_tilize_output_cb, *dst_buffer); - - auto& writer_runtime_args_by_core = GetRuntimeArgs(program, downsample_writer_kernel_id); - for (const auto& core : cores) { - auto& runtime_args = writer_runtime_args_by_core[core.x][core.y]; - runtime_args[8] = src_buffer->address(); - runtime_args[39] = src_buffer->address(); - } - }; - - return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_args_callback}; -} -} //name space data_movement -} // namespace operations -} // namespace ttnn +} //name space ttnn::operations::data_movement diff --git a/ttnn/cpp/ttnn/operations/data_movement/downsample/device/downsample_program_factory.cpp b/ttnn/cpp/ttnn/operations/data_movement/downsample/device/downsample_program_factory.cpp new file mode 100644 index 00000000000..5c6bb5a9e88 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/data_movement/downsample/device/downsample_program_factory.cpp @@ -0,0 +1,877 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "downsample_program_factory.hpp" + +#include + +#include "ttnn/deprecated/tt_dnn/op_library/math.hpp" +#include "ttnn/deprecated/tt_dnn/op_library/untilize/untilize_op.hpp" +#include "ttnn/deprecated/tt_dnn/op_library/work_split.hpp" +#include "tt_metal/common/constants.hpp" +#include "tt_metal/detail/util.hpp" +#include "tt_metal/host_api.hpp" + +using namespace tt::constants; + + +namespace ttnn::operations::data_movement::detail { + +std::pair get_num_cores_height_width_sliced( + CoreRangeSet all_cores, TensorMemoryLayout memory_layout, ShardOrientation shard_orientation) { + TT_ASSERT( + memory_layout == TensorMemoryLayout::HEIGHT_SHARDED || memory_layout == TensorMemoryLayout::BLOCK_SHARDED); + if (memory_layout == TensorMemoryLayout::HEIGHT_SHARDED) { + TT_ASSERT(shard_orientation == ShardOrientation::ROW_MAJOR); + } else { + TT_ASSERT(shard_orientation == ShardOrientation::COL_MAJOR); + TT_ASSERT(all_cores.ranges().size() == 1); + } + uint32_t num_cores = all_cores.num_cores(); + auto first_core_range = *all_cores.ranges().begin(); + uint32_t num_cores_height_sliced = + memory_layout == TensorMemoryLayout::HEIGHT_SHARDED ? num_cores : first_core_range.end_coord.x + 1; + uint32_t num_cores_width_sliced = memory_layout == TensorMemoryLayout::HEIGHT_SHARDED + ? 1 + : first_core_range.end_coord.y + 1; // width is not sliced when height sharded + return {num_cores_height_sliced, num_cores_width_sliced}; +} + + +struct DownsampleReadPatternParams { + uint32_t top_partial_middle_aligned_row_width; + uint32_t skip_top_partial_middle_aligned_row; + uint32_t top_partial_right_aligned_row_width; + uint32_t skip_top_partial_right_aligned_row; + uint32_t num_rows_top_partial_image; + uint32_t num_skip_rows_top_partial_image; + uint32_t num_full_images; + uint32_t num_rows_bottom_partial_image; + uint32_t num_skip_rows_bottom_partial_image; + uint32_t bottom_partial_left_aligned_row_width; + uint32_t skip_bottom_partial_left_aligned_row; +}; + + +struct ImgTrackingVars { + uint32_t img_h = 0; + uint32_t img_w = 0; + uint32_t next_img_h = 0; // img_h after stride + uint32_t next_img_w = 0; + uint32_t input_flat_h = 0; // index within sharded input + uint32_t output_flat_h = 0; // index within sharded output +}; + +DownsampleReadPatternParams generate_downsample_read_pattern( + ImgTrackingVars& v, + uint32_t img_height, + uint32_t img_width, + uint32_t img_stride_h, + uint32_t img_stride_w, + uint32_t input_end_flat_h, + uint32_t output_end_flat_h, + bool current_region_is_halo_prev_core, + bool current_region_is_halo_next_core) { + // Sanity checks at the start for local data + TT_ASSERT(v.next_img_h >= v.img_h); + TT_ASSERT(v.next_img_w == v.img_w); // assumption that the start is picked and not skipped by stride + TT_ASSERT(v.img_h < img_height); + TT_ASSERT(v.next_img_w < img_width); + if (current_region_is_halo_prev_core) { + // cout << "GENERATING READ PATTERN FOR HALO REGION FROM PREVIOUS CORE" << endl; + TT_ASSERT(!current_region_is_halo_next_core); + TT_ASSERT(v.input_flat_h != 0); + TT_ASSERT(v.output_flat_h == 0); + } else if (current_region_is_halo_next_core) { + // cout << "GENERATING READ PATTERN FOR HALO REGION FROM NEXT CORE" << endl; + TT_ASSERT(!current_region_is_halo_prev_core); + TT_ASSERT(v.input_flat_h == 0); + TT_ASSERT(v.output_flat_h != 0); + } else { + // cout << "GENERATING READ PATTERN FOR LOCAL REGION" << endl; + } + + // cout << "img_h=" << v.img_h << ", img_w=" << v.img_w << ", next_img_h=" << v.next_img_h << ", next_img_w=" << + // v.img_w << endl; cout << "v.input_flat_h=" << v.input_flat_h << ", input_end_flat_h=" << input_end_flat_h << ", + // v.output_flat_h=" << v.output_flat_h << ", output_end_flat_h=" << output_end_flat_h << endl; + + TT_ASSERT(v.input_flat_h < input_end_flat_h); + TT_ASSERT(v.output_flat_h < output_end_flat_h); + + uint32_t output_img_height = std::ceil((double)img_height / (double)img_stride_h); + uint32_t output_img_width = std::ceil((double)img_width / (double)img_stride_w); + bool found_halo_for_next_core = false; + + uint32_t top_partial_middle_aligned_row_width = 0; + uint32_t skip_top_partial_middle_aligned_row = 1; + uint32_t top_partial_right_aligned_row_width = 0; + uint32_t skip_top_partial_right_aligned_row = 1; + uint32_t num_rows_top_partial_image = 0; + uint32_t num_skip_rows_top_partial_image = 0; + uint32_t num_full_images = 0; + uint32_t num_rows_bottom_partial_image = 0; + uint32_t num_skip_rows_bottom_partial_image = 0; + uint32_t bottom_partial_left_aligned_row_width = 0; + uint32_t skip_bottom_partial_left_aligned_row = 1; + if (v.img_w != 0) { + // Check if its right aligned or middle aligned (special corner case for halo) + if (v.input_flat_h + img_width - v.img_w <= input_end_flat_h + 1) { + // top partial right aligned + top_partial_right_aligned_row_width = img_width - v.img_w; + skip_top_partial_right_aligned_row = (v.next_img_h == v.img_h) ? 0 : 1; + v.input_flat_h += top_partial_right_aligned_row_width; + if (!skip_top_partial_right_aligned_row) { + v.output_flat_h += std::ceil((double)top_partial_right_aligned_row_width / (double)img_stride_w); + TT_ASSERT(v.output_flat_h <= output_end_flat_h); + } + v.img_w = 0; + v.next_img_w = 0; + if (v.img_h == img_height - 1) { + v.img_h = 0; + v.next_img_h = 0; + } else { + v.img_h += 1; + if (v.next_img_h < v.img_h) { + v.next_img_h += img_stride_h; + } + } + } else { + // special corner case for halo region + // middle aligned + TT_ASSERT(input_end_flat_h - v.input_flat_h + 1 < img_width); + TT_ASSERT(current_region_is_halo_prev_core || current_region_is_halo_next_core); + // top partial middle aligned + top_partial_middle_aligned_row_width = input_end_flat_h - v.input_flat_h + 1; + skip_top_partial_middle_aligned_row = (v.next_img_h == v.img_h) ? 0 : 1; + v.input_flat_h += top_partial_middle_aligned_row_width; + if (!skip_top_partial_middle_aligned_row) { + v.output_flat_h += std::ceil((double)top_partial_middle_aligned_row_width / (double)img_stride_w); + TT_ASSERT(v.output_flat_h <= output_end_flat_h); + } + uint32_t img_w_start = v.img_w; + while (v.img_w < img_w_start + top_partial_middle_aligned_row_width) { + v.img_w += 1; + if (v.next_img_w < v.img_w) { + v.next_img_w += img_stride_w; + } + } + TT_ASSERT(v.img_w < img_width - 1); + } + } + TT_ASSERT(v.next_img_w == v.img_w); + TT_ASSERT(v.output_flat_h <= output_end_flat_h); + TT_ASSERT(v.next_img_h >= v.img_h); + if (v.img_w != 0) { + // special case for halo + TT_ASSERT(v.input_flat_h == input_end_flat_h + 1); + } + TT_ASSERT(v.img_h < img_height && v.img_w < img_width); + + uint32_t num_rows_remaining_of_current_image = (v.img_h == 0) ? 0 : img_height - v.img_h; + if (num_rows_remaining_of_current_image > 0) { + uint32_t num_rows_to_skip = v.next_img_h - v.img_h; + uint32_t output_h_from_remaining_rows_of_current_image = + std::ceil((double)(num_rows_remaining_of_current_image - num_rows_to_skip) / (double)img_stride_h) * + output_img_width; + bool output_for_partial_top_image = + v.output_flat_h + output_h_from_remaining_rows_of_current_image <= output_end_flat_h + 1; + bool input_for_partial_top_image = + v.input_flat_h + (num_rows_remaining_of_current_image * img_width) <= input_end_flat_h + 1; + if (output_for_partial_top_image && input_for_partial_top_image) { + // Top partial image section + num_rows_top_partial_image = img_height - v.img_h; + num_skip_rows_top_partial_image = v.next_img_h - v.img_h; + // Sanity check + TT_ASSERT((v.img_h + num_rows_top_partial_image == img_height)); + v.img_h = 0; + v.next_img_h = 0; + v.input_flat_h += (num_rows_top_partial_image * img_width); + v.output_flat_h += output_h_from_remaining_rows_of_current_image; + TT_ASSERT(v.input_flat_h <= input_end_flat_h + 1); + } + TT_ASSERT(v.output_flat_h <= output_end_flat_h + 1); + } + TT_ASSERT(v.img_h < img_height && v.img_w < img_width); + + if (v.img_h == 0 && v.img_w == 0) { + // Check for full images + while (1) { + bool output_for_current_full_image = + v.output_flat_h + (output_img_height * output_img_width) <= output_end_flat_h + 1; + bool input_for_current_full_image = v.input_flat_h + (img_height * img_width) <= input_end_flat_h + 1; + if (!output_for_current_full_image || !input_for_current_full_image) { + break; + } + v.input_flat_h += (img_height * img_width); + v.img_h = 0; + v.img_w = 0; + v.next_img_h = 0; + v.next_img_w = 0; + num_full_images += 1; + v.output_flat_h += (output_img_height * output_img_width); + } + TT_ASSERT(v.img_h == 0 && v.img_w == 0 && v.next_img_h == 0 && v.next_img_w == 0); + } + + // Sanity check + TT_ASSERT(v.input_flat_h <= input_end_flat_h + 1); + TT_ASSERT(v.output_flat_h <= output_end_flat_h + 1); + + bool found_first_unskipped_row_in_bottom_partial_imgage = false; + // check for bottom partial image rows + while (1) { + bool output_for_bottom_partial_image_row = (v.next_img_h == v.img_h) + ? (v.output_flat_h + output_img_width <= output_end_flat_h + 1) + : true; // true for skipped row + bool input_for_bottom_partial_image_row = v.input_flat_h + img_width <= input_end_flat_h + 1; + if (!output_for_bottom_partial_image_row || !input_for_bottom_partial_image_row) { + break; + } + if (!found_first_unskipped_row_in_bottom_partial_imgage) { + if (v.next_img_h == v.img_h) { + found_first_unskipped_row_in_bottom_partial_imgage = true; + } else { + TT_ASSERT(v.next_img_h > v.img_h); + num_skip_rows_bottom_partial_image += 1; + } + } + v.input_flat_h += img_width; + if (v.next_img_h == v.img_h) { + v.output_flat_h += output_img_width; + } + v.img_w = 0; + v.next_img_w = 0; + TT_ASSERT(v.img_h < img_height - 1); // this is supposed to be a bottom partial image + v.img_h += 1; + if (v.next_img_h < v.img_h) { + v.next_img_h += img_stride_h; + TT_ASSERT(v.next_img_h <= img_height); // odd heights and odd size sharding with stride > 1 not supported + if (v.next_img_h == img_height && v.img_h == img_height) { + v.next_img_h = 0; + v.img_h = 0; + break; + } + } + num_rows_bottom_partial_image += 1; + } + + // Sanity check + TT_ASSERT(v.input_flat_h <= input_end_flat_h + 1); + TT_ASSERT(v.output_flat_h <= output_end_flat_h + 1); + TT_ASSERT(v.img_h < img_height && v.img_w < img_width); + + // check if there is a bottom partial left aligned row + if (v.input_flat_h <= input_end_flat_h && v.output_flat_h <= output_end_flat_h) { + TT_ASSERT(v.img_w == 0 && v.next_img_w == 0); + // bottom partial left aligned row width can be split between 2 cores + uint32_t input_remaining = input_end_flat_h - v.input_flat_h + 1; + uint32_t output_remaining = output_end_flat_h - v.output_flat_h + 1; + TT_ASSERT( + output_remaining < output_img_width || + input_remaining < img_width); // there must be a partial width either on input side or output side + bottom_partial_left_aligned_row_width = input_remaining; + if (output_remaining < output_img_width) { + bottom_partial_left_aligned_row_width = std::min(input_remaining, output_remaining * img_stride_w); + } + // sanity + TT_ASSERT(bottom_partial_left_aligned_row_width < img_width); + TT_ASSERT(v.next_img_h >= v.img_h); + skip_bottom_partial_left_aligned_row = (v.next_img_h == v.img_h) ? 0 : 1; + while (v.img_w < bottom_partial_left_aligned_row_width) { + v.img_w += 1; + if (v.next_img_w < v.img_w) { + v.next_img_w += img_stride_w; + TT_ASSERT(v.next_img_w < img_width); // odd widths and odd size sharding with stride > 1 not supported + } + } + TT_ASSERT(v.img_w == bottom_partial_left_aligned_row_width && v.next_img_w >= v.img_w); + v.input_flat_h += bottom_partial_left_aligned_row_width; + if (!skip_bottom_partial_left_aligned_row) { + v.output_flat_h += std::ceil((double)bottom_partial_left_aligned_row_width / (double)img_stride_w); + } + } + TT_ASSERT(v.img_h < img_height && v.img_w < img_width); + + if (0) { + log_debug(tt::LogOp, " top_partial_middle_aligned_row_width: {}", top_partial_middle_aligned_row_width); + log_debug(tt::LogOp, " skip_top_partial_middle_aligned_row: {}", skip_top_partial_middle_aligned_row); + log_debug(tt::LogOp, " top_partial_right_aligned_row_width: {}", top_partial_right_aligned_row_width); + log_debug(tt::LogOp, " skip_top_partial_right_aligned_row: {}", skip_top_partial_right_aligned_row); + log_debug(tt::LogOp, " num_rows_top_partial_image: {}", num_rows_top_partial_image); + log_debug(tt::LogOp, " num_skip_rows_top_partial_image: {}", num_skip_rows_top_partial_image); + log_debug(tt::LogOp, " num_full_images: {}", num_full_images); + log_debug(tt::LogOp, " num_rows_bottom_partial_image: {}", num_rows_bottom_partial_image); + log_debug(tt::LogOp, " num_skip_rows_bottom_partial_image: {}", num_skip_rows_bottom_partial_image); + log_debug(tt::LogOp, " bottom_partial_left_aligned_row_width: {}", bottom_partial_left_aligned_row_width); + log_debug(tt::LogOp, " skip_bottom_partial_left_aligned_row: {}", skip_bottom_partial_left_aligned_row); + log_debug(tt::LogOp, " v.input_flat_h: {}", v.input_flat_h); + log_debug(tt::LogOp, " v.output_flat_h: {}", v.output_flat_h); + log_debug(tt::LogOp, " input_end_flat_h: {}", input_end_flat_h); + log_debug(tt::LogOp, " output_end_flat_h: {}", output_end_flat_h); + // cout << "img_h=" << v.img_h << ", img_w=" << v.img_w << ", next_img_h=" << v.next_img_h << ", next_img_w=" << + // v.img_w << endl; + } + + // Sanity check + TT_ASSERT(v.input_flat_h <= input_end_flat_h + 1); + TT_ASSERT(v.output_flat_h <= output_end_flat_h + 1); + + if (v.input_flat_h == input_end_flat_h + 1) { + v.input_flat_h = 0; + } + if (v.output_flat_h == output_end_flat_h + 1) { + v.output_flat_h = 0; + } + return DownsampleReadPatternParams{ + .top_partial_middle_aligned_row_width = top_partial_middle_aligned_row_width, + .skip_top_partial_middle_aligned_row = skip_top_partial_middle_aligned_row, + .top_partial_right_aligned_row_width = top_partial_right_aligned_row_width, + .skip_top_partial_right_aligned_row = skip_top_partial_right_aligned_row, + .num_rows_top_partial_image = num_rows_top_partial_image, + .num_skip_rows_top_partial_image = num_skip_rows_top_partial_image, + .num_full_images = num_full_images, + .num_rows_bottom_partial_image = num_rows_bottom_partial_image, + .num_skip_rows_bottom_partial_image = num_skip_rows_bottom_partial_image, + .bottom_partial_left_aligned_row_width = bottom_partial_left_aligned_row_width, + .skip_bottom_partial_left_aligned_row = skip_bottom_partial_left_aligned_row}; +} + + +operation::ProgramWithCallbacks downsample_single_core( + const Tensor& a, std::array downsample_params, Tensor& output) { + tt::tt_metal::Program program = tt::tt_metal::CreateProgram(); + + tt::DataFormat input_cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(a.get_dtype()); + uint32_t input_single_tile_size = tt::tt_metal::detail::TileSize(input_cb_data_format); + tt::DataFormat output_cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(output.get_dtype()); + uint32_t output_single_tile_size = tt::tt_metal::detail::TileSize(output_cb_data_format); + tt::DataFormat untilized_cb_data_format = tt::DataFormat::Float16_b; + uint32_t untilized_single_tile_size = tt::tt_metal::detail::TileSize(untilized_cb_data_format); + auto [img_batch_size, img_height, img_width, img_stride_h, img_stride_w] = downsample_params; + tt::tt_metal::Buffer* src0_buffer = a.buffer(); + + TT_ASSERT(a.get_legacy_shape()[0] == 1 && a.get_legacy_shape()[1] == 1); + TT_ASSERT(output.get_legacy_shape()[0] == 1 && output.get_legacy_shape()[1] == 1); + + tt::tt_metal::Device* device = a.device(); + + tt::tt_metal::Buffer* dst_buffer = output.buffer(); + TT_ASSERT(dst_buffer != nullptr, "Output buffer should be allocated on device!"); + // Sanity check of output size + TT_ASSERT(output.volume() % TILE_HW == 0); + uint32_t unpadded_input_volume = img_batch_size * img_height * img_width; + TT_ASSERT(a.volume() >= unpadded_input_volume); + uint32_t unpadded_output_volume = ceil((double)unpadded_input_volume / (double)(img_stride_h * img_stride_w)); + TT_ASSERT(output.volume() >= unpadded_output_volume); + + uint32_t ncores_x_full_grid = device->compute_with_storage_grid_size().x; + auto [num_cores_height_sliced, num_cores_width_sliced] = get_num_cores_height_width_sliced( + a.shard_spec().value().grid, a.memory_config().memory_layout, a.shard_spec().value().orientation); + uint32_t num_cores = num_cores_height_sliced * num_cores_width_sliced; + auto all_cores = a.shard_spec().value().grid; + auto memory_layout = a.memory_config().memory_layout; + TT_ASSERT(all_cores == output.shard_spec().value().grid); + TT_ASSERT(memory_layout == output.memory_config().memory_layout); + TT_ASSERT( + memory_layout == TensorMemoryLayout::HEIGHT_SHARDED || memory_layout == TensorMemoryLayout::BLOCK_SHARDED); + if (memory_layout == TensorMemoryLayout::BLOCK_SHARDED) { + TT_ASSERT(all_cores.ranges().size() == 1); + } else { + TT_ASSERT(num_cores_width_sliced == 1); + } + uint32_t num_cores_x = + memory_layout == TensorMemoryLayout::HEIGHT_SHARDED ? ncores_x_full_grid : num_cores_height_sliced; + + auto core_range = all_cores; + + uint32_t input_height = + a.get_legacy_shape()[2]; // input height == flattened face of input image, multiple images are stacked in H dim + uint32_t input_width = a.get_legacy_shape()[3]; // input width == input image # of channels + uint32_t output_height = output.get_legacy_shape()[2]; // output height == flattened face of output image, multiple + // images are stacked in H dim + uint32_t output_width = output.get_legacy_shape()[3]; + TT_ASSERT(input_width == output_width); + + uint32_t input_height_unpadded = img_batch_size * img_height * img_width; + TT_ASSERT(input_height >= input_height_unpadded); + uint32_t output_height_unpadded = + img_batch_size * std::ceil((double)(img_height * img_width) / (double)(img_stride_h * img_stride_w)); + TT_ASSERT(output_height >= output_height_unpadded); + + uint32_t input_shard_height = a.shard_spec().value().shape[0]; + TT_ASSERT(input_shard_height * num_cores_height_sliced >= input_height); // last core shard size may be padded + uint32_t input_height_padded = input_shard_height * num_cores_height_sliced; + uint32_t input_shard_width = a.shard_spec().value().shape[1]; + TT_ASSERT(input_shard_width * num_cores_width_sliced == input_width); + + uint32_t input_height_padding = input_height_padded - input_height_unpadded; + TT_ASSERT(input_height_padding < input_shard_height); // last core has padding + uint32_t last_core_input_shard_height_unpadded = input_shard_height - input_height_padding; + + uint32_t output_shard_height = output.shard_spec().value().shape[0]; + uint32_t output_shard_width = output.shard_spec().value().shape[1]; + + TT_ASSERT(output_shard_width == input_shard_width); + TT_ASSERT(output_shard_height * num_cores_height_sliced >= output_height); // last core shard size may be padded + uint32_t output_height_padded = output_shard_height * num_cores_height_sliced; + uint32_t output_height_padding = output_height_padded - output_height_unpadded; + TT_ASSERT(output_height_padding < output_shard_height); + uint32_t last_core_output_shard_height_unpadded = output_shard_height - output_height_padding; + + uint32_t input_shard_width_bytes = input_shard_width * datum_size(untilized_cb_data_format); + + TT_ASSERT(input_shard_width % TILE_WIDTH == 0); + uint32_t num_input_tiles_in_row = input_shard_width / TILE_WIDTH; + TT_ASSERT(input_shard_height % TILE_HEIGHT == 0); + uint32_t num_rows_of_input_tiles = input_shard_height / TILE_HEIGHT; + + uint32_t num_output_tiles_in_row = num_input_tiles_in_row; + TT_ASSERT(output_shard_height % TILE_HEIGHT == 0); + uint32_t num_rows_of_output_tiles = output_shard_height / TILE_HEIGHT; + + uint32_t input_cb_index = tt::CB::c_in0; + uint32_t num_input_tiles = num_input_tiles_in_row * num_rows_of_input_tiles; + tt::tt_metal::CircularBufferConfig input_cb_config = + tt::tt_metal::CircularBufferConfig( + num_input_tiles * input_single_tile_size, {{input_cb_index, input_cb_data_format}}) + .set_page_size(input_cb_index, input_single_tile_size); + input_cb_config = input_cb_config.set_globally_allocated_address(*a.buffer()); + auto input_cb = tt::tt_metal::CreateCircularBuffer(program, core_range, input_cb_config); + log_debug( + tt::LogOp, + "CB {}: PS = {} NP = {} :: TOTAL = {}", + input_cb_index, + input_single_tile_size, + num_input_tiles, + input_single_tile_size * num_input_tiles); + + // CB to store halo data + // hardcode to store 1 row of tiles + uint32_t halo_prev_input_cb_index = tt::CB::c_in1; + uint32_t halo_prev_input_cb_max_rows_of_tiles = 4; + uint32_t num_halo_prev_cb_input_tiles = num_input_tiles_in_row * halo_prev_input_cb_max_rows_of_tiles; + tt::tt_metal::CircularBufferConfig halo_prev_input_cb_config = + tt::tt_metal::CircularBufferConfig( + num_halo_prev_cb_input_tiles * input_single_tile_size, {{halo_prev_input_cb_index, input_cb_data_format}}) + .set_page_size(halo_prev_input_cb_index, input_single_tile_size); + auto halo_prev_input_cb = tt::tt_metal::CreateCircularBuffer(program, core_range, halo_prev_input_cb_config); + log_debug( + tt::LogOp, + "CB {}: PS = {} NP = {} :: TOTAL = {}", + halo_prev_input_cb_index, + input_single_tile_size, + num_halo_prev_cb_input_tiles, + input_single_tile_size * num_halo_prev_cb_input_tiles); + + uint32_t halo_next_input_cb_index = tt::CB::c_in2; + uint32_t halo_next_input_cb_max_rows_of_tiles = 33; // TODO: Remove hardcoding + uint32_t num_halo_next_cb_input_tiles = num_input_tiles_in_row * halo_next_input_cb_max_rows_of_tiles; + tt::tt_metal::CircularBufferConfig halo_next_input_cb_config = + tt::tt_metal::CircularBufferConfig( + num_halo_next_cb_input_tiles * input_single_tile_size, {{halo_next_input_cb_index, input_cb_data_format}}) + .set_page_size(halo_next_input_cb_index, input_single_tile_size); + auto halo_next_input_cb = tt::tt_metal::CreateCircularBuffer(program, core_range, halo_next_input_cb_config); + log_debug( + tt::LogOp, + "CB {}: PS = {} NP = {} :: TOTAL = {}", + halo_next_input_cb_index, + input_single_tile_size, + num_halo_next_cb_input_tiles, + input_single_tile_size * num_halo_next_cb_input_tiles); + + // CB to store reader pattern array + // read pattern array size == output_height + uint32_t reader_pattern_array_size = output_shard_height; + uint32_t reader_pattern_array_cb_index = tt::CB::c_intermed1; + tt::tt_metal::CircularBufferConfig reader_pattern_array_cb_config = + tt::tt_metal::CircularBufferConfig( + reader_pattern_array_size * 4, {{reader_pattern_array_cb_index, tt::DataFormat::Float16_b}}) + .set_page_size(reader_pattern_array_cb_index, 4); + auto reader_pattern_array_cb = tt::tt_metal::CreateCircularBuffer(program, core_range, reader_pattern_array_cb_config); + log_debug( + tt::LogOp, + "CB {}: PS = {} NP = {} :: TOTAL = {}", + reader_pattern_array_cb_index, + 4, + reader_pattern_array_size, + 4 * reader_pattern_array_size); + + // untilized CB has size - [32, full width] + uint32_t untilize_cb_index = tt::CB::c_intermed2; + uint32_t num_tiles_untilize_cb = num_input_tiles_in_row; + tt::tt_metal::CircularBufferConfig untilize_cb_config = + tt::tt_metal::CircularBufferConfig( + num_tiles_untilize_cb * untilized_single_tile_size, {{untilize_cb_index, untilized_cb_data_format}}) + .set_page_size(untilize_cb_index, untilized_single_tile_size); + auto untilize_cb = tt::tt_metal::CreateCircularBuffer(program, core_range, untilize_cb_config); + log_debug( + tt::LogOp, + "CB {}: PS = {} NP = {} :: TOTAL = {}", + untilize_cb_index, + untilized_single_tile_size, + num_tiles_untilize_cb, + untilized_single_tile_size * num_tiles_untilize_cb); + + uint32_t num_output_tiles = num_output_tiles_in_row * num_rows_of_output_tiles; + uint32_t untilize_downsampled_cb_index = tt::CB::c_intermed3; + uint32_t num_tiles_untilize_downsampled_cb = + num_output_tiles; // untilize downsampled cb size == output size per core + tt::tt_metal::CircularBufferConfig untilize_downsampled_cb_config = + tt::tt_metal::CircularBufferConfig( + num_tiles_untilize_downsampled_cb * untilized_single_tile_size, + {{untilize_downsampled_cb_index, untilized_cb_data_format}}) + .set_page_size(untilize_downsampled_cb_index, untilized_single_tile_size); + auto untilize_downsampled_cb = tt::tt_metal::CreateCircularBuffer(program, core_range, untilize_downsampled_cb_config); + log_debug( + tt::LogOp, + "CB {}: PS = {} NP = {} :: TOTAL = {}", + untilize_downsampled_cb_index, + untilized_single_tile_size, + num_tiles_untilize_downsampled_cb, + untilized_single_tile_size * num_tiles_untilize_downsampled_cb); + + uint32_t final_tilize_output_cb_index = tt::CB::c_out0; + uint32_t num_tiles_final_tilize_output_cb = num_output_tiles; // final output cb size == output size per core + tt::tt_metal::CircularBufferConfig final_tilize_output_cb_config = + tt::tt_metal::CircularBufferConfig( + num_tiles_final_tilize_output_cb * output_single_tile_size, + {{final_tilize_output_cb_index, output_cb_data_format}}) + .set_page_size(final_tilize_output_cb_index, output_single_tile_size); + final_tilize_output_cb_config = final_tilize_output_cb_config.set_globally_allocated_address(*output.buffer()); + auto final_tilize_output_cb = tt::tt_metal::CreateCircularBuffer(program, core_range, final_tilize_output_cb_config); + log_debug( + tt::LogOp, + "CB {}: PS = {} NP = {} :: TOTAL = {}", + final_tilize_output_cb_index, + output_single_tile_size, + num_tiles_final_tilize_output_cb, + output_single_tile_size * num_tiles_final_tilize_output_cb); + + uint32_t log_base_2_of_conv_act_size_c_bytes = (uint32_t)std::log2((float)input_shard_width_bytes); + uint32_t stride_h_x_image_width = img_stride_h * img_width; + std::vector writer_compile_time_args = { + (std::uint32_t)untilize_cb_index, + (std::uint32_t)untilize_downsampled_cb_index, + (std::uint32_t)final_tilize_output_cb_index, + (std::uint32_t)reader_pattern_array_cb_index, + (std::uint32_t)datum_size(untilized_cb_data_format), + (std::uint32_t)input_shard_width_bytes, + (std::uint32_t)halo_prev_input_cb_index, + (std::uint32_t)halo_next_input_cb_index, + log_base_2_of_conv_act_size_c_bytes, + stride_h_x_image_width}; + + // Writer to downsample - drops rows from untilized cb + tt::tt_metal::KernelHandle downsample_writer_kernel_id = tt::tt_metal::CreateKernel( + program, + "ttnn/cpp/ttnn/operations/data_movement/downsample/device/kernels/downsample_writer_kernel.cpp", + core_range, + tt::tt_metal::WriterDataMovementConfig(writer_compile_time_args)); + + vector compute_args = { + input_cb_index, + halo_prev_input_cb_index, + halo_next_input_cb_index, + untilize_cb_index, + untilize_downsampled_cb_index, + final_tilize_output_cb_index, + num_input_tiles_in_row, + num_rows_of_output_tiles, + num_output_tiles_in_row, + }; + string compute_kernel = "ttnn/cpp/ttnn/operations/data_movement/downsample/device/kernels/downsample_compute_kernel.cpp"; + if (num_input_tiles_in_row <= MAX_PACK_UNTILIZE_WIDTH) { + compute_kernel = + "ttnn/cpp/ttnn/operations/data_movement/downsample/device/kernels/downsample_fast_pack_untilize_compute_kernel.cpp"; + } + auto downsample_compute_kernel_id = tt::tt_metal::CreateKernel( + program, compute_kernel, core_range, tt::tt_metal::ComputeConfig{.compile_args = compute_args}); + + // track img h, img w, across cores + ImgTrackingVars v; + CoreCoord prev_core = {0, 0}; + bool input_flat_h_is_of_current_core = true; + const auto& cores = corerange_to_cores(all_cores, std::nullopt, true); + for (uint32_t i = 0; i < num_cores; i++) { + const CoreCoord& core = cores[i]; + // cout << "i=" << i << endl; + + uint32_t input_end_flat_h = input_shard_height - 1; + uint32_t output_end_flat_h = output_shard_height - 1; + + bool halo_prev_read_enabled = false; + DownsampleReadPatternParams halo_prev_read_pattern_params; + uint32_t halo_prev_noc_x = 0; + uint32_t halo_prev_noc_y = 0; + uint32_t halo_prev_start_addr = 0; + uint32_t halo_prev_addr_offset = 0; + uint32_t halo_prev_num_tiles = 0; + uint32_t halo_prev_size_bytes = 0; + uint32_t halo_prev_input_num_rows_of_tiles = 0; + uint32_t halo_prev_read_pattern_offset = 0; + + bool halo_next_read_enabled = false; + DownsampleReadPatternParams halo_next_read_pattern_params; + uint32_t halo_next_noc_x = 0; + uint32_t halo_next_noc_y = 0; + uint32_t halo_next_start_addr = 0; + uint32_t halo_next_addr_offset = 0; + uint32_t halo_next_num_tiles = 0; + uint32_t halo_next_size_bytes = 0; + uint32_t halo_next_input_num_rows_of_tiles = 0; + uint32_t halo_next_read_pattern_offset = 0; + uint32_t local_read_pattern_offset = 0; + uint32_t current_core_input_end_flat_h = input_end_flat_h; + uint32_t next_core_input_end_flat_h = input_end_flat_h; + if (memory_layout == TensorMemoryLayout::BLOCK_SHARDED) { + if (i % num_cores_x == 0) { + // first core in row + // reset + v.input_flat_h = 0; + v.output_flat_h = 0; + } else if (i % num_cores_x == num_cores_x - 1) { + // set unpadded height as end index for last core in row + current_core_input_end_flat_h = last_core_input_shard_height_unpadded - 1; + output_end_flat_h = last_core_output_shard_height_unpadded - 1; + } else if (i % num_cores_x == num_cores_x - 2) { + next_core_input_end_flat_h = last_core_input_shard_height_unpadded - 1; + } + } else if (i == num_cores - 1) { + // for height sharding, set unpadded height as end index for last core + current_core_input_end_flat_h = last_core_input_shard_height_unpadded - 1; + output_end_flat_h = last_core_output_shard_height_unpadded - 1; + } else if (i == num_cores - 2) { + next_core_input_end_flat_h = last_core_input_shard_height_unpadded - 1; + } + if (v.input_flat_h != 0 && !input_flat_h_is_of_current_core) { + // halo region of previous core + TT_ASSERT(i != 0); + if (memory_layout == TensorMemoryLayout::BLOCK_SHARDED) { + TT_ASSERT(prev_core.y == core.y); // for block sharding case, prev core is left core + } + halo_prev_read_enabled = true; + TT_ASSERT(v.input_flat_h < input_shard_height); + // get halo start tile address from height idx + uint32_t halo_prev_start_tile_id_h = v.input_flat_h / TILE_HEIGHT; + TT_ASSERT( + input_shard_height - v.input_flat_h <= + TILE_HEIGHT * + halo_prev_input_cb_max_rows_of_tiles); // halo input cb is hardcoded to store only 4 rows of tiles + // for now. TODO: allocate bigger CB or read in blocks + // get halo size + halo_prev_size_bytes = (input_shard_height - (halo_prev_start_tile_id_h * TILE_HEIGHT)) * + input_shard_width / TILE_HW * input_single_tile_size; + TT_ASSERT(halo_prev_size_bytes % input_single_tile_size == 0); + halo_prev_num_tiles = halo_prev_size_bytes / input_single_tile_size; + TT_ASSERT(halo_prev_num_tiles <= num_halo_prev_cb_input_tiles); + TT_ASSERT(halo_prev_num_tiles % num_input_tiles_in_row == 0); + halo_prev_input_num_rows_of_tiles = halo_prev_num_tiles / num_input_tiles_in_row; + halo_prev_addr_offset = num_input_tiles_in_row * halo_prev_start_tile_id_h * input_single_tile_size; + halo_prev_start_addr = GetCircularBufferConfig(program, input_cb).globally_allocated_address().value(); + + TT_ASSERT( + (halo_prev_start_addr + halo_prev_addr_offset) % 32 == 0); // read address should be 32 byte aligned + auto halo_noc_coords = device->worker_core_from_logical_core(prev_core); + halo_prev_noc_x = halo_noc_coords.x; + halo_prev_noc_y = halo_noc_coords.y; + TT_ASSERT(v.input_flat_h >= halo_prev_start_tile_id_h * TILE_HEIGHT); + halo_prev_read_pattern_offset = v.input_flat_h - (halo_prev_start_tile_id_h * TILE_HEIGHT); + local_read_pattern_offset = halo_prev_input_num_rows_of_tiles * TILE_HEIGHT; + halo_prev_read_pattern_params = generate_downsample_read_pattern( + v, img_height, img_width, img_stride_h, img_stride_w, input_end_flat_h, output_end_flat_h, true, false); + } + // local core + TT_ASSERT(v.output_flat_h < output_shard_height); + uint32_t local_start_h = v.input_flat_h; + DownsampleReadPatternParams local_read_pattern_params = generate_downsample_read_pattern( + v, + img_height, + img_width, + img_stride_h, + img_stride_w, + current_core_input_end_flat_h, + output_end_flat_h, + false, + false); + TT_ASSERT(v.output_flat_h <= output_shard_height); + uint32_t local_end_h_exclusive = v.input_flat_h == 0 ? input_shard_height : v.input_flat_h; + uint32_t local_num_rows = local_end_h_exclusive - local_start_h; + TT_ASSERT(local_num_rows > 0); + uint32_t local_input_num_rows_of_tiles = std::ceil((double)local_num_rows / (double)TILE_HEIGHT); + uint32_t local_input_offset_rows_of_tiles = local_start_h / TILE_HEIGHT; + if (local_start_h != 0) { + TT_ASSERT(local_read_pattern_offset == 0); + local_read_pattern_offset = local_start_h % TILE_HEIGHT; + } + if (v.input_flat_h != 0) { + input_flat_h_is_of_current_core = false; + } else { + input_flat_h_is_of_current_core = true; // updating flag for next core + } + TT_ASSERT(local_input_num_rows_of_tiles <= num_rows_of_input_tiles); + + if (v.output_flat_h != 0) { + // need to read halo from next core + TT_ASSERT(i != num_cores - 1); + TT_ASSERT(v.input_flat_h == 0); + const CoreCoord& next_core = cores[i + 1]; + if (memory_layout == TensorMemoryLayout::BLOCK_SHARDED) { + TT_ASSERT(next_core.y == core.y); // for block sharding case, next core is right core + } + halo_next_read_enabled = true; + halo_next_read_pattern_params = generate_downsample_read_pattern( + v, + img_height, + img_width, + img_stride_h, + img_stride_w, + next_core_input_end_flat_h, + output_end_flat_h, + false, + true); + TT_ASSERT(v.output_flat_h == 0); + TT_ASSERT(v.input_flat_h != 0 && v.input_flat_h < input_shard_height); + TT_ASSERT( + v.input_flat_h <= TILE_HEIGHT * halo_next_input_cb_max_rows_of_tiles, + "v.input_flat_h ({}) should be <= TILE_HEIGHT * halo_next_input_cb_max_rows_of_tiles ({})", + v.input_flat_h, + halo_next_input_cb_max_rows_of_tiles); // halo next input cb is hardcoded to store only 5 rows of tiles + // for now. TODO: allocate bigger CB or read in blocks + uint32_t halo_next_end_tile_id_h = v.input_flat_h / TILE_HEIGHT; + // get halo size + halo_next_size_bytes = + (halo_next_end_tile_id_h + 1) * TILE_HEIGHT * input_shard_width / TILE_HW * input_single_tile_size; + TT_ASSERT(halo_next_size_bytes % input_single_tile_size == 0); + halo_next_num_tiles = halo_next_size_bytes / input_single_tile_size; + TT_ASSERT(halo_next_num_tiles <= num_halo_next_cb_input_tiles); + TT_ASSERT(halo_next_num_tiles % num_input_tiles_in_row == 0); + halo_next_input_num_rows_of_tiles = halo_next_num_tiles / num_input_tiles_in_row; + halo_next_addr_offset = 0; + halo_next_start_addr = GetCircularBufferConfig(program, input_cb).globally_allocated_address().value(); + TT_ASSERT( + (halo_next_start_addr + halo_next_addr_offset) % 32 == 0); // read address should be 32 byte aligned + auto halo_noc_coords = device->worker_core_from_logical_core(next_core); + halo_next_noc_x = halo_noc_coords.x; + halo_next_noc_y = halo_noc_coords.y; + TT_ASSERT(halo_prev_input_num_rows_of_tiles == 0); + halo_next_read_pattern_offset = local_input_num_rows_of_tiles * TILE_HEIGHT; + } + TT_ASSERT(v.output_flat_h == 0); + + // Compile runtime args + vector compile_rt_kernel_args = { + local_input_num_rows_of_tiles, + local_input_offset_rows_of_tiles, + halo_prev_read_enabled, + halo_prev_input_num_rows_of_tiles, + halo_next_read_enabled, + halo_next_input_num_rows_of_tiles, + }; + + tt::tt_metal::SetRuntimeArgs(program, downsample_compute_kernel_id, core, compile_rt_kernel_args); + + // Writer runtime args + vector writer_kernel_args = { + (uint32_t)img_height, + (uint32_t)img_width, + (uint32_t)img_stride_h, + (uint32_t)img_stride_w, + + // halo prev args + halo_prev_read_enabled, + halo_prev_noc_x, + halo_prev_noc_y, + halo_prev_num_tiles, + halo_prev_start_addr, + halo_prev_addr_offset, + halo_prev_size_bytes, + + // halo prev read pattern args + halo_prev_read_pattern_offset, + halo_prev_read_pattern_params.top_partial_middle_aligned_row_width, + halo_prev_read_pattern_params.skip_top_partial_middle_aligned_row, + halo_prev_read_pattern_params.top_partial_right_aligned_row_width, + halo_prev_read_pattern_params.skip_top_partial_right_aligned_row, + halo_prev_read_pattern_params.num_rows_top_partial_image, + halo_prev_read_pattern_params.num_skip_rows_top_partial_image, + halo_prev_read_pattern_params.num_full_images, + halo_prev_read_pattern_params.num_rows_bottom_partial_image, + halo_prev_read_pattern_params.num_skip_rows_bottom_partial_image, + halo_prev_read_pattern_params.bottom_partial_left_aligned_row_width, + halo_prev_read_pattern_params.skip_bottom_partial_left_aligned_row, + + // local read pattern args + local_read_pattern_offset, + local_read_pattern_params.top_partial_middle_aligned_row_width, + local_read_pattern_params.skip_top_partial_middle_aligned_row, + local_read_pattern_params.top_partial_right_aligned_row_width, + local_read_pattern_params.skip_top_partial_right_aligned_row, + local_read_pattern_params.num_rows_top_partial_image, + local_read_pattern_params.num_skip_rows_top_partial_image, + local_read_pattern_params.num_full_images, + local_read_pattern_params.num_rows_bottom_partial_image, + local_read_pattern_params.num_skip_rows_bottom_partial_image, + local_read_pattern_params.bottom_partial_left_aligned_row_width, + local_read_pattern_params.skip_bottom_partial_left_aligned_row, + + // halo next core args + halo_next_read_enabled, + halo_next_noc_x, + halo_next_noc_y, + halo_next_num_tiles, + halo_next_start_addr, + halo_next_addr_offset, + halo_next_size_bytes, + + // halo next read pattern args + halo_next_read_pattern_offset, + halo_next_read_pattern_params.top_partial_middle_aligned_row_width, + halo_next_read_pattern_params.skip_top_partial_middle_aligned_row, + halo_next_read_pattern_params.top_partial_right_aligned_row_width, + halo_next_read_pattern_params.skip_top_partial_right_aligned_row, + halo_next_read_pattern_params.num_rows_top_partial_image, + halo_next_read_pattern_params.num_skip_rows_top_partial_image, + halo_next_read_pattern_params.num_full_images, + halo_next_read_pattern_params.num_rows_bottom_partial_image, + halo_next_read_pattern_params.num_skip_rows_bottom_partial_image, + halo_next_read_pattern_params.bottom_partial_left_aligned_row_width, + halo_next_read_pattern_params.skip_bottom_partial_left_aligned_row, + + halo_prev_input_num_rows_of_tiles + local_input_num_rows_of_tiles + halo_next_input_num_rows_of_tiles, + num_input_tiles_in_row, + num_output_tiles, + + (uint32_t) false}; + + tt::tt_metal::SetRuntimeArgs(program, downsample_writer_kernel_id, core, writer_kernel_args); + prev_core = core; + } + + auto override_runtime_args_callback = [input_cb = input_cb, + final_tilize_output_cb = final_tilize_output_cb, + downsample_writer_kernel_id = downsample_writer_kernel_id, + cores = cores]( + const void* operation, + Program& program, + const std::vector& input_tensors, + const std::vector>& optional_input_tensors, + const std::vector& output_tensors) { + auto src_buffer = input_tensors.at(0).buffer(); + auto dst_buffer = output_tensors.at(0).buffer(); + + UpdateDynamicCircularBufferAddress(program, input_cb, *src_buffer); + UpdateDynamicCircularBufferAddress(program, final_tilize_output_cb, *dst_buffer); + + auto& writer_runtime_args_by_core = GetRuntimeArgs(program, downsample_writer_kernel_id); + for (const auto& core : cores) { + auto& runtime_args = writer_runtime_args_by_core[core.x][core.y]; + runtime_args[8] = src_buffer->address(); + runtime_args[39] = src_buffer->address(); + } + }; + + return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_args_callback}; +} + +} // namespace ttnn::operations::data_movement::detail diff --git a/ttnn/cpp/ttnn/operations/data_movement/downsample/device/downsample_program_factory.hpp b/ttnn/cpp/ttnn/operations/data_movement/downsample/device/downsample_program_factory.hpp new file mode 100644 index 00000000000..ef32eb54942 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/data_movement/downsample/device/downsample_program_factory.hpp @@ -0,0 +1,18 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 +#pragma once + + +#include "tt_metal/host_api.hpp" + +using namespace tt::constants; + +namespace ttnn::operations::data_movement::detail { + +std::pair get_num_cores_height_width_sliced( + CoreRangeSet all_cores, TensorMemoryLayout memory_layout, ShardOrientation shard_orientation); +operation::ProgramWithCallbacks downsample_single_core( + const Tensor& a, std::array downsample_params, Tensor& output); + +} // namespace ttnn::operations::data_movement::detail diff --git a/ttnn/cpp/ttnn/operations/data_movement/pad/device/pad_program_factory.cpp b/ttnn/cpp/ttnn/operations/data_movement/pad/device/pad_program_factory.cpp new file mode 100644 index 00000000000..44a2d903705 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/data_movement/pad/device/pad_program_factory.cpp @@ -0,0 +1,1127 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "ttnn/tensor/host_buffer/functions.hpp" +#include "ttnn/deprecated/tt_dnn/op_library/work_split.hpp" +#include "ttnn/deprecated/tt_dnn/op_library/math.hpp" +#include "tt_metal/common/constants.hpp" +#include "tt_metal/detail/util.hpp" +#include "tt_metal/host_api.hpp" +#include "tt_log.h" +using namespace tt::constants; + +namespace ttnn::operations::data_movement::detail { + + +operation::ProgramWithCallbacks pad_rm_reader_writer(const Tensor &a, + Tensor &output, + const tt::tt_metal::Shape &output_tensor_shape, + const tt::tt_metal::Shape &input_tensor_start, + const float pad_value) { + Program program{}; + + auto output_shape = output_tensor_shape; + + uint32_t unpadded_row_size_nbytes = a.get_legacy_shape()[3] * a.element_size(); + uint32_t padded_row_size_nbytes = output_shape[3] * a.element_size(); // Assuming output is same datatype as input + TT_ASSERT(unpadded_row_size_nbytes <= padded_row_size_nbytes, "Padded output tensor size should be >= input tensor size"); + + // construct const buffer with the pad_value + Device *device = a.device(); + uint32_t pad_value_const_buffer_size = 32; // noc transfers in chunks of 32 + uint32_t pad_value_const_buffer_nbytes = pad_value_const_buffer_size * a.element_size(); + auto pad_value_const_buffer = tt::tt_metal::owned_buffer::create(std::vector(pad_value_const_buffer_size, bfloat16(pad_value))); + const Tensor pad_value_const_tensor = + Tensor( + OwnedStorage{pad_value_const_buffer}, + Shape(std::array{1, 1, 1, pad_value_const_buffer_size}), + DataType::BFLOAT16, + Layout::ROW_MAJOR) + .to(device, MemoryConfig{.memory_layout = TensorMemoryLayout::INTERLEAVED, .buffer_type = BufferType::L1}); + auto pad_value_const_tensor_addr = pad_value_const_tensor.buffer()->address(); + + CoreRange cores({0, 0}, {0, 0}); + uint32_t cb_id = tt::CB::c_in0; + uint32_t cb_npages = 16; // multibuffering + uint32_t cb_pagesize = tt::round_up(padded_row_size_nbytes, tt::constants::TILE_WIDTH); + tt::DataFormat in_df = tt::tt_metal::datatype_to_dataformat_converter(a.get_dtype()); + tt::tt_metal::CircularBufferConfig cb_config = tt::tt_metal::CircularBufferConfig(cb_npages * cb_pagesize, {{cb_id, in_df}}) + .set_page_size(cb_id, cb_pagesize); + auto cb = tt::tt_metal::CreateCircularBuffer(program, cores, cb_config); + + + Buffer *src0_buffer = a.buffer(); + Buffer *dst_buffer = output.buffer(); + TT_ASSERT(dst_buffer != nullptr, "Output buffer should be allocated on device!"); + + bool src0_is_dram = src0_buffer->buffer_type() == BufferType::DRAM ? 1 : 0; + bool dst_is_dram = dst_buffer->buffer_type() == BufferType::DRAM ? 1 : 0; + bool src_stick_size_is_power_of_two = is_power_of_two_at_least_32(unpadded_row_size_nbytes); + uint32_t src_log2_stick_size = src_stick_size_is_power_of_two ? (std::uint32_t) std::log2(unpadded_row_size_nbytes) : 0; + bool dst_stick_size_is_power_of_two = is_power_of_two_at_least_32(padded_row_size_nbytes); + uint32_t dst_log2_stick_size = dst_stick_size_is_power_of_two ? (std::uint32_t) std::log2(padded_row_size_nbytes) : 0; + std::vector reader_ct_args = {(std::uint32_t) src0_is_dram, + (std::uint32_t) dst_is_dram, + (std::uint32_t) src_stick_size_is_power_of_two, + (std::uint32_t) src_log2_stick_size, + (std::uint32_t) dst_stick_size_is_power_of_two, + (std::uint32_t) dst_log2_stick_size}; + std::vector writer_ct_args = reader_ct_args; + + bfloat16 bfloat_pad_value = bfloat16(pad_value); + bfloat16 bfloat_zero = bfloat16(0.0f); + uint32_t packed_pad_value = pack_two_bfloat16_into_uint32({bfloat_zero, bfloat_pad_value}); + + KernelHandle reader_kernel_id = CreateKernel(program, + "ttnn/cpp/ttnn/operations/data_movement/pad/device/kernels/dataflow/reader_pad_dims_rm_interleaved.cpp", + cores, + tt::tt_metal::ReaderDataMovementConfig(reader_ct_args)); + KernelHandle writer_kernel_id = CreateKernel(program, + "ttnn/cpp/ttnn/operations/data_movement/pad/device/kernels/dataflow/writer_pad_dims_rm_interleaved.cpp", + cores, + tt::tt_metal::WriterDataMovementConfig(writer_ct_args)); + uint32_t padded_row_diff_size_nbytes = padded_row_size_nbytes - unpadded_row_size_nbytes; + + #if 0 + { + log_debug("src0_buffer_addr: {}", src0_buffer->address()); + log_debug("dst_buffer_addr: {}", dst_buffer->address()); + log_debug("a.shape[0]: {}", a.get_legacy_shape()[0]); + log_debug("out.shape[0]: {}", output_shape[0]); + log_debug("a.shape[1]: {}", a.get_legacy_shape()[1]); + log_debug("out.shape[1]: {}", output_shape[1]); + log_debug("a.shape[2]: {}", a.get_legacy_shape()[2]); + log_debug("out.shape[2]: {}", output_shape[2]); + log_debug("s.shape[3]: {}", a.get_legacy_shape()[3]); + log_debug("out.shape[3]: {}", output_shape[3]); + log_debug("unpadded_row_size_nbytes: {}", unpadded_row_size_nbytes); + log_debug("padded_row_size_nbytes: {}", padded_row_size_nbytes); + log_debug("padded_row_diff_size_nbytes: {}", padded_row_diff_size_nbytes); + log_debug("pad_value_const_tensor_addr: {}", pad_value_const_tensor_addr); + log_debug("pad_value_const_buffer_nbytes: {}", pad_value_const_buffer_nbytes); + log_debug("packed_pad_value: {}", packed_pad_value); + } + #endif + + uint32_t start_src_stick_id = 0; + uint32_t start_dst_stick_id = 0; + vector reader_rt_args = {src0_buffer->address(), + dst_buffer->address(), + a.get_legacy_shape()[0], + output_shape[0], + a.get_legacy_shape()[1], + output_shape[1], + a.get_legacy_shape()[2], + output_shape[2], + a.get_legacy_shape()[3], + output_shape[3], + unpadded_row_size_nbytes, + padded_row_size_nbytes, + padded_row_diff_size_nbytes, + pad_value_const_tensor_addr, + pad_value_const_buffer_nbytes, + packed_pad_value, + start_src_stick_id, + start_dst_stick_id, + 0, + 0, + 0, + output_shape[2], + a.get_legacy_shape()[2], + unpadded_row_size_nbytes, + padded_row_size_nbytes, + 0, + output.get_legacy_shape()[0] + }; + vector writer_rt_args = reader_rt_args; + tt::tt_metal::SetRuntimeArgs(program, + reader_kernel_id, + cores, + reader_rt_args); + tt::tt_metal::SetRuntimeArgs(program, + writer_kernel_id, + cores, + writer_rt_args); + + auto override_runtime_args_callback = + [reader_kernel_id=reader_kernel_id, writer_kernel_id=writer_kernel_id]( + const Program &program, + const std::vector& input_buffers, + const std::vector& output_buffers) { + auto src_buffer = input_buffers.at(0); + auto dst_buffer = output_buffers.at(0); + CoreCoord core = {0, 0}; + { + auto &runtime_args = tt::tt_metal::GetRuntimeArgs(program, reader_kernel_id, core); + runtime_args[0] = src_buffer->address(); + runtime_args[1] = dst_buffer->address(); + } + { + auto &runtime_args = tt::tt_metal::GetRuntimeArgs(program, writer_kernel_id, core); + runtime_args[0] = src_buffer->address(); + runtime_args[1] = dst_buffer->address(); + } + }; + + return {std::move(program), override_runtime_args_callback}; +} + +operation::ProgramWithCallbacks pad_rm_opt(const Tensor &a, + Tensor &output, + const Shape &output_tensor_shape, + const Shape &input_tensor_start, + const float pad_value) { + Program program{}; + + auto output_shape = output_tensor_shape; + + uint32_t unpadded_row_size_nbytes = a.get_legacy_shape()[3] * a.element_size(); + uint32_t padded_row_size_nbytes = output_shape[3] * a.element_size(); // Assuming output is same datatype as input + TT_ASSERT(unpadded_row_size_nbytes <= padded_row_size_nbytes, "Padded output tensor size should be >= input tensor size"); + + Device *device = a.device(); + auto dst_buffer_l1 = Buffer(device, padded_row_size_nbytes, padded_row_size_nbytes, BufferType::L1); + + // construct const buffer with the pad_value + uint32_t pad_value_const_buffer_size = 32; // noc transfers in chunks of 32 + uint32_t pad_value_const_buffer_nbytes = pad_value_const_buffer_size * a.element_size(); + auto pad_value_const_buffer = owned_buffer::create(std::vector(pad_value_const_buffer_size, bfloat16(pad_value))); + const Tensor pad_value_const_tensor = + Tensor( + OwnedStorage{pad_value_const_buffer}, + Shape(std::array{1, 1, 1, pad_value_const_buffer_size}), + DataType::BFLOAT16, + Layout::ROW_MAJOR) + .to(device, MemoryConfig{.memory_layout = TensorMemoryLayout::INTERLEAVED, .buffer_type = BufferType::L1}); + auto pad_value_const_tensor_addr = pad_value_const_tensor.buffer()->address(); + + Buffer *src0_buffer = a.buffer(); + Buffer *dst_buffer = output.buffer(); + TT_ASSERT(dst_buffer != nullptr, "Output buffer should be allocated on device!"); + + bool src0_is_dram = src0_buffer->buffer_type() == BufferType::DRAM ? 1 : 0; + bool dst_is_dram = dst_buffer->buffer_type() == BufferType::DRAM ? 1 : 0; + bool src_stick_size_is_power_of_two = is_power_of_two_at_least_32(unpadded_row_size_nbytes); + uint32_t src_log2_stick_size = src_stick_size_is_power_of_two ? (std::uint32_t) std::log2(unpadded_row_size_nbytes) : 0; + bool dst_stick_size_is_power_of_two = is_power_of_two_at_least_32(padded_row_size_nbytes); + uint32_t dst_log2_stick_size = dst_stick_size_is_power_of_two ? (std::uint32_t) std::log2(padded_row_size_nbytes) : 0; + std::vector reader_ct_args = {(std::uint32_t) src0_is_dram, + (std::uint32_t) dst_is_dram, + (std::uint32_t) src_stick_size_is_power_of_two, + (std::uint32_t) src_log2_stick_size, + (std::uint32_t) dst_stick_size_is_power_of_two, + (std::uint32_t) dst_log2_stick_size}; + + bfloat16 bfloat_pad_value = bfloat16(pad_value); + bfloat16 bfloat_zero = bfloat16(0.0f); + uint32_t packed_pad_value = pack_two_bfloat16_into_uint32({bfloat_zero, bfloat_pad_value}); + + CoreRange core({0, 0}, {0, 0}); + KernelHandle reader_kernel_id = CreateKernel(program, + "ttnn/cpp/ttnn/operations/data_movement/pad/device/kernels/dataflow/pad_dims_rm_interleaved_opt.cpp", + core, + tt::tt_metal::ReaderDataMovementConfig(reader_ct_args)); + uint32_t padded_row_diff_size_nbytes = padded_row_size_nbytes - unpadded_row_size_nbytes; + + #if 0 + { + tt::log_debug("src0_buffer_addr: {}", src0_buffer->address()); + tt::log_debug("dst_buffer_addr: {}", dst_buffer->address()); + tt::log_debug("a.shape[0]: {}", a.get_legacy_shape()[0]); + tt::log_debug("out.shape[0]: {}", output_shape[0]); + tt::log_debug("a.shape[1]: {}", a.get_legacy_shape()[1]); + tt::log_debug("out.shape[1]: {}", output_shape[1]); + tt::log_debug("a.shape[2]: {}", a.get_legacy_shape()[2]); + tt::log_debug("out.shape[2]: {}", output_shape[2]); + tt::log_debug("s.shape[3]: {}", a.get_legacy_shape()[3]); + tt::log_debug("out.shape[3]: {}", output_shape[3]); + tt::log_debug("unpadded_row_size_nbytes: {}", unpadded_row_size_nbytes); + tt::log_debug("padded_row_size_nbytes: {}", padded_row_size_nbytes); + tt::log_debug("padded_row_diff_size_nbytes: {}", padded_row_diff_size_nbytes); + tt::log_debug("pad_value_const_tensor_addr: {}", pad_value_const_tensor_addr); + tt::log_debug("pad_value_const_buffer_nbytes: {}", pad_value_const_buffer_nbytes); + tt::log_debug("packed_pad_value: {}", packed_pad_value); + tt::log_debug("dst_buffer_l1_addr: {}", dst_buffer_l1.address()); + } + #endif + + vector reader_rt_args = {src0_buffer->address(), + dst_buffer->address(), + a.get_legacy_shape()[0], + output_shape[0], + a.get_legacy_shape()[1], + output_shape[1], + a.get_legacy_shape()[2], + output_shape[2], + a.get_legacy_shape()[3], + output_shape[3], + unpadded_row_size_nbytes, + padded_row_size_nbytes, + padded_row_diff_size_nbytes, + pad_value_const_tensor_addr, + pad_value_const_buffer_nbytes, + packed_pad_value, + dst_buffer_l1.address()}; + tt::tt_metal::SetRuntimeArgs(program, + reader_kernel_id, + core, + reader_rt_args); + + auto override_runtime_args_callback = [kernel_id=reader_kernel_id](const Program &program, + const std::vector& input_buffers, + const std::vector& output_buffers) { + auto src_buffer = input_buffers.at(0); + auto dst_buffer = output_buffers.at(0); + CoreCoord core = {0, 0}; + { + auto &runtime_args = tt::tt_metal::GetRuntimeArgs(program, kernel_id, core); + runtime_args[0] = src_buffer->address(); + runtime_args[1] = dst_buffer->address(); + } + }; + + return {std::move(program), override_runtime_args_callback}; +} + +operation::ProgramWithCallbacks pad_rm(const Tensor &a, Tensor &output, const Shape &output_tensor_shape, const Shape &input_tensor_start, const float pad_value) { + + tt::tt_metal::Program program{}; + + CoreRange core({0, 0}, {0, 0}); + + // This should allocate a DRAM buffer on the device + tt::tt_metal::Device *device = a.device(); + + auto output_shape = output_tensor_shape; + + tt::tt_metal::Buffer *src0_buffer = a.buffer(); + + uint32_t unpadded_row_size_bytes = a.get_legacy_shape()[3] * a.element_size(); + uint32_t padded_row_size_bytes = output_shape[3] * a.element_size(); + + tt::tt_metal::Buffer *dst_buffer = output.buffer(); + TT_ASSERT(dst_buffer != nullptr, "Output buffer should be allocated on device!"); + + uint32_t src_stick_size = unpadded_row_size_bytes; + uint32_t dst_stick_size = padded_row_size_bytes; + + uint32_t dst_buffer_size = dst_stick_size; + + tt::tt_metal::InterleavedBufferConfig buff_config{ + .device= device, + .size = dst_buffer_size, + .page_size = dst_buffer_size, + .buffer_type = tt::tt_metal::BufferType::L1 + }; + + auto dst_buffer_l1 = tt::tt_metal::CreateBuffer(buff_config); + + bfloat16 bfloat_pad_value = bfloat16(pad_value); + uint32_t packed_pad_value = pack_two_bfloat16_into_uint32({bfloat_pad_value, bfloat_pad_value}); + + vector reader_kernel_args = { + src0_buffer->address(), + dst_buffer->address(), + a.get_legacy_shape()[0], + output_shape[0], + a.get_legacy_shape()[1], + output_shape[1], + a.get_legacy_shape()[2], + output_shape[2], + a.get_legacy_shape()[3], + output_shape[3], + unpadded_row_size_bytes, + padded_row_size_bytes, + padded_row_size_bytes - unpadded_row_size_bytes, + packed_pad_value, + dst_buffer_l1->address() + }; + bool src0_is_dram = src0_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; + bool dst_is_dram = dst_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; + bool src_stick_size_is_power_of_two = tt::tt_metal::is_power_of_two_at_least_32(src_stick_size); + uint32_t src_log2_stick_size = src_stick_size_is_power_of_two ? (std::uint32_t) std::log2(src_stick_size) : 0; + bool dst_stick_size_is_power_of_two = tt::tt_metal::is_power_of_two_at_least_32(dst_stick_size); + uint32_t dst_log2_stick_size = dst_stick_size_is_power_of_two ? (std::uint32_t) std::log2(dst_stick_size) : 0; + std::vector compile_time_args_vec = { + (std::uint32_t) src0_is_dram, + (std::uint32_t) dst_is_dram, + (std::uint32_t) src_stick_size_is_power_of_two, + (std::uint32_t) src_log2_stick_size, + (std::uint32_t) dst_stick_size_is_power_of_two, + (std::uint32_t) dst_log2_stick_size, + + }; + + // Tilized reader + tt::tt_metal::KernelHandle unary_reader_kernel_id = tt::tt_metal::CreateKernel( + program, + "ttnn/cpp/ttnn/operations/data_movement/pad/device/kernels/dataflow/pad_dims_rm_interleaved.cpp", + core, + tt::tt_metal::ReaderDataMovementConfig(compile_time_args_vec)); + + tt::tt_metal::SetRuntimeArgs( + program, + unary_reader_kernel_id, + core, + reader_kernel_args + ); + + auto override_runtime_args_callback = [kernel_id=unary_reader_kernel_id]( + const Program &program, + const std::vector& input_buffers, + const std::vector& output_buffers + ) { + + auto src_buffer = input_buffers.at(0); + auto dst_buffer = output_buffers.at(0); + + CoreCoord core = {0, 0}; + + { + auto &runtime_args = tt::tt_metal::GetRuntimeArgs(program, kernel_id, core); + runtime_args[0] = src_buffer->address(); + runtime_args[1] = dst_buffer->address(); + } + }; + + return {std::move(program), override_runtime_args_callback}; +} + +operation::ProgramWithCallbacks pad_tile(const Tensor &a, Tensor& output, const tt::tt_metal::Shape &output_tensor_shape, const tt::tt_metal::Shape &input_tensor_start, const float pad_value) { + + tt::tt_metal::Program program{}; + + CoreRange core({0, 0}, {0, 0}); + + // This should allocate a DRAM buffer on the device + tt::tt_metal::Device *device = a.device(); + + auto output_shape = output_tensor_shape; + + tt::tt_metal::Buffer *src0_buffer = a.buffer(); + + tt::tt_metal::Buffer *dst_buffer = output.buffer(); + TT_ASSERT(dst_buffer != nullptr, "Output buffer should be allocated on device!"); + + tt::DataFormat cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(a.get_dtype()); + uint32_t single_tile_size = tt::tt_metal::detail::TileSize(cb_data_format); + + tt::log_debug("pad_tile"); + tt::log_debug("cb_data_format: {}", cb_data_format); + tt::log_debug("single_tile_size: {}", single_tile_size); + tt::log_debug("output_tensor_shape: {}", output_tensor_shape); + tt::log_debug("input_tensor_start: {}", input_tensor_start); + tt::log_debug("pad_value: {}", pad_value); + + uint32_t src0_cb_index = 0; + uint32_t num_input_tiles = 2; + tt::tt_metal::CircularBufferConfig cb_src0_config = tt::tt_metal::CircularBufferConfig(num_input_tiles * single_tile_size, {{src0_cb_index, cb_data_format}}) + .set_page_size(src0_cb_index, single_tile_size); + auto cb_src0 = tt::tt_metal::CreateCircularBuffer(program, core, cb_src0_config); + + uint32_t src1_cb_index = 1; // For pad buffer + uint32_t num_pad_tiles = 1; + tt::tt_metal::CircularBufferConfig cb_src1_config = tt::tt_metal::CircularBufferConfig(num_pad_tiles * single_tile_size, {{src1_cb_index, cb_data_format}}) + .set_page_size(src1_cb_index, single_tile_size); + auto cb_src1 = tt::tt_metal::CreateCircularBuffer(program, core, cb_src1_config); + + bfloat16 bfloat_pad_value = bfloat16(pad_value); + uint32_t packed_pad_value = pack_two_bfloat16_into_uint32({bfloat_pad_value, bfloat_pad_value}); + + uint32_t num_unpadded_Xt = a.get_legacy_shape()[3] / TILE_WIDTH; + uint32_t num_total_Xt = output_shape[3] / TILE_WIDTH; + uint32_t num_padded_Xt = num_total_Xt - num_unpadded_Xt; + uint32_t num_unpadded_Yt = a.get_legacy_shape()[2] / TILE_HEIGHT; + uint32_t num_total_Yt = output_shape[2] / TILE_HEIGHT; + uint32_t num_padded_Yt = (num_total_Yt - num_unpadded_Yt) * num_total_Xt; + uint32_t num_unpadded_Z = a.get_legacy_shape()[1]; + uint32_t num_total_Z = output_shape[1]; + uint32_t num_padded_Zt = (num_total_Z - num_unpadded_Z) * num_total_Yt * num_total_Xt; + uint32_t num_unpadded_W = a.get_legacy_shape()[0]; + uint32_t num_total_W = output_shape[0]; + uint32_t num_padded_Wt = (num_total_W - num_unpadded_W) * num_total_Z * num_total_Yt * num_total_Xt; + + uint32_t num_unpadded_tiles = a.volume() / TILE_HW; + + vector reader_kernel_args = { + src0_buffer->address(), + num_unpadded_tiles, 0 + }; + vector writer_kernel_args = { + dst_buffer->address(), + num_unpadded_W, + num_padded_Wt, + num_unpadded_Z, + num_padded_Zt, + num_unpadded_Yt, + num_padded_Yt, + num_unpadded_Xt, + num_padded_Xt, + packed_pad_value, + }; + + // Reader compile-time args + // Data is 32 byte aligned + bool src0_is_dram = src0_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; + bool dst_is_dram = dst_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; + std::vector reader_compile_time_args = { + // interleaved accessor args + (std::uint32_t) src0_is_dram + }; + std::vector writer_compile_time_args = { + // interleaved accessor args + (std::uint32_t) src0_cb_index, + (std::uint32_t) src1_cb_index, + (std::uint32_t) dst_is_dram + }; + // Tilized reader + tt::tt_metal::KernelHandle unary_reader_kernel_id = tt::tt_metal::CreateKernel( + program, + "ttnn/cpp/ttnn/operations/eltwise/unary/device/kernels/dataflow/reader_unary_interleaved_start_id.cpp", + core, + tt::tt_metal::ReaderDataMovementConfig(reader_compile_time_args)); + + tt::tt_metal::KernelHandle unary_writer_kernel_id = tt::tt_metal::CreateKernel( + program, + "ttnn/cpp/ttnn/operations/data_movement/pad/device/kernels/dataflow/writer_unary_pad_dims_interleaved.cpp", + core, + tt::tt_metal::WriterDataMovementConfig(writer_compile_time_args)); + + tt::tt_metal::SetRuntimeArgs( + program, + unary_reader_kernel_id, + core, + reader_kernel_args + ); + + tt::tt_metal::SetRuntimeArgs( + program, + unary_writer_kernel_id, + core, + writer_kernel_args + ); + + auto override_runtime_args_callback = [unary_reader_kernel_id, unary_writer_kernel_id]( + const Program &program, + const std::vector& input_buffers, + const std::vector& output_buffers + ) { + + auto src_dram_buffer = input_buffers.at(0); + + auto dst_dram_buffer = output_buffers.at(0); + + CoreCoord core = {0, 0}; + + { + auto &runtime_args = tt::tt_metal::GetRuntimeArgs(program, unary_reader_kernel_id, core); + runtime_args[0] = src_dram_buffer->address(); + } + + { + auto &runtime_args = tt::tt_metal::GetRuntimeArgs(program, unary_writer_kernel_id, core); + runtime_args[0] = dst_dram_buffer->address(); + } + }; + + return {std::move(program), override_runtime_args_callback}; +} + + +inline void log_rt_args(const CoreCoord& core, vector& args) { + for (auto v : args) { + tt::log_debug(tt::LogOp, "{},{} :: {}", core.x, core.y, v); + } +} + +// This is currently mostly hardcoded for resnet shapes +inline std::tuple + split_across_cores(CoreCoord grid_size, uint32_t nbatch, uint32_t nchannel, uint32_t ntiles_h, uint32_t ntiles_w) { + + uint32_t ncores, ncores_h, ncores_w, ntiles_per_core_h, ntiles_per_core_w, nbatch_per_core_h, ncores_per_batch_h; + + ncores_h = 1; + + // each batch needs to be padded independently + switch (nbatch) { + case 1: + ncores_h = 1; + nbatch_per_core_h = 1; + ntiles_per_core_h = 1; + switch (ntiles_h) { + case 2: ncores_h = 2; ntiles_per_core_h = 1; break; + case 4: ncores_h = 4; ntiles_per_core_h = 1; break; + case 8: ncores_h = 8; ntiles_per_core_h = 1; break; + case 64: ncores_h = 8; ntiles_per_core_h = 8; break; + } + ncores_per_batch_h = ncores_h; + break; + + case 2: + ncores_h = 1; + ncores_per_batch_h = 1; + nbatch_per_core_h = 1; + ntiles_per_core_h = 1; + switch (ntiles_h) { + case 2: ncores_per_batch_h = 2; ncores_h = ncores_per_batch_h * nbatch; ntiles_per_core_h = 1; break; + case 4: ncores_per_batch_h = 4; ncores_h = ncores_per_batch_h * nbatch; ntiles_per_core_h = 1; break; + case 8: ncores_per_batch_h = 4; ncores_h = ncores_per_batch_h * nbatch; ntiles_per_core_h = 2; break; + case 64: ncores_per_batch_h = 4; ncores_h = ncores_per_batch_h * nbatch; ntiles_per_core_h = 16; break; + } + break; + + case 8: + ncores_h = 8; + ncores_per_batch_h = 1; + nbatch_per_core_h = 1; + ntiles_per_core_h = ntiles_h; + break; + + default: + TT_ASSERT(false, "unhandled nbatch. TODO"); + + // generic case -- TODO + + // one of the following will be 0 when grid_size.y != nbatch + uint32_t nbatch_per_core_h = nbatch / grid_size.y; // floor + uint32_t ncores_per_batch_h = grid_size.y / nbatch; // floor + if (nbatch == grid_size.y) { + nbatch_per_core_h = 1; + ncores_per_batch_h = 1; + } + + // currently uses hardcoded values for resnet50 + // TT_ASSERT(ntiles_h == 1 || ntiles_h == 2 || ntiles_h == 4 || ntiles_h == 16, "Only Resnet50 shapes are supported in multicore version for now."); + // TT_ASSERT(ntiles_w == 64, "Only Resnet50 shapes are supported in multicore version for now."); + + TT_ASSERT(nbatch <= grid_size.y, "Unsupported case with nbatch > grid_size.y!"); + + uint32_t ncores_h = 1; + uint32_t ntiles_per_core_h = ntiles_h / ncores_h; + if (nbatch_per_core_h == 0) { + // there are multiple cores along h per batch + nbatch_per_core_h = 1; + ncores_h = ncores_per_batch_h * nbatch; + ntiles_per_core_h = ntiles_h / ncores_per_batch_h; + } else if (ncores_per_batch_h == 0) { + // unsupported case. TODO. + TT_ASSERT(false); + // there are multiple batch per core along h + // ncores_per_batch_h = 1; + // ncores_h = (uint32_t) ceil((float) nbatch / nbatch_per_core_h); + // ntiles_per_core_h = nbatch_per_core_h * ntiles_h; + } else { + TT_ASSERT("Something went terribly wrong in splitting acrtoss cores"); + } + break; + } + + ncores_w = 1; + switch (ntiles_w) { + case 2: ncores_w = 2; break; + case 4: ncores_w = 4; break; + case 8: ncores_w = 8; break; + case 64: ncores_w = 8; break; + } + ncores = ncores_h * ncores_w; + ntiles_per_core_w = ntiles_w / ncores_w; + std::set all_cores; + std::set core_range; + + all_cores.insert(CoreRange(CoreCoord(0, 0), CoreCoord(ncores_w - 1, ncores_h - 1))); + core_range.insert(CoreRange(CoreCoord(0, 0), CoreCoord(ncores_w - 1, ncores_h - 1))); + + return std::make_tuple(ncores, ncores_h, ncores_w, all_cores, core_range, ntiles_per_core_h, ntiles_per_core_w, nbatch_per_core_h, ncores_per_batch_h); +} + +operation::ProgramWithCallbacks pad_rm_reader_writer_multi_core(const Tensor &a, + Tensor &output, + const tt::tt_metal::Shape &output_tensor_shape, + const tt::tt_metal::Shape &input_tensor_start, + const float pad_value) { + Program program{}; + + auto output_shape = output_tensor_shape; + + uint32_t unpadded_row_size_nbytes = a.get_legacy_shape()[3] * a.element_size(); + uint32_t padded_row_size_nbytes = output_shape[3] * a.element_size(); // Assuming output is same datatype as input + TT_ASSERT(unpadded_row_size_nbytes <= padded_row_size_nbytes, "Padded output tensor size should be >= input tensor size"); + + Device *device = a.device(); + + // construct const buffer with the pad_value + uint32_t pad_value_const_buffer_size = 32; // noc transfers in chunks of 32 + uint32_t pad_value_const_buffer_nbytes = pad_value_const_buffer_size * a.element_size(); + auto pad_value_const_buffer = owned_buffer::create(std::vector(pad_value_const_buffer_size, bfloat16(pad_value))); + // NOTE: The const buffer is always in L1 + // TODO: make a local buffer for each core? + const Tensor pad_value_const_tensor = + Tensor( + OwnedStorage{pad_value_const_buffer}, + Shape(std::array{1, 1, 1, pad_value_const_buffer_size}), + DataType::BFLOAT16, + Layout::ROW_MAJOR) + .to(device, MemoryConfig{.memory_layout = TensorMemoryLayout::INTERLEAVED, .buffer_type = BufferType::L1}); + auto pad_value_const_tensor_addr = pad_value_const_tensor.buffer()->address(); + + // uint32_t ntiles_h = output_tensor_shape[0] * output_tensor_shape[1] * output_tensor_shape[2] / TILE_HEIGHT; + uint32_t ntiles_h = output_tensor_shape[2] / TILE_HEIGHT; + uint32_t ntiles_w = output_tensor_shape[3] / TILE_WIDTH; + + auto grid_size = device->compute_with_storage_grid_size(); + uint32_t nbatch = output_tensor_shape[0]; + uint32_t nchannel = output_tensor_shape[1]; + // first the batch dim is distributed along H, and within each batch then the tiles are distributed. + auto [ncores, ncores_h, ncores_w, all_cores, core_range, ntiles_per_core_h, ntiles_per_core_w, nbatch_per_core_h, ncores_per_batch_h] = split_across_cores(grid_size, nbatch, nchannel, ntiles_h, ntiles_w); + + int32_t src_nbytes_per_core_w = ntiles_per_core_w * TILE_WIDTH * a.element_size(); + int32_t dst_nbytes_per_core_w = ntiles_per_core_w * TILE_WIDTH * output.element_size(); + + uint32_t cb_id = tt::CB::c_in0; + uint32_t cb_npages = 16; // multibuffering for perf + // uint32_t cb_npages = 1; // multibuffering for perf + uint32_t cb_pagesize = (uint32_t) ceil((float) dst_nbytes_per_core_w / tt::constants::TILE_WIDTH) * tt::constants::TILE_WIDTH; + tt::DataFormat in_df = tt::tt_metal::datatype_to_dataformat_converter(a.get_dtype()); + tt::tt_metal::CircularBufferConfig cb_config = tt::tt_metal::CircularBufferConfig(cb_npages * cb_pagesize, {{cb_id, in_df}}) + .set_page_size(cb_id, cb_pagesize); + auto cb_src0 = tt::tt_metal::CreateCircularBuffer(program, all_cores, cb_config); + + Buffer *src0_buffer = a.buffer(); + Buffer *dst_buffer = output.buffer(); + TT_ASSERT(dst_buffer != nullptr, "Output buffer should be allocated on device!"); + + bool src0_is_dram = src0_buffer->buffer_type() == BufferType::DRAM ? 1 : 0; + bool dst_is_dram = dst_buffer->buffer_type() == BufferType::DRAM ? 1 : 0; + bool src_stick_size_is_power_of_two = is_power_of_two_at_least_32(unpadded_row_size_nbytes); + uint32_t src_log2_stick_size = src_stick_size_is_power_of_two ? (std::uint32_t) std::log2(unpadded_row_size_nbytes) : 0; + bool dst_stick_size_is_power_of_two = is_power_of_two_at_least_32(padded_row_size_nbytes); + uint32_t dst_log2_stick_size = dst_stick_size_is_power_of_two ? (std::uint32_t) std::log2(padded_row_size_nbytes) : 0; + std::vector reader_ct_args = {(std::uint32_t) src0_is_dram, + (std::uint32_t) dst_is_dram, + (std::uint32_t) src_stick_size_is_power_of_two, + (std::uint32_t) src_log2_stick_size, + (std::uint32_t) dst_stick_size_is_power_of_two, + (std::uint32_t) dst_log2_stick_size}; + std::vector writer_ct_args = reader_ct_args; + + bfloat16 bfloat_pad_value = bfloat16(pad_value); + bfloat16 bfloat_zero = bfloat16(0.0f); + uint32_t packed_pad_value = pack_two_bfloat16_into_uint32({bfloat_zero, bfloat_pad_value}); + + KernelHandle reader_kernel_id = CreateKernel(program, + "ttnn/cpp/ttnn/operations/data_movement/pad/device/kernels/dataflow/reader_pad_dims_rm_interleaved.cpp", + all_cores, + tt::tt_metal::ReaderDataMovementConfig(reader_ct_args)); + KernelHandle writer_kernel_id = CreateKernel(program, + "ttnn/cpp/ttnn/operations/data_movement/pad/device/kernels/dataflow/writer_pad_dims_rm_interleaved.cpp", + all_cores, + tt::tt_metal::WriterDataMovementConfig(writer_ct_args)); + // int32_t padded_row_diff_size_nbytes = padded_row_size_nbytes - unpadded_row_size_nbytes; + log_rt_args(CoreCoord{0, 0}, reader_ct_args); + + #if 1 + { + tt::log_debug("ncores: {}", ncores); + tt::log_debug("ncores_h: {}", ncores_h); + tt::log_debug("ncores_w: {}", ncores_w); + tt::log_debug("ntiles_per_core_h: {}", ntiles_per_core_h); + tt::log_debug("ntiles_per_core_w: {}", ntiles_per_core_w); + tt::log_debug("src0_buffer_addr: {}", src0_buffer->address()); + tt::log_debug("dst_buffer_addr: {}", dst_buffer->address()); + tt::log_debug("a.shape[0]: {}", a.get_legacy_shape()[0]); + tt::log_debug("out.shape[0]: {}", output_shape[0]); + tt::log_debug("a.shape[1]: {}", a.get_legacy_shape()[1]); + tt::log_debug("out.shape[1]: {}", output_shape[1]); + tt::log_debug("a.shape[2]: {}", a.get_legacy_shape()[2]); + tt::log_debug("out.shape[2]: {}", output_shape[2]); + tt::log_debug("s.shape[3]: {}", a.get_legacy_shape()[3]); + tt::log_debug("out.shape[3]: {}", output_shape[3]); + tt::log_debug("unpadded_row_size_nbytes: {}", unpadded_row_size_nbytes); + tt::log_debug("padded_row_size_nbytes: {}", padded_row_size_nbytes); + // tt::log_debug("padded_row_diff_size_nbytes: {}", padded_row_diff_size_nbytes); + tt::log_debug("pad_value_const_tensor_addr: {}", pad_value_const_tensor_addr); + tt::log_debug("pad_value_const_buffer_nbytes: {}", pad_value_const_buffer_nbytes); + tt::log_debug("packed_pad_value: {}", packed_pad_value); + tt::log_debug("src_nbytes_per_core_w: {}", src_nbytes_per_core_w); + tt::log_debug("dst_nbytes_per_core_w: {}", dst_nbytes_per_core_w); + tt::log_debug("nbatch_per_core_h: {}", nbatch_per_core_h); + tt::log_debug("ncores_per_batch_h: {}", ncores_per_batch_h); + } + #endif + + uint32_t start_src_stick_id = 0; + uint32_t start_dst_stick_id = 0; + uint32_t start_src_stick_wi = 0; // start of stick segment for 2d decomp + uint32_t start_dst_stick_wi = 0; + int32_t local_nsticks = ntiles_per_core_h * TILE_HEIGHT; + int32_t rem_nbatch = nbatch; // per core h, there are ncores_per_batch_h cores, ie each batch ncores_h = ncores_per_batch_h + for (int32_t b = 0; b < nbatch; ++ b) { + int32_t rem_src_nsticks = a.get_legacy_shape()[2]; + for (uint32_t j = 0; j < ncores_per_batch_h; ++ j) { + uint32_t num_local_unpadded_nsticks = local_nsticks; + if (rem_src_nsticks - local_nsticks >= 0) { + // not reached padding sticks yet + rem_src_nsticks -= local_nsticks; + } else { + num_local_unpadded_nsticks = rem_src_nsticks; + rem_src_nsticks = 0; + } + start_src_stick_wi = 0; + start_dst_stick_wi = 0; + int32_t rem_src_stick_size_nbytes = unpadded_row_size_nbytes; + for (uint32_t i = 0; i < ncores_w; ++ i) { + CoreCoord core = {i, b * ncores_per_batch_h + j}; + uint32_t curr_stick_size_nbytes = 0; + int32_t curr_stick_diff_nbytes = 0; + if (rem_src_stick_size_nbytes - dst_nbytes_per_core_w >= 0) { + // no padding on this core + curr_stick_size_nbytes = dst_nbytes_per_core_w; + rem_src_stick_size_nbytes -= dst_nbytes_per_core_w; + } else { + // this core has padding + curr_stick_size_nbytes = rem_src_stick_size_nbytes; + curr_stick_diff_nbytes = dst_nbytes_per_core_w - curr_stick_size_nbytes; + rem_src_stick_size_nbytes = 0; + } + vector reader_rt_args = {src0_buffer->address(), + dst_buffer->address(), + a.get_legacy_shape()[0], + output_shape[0], + a.get_legacy_shape()[1], + output_shape[1], + a.get_legacy_shape()[2], + output_shape[2], + a.get_legacy_shape()[3], + output_shape[3], + curr_stick_size_nbytes, + (uint32_t) dst_nbytes_per_core_w, + (uint32_t) curr_stick_diff_nbytes, + pad_value_const_tensor_addr, + pad_value_const_buffer_nbytes, + packed_pad_value, + start_src_stick_id, + start_dst_stick_id, + start_src_stick_wi, + start_dst_stick_wi, + start_src_stick_wi * a.element_size(), + (uint32_t) local_nsticks, + num_local_unpadded_nsticks, + unpadded_row_size_nbytes, + padded_row_size_nbytes, + start_dst_stick_wi * output.element_size(), + nbatch_per_core_h + }; + // if (core.x == 0) log_rt_args(core, reader_rt_args); + // if (core.x == 0) { + // log_debug("{} :: start_src_stick_id: {}", core.y, start_src_stick_id); + // log_debug("{} :: start_dst_stick_id: {}", core.y, start_dst_stick_id); + // log_debug("{} :: local_nsticks: {}", core.y, local_nsticks); + // log_debug("{} :: num_local_unpadded_nsticks: {}", core.y, num_local_unpadded_nsticks); + // log_debug("{} :: nbatch_per_core_h: {}", core.y, nbatch_per_core_h); + // log_debug("{} :: ncores_per_batch_h: {}", core.y, ncores_per_batch_h); + // } + vector writer_rt_args = reader_rt_args; + tt::tt_metal::SetRuntimeArgs(program, + reader_kernel_id, + core, + reader_rt_args); + tt::tt_metal::SetRuntimeArgs(program, + writer_kernel_id, + core, + writer_rt_args); + start_src_stick_wi += ntiles_per_core_w * TILE_WIDTH; + start_dst_stick_wi += ntiles_per_core_w * TILE_WIDTH; + } // for ncores_w + start_src_stick_id += num_local_unpadded_nsticks; + start_dst_stick_id += local_nsticks; + } // for ncores_h + } + + auto override_runtime_args_callback = [reader_kernel_id = reader_kernel_id, + writer_kernel_id = writer_kernel_id, + ncores_h = ncores_h, + ncores_w = ncores_w]( + const Program &program, + const std::vector &input_buffers, + const std::vector &output_buffers) { + auto src_buffer = input_buffers.at(0); + auto dst_buffer = output_buffers.at(0); + + for (uint32_t j = 0; j < ncores_h; ++ j) { + for (uint32_t i = 0; i < ncores_w; ++ i) { + CoreCoord core = {i, j}; + { + auto &runtime_args = tt::tt_metal::GetRuntimeArgs(program, reader_kernel_id, core); + runtime_args[0] = src_buffer->address(); + runtime_args[1] = dst_buffer->address(); + } + { + auto &runtime_args = tt::tt_metal::GetRuntimeArgs(program, writer_kernel_id, core); + runtime_args[0] = src_buffer->address(); + runtime_args[1] = dst_buffer->address(); + } + } + } + }; + + return {std::move(program), override_runtime_args_callback}; +} + + +std::vector, std::vector > > get_runtime_args_rm(const Tensor &input_tensor, + Tensor &output_tensor, + uint32_t num_cores_total, + uint32_t num_cores, + uint32_t num_cores_y, + CoreRangeSet core_group_1, + uint32_t num_w_sticks_per_core_group_1, + CoreRangeSet core_group_2, + uint32_t num_w_sticks_per_core_group_2 + ){ + + auto input_buffer = input_tensor.buffer(); + auto output_buffer = output_tensor.buffer(); + auto input_shape = input_tensor.get_legacy_shape(); + auto output_shape = output_tensor.get_legacy_shape(); + + uint32_t W = input_shape[3], H = input_shape[2], C = input_shape[1], N = input_shape[0]; + uint32_t W_bytes = W * input_tensor.element_size(); + + uint32_t W_padded = output_shape[3], H_padded = output_shape[2], C_padded = output_shape[1], N_padded = output_shape[0]; + uint32_t W_padded_bytes = W_padded * input_tensor.element_size(); + + std::uint32_t num_dims = static_cast(input_shape.rank()); + std::vector start_dim_offset(num_dims, 0); + + + std::vector, std::vector > > ret_val(num_cores_total); + + + uint32_t max_read_size = 2048; + uint32_t curr_c = 0, curr_h = 0, curr_n = 0; + for(uint32_t i = 0, curr_sticks_read = 0, curr_sticks_write = 0; i < num_cores_total; i++) { + CoreCoord core = {i / num_cores_y, i % num_cores_y}; + + + uint32_t num_sticks_per_core; + if (core_group_1.core_coord_in_core_ranges(core)) { + num_sticks_per_core = num_w_sticks_per_core_group_1; + } else if (core_group_2.core_coord_in_core_ranges(core)) { + num_sticks_per_core = num_w_sticks_per_core_group_2; + } else { + //no-op + num_sticks_per_core = 0; + } + + + // issue more reads before calling barrier + uint32_t num_sticks_per_core_read = 0, num_read_per_barrier = 0; + if (num_sticks_per_core != 0) { + num_sticks_per_core_read = merge_num_sticks_to_read(num_sticks_per_core, W_bytes, max_read_size); + num_read_per_barrier = num_sticks_per_core / num_sticks_per_core_read; + } + + // reader + std::vector reader_runtime_args = { + input_buffer->address(), + num_sticks_per_core_read, + num_read_per_barrier, + curr_sticks_read, + }; + reader_runtime_args.insert(reader_runtime_args.end(), start_dim_offset.begin(), start_dim_offset.end()); + + // writer + std::vector writer_runtime_args = { + output_buffer->address(), + num_sticks_per_core_read, + num_read_per_barrier, + curr_sticks_write + }; + + ret_val[i] = {reader_runtime_args, writer_runtime_args}; + + curr_sticks_write += num_sticks_per_core; + + for (uint32_t i = 0; i < num_sticks_per_core; ++i) { + + if (curr_h < H and curr_c < C and curr_n < N) { + curr_sticks_read++; + } + + curr_h++; + if (curr_h == H_padded) { + curr_c++; + curr_h = 0; + if (curr_c == C_padded) { + curr_n++; + curr_c = 0; + } + } + } + + start_dim_offset = {0, curr_h, curr_c, curr_n}; + + } + + return ret_val; +} + +operation::ProgramWithCallbacks pad_rm_reader_writer_multi_core_v2(const Tensor &a, + Tensor &output, + const tt::tt_metal::Shape &output_tensor_shape, + const tt::tt_metal::Shape &input_tensor_start, + const float pad_value) { + Program program{}; + + auto output_shape = output_tensor_shape; + uint32_t W = a.shape()[3], H = a.shape()[2], C = a.shape()[1], N = a.shape()[0]; + uint32_t NCH = H * C * N; + uint32_t W_padded = output_tensor_shape[3], H_padded = output_tensor_shape[2], C_padded = output_tensor_shape[1], N_padded = output_tensor_shape[0]; + uint32_t NCH_padded = H_padded * C_padded * N_padded; + + auto stick_size = W * a.element_size(); + auto stick_size_padded = W_padded * a.element_size(); + auto rem_stick_size_padded = stick_size_padded - stick_size; + uint32_t row_major_min_bytes = 16; + + tt::DataFormat cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(a.get_dtype()); + + Device *device = a.device(); + + auto compute_with_storage_grid_size = device->compute_with_storage_grid_size(); + uint32_t num_cores_x = compute_with_storage_grid_size.x; + uint32_t num_cores_y = compute_with_storage_grid_size.y; + uint32_t num_cores_total = num_cores_x * num_cores_y; + CoreRange total_cores({0, 0}, {num_cores_x-1, num_cores_y-1}); + + auto [num_cores, all_cores, core_group_1, core_group_2, num_sticks_padded_per_core_group_1, num_sticks_padded_per_core_group_2] = split_work_to_cores(compute_with_storage_grid_size, NCH_padded); + + uint32_t src0_cb_index = 0; + auto num_sticks = num_sticks_padded_per_core_group_1 > num_sticks_padded_per_core_group_2 ? num_sticks_padded_per_core_group_1 : num_sticks_padded_per_core_group_2; + + tt::tt_metal::CircularBufferConfig cb_src0_config = tt::tt_metal::CircularBufferConfig(num_sticks * stick_size_padded, {{src0_cb_index, cb_data_format}}) + .set_page_size(src0_cb_index, stick_size_padded); + auto cb_src0 = tt::tt_metal::CreateCircularBuffer(program, total_cores, cb_src0_config); + + // construct const buffer with the pad_value + bool not_pad_by_zero = pad_value != 0; + if (not_pad_by_zero) { + uint32_t src1_cb_index = 1; + tt::tt_metal::CircularBufferConfig cb_src1_config = tt::tt_metal::CircularBufferConfig(row_major_min_bytes, {{src1_cb_index, cb_data_format}}) + .set_page_size(src1_cb_index, row_major_min_bytes); + auto cb_src1 = tt::tt_metal::CreateCircularBuffer(program, total_cores, cb_src1_config); + } + + Buffer *src0_buffer = a.buffer(); + Buffer *dst_buffer = output.buffer(); + TT_ASSERT(dst_buffer != nullptr, "Output buffer should be allocated on device!"); + + bfloat16 bfloat_pad_value = bfloat16(pad_value); + uint32_t packed_pad_value = pack_two_bfloat16_into_uint32({bfloat_pad_value, bfloat_pad_value}); + + bool src0_is_dram = src0_buffer->buffer_type() == BufferType::DRAM ? 1 : 0; + bool dst_is_dram = dst_buffer->buffer_type() == BufferType::DRAM ? 1 : 0; + bool src_stick_size_is_power_of_two = is_power_of_two_at_least_32(stick_size); + uint32_t src_log2_stick_size = src_stick_size_is_power_of_two ? (std::uint32_t) std::log2(stick_size) : 0; + bool dst_stick_size_is_power_of_two = is_power_of_two_at_least_32(stick_size_padded); + uint32_t dst_log2_stick_size = dst_stick_size_is_power_of_two ? (std::uint32_t) std::log2(stick_size_padded) : 0; + std::vector reader_ct_args = {(std::uint32_t) src0_is_dram, + (std::uint32_t) N, + (std::uint32_t) H, + (std::uint32_t) C, + (std::uint32_t) stick_size, + (std::uint32_t) N_padded, + (std::uint32_t) H_padded, + (std::uint32_t) C_padded, + (std::uint32_t) stick_size_padded, + (std::uint32_t) (stick_size_padded - stick_size), + (std::uint32_t) not_pad_by_zero, + (std::uint32_t) packed_pad_value, + (std::uint32_t) row_major_min_bytes, + (std::uint32_t) (rem_stick_size_padded / row_major_min_bytes), + (std::uint32_t) (stick_size_padded / row_major_min_bytes), + (std::uint32_t) src_stick_size_is_power_of_two, + (std::uint32_t) src_stick_size_is_power_of_two ? src_log2_stick_size : stick_size}; + std::vector writer_ct_args = {(std::uint32_t) src0_cb_index, + (std::uint32_t) dst_is_dram, + (std::uint32_t) stick_size_padded, + (std::uint32_t) dst_stick_size_is_power_of_two, + (std::uint32_t) dst_stick_size_is_power_of_two ? dst_log2_stick_size : stick_size_padded}; + + KernelHandle reader_kernel_id = CreateKernel(program, + "ttnn/cpp/ttnn/operations/data_movement/pad/device/kernels/dataflow/reader_pad_dims_rm_interleaved_v2.cpp", + total_cores, + tt::tt_metal::ReaderDataMovementConfig(reader_ct_args)); + KernelHandle writer_kernel_id = CreateKernel(program, + "ttnn/cpp/ttnn/operations/data_movement/pad/device/kernels/dataflow/writer_pad_dims_rm_interleaved_v2.cpp", + total_cores, + tt::tt_metal::WriterDataMovementConfig(writer_ct_args)); + + auto all_runtime_args = get_runtime_args_rm(a, output, num_cores_total, num_cores, num_cores_y, core_group_1, num_sticks_padded_per_core_group_1, core_group_2, num_sticks_padded_per_core_group_2); + + for(uint32_t i = 0; i < num_cores_total; i++) { + CoreCoord core = {i / num_cores_y, i % num_cores_y}; + tt::tt_metal::SetRuntimeArgs( + program, + reader_kernel_id, + core, + all_runtime_args[i].first + ); + + tt::tt_metal::SetRuntimeArgs( + program, + writer_kernel_id, + core, + all_runtime_args[i].second + + ); + } + + auto override_runtime_args_callback = [ + reader_kernel_id, + writer_kernel_id, + compute_with_storage_grid_size + ] + ( + const void* operation, + const Program& program, + const std::vector& input_tensors, + const std::vector>&, + const std::vector& output_tensors + ) { + auto src_tensor = input_tensors.at(0); + + auto dst_tensor = output_tensors.at(0); + + uint32_t num_cores_x = compute_with_storage_grid_size.x; + uint32_t num_cores_y = compute_with_storage_grid_size.y; + + uint32_t num_cores_total = num_cores_x * num_cores_y; + + auto output_tensor_shape = dst_tensor.shape(); + uint32_t W_padded = output_tensor_shape[3], H_padded = output_tensor_shape[2], C_padded = output_tensor_shape[1], N_padded = output_tensor_shape[0]; + uint32_t NCH_padded = H_padded * C_padded * N_padded; + + auto [num_cores, all_cores, core_group_1, core_group_2, num_sticks_padded_per_core_group_1, num_sticks_padded_per_core_group_2] = split_work_to_cores(compute_with_storage_grid_size, NCH_padded); + auto all_runtime_args = get_runtime_args_rm(src_tensor, dst_tensor, num_cores_total, num_cores, num_cores_y, core_group_1, num_sticks_padded_per_core_group_1, core_group_2, num_sticks_padded_per_core_group_2); + + for(uint32_t i = 0; i < num_cores_total; i++) { + CoreCoord core = {i / num_cores_y, i % num_cores_y}; + + { + SetRuntimeArgs(program, reader_kernel_id, core, all_runtime_args[i].first); + } + + { + SetRuntimeArgs(program, writer_kernel_id, core, all_runtime_args[i].second); + } + } + }; + + return {.program=std::move(program), .override_runtime_arguments_callback=override_runtime_args_callback}; +} + + + +} // namespace ttnn::operations::reduction::detail diff --git a/ttnn/cpp/ttnn/operations/data_movement/pad/device/pad_program_factory.hpp b/ttnn/cpp/ttnn/operations/data_movement/pad/device/pad_program_factory.hpp index 5320e1cb98a..f28241392fb 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/pad/device/pad_program_factory.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/pad/device/pad_program_factory.hpp @@ -2,14 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "ttnn/tensor/host_buffer/functions.hpp" #include "tt_metal/host_api.hpp" -#include "ttnn/deprecated/tt_dnn/op_library/work_split.hpp" -#include "ttnn/deprecated/tt_dnn/op_library/math.hpp" -#include "tt_metal/common/constants.hpp" -#include "tt_metal/detail/util.hpp" -#include "tt_metal/host_api.hpp" -#include "tt_log.h" namespace ttnn::operations::data_movement::detail { @@ -18,1109 +11,33 @@ operation::ProgramWithCallbacks pad_rm_reader_writer(const Tensor &a, Tensor &output, const tt::tt_metal::Shape &output_tensor_shape, const tt::tt_metal::Shape &input_tensor_start, - const float pad_value) { - Program program{}; - - auto output_shape = output_tensor_shape; - - uint32_t unpadded_row_size_nbytes = a.get_legacy_shape()[3] * a.element_size(); - uint32_t padded_row_size_nbytes = output_shape[3] * a.element_size(); // Assuming output is same datatype as input - TT_ASSERT(unpadded_row_size_nbytes <= padded_row_size_nbytes, "Padded output tensor size should be >= input tensor size"); - - // construct const buffer with the pad_value - Device *device = a.device(); - uint32_t pad_value_const_buffer_size = 32; // noc transfers in chunks of 32 - uint32_t pad_value_const_buffer_nbytes = pad_value_const_buffer_size * a.element_size(); - auto pad_value_const_buffer = tt::tt_metal::owned_buffer::create(std::vector(pad_value_const_buffer_size, bfloat16(pad_value))); - const Tensor pad_value_const_tensor = - Tensor( - OwnedStorage{pad_value_const_buffer}, - Shape(std::array{1, 1, 1, pad_value_const_buffer_size}), - DataType::BFLOAT16, - Layout::ROW_MAJOR) - .to(device, MemoryConfig{.memory_layout = TensorMemoryLayout::INTERLEAVED, .buffer_type = BufferType::L1}); - auto pad_value_const_tensor_addr = pad_value_const_tensor.buffer()->address(); - - CoreRange cores({0, 0}, {0, 0}); - uint32_t cb_id = tt::CB::c_in0; - uint32_t cb_npages = 16; // multibuffering - uint32_t cb_pagesize = tt::round_up(padded_row_size_nbytes, tt::constants::TILE_WIDTH); - tt::DataFormat in_df = tt::tt_metal::datatype_to_dataformat_converter(a.get_dtype()); - tt::tt_metal::CircularBufferConfig cb_config = tt::tt_metal::CircularBufferConfig(cb_npages * cb_pagesize, {{cb_id, in_df}}) - .set_page_size(cb_id, cb_pagesize); - auto cb = tt::tt_metal::CreateCircularBuffer(program, cores, cb_config); - - - Buffer *src0_buffer = a.buffer(); - Buffer *dst_buffer = output.buffer(); - TT_ASSERT(dst_buffer != nullptr, "Output buffer should be allocated on device!"); - - bool src0_is_dram = src0_buffer->buffer_type() == BufferType::DRAM ? 1 : 0; - bool dst_is_dram = dst_buffer->buffer_type() == BufferType::DRAM ? 1 : 0; - bool src_stick_size_is_power_of_two = is_power_of_two_at_least_32(unpadded_row_size_nbytes); - uint32_t src_log2_stick_size = src_stick_size_is_power_of_two ? (std::uint32_t) std::log2(unpadded_row_size_nbytes) : 0; - bool dst_stick_size_is_power_of_two = is_power_of_two_at_least_32(padded_row_size_nbytes); - uint32_t dst_log2_stick_size = dst_stick_size_is_power_of_two ? (std::uint32_t) std::log2(padded_row_size_nbytes) : 0; - std::vector reader_ct_args = {(std::uint32_t) src0_is_dram, - (std::uint32_t) dst_is_dram, - (std::uint32_t) src_stick_size_is_power_of_two, - (std::uint32_t) src_log2_stick_size, - (std::uint32_t) dst_stick_size_is_power_of_two, - (std::uint32_t) dst_log2_stick_size}; - std::vector writer_ct_args = reader_ct_args; - - bfloat16 bfloat_pad_value = bfloat16(pad_value); - bfloat16 bfloat_zero = bfloat16(0.0f); - uint32_t packed_pad_value = pack_two_bfloat16_into_uint32({bfloat_zero, bfloat_pad_value}); - - KernelHandle reader_kernel_id = CreateKernel(program, - "ttnn/cpp/ttnn/operations/data_movement/pad/device/kernels/dataflow/reader_pad_dims_rm_interleaved.cpp", - cores, - tt::tt_metal::ReaderDataMovementConfig(reader_ct_args)); - KernelHandle writer_kernel_id = CreateKernel(program, - "ttnn/cpp/ttnn/operations/data_movement/pad/device/kernels/dataflow/writer_pad_dims_rm_interleaved.cpp", - cores, - tt::tt_metal::WriterDataMovementConfig(writer_ct_args)); - uint32_t padded_row_diff_size_nbytes = padded_row_size_nbytes - unpadded_row_size_nbytes; + const float pad_value); - #if 0 - { - log_debug("src0_buffer_addr: {}", src0_buffer->address()); - log_debug("dst_buffer_addr: {}", dst_buffer->address()); - log_debug("a.shape[0]: {}", a.get_legacy_shape()[0]); - log_debug("out.shape[0]: {}", output_shape[0]); - log_debug("a.shape[1]: {}", a.get_legacy_shape()[1]); - log_debug("out.shape[1]: {}", output_shape[1]); - log_debug("a.shape[2]: {}", a.get_legacy_shape()[2]); - log_debug("out.shape[2]: {}", output_shape[2]); - log_debug("s.shape[3]: {}", a.get_legacy_shape()[3]); - log_debug("out.shape[3]: {}", output_shape[3]); - log_debug("unpadded_row_size_nbytes: {}", unpadded_row_size_nbytes); - log_debug("padded_row_size_nbytes: {}", padded_row_size_nbytes); - log_debug("padded_row_diff_size_nbytes: {}", padded_row_diff_size_nbytes); - log_debug("pad_value_const_tensor_addr: {}", pad_value_const_tensor_addr); - log_debug("pad_value_const_buffer_nbytes: {}", pad_value_const_buffer_nbytes); - log_debug("packed_pad_value: {}", packed_pad_value); - } - #endif - - uint32_t start_src_stick_id = 0; - uint32_t start_dst_stick_id = 0; - vector reader_rt_args = {src0_buffer->address(), - dst_buffer->address(), - a.get_legacy_shape()[0], - output_shape[0], - a.get_legacy_shape()[1], - output_shape[1], - a.get_legacy_shape()[2], - output_shape[2], - a.get_legacy_shape()[3], - output_shape[3], - unpadded_row_size_nbytes, - padded_row_size_nbytes, - padded_row_diff_size_nbytes, - pad_value_const_tensor_addr, - pad_value_const_buffer_nbytes, - packed_pad_value, - start_src_stick_id, - start_dst_stick_id, - 0, - 0, - 0, - output_shape[2], - a.get_legacy_shape()[2], - unpadded_row_size_nbytes, - padded_row_size_nbytes, - 0, - output.get_legacy_shape()[0] - }; - vector writer_rt_args = reader_rt_args; - tt::tt_metal::SetRuntimeArgs(program, - reader_kernel_id, - cores, - reader_rt_args); - tt::tt_metal::SetRuntimeArgs(program, - writer_kernel_id, - cores, - writer_rt_args); - - auto override_runtime_args_callback = - [reader_kernel_id=reader_kernel_id, writer_kernel_id=writer_kernel_id]( - const Program &program, - const std::vector& input_buffers, - const std::vector& output_buffers) { - auto src_buffer = input_buffers.at(0); - auto dst_buffer = output_buffers.at(0); - CoreCoord core = {0, 0}; - { - auto &runtime_args = tt::tt_metal::GetRuntimeArgs(program, reader_kernel_id, core); - runtime_args[0] = src_buffer->address(); - runtime_args[1] = dst_buffer->address(); - } - { - auto &runtime_args = tt::tt_metal::GetRuntimeArgs(program, writer_kernel_id, core); - runtime_args[0] = src_buffer->address(); - runtime_args[1] = dst_buffer->address(); - } - }; - - return {std::move(program), override_runtime_args_callback}; -} operation::ProgramWithCallbacks pad_rm_opt(const Tensor &a, Tensor &output, const Shape &output_tensor_shape, const Shape &input_tensor_start, - const float pad_value) { - Program program{}; - - auto output_shape = output_tensor_shape; - - uint32_t unpadded_row_size_nbytes = a.get_legacy_shape()[3] * a.element_size(); - uint32_t padded_row_size_nbytes = output_shape[3] * a.element_size(); // Assuming output is same datatype as input - TT_ASSERT(unpadded_row_size_nbytes <= padded_row_size_nbytes, "Padded output tensor size should be >= input tensor size"); - - Device *device = a.device(); - auto dst_buffer_l1 = Buffer(device, padded_row_size_nbytes, padded_row_size_nbytes, BufferType::L1); - - // construct const buffer with the pad_value - uint32_t pad_value_const_buffer_size = 32; // noc transfers in chunks of 32 - uint32_t pad_value_const_buffer_nbytes = pad_value_const_buffer_size * a.element_size(); - auto pad_value_const_buffer = owned_buffer::create(std::vector(pad_value_const_buffer_size, bfloat16(pad_value))); - const Tensor pad_value_const_tensor = - Tensor( - OwnedStorage{pad_value_const_buffer}, - Shape(std::array{1, 1, 1, pad_value_const_buffer_size}), - DataType::BFLOAT16, - Layout::ROW_MAJOR) - .to(device, MemoryConfig{.memory_layout = TensorMemoryLayout::INTERLEAVED, .buffer_type = BufferType::L1}); - auto pad_value_const_tensor_addr = pad_value_const_tensor.buffer()->address(); - - Buffer *src0_buffer = a.buffer(); - Buffer *dst_buffer = output.buffer(); - TT_ASSERT(dst_buffer != nullptr, "Output buffer should be allocated on device!"); - - bool src0_is_dram = src0_buffer->buffer_type() == BufferType::DRAM ? 1 : 0; - bool dst_is_dram = dst_buffer->buffer_type() == BufferType::DRAM ? 1 : 0; - bool src_stick_size_is_power_of_two = is_power_of_two_at_least_32(unpadded_row_size_nbytes); - uint32_t src_log2_stick_size = src_stick_size_is_power_of_two ? (std::uint32_t) std::log2(unpadded_row_size_nbytes) : 0; - bool dst_stick_size_is_power_of_two = is_power_of_two_at_least_32(padded_row_size_nbytes); - uint32_t dst_log2_stick_size = dst_stick_size_is_power_of_two ? (std::uint32_t) std::log2(padded_row_size_nbytes) : 0; - std::vector reader_ct_args = {(std::uint32_t) src0_is_dram, - (std::uint32_t) dst_is_dram, - (std::uint32_t) src_stick_size_is_power_of_two, - (std::uint32_t) src_log2_stick_size, - (std::uint32_t) dst_stick_size_is_power_of_two, - (std::uint32_t) dst_log2_stick_size}; - - bfloat16 bfloat_pad_value = bfloat16(pad_value); - bfloat16 bfloat_zero = bfloat16(0.0f); - uint32_t packed_pad_value = pack_two_bfloat16_into_uint32({bfloat_zero, bfloat_pad_value}); - - CoreRange core({0, 0}, {0, 0}); - KernelHandle reader_kernel_id = CreateKernel(program, - "ttnn/cpp/ttnn/operations/data_movement/pad/device/kernels/dataflow/pad_dims_rm_interleaved_opt.cpp", - core, - tt::tt_metal::ReaderDataMovementConfig(reader_ct_args)); - uint32_t padded_row_diff_size_nbytes = padded_row_size_nbytes - unpadded_row_size_nbytes; - - #if 0 - { - tt::log_debug("src0_buffer_addr: {}", src0_buffer->address()); - tt::log_debug("dst_buffer_addr: {}", dst_buffer->address()); - tt::log_debug("a.shape[0]: {}", a.get_legacy_shape()[0]); - tt::log_debug("out.shape[0]: {}", output_shape[0]); - tt::log_debug("a.shape[1]: {}", a.get_legacy_shape()[1]); - tt::log_debug("out.shape[1]: {}", output_shape[1]); - tt::log_debug("a.shape[2]: {}", a.get_legacy_shape()[2]); - tt::log_debug("out.shape[2]: {}", output_shape[2]); - tt::log_debug("s.shape[3]: {}", a.get_legacy_shape()[3]); - tt::log_debug("out.shape[3]: {}", output_shape[3]); - tt::log_debug("unpadded_row_size_nbytes: {}", unpadded_row_size_nbytes); - tt::log_debug("padded_row_size_nbytes: {}", padded_row_size_nbytes); - tt::log_debug("padded_row_diff_size_nbytes: {}", padded_row_diff_size_nbytes); - tt::log_debug("pad_value_const_tensor_addr: {}", pad_value_const_tensor_addr); - tt::log_debug("pad_value_const_buffer_nbytes: {}", pad_value_const_buffer_nbytes); - tt::log_debug("packed_pad_value: {}", packed_pad_value); - tt::log_debug("dst_buffer_l1_addr: {}", dst_buffer_l1.address()); - } - #endif - - vector reader_rt_args = {src0_buffer->address(), - dst_buffer->address(), - a.get_legacy_shape()[0], - output_shape[0], - a.get_legacy_shape()[1], - output_shape[1], - a.get_legacy_shape()[2], - output_shape[2], - a.get_legacy_shape()[3], - output_shape[3], - unpadded_row_size_nbytes, - padded_row_size_nbytes, - padded_row_diff_size_nbytes, - pad_value_const_tensor_addr, - pad_value_const_buffer_nbytes, - packed_pad_value, - dst_buffer_l1.address()}; - tt::tt_metal::SetRuntimeArgs(program, - reader_kernel_id, - core, - reader_rt_args); - - auto override_runtime_args_callback = [kernel_id=reader_kernel_id](const Program &program, - const std::vector& input_buffers, - const std::vector& output_buffers) { - auto src_buffer = input_buffers.at(0); - auto dst_buffer = output_buffers.at(0); - CoreCoord core = {0, 0}; - { - auto &runtime_args = tt::tt_metal::GetRuntimeArgs(program, kernel_id, core); - runtime_args[0] = src_buffer->address(); - runtime_args[1] = dst_buffer->address(); - } - }; - - return {std::move(program), override_runtime_args_callback}; -} - -operation::ProgramWithCallbacks pad_rm(const Tensor &a, Tensor &output, const Shape &output_tensor_shape, const Shape &input_tensor_start, const float pad_value) { - - tt::tt_metal::Program program{}; - - CoreRange core({0, 0}, {0, 0}); - - // This should allocate a DRAM buffer on the device - tt::tt_metal::Device *device = a.device(); - - auto output_shape = output_tensor_shape; - - tt::tt_metal::Buffer *src0_buffer = a.buffer(); - - uint32_t unpadded_row_size_bytes = a.get_legacy_shape()[3] * a.element_size(); - uint32_t padded_row_size_bytes = output_shape[3] * a.element_size(); - - tt::tt_metal::Buffer *dst_buffer = output.buffer(); - TT_ASSERT(dst_buffer != nullptr, "Output buffer should be allocated on device!"); - - uint32_t src_stick_size = unpadded_row_size_bytes; - uint32_t dst_stick_size = padded_row_size_bytes; - - uint32_t dst_buffer_size = dst_stick_size; - - tt::tt_metal::InterleavedBufferConfig buff_config{ - .device= device, - .size = dst_buffer_size, - .page_size = dst_buffer_size, - .buffer_type = tt::tt_metal::BufferType::L1 - }; - - auto dst_buffer_l1 = tt::tt_metal::CreateBuffer(buff_config); - - bfloat16 bfloat_pad_value = bfloat16(pad_value); - uint32_t packed_pad_value = pack_two_bfloat16_into_uint32({bfloat_pad_value, bfloat_pad_value}); - - vector reader_kernel_args = { - src0_buffer->address(), - dst_buffer->address(), - a.get_legacy_shape()[0], - output_shape[0], - a.get_legacy_shape()[1], - output_shape[1], - a.get_legacy_shape()[2], - output_shape[2], - a.get_legacy_shape()[3], - output_shape[3], - unpadded_row_size_bytes, - padded_row_size_bytes, - padded_row_size_bytes - unpadded_row_size_bytes, - packed_pad_value, - dst_buffer_l1->address() - }; - bool src0_is_dram = src0_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; - bool dst_is_dram = dst_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; - bool src_stick_size_is_power_of_two = tt::tt_metal::is_power_of_two_at_least_32(src_stick_size); - uint32_t src_log2_stick_size = src_stick_size_is_power_of_two ? (std::uint32_t) std::log2(src_stick_size) : 0; - bool dst_stick_size_is_power_of_two = tt::tt_metal::is_power_of_two_at_least_32(dst_stick_size); - uint32_t dst_log2_stick_size = dst_stick_size_is_power_of_two ? (std::uint32_t) std::log2(dst_stick_size) : 0; - std::vector compile_time_args_vec = { - (std::uint32_t) src0_is_dram, - (std::uint32_t) dst_is_dram, - (std::uint32_t) src_stick_size_is_power_of_two, - (std::uint32_t) src_log2_stick_size, - (std::uint32_t) dst_stick_size_is_power_of_two, - (std::uint32_t) dst_log2_stick_size, - - }; - - // Tilized reader - tt::tt_metal::KernelHandle unary_reader_kernel_id = tt::tt_metal::CreateKernel( - program, - "ttnn/cpp/ttnn/operations/data_movement/pad/device/kernels/dataflow/pad_dims_rm_interleaved.cpp", - core, - tt::tt_metal::ReaderDataMovementConfig(compile_time_args_vec)); - - tt::tt_metal::SetRuntimeArgs( - program, - unary_reader_kernel_id, - core, - reader_kernel_args - ); - - auto override_runtime_args_callback = [kernel_id=unary_reader_kernel_id]( - const Program &program, - const std::vector& input_buffers, - const std::vector& output_buffers - ) { - - auto src_buffer = input_buffers.at(0); - auto dst_buffer = output_buffers.at(0); - - CoreCoord core = {0, 0}; - - { - auto &runtime_args = tt::tt_metal::GetRuntimeArgs(program, kernel_id, core); - runtime_args[0] = src_buffer->address(); - runtime_args[1] = dst_buffer->address(); - } - }; - - return {std::move(program), override_runtime_args_callback}; -} - -operation::ProgramWithCallbacks pad_tile(const Tensor &a, Tensor& output, const tt::tt_metal::Shape &output_tensor_shape, const tt::tt_metal::Shape &input_tensor_start, const float pad_value) { - - tt::tt_metal::Program program{}; - - CoreRange core({0, 0}, {0, 0}); - - // This should allocate a DRAM buffer on the device - tt::tt_metal::Device *device = a.device(); - - auto output_shape = output_tensor_shape; - - tt::tt_metal::Buffer *src0_buffer = a.buffer(); - - tt::tt_metal::Buffer *dst_buffer = output.buffer(); - TT_ASSERT(dst_buffer != nullptr, "Output buffer should be allocated on device!"); - - tt::DataFormat cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(a.get_dtype()); - uint32_t single_tile_size = tt::tt_metal::detail::TileSize(cb_data_format); - - tt::log_debug("pad_tile"); - tt::log_debug("cb_data_format: {}", cb_data_format); - tt::log_debug("single_tile_size: {}", single_tile_size); - tt::log_debug("output_tensor_shape: {}", output_tensor_shape); - tt::log_debug("input_tensor_start: {}", input_tensor_start); - tt::log_debug("pad_value: {}", pad_value); - - uint32_t src0_cb_index = 0; - uint32_t num_input_tiles = 2; - tt::tt_metal::CircularBufferConfig cb_src0_config = tt::tt_metal::CircularBufferConfig(num_input_tiles * single_tile_size, {{src0_cb_index, cb_data_format}}) - .set_page_size(src0_cb_index, single_tile_size); - auto cb_src0 = tt::tt_metal::CreateCircularBuffer(program, core, cb_src0_config); - - uint32_t src1_cb_index = 1; // For pad buffer - uint32_t num_pad_tiles = 1; - tt::tt_metal::CircularBufferConfig cb_src1_config = tt::tt_metal::CircularBufferConfig(num_pad_tiles * single_tile_size, {{src1_cb_index, cb_data_format}}) - .set_page_size(src1_cb_index, single_tile_size); - auto cb_src1 = tt::tt_metal::CreateCircularBuffer(program, core, cb_src1_config); - - bfloat16 bfloat_pad_value = bfloat16(pad_value); - uint32_t packed_pad_value = pack_two_bfloat16_into_uint32({bfloat_pad_value, bfloat_pad_value}); - - uint32_t num_unpadded_Xt = a.get_legacy_shape()[3] / TILE_WIDTH; - uint32_t num_total_Xt = output_shape[3] / TILE_WIDTH; - uint32_t num_padded_Xt = num_total_Xt - num_unpadded_Xt; - uint32_t num_unpadded_Yt = a.get_legacy_shape()[2] / TILE_HEIGHT; - uint32_t num_total_Yt = output_shape[2] / TILE_HEIGHT; - uint32_t num_padded_Yt = (num_total_Yt - num_unpadded_Yt) * num_total_Xt; - uint32_t num_unpadded_Z = a.get_legacy_shape()[1]; - uint32_t num_total_Z = output_shape[1]; - uint32_t num_padded_Zt = (num_total_Z - num_unpadded_Z) * num_total_Yt * num_total_Xt; - uint32_t num_unpadded_W = a.get_legacy_shape()[0]; - uint32_t num_total_W = output_shape[0]; - uint32_t num_padded_Wt = (num_total_W - num_unpadded_W) * num_total_Z * num_total_Yt * num_total_Xt; - - uint32_t num_unpadded_tiles = a.volume() / TILE_HW; - - vector reader_kernel_args = { - src0_buffer->address(), - num_unpadded_tiles, 0 - }; - vector writer_kernel_args = { - dst_buffer->address(), - num_unpadded_W, - num_padded_Wt, - num_unpadded_Z, - num_padded_Zt, - num_unpadded_Yt, - num_padded_Yt, - num_unpadded_Xt, - num_padded_Xt, - packed_pad_value, - }; - - // Reader compile-time args - // Data is 32 byte aligned - bool src0_is_dram = src0_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; - bool dst_is_dram = dst_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; - std::vector reader_compile_time_args = { - // interleaved accessor args - (std::uint32_t) src0_is_dram - }; - std::vector writer_compile_time_args = { - // interleaved accessor args - (std::uint32_t) src0_cb_index, - (std::uint32_t) src1_cb_index, - (std::uint32_t) dst_is_dram - }; - // Tilized reader - tt::tt_metal::KernelHandle unary_reader_kernel_id = tt::tt_metal::CreateKernel( - program, - "ttnn/cpp/ttnn/operations/eltwise/unary/device/kernels/dataflow/reader_unary_interleaved_start_id.cpp", - core, - tt::tt_metal::ReaderDataMovementConfig(reader_compile_time_args)); - - tt::tt_metal::KernelHandle unary_writer_kernel_id = tt::tt_metal::CreateKernel( - program, - "ttnn/cpp/ttnn/operations/data_movement/pad/device/kernels/dataflow/writer_unary_pad_dims_interleaved.cpp", - core, - tt::tt_metal::WriterDataMovementConfig(writer_compile_time_args)); - - tt::tt_metal::SetRuntimeArgs( - program, - unary_reader_kernel_id, - core, - reader_kernel_args - ); - - tt::tt_metal::SetRuntimeArgs( - program, - unary_writer_kernel_id, - core, - writer_kernel_args - ); - - auto override_runtime_args_callback = [unary_reader_kernel_id, unary_writer_kernel_id]( - const Program &program, - const std::vector& input_buffers, - const std::vector& output_buffers - ) { - - auto src_dram_buffer = input_buffers.at(0); - - auto dst_dram_buffer = output_buffers.at(0); - - CoreCoord core = {0, 0}; - - { - auto &runtime_args = tt::tt_metal::GetRuntimeArgs(program, unary_reader_kernel_id, core); - runtime_args[0] = src_dram_buffer->address(); - } - - { - auto &runtime_args = tt::tt_metal::GetRuntimeArgs(program, unary_writer_kernel_id, core); - runtime_args[0] = dst_dram_buffer->address(); - } - }; - - return {std::move(program), override_runtime_args_callback}; -} - - -inline void log_rt_args(const CoreCoord& core, vector& args) { - for (auto v : args) { - tt::log_debug(tt::LogOp, "{},{} :: {}", core.x, core.y, v); - } -} - -// This is currently mostly hardcoded for resnet shapes -inline std::tuple - split_across_cores(CoreCoord grid_size, uint32_t nbatch, uint32_t nchannel, uint32_t ntiles_h, uint32_t ntiles_w) { - - uint32_t ncores, ncores_h, ncores_w, ntiles_per_core_h, ntiles_per_core_w, nbatch_per_core_h, ncores_per_batch_h; + const float pad_value); - ncores_h = 1; +operation::ProgramWithCallbacks pad_rm(const Tensor &a, Tensor &output, const Shape &output_tensor_shape, const Shape &input_tensor_start, const float pad_value); - // each batch needs to be padded independently - switch (nbatch) { - case 1: - ncores_h = 1; - nbatch_per_core_h = 1; - ntiles_per_core_h = 1; - switch (ntiles_h) { - case 2: ncores_h = 2; ntiles_per_core_h = 1; break; - case 4: ncores_h = 4; ntiles_per_core_h = 1; break; - case 8: ncores_h = 8; ntiles_per_core_h = 1; break; - case 64: ncores_h = 8; ntiles_per_core_h = 8; break; - } - ncores_per_batch_h = ncores_h; - break; - - case 2: - ncores_h = 1; - ncores_per_batch_h = 1; - nbatch_per_core_h = 1; - ntiles_per_core_h = 1; - switch (ntiles_h) { - case 2: ncores_per_batch_h = 2; ncores_h = ncores_per_batch_h * nbatch; ntiles_per_core_h = 1; break; - case 4: ncores_per_batch_h = 4; ncores_h = ncores_per_batch_h * nbatch; ntiles_per_core_h = 1; break; - case 8: ncores_per_batch_h = 4; ncores_h = ncores_per_batch_h * nbatch; ntiles_per_core_h = 2; break; - case 64: ncores_per_batch_h = 4; ncores_h = ncores_per_batch_h * nbatch; ntiles_per_core_h = 16; break; - } - break; - - case 8: - ncores_h = 8; - ncores_per_batch_h = 1; - nbatch_per_core_h = 1; - ntiles_per_core_h = ntiles_h; - break; - - default: - TT_ASSERT(false, "unhandled nbatch. TODO"); - - // generic case -- TODO - - // one of the following will be 0 when grid_size.y != nbatch - uint32_t nbatch_per_core_h = nbatch / grid_size.y; // floor - uint32_t ncores_per_batch_h = grid_size.y / nbatch; // floor - if (nbatch == grid_size.y) { - nbatch_per_core_h = 1; - ncores_per_batch_h = 1; - } - - // currently uses hardcoded values for resnet50 - // TT_ASSERT(ntiles_h == 1 || ntiles_h == 2 || ntiles_h == 4 || ntiles_h == 16, "Only Resnet50 shapes are supported in multicore version for now."); - // TT_ASSERT(ntiles_w == 64, "Only Resnet50 shapes are supported in multicore version for now."); - - TT_ASSERT(nbatch <= grid_size.y, "Unsupported case with nbatch > grid_size.y!"); - - uint32_t ncores_h = 1; - uint32_t ntiles_per_core_h = ntiles_h / ncores_h; - if (nbatch_per_core_h == 0) { - // there are multiple cores along h per batch - nbatch_per_core_h = 1; - ncores_h = ncores_per_batch_h * nbatch; - ntiles_per_core_h = ntiles_h / ncores_per_batch_h; - } else if (ncores_per_batch_h == 0) { - // unsupported case. TODO. - TT_ASSERT(false); - // there are multiple batch per core along h - // ncores_per_batch_h = 1; - // ncores_h = (uint32_t) ceil((float) nbatch / nbatch_per_core_h); - // ntiles_per_core_h = nbatch_per_core_h * ntiles_h; - } else { - TT_ASSERT("Something went terribly wrong in splitting acrtoss cores"); - } - break; - } - - ncores_w = 1; - switch (ntiles_w) { - case 2: ncores_w = 2; break; - case 4: ncores_w = 4; break; - case 8: ncores_w = 8; break; - case 64: ncores_w = 8; break; - } - ncores = ncores_h * ncores_w; - ntiles_per_core_w = ntiles_w / ncores_w; - std::set all_cores; - std::set core_range; - - all_cores.insert(CoreRange(CoreCoord(0, 0), CoreCoord(ncores_w - 1, ncores_h - 1))); - core_range.insert(CoreRange(CoreCoord(0, 0), CoreCoord(ncores_w - 1, ncores_h - 1))); - - return std::make_tuple(ncores, ncores_h, ncores_w, all_cores, core_range, ntiles_per_core_h, ntiles_per_core_w, nbatch_per_core_h, ncores_per_batch_h); -} +operation::ProgramWithCallbacks pad_tile(const Tensor &a, Tensor& output, const tt::tt_metal::Shape &output_tensor_shape, const tt::tt_metal::Shape &input_tensor_start, const float pad_value); operation::ProgramWithCallbacks pad_rm_reader_writer_multi_core(const Tensor &a, Tensor &output, const tt::tt_metal::Shape &output_tensor_shape, const tt::tt_metal::Shape &input_tensor_start, - const float pad_value) { - Program program{}; - - auto output_shape = output_tensor_shape; - - uint32_t unpadded_row_size_nbytes = a.get_legacy_shape()[3] * a.element_size(); - uint32_t padded_row_size_nbytes = output_shape[3] * a.element_size(); // Assuming output is same datatype as input - TT_ASSERT(unpadded_row_size_nbytes <= padded_row_size_nbytes, "Padded output tensor size should be >= input tensor size"); - - Device *device = a.device(); - - // construct const buffer with the pad_value - uint32_t pad_value_const_buffer_size = 32; // noc transfers in chunks of 32 - uint32_t pad_value_const_buffer_nbytes = pad_value_const_buffer_size * a.element_size(); - auto pad_value_const_buffer = owned_buffer::create(std::vector(pad_value_const_buffer_size, bfloat16(pad_value))); - // NOTE: The const buffer is always in L1 - // TODO: make a local buffer for each core? - const Tensor pad_value_const_tensor = - Tensor( - OwnedStorage{pad_value_const_buffer}, - Shape(std::array{1, 1, 1, pad_value_const_buffer_size}), - DataType::BFLOAT16, - Layout::ROW_MAJOR) - .to(device, MemoryConfig{.memory_layout = TensorMemoryLayout::INTERLEAVED, .buffer_type = BufferType::L1}); - auto pad_value_const_tensor_addr = pad_value_const_tensor.buffer()->address(); - - // uint32_t ntiles_h = output_tensor_shape[0] * output_tensor_shape[1] * output_tensor_shape[2] / TILE_HEIGHT; - uint32_t ntiles_h = output_tensor_shape[2] / TILE_HEIGHT; - uint32_t ntiles_w = output_tensor_shape[3] / TILE_WIDTH; - - auto grid_size = device->compute_with_storage_grid_size(); - uint32_t nbatch = output_tensor_shape[0]; - uint32_t nchannel = output_tensor_shape[1]; - // first the batch dim is distributed along H, and within each batch then the tiles are distributed. - auto [ncores, ncores_h, ncores_w, all_cores, core_range, ntiles_per_core_h, ntiles_per_core_w, nbatch_per_core_h, ncores_per_batch_h] = split_across_cores(grid_size, nbatch, nchannel, ntiles_h, ntiles_w); - - int32_t src_nbytes_per_core_w = ntiles_per_core_w * TILE_WIDTH * a.element_size(); - int32_t dst_nbytes_per_core_w = ntiles_per_core_w * TILE_WIDTH * output.element_size(); - - uint32_t cb_id = tt::CB::c_in0; - uint32_t cb_npages = 16; // multibuffering for perf - // uint32_t cb_npages = 1; // multibuffering for perf - uint32_t cb_pagesize = (uint32_t) ceil((float) dst_nbytes_per_core_w / tt::constants::TILE_WIDTH) * tt::constants::TILE_WIDTH; - tt::DataFormat in_df = tt::tt_metal::datatype_to_dataformat_converter(a.get_dtype()); - tt::tt_metal::CircularBufferConfig cb_config = tt::tt_metal::CircularBufferConfig(cb_npages * cb_pagesize, {{cb_id, in_df}}) - .set_page_size(cb_id, cb_pagesize); - auto cb_src0 = tt::tt_metal::CreateCircularBuffer(program, all_cores, cb_config); - - Buffer *src0_buffer = a.buffer(); - Buffer *dst_buffer = output.buffer(); - TT_ASSERT(dst_buffer != nullptr, "Output buffer should be allocated on device!"); - - bool src0_is_dram = src0_buffer->buffer_type() == BufferType::DRAM ? 1 : 0; - bool dst_is_dram = dst_buffer->buffer_type() == BufferType::DRAM ? 1 : 0; - bool src_stick_size_is_power_of_two = is_power_of_two_at_least_32(unpadded_row_size_nbytes); - uint32_t src_log2_stick_size = src_stick_size_is_power_of_two ? (std::uint32_t) std::log2(unpadded_row_size_nbytes) : 0; - bool dst_stick_size_is_power_of_two = is_power_of_two_at_least_32(padded_row_size_nbytes); - uint32_t dst_log2_stick_size = dst_stick_size_is_power_of_two ? (std::uint32_t) std::log2(padded_row_size_nbytes) : 0; - std::vector reader_ct_args = {(std::uint32_t) src0_is_dram, - (std::uint32_t) dst_is_dram, - (std::uint32_t) src_stick_size_is_power_of_two, - (std::uint32_t) src_log2_stick_size, - (std::uint32_t) dst_stick_size_is_power_of_two, - (std::uint32_t) dst_log2_stick_size}; - std::vector writer_ct_args = reader_ct_args; - - bfloat16 bfloat_pad_value = bfloat16(pad_value); - bfloat16 bfloat_zero = bfloat16(0.0f); - uint32_t packed_pad_value = pack_two_bfloat16_into_uint32({bfloat_zero, bfloat_pad_value}); - - KernelHandle reader_kernel_id = CreateKernel(program, - "ttnn/cpp/ttnn/operations/data_movement/pad/device/kernels/dataflow/reader_pad_dims_rm_interleaved.cpp", - all_cores, - tt::tt_metal::ReaderDataMovementConfig(reader_ct_args)); - KernelHandle writer_kernel_id = CreateKernel(program, - "ttnn/cpp/ttnn/operations/data_movement/pad/device/kernels/dataflow/writer_pad_dims_rm_interleaved.cpp", - all_cores, - tt::tt_metal::WriterDataMovementConfig(writer_ct_args)); - // int32_t padded_row_diff_size_nbytes = padded_row_size_nbytes - unpadded_row_size_nbytes; - log_rt_args(CoreCoord{0, 0}, reader_ct_args); - - #if 1 - { - tt::log_debug("ncores: {}", ncores); - tt::log_debug("ncores_h: {}", ncores_h); - tt::log_debug("ncores_w: {}", ncores_w); - tt::log_debug("ntiles_per_core_h: {}", ntiles_per_core_h); - tt::log_debug("ntiles_per_core_w: {}", ntiles_per_core_w); - tt::log_debug("src0_buffer_addr: {}", src0_buffer->address()); - tt::log_debug("dst_buffer_addr: {}", dst_buffer->address()); - tt::log_debug("a.shape[0]: {}", a.get_legacy_shape()[0]); - tt::log_debug("out.shape[0]: {}", output_shape[0]); - tt::log_debug("a.shape[1]: {}", a.get_legacy_shape()[1]); - tt::log_debug("out.shape[1]: {}", output_shape[1]); - tt::log_debug("a.shape[2]: {}", a.get_legacy_shape()[2]); - tt::log_debug("out.shape[2]: {}", output_shape[2]); - tt::log_debug("s.shape[3]: {}", a.get_legacy_shape()[3]); - tt::log_debug("out.shape[3]: {}", output_shape[3]); - tt::log_debug("unpadded_row_size_nbytes: {}", unpadded_row_size_nbytes); - tt::log_debug("padded_row_size_nbytes: {}", padded_row_size_nbytes); - // tt::log_debug("padded_row_diff_size_nbytes: {}", padded_row_diff_size_nbytes); - tt::log_debug("pad_value_const_tensor_addr: {}", pad_value_const_tensor_addr); - tt::log_debug("pad_value_const_buffer_nbytes: {}", pad_value_const_buffer_nbytes); - tt::log_debug("packed_pad_value: {}", packed_pad_value); - tt::log_debug("src_nbytes_per_core_w: {}", src_nbytes_per_core_w); - tt::log_debug("dst_nbytes_per_core_w: {}", dst_nbytes_per_core_w); - tt::log_debug("nbatch_per_core_h: {}", nbatch_per_core_h); - tt::log_debug("ncores_per_batch_h: {}", ncores_per_batch_h); - } - #endif - - uint32_t start_src_stick_id = 0; - uint32_t start_dst_stick_id = 0; - uint32_t start_src_stick_wi = 0; // start of stick segment for 2d decomp - uint32_t start_dst_stick_wi = 0; - int32_t local_nsticks = ntiles_per_core_h * TILE_HEIGHT; - int32_t rem_nbatch = nbatch; // per core h, there are ncores_per_batch_h cores, ie each batch ncores_h = ncores_per_batch_h - for (int32_t b = 0; b < nbatch; ++ b) { - int32_t rem_src_nsticks = a.get_legacy_shape()[2]; - for (uint32_t j = 0; j < ncores_per_batch_h; ++ j) { - uint32_t num_local_unpadded_nsticks = local_nsticks; - if (rem_src_nsticks - local_nsticks >= 0) { - // not reached padding sticks yet - rem_src_nsticks -= local_nsticks; - } else { - num_local_unpadded_nsticks = rem_src_nsticks; - rem_src_nsticks = 0; - } - start_src_stick_wi = 0; - start_dst_stick_wi = 0; - int32_t rem_src_stick_size_nbytes = unpadded_row_size_nbytes; - for (uint32_t i = 0; i < ncores_w; ++ i) { - CoreCoord core = {i, b * ncores_per_batch_h + j}; - uint32_t curr_stick_size_nbytes = 0; - int32_t curr_stick_diff_nbytes = 0; - if (rem_src_stick_size_nbytes - dst_nbytes_per_core_w >= 0) { - // no padding on this core - curr_stick_size_nbytes = dst_nbytes_per_core_w; - rem_src_stick_size_nbytes -= dst_nbytes_per_core_w; - } else { - // this core has padding - curr_stick_size_nbytes = rem_src_stick_size_nbytes; - curr_stick_diff_nbytes = dst_nbytes_per_core_w - curr_stick_size_nbytes; - rem_src_stick_size_nbytes = 0; - } - vector reader_rt_args = {src0_buffer->address(), - dst_buffer->address(), - a.get_legacy_shape()[0], - output_shape[0], - a.get_legacy_shape()[1], - output_shape[1], - a.get_legacy_shape()[2], - output_shape[2], - a.get_legacy_shape()[3], - output_shape[3], - curr_stick_size_nbytes, - (uint32_t) dst_nbytes_per_core_w, - (uint32_t) curr_stick_diff_nbytes, - pad_value_const_tensor_addr, - pad_value_const_buffer_nbytes, - packed_pad_value, - start_src_stick_id, - start_dst_stick_id, - start_src_stick_wi, - start_dst_stick_wi, - start_src_stick_wi * a.element_size(), - (uint32_t) local_nsticks, - num_local_unpadded_nsticks, - unpadded_row_size_nbytes, - padded_row_size_nbytes, - start_dst_stick_wi * output.element_size(), - nbatch_per_core_h - }; - // if (core.x == 0) log_rt_args(core, reader_rt_args); - // if (core.x == 0) { - // log_debug("{} :: start_src_stick_id: {}", core.y, start_src_stick_id); - // log_debug("{} :: start_dst_stick_id: {}", core.y, start_dst_stick_id); - // log_debug("{} :: local_nsticks: {}", core.y, local_nsticks); - // log_debug("{} :: num_local_unpadded_nsticks: {}", core.y, num_local_unpadded_nsticks); - // log_debug("{} :: nbatch_per_core_h: {}", core.y, nbatch_per_core_h); - // log_debug("{} :: ncores_per_batch_h: {}", core.y, ncores_per_batch_h); - // } - vector writer_rt_args = reader_rt_args; - tt::tt_metal::SetRuntimeArgs(program, - reader_kernel_id, - core, - reader_rt_args); - tt::tt_metal::SetRuntimeArgs(program, - writer_kernel_id, - core, - writer_rt_args); - start_src_stick_wi += ntiles_per_core_w * TILE_WIDTH; - start_dst_stick_wi += ntiles_per_core_w * TILE_WIDTH; - } // for ncores_w - start_src_stick_id += num_local_unpadded_nsticks; - start_dst_stick_id += local_nsticks; - } // for ncores_h - } - - auto override_runtime_args_callback = [reader_kernel_id = reader_kernel_id, - writer_kernel_id = writer_kernel_id, - ncores_h = ncores_h, - ncores_w = ncores_w]( - const Program &program, - const std::vector &input_buffers, - const std::vector &output_buffers) { - auto src_buffer = input_buffers.at(0); - auto dst_buffer = output_buffers.at(0); - - for (uint32_t j = 0; j < ncores_h; ++ j) { - for (uint32_t i = 0; i < ncores_w; ++ i) { - CoreCoord core = {i, j}; - { - auto &runtime_args = tt::tt_metal::GetRuntimeArgs(program, reader_kernel_id, core); - runtime_args[0] = src_buffer->address(); - runtime_args[1] = dst_buffer->address(); - } - { - auto &runtime_args = tt::tt_metal::GetRuntimeArgs(program, writer_kernel_id, core); - runtime_args[0] = src_buffer->address(); - runtime_args[1] = dst_buffer->address(); - } - } - } - }; - - return {std::move(program), override_runtime_args_callback}; -} - - -std::vector, std::vector > > get_runtime_args_rm(const Tensor &input_tensor, - Tensor &output_tensor, - uint32_t num_cores_total, - uint32_t num_cores, - uint32_t num_cores_y, - CoreRangeSet core_group_1, - uint32_t num_w_sticks_per_core_group_1, - CoreRangeSet core_group_2, - uint32_t num_w_sticks_per_core_group_2 - ){ - - auto input_buffer = input_tensor.buffer(); - auto output_buffer = output_tensor.buffer(); - auto input_shape = input_tensor.get_legacy_shape(); - auto output_shape = output_tensor.get_legacy_shape(); - - uint32_t W = input_shape[3], H = input_shape[2], C = input_shape[1], N = input_shape[0]; - uint32_t W_bytes = W * input_tensor.element_size(); - - uint32_t W_padded = output_shape[3], H_padded = output_shape[2], C_padded = output_shape[1], N_padded = output_shape[0]; - uint32_t W_padded_bytes = W_padded * input_tensor.element_size(); - - std::uint32_t num_dims = static_cast(input_shape.rank()); - std::vector start_dim_offset(num_dims, 0); - - - std::vector, std::vector > > ret_val(num_cores_total); - - - uint32_t max_read_size = 2048; - uint32_t curr_c = 0, curr_h = 0, curr_n = 0; - for(uint32_t i = 0, curr_sticks_read = 0, curr_sticks_write = 0; i < num_cores_total; i++) { - CoreCoord core = {i / num_cores_y, i % num_cores_y}; + const float pad_value); - uint32_t num_sticks_per_core; - if (core_group_1.core_coord_in_core_ranges(core)) { - num_sticks_per_core = num_w_sticks_per_core_group_1; - } else if (core_group_2.core_coord_in_core_ranges(core)) { - num_sticks_per_core = num_w_sticks_per_core_group_2; - } else { - //no-op - num_sticks_per_core = 0; - } - - - // issue more reads before calling barrier - uint32_t num_sticks_per_core_read = 0, num_read_per_barrier = 0; - if (num_sticks_per_core != 0) { - num_sticks_per_core_read = merge_num_sticks_to_read(num_sticks_per_core, W_bytes, max_read_size); - num_read_per_barrier = num_sticks_per_core / num_sticks_per_core_read; - } - - // reader - std::vector reader_runtime_args = { - input_buffer->address(), - num_sticks_per_core_read, - num_read_per_barrier, - curr_sticks_read, - }; - reader_runtime_args.insert(reader_runtime_args.end(), start_dim_offset.begin(), start_dim_offset.end()); - - // writer - std::vector writer_runtime_args = { - output_buffer->address(), - num_sticks_per_core_read, - num_read_per_barrier, - curr_sticks_write - }; - - ret_val[i] = {reader_runtime_args, writer_runtime_args}; - - curr_sticks_write += num_sticks_per_core; - - for (uint32_t i = 0; i < num_sticks_per_core; ++i) { - - if (curr_h < H and curr_c < C and curr_n < N) { - curr_sticks_read++; - } - - curr_h++; - if (curr_h == H_padded) { - curr_c++; - curr_h = 0; - if (curr_c == C_padded) { - curr_n++; - curr_c = 0; - } - } - } - - start_dim_offset = {0, curr_h, curr_c, curr_n}; - - } - - return ret_val; -} operation::ProgramWithCallbacks pad_rm_reader_writer_multi_core_v2(const Tensor &a, Tensor &output, const tt::tt_metal::Shape &output_tensor_shape, const tt::tt_metal::Shape &input_tensor_start, - const float pad_value) { - Program program{}; - - auto output_shape = output_tensor_shape; - uint32_t W = a.shape()[3], H = a.shape()[2], C = a.shape()[1], N = a.shape()[0]; - uint32_t NCH = H * C * N; - uint32_t W_padded = output_tensor_shape[3], H_padded = output_tensor_shape[2], C_padded = output_tensor_shape[1], N_padded = output_tensor_shape[0]; - uint32_t NCH_padded = H_padded * C_padded * N_padded; - - auto stick_size = W * a.element_size(); - auto stick_size_padded = W_padded * a.element_size(); - auto rem_stick_size_padded = stick_size_padded - stick_size; - uint32_t row_major_min_bytes = 16; - - tt::DataFormat cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(a.get_dtype()); - - Device *device = a.device(); - - auto compute_with_storage_grid_size = device->compute_with_storage_grid_size(); - uint32_t num_cores_x = compute_with_storage_grid_size.x; - uint32_t num_cores_y = compute_with_storage_grid_size.y; - uint32_t num_cores_total = num_cores_x * num_cores_y; - CoreRange total_cores({0, 0}, {num_cores_x-1, num_cores_y-1}); - - auto [num_cores, all_cores, core_group_1, core_group_2, num_sticks_padded_per_core_group_1, num_sticks_padded_per_core_group_2] = split_work_to_cores(compute_with_storage_grid_size, NCH_padded); - - uint32_t src0_cb_index = 0; - auto num_sticks = num_sticks_padded_per_core_group_1 > num_sticks_padded_per_core_group_2 ? num_sticks_padded_per_core_group_1 : num_sticks_padded_per_core_group_2; - - tt::tt_metal::CircularBufferConfig cb_src0_config = tt::tt_metal::CircularBufferConfig(num_sticks * stick_size_padded, {{src0_cb_index, cb_data_format}}) - .set_page_size(src0_cb_index, stick_size_padded); - auto cb_src0 = tt::tt_metal::CreateCircularBuffer(program, total_cores, cb_src0_config); - - // construct const buffer with the pad_value - bool not_pad_by_zero = pad_value != 0; - if (not_pad_by_zero) { - uint32_t src1_cb_index = 1; - tt::tt_metal::CircularBufferConfig cb_src1_config = tt::tt_metal::CircularBufferConfig(row_major_min_bytes, {{src1_cb_index, cb_data_format}}) - .set_page_size(src1_cb_index, row_major_min_bytes); - auto cb_src1 = tt::tt_metal::CreateCircularBuffer(program, total_cores, cb_src1_config); - } - - Buffer *src0_buffer = a.buffer(); - Buffer *dst_buffer = output.buffer(); - TT_ASSERT(dst_buffer != nullptr, "Output buffer should be allocated on device!"); - - bfloat16 bfloat_pad_value = bfloat16(pad_value); - uint32_t packed_pad_value = pack_two_bfloat16_into_uint32({bfloat_pad_value, bfloat_pad_value}); - - bool src0_is_dram = src0_buffer->buffer_type() == BufferType::DRAM ? 1 : 0; - bool dst_is_dram = dst_buffer->buffer_type() == BufferType::DRAM ? 1 : 0; - bool src_stick_size_is_power_of_two = is_power_of_two_at_least_32(stick_size); - uint32_t src_log2_stick_size = src_stick_size_is_power_of_two ? (std::uint32_t) std::log2(stick_size) : 0; - bool dst_stick_size_is_power_of_two = is_power_of_two_at_least_32(stick_size_padded); - uint32_t dst_log2_stick_size = dst_stick_size_is_power_of_two ? (std::uint32_t) std::log2(stick_size_padded) : 0; - std::vector reader_ct_args = {(std::uint32_t) src0_is_dram, - (std::uint32_t) N, - (std::uint32_t) H, - (std::uint32_t) C, - (std::uint32_t) stick_size, - (std::uint32_t) N_padded, - (std::uint32_t) H_padded, - (std::uint32_t) C_padded, - (std::uint32_t) stick_size_padded, - (std::uint32_t) (stick_size_padded - stick_size), - (std::uint32_t) not_pad_by_zero, - (std::uint32_t) packed_pad_value, - (std::uint32_t) row_major_min_bytes, - (std::uint32_t) (rem_stick_size_padded / row_major_min_bytes), - (std::uint32_t) (stick_size_padded / row_major_min_bytes), - (std::uint32_t) src_stick_size_is_power_of_two, - (std::uint32_t) src_stick_size_is_power_of_two ? src_log2_stick_size : stick_size}; - std::vector writer_ct_args = {(std::uint32_t) src0_cb_index, - (std::uint32_t) dst_is_dram, - (std::uint32_t) stick_size_padded, - (std::uint32_t) dst_stick_size_is_power_of_two, - (std::uint32_t) dst_stick_size_is_power_of_two ? dst_log2_stick_size : stick_size_padded}; - - KernelHandle reader_kernel_id = CreateKernel(program, - "ttnn/cpp/ttnn/operations/data_movement/pad/device/kernels/dataflow/reader_pad_dims_rm_interleaved_v2.cpp", - total_cores, - tt::tt_metal::ReaderDataMovementConfig(reader_ct_args)); - KernelHandle writer_kernel_id = CreateKernel(program, - "ttnn/cpp/ttnn/operations/data_movement/pad/device/kernels/dataflow/writer_pad_dims_rm_interleaved_v2.cpp", - total_cores, - tt::tt_metal::WriterDataMovementConfig(writer_ct_args)); - - auto all_runtime_args = get_runtime_args_rm(a, output, num_cores_total, num_cores, num_cores_y, core_group_1, num_sticks_padded_per_core_group_1, core_group_2, num_sticks_padded_per_core_group_2); - - for(uint32_t i = 0; i < num_cores_total; i++) { - CoreCoord core = {i / num_cores_y, i % num_cores_y}; - tt::tt_metal::SetRuntimeArgs( - program, - reader_kernel_id, - core, - all_runtime_args[i].first - ); - - tt::tt_metal::SetRuntimeArgs( - program, - writer_kernel_id, - core, - all_runtime_args[i].second - - ); - } - - auto override_runtime_args_callback = [ - reader_kernel_id, - writer_kernel_id, - compute_with_storage_grid_size - ] - ( - const void* operation, - const Program& program, - const std::vector& input_tensors, - const std::vector>&, - const std::vector& output_tensors - ) { - auto src_tensor = input_tensors.at(0); - - auto dst_tensor = output_tensors.at(0); - - uint32_t num_cores_x = compute_with_storage_grid_size.x; - uint32_t num_cores_y = compute_with_storage_grid_size.y; - - uint32_t num_cores_total = num_cores_x * num_cores_y; - - auto output_tensor_shape = dst_tensor.shape(); - uint32_t W_padded = output_tensor_shape[3], H_padded = output_tensor_shape[2], C_padded = output_tensor_shape[1], N_padded = output_tensor_shape[0]; - uint32_t NCH_padded = H_padded * C_padded * N_padded; - - auto [num_cores, all_cores, core_group_1, core_group_2, num_sticks_padded_per_core_group_1, num_sticks_padded_per_core_group_2] = split_work_to_cores(compute_with_storage_grid_size, NCH_padded); - auto all_runtime_args = get_runtime_args_rm(src_tensor, dst_tensor, num_cores_total, num_cores, num_cores_y, core_group_1, num_sticks_padded_per_core_group_1, core_group_2, num_sticks_padded_per_core_group_2); - - for(uint32_t i = 0; i < num_cores_total; i++) { - CoreCoord core = {i / num_cores_y, i % num_cores_y}; - - { - SetRuntimeArgs(program, reader_kernel_id, core, all_runtime_args[i].first); - } - - { - SetRuntimeArgs(program, writer_kernel_id, core, all_runtime_args[i].second); - } - } - }; + const float pad_value); - return {.program=std::move(program), .override_runtime_arguments_callback=override_runtime_args_callback}; -} diff --git a/ttnn/cpp/ttnn/operations/data_movement/slice/device/slice_op.cpp b/ttnn/cpp/ttnn/operations/data_movement/slice/device/slice_op.cpp index 18f310fa4e7..2f109e2ced5 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/slice/device/slice_op.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/slice/device/slice_op.cpp @@ -120,7 +120,7 @@ operation::ProgramWithCallbacks Slice::create_program( const auto &input_tensor_a = input_tensors.at(0); auto &output_tensor = output_tensors.at(0); - return slice_multi_core(input_tensor_a, output_tensor, this->slice_start, this->slice_end); + return detail::slice_multi_core(input_tensor_a, output_tensor, this->slice_start, this->slice_end); } diff --git a/ttnn/cpp/ttnn/operations/data_movement/slice/device/slice_program_factory.cpp b/ttnn/cpp/ttnn/operations/data_movement/slice/device/slice_program_factory.cpp new file mode 100644 index 00000000000..0c2c44a1ea3 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/data_movement/slice/device/slice_program_factory.cpp @@ -0,0 +1,545 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "optional" +#include "ttnn/deprecated/tt_dnn/op_library/math.hpp" +#include "ttnn/deprecated/tt_dnn/op_library/work_split.hpp" +#include "tt_metal/common/constants.hpp" +#include "tt_metal/detail/util.hpp" +#include "tt_metal/host_api.hpp" + +#include "slice_op.hpp" +using namespace tt::constants; + + +namespace ttnn::operations::data_movement::detail { + +inline std::vector, std::vector>> get_slice_runtime_args_rm( + const Tensor& input_tensor, + Tensor& output_tensor, + const tt::tt_metal::Shape& output_tensor_start, + uint32_t num_cores_total, + uint32_t num_cores, + uint32_t num_cores_y, + CoreRangeSet core_group_1, + CoreRangeSet core_group_2, + uint32_t num_sticks_per_core_group_1, + uint32_t num_sticks_per_core_group_2, + uint32_t max_read_size) { + tt::tt_metal::Device* device = input_tensor.device(); + + auto input_buffer = input_tensor.buffer(); + auto output_buffer = output_tensor.buffer(); + auto input_shape = input_tensor.get_legacy_shape(); + auto output_shape = output_tensor.get_legacy_shape(); + + uint32_t padded_row_size_bytes = input_shape[-1] * input_tensor.element_size(); + uint32_t unpadded_row_size_bytes = output_shape[-1] * input_tensor.element_size(); + + std::uint32_t num_dims = static_cast(input_shape.rank()); + std::vector num_unpadded_sticks_per_dim(num_dims); + std::vector num_padded_sticks_per_dim(num_dims); + std::vector id_per_dim(num_dims); + + std::vector accumulated_total_per_dim(num_dims); + + // TODO: Remove first element of these arrays and update kernel accordingly + // This currently just matches tile version where we iterate over the row as well + num_unpadded_sticks_per_dim[0] = 1; + num_padded_sticks_per_dim[0] = 0; + accumulated_total_per_dim[0] = 1; + + for (int32_t i = 1; i < num_dims; i++) { + uint32_t num_unpadded_dim = output_shape[-(i + 1)]; + uint32_t num_total_dim = input_shape[-(i + 1)]; + uint32_t num_padded_dim = (num_total_dim - num_unpadded_dim) * accumulated_total_per_dim[i - 1]; + num_unpadded_sticks_per_dim[i] = num_unpadded_dim; + num_padded_sticks_per_dim[i] = num_padded_dim; + accumulated_total_per_dim[i] = num_total_dim * accumulated_total_per_dim[i - 1]; + } + + uint32_t unpadded_row_size_bytes_offset = tt::round_up(unpadded_row_size_bytes, TILE_WIDTH / 2); + + vector common_reader_kernel_args = { + input_tensor.buffer()->address() + output_tensor_start[-1] * output_tensor.element_size(), + padded_row_size_bytes, + unpadded_row_size_bytes, + unpadded_row_size_bytes_offset, + num_dims, + 0, + 0, + 0, + 0}; + common_reader_kernel_args.insert( + common_reader_kernel_args.end(), num_unpadded_sticks_per_dim.begin(), num_unpadded_sticks_per_dim.end()); + common_reader_kernel_args.insert( + common_reader_kernel_args.end(), num_padded_sticks_per_dim.begin(), num_padded_sticks_per_dim.end()); + + std::vector, std::vector>> ret_val(num_cores_total); + + uint32_t start_offset = ttnn::operations::data_movement::get_rm_start_offset(input_tensor, ttnn::Shape(output_tensor_start)); + for (uint32_t i = 0, num_sticks_written = 0; i < num_cores_total; i++) { + CoreCoord core = {i / num_cores_y, i % num_cores_y}; + uint32_t num_sticks_per_core; + if (core_group_1.core_coord_in_core_ranges(core)) { + num_sticks_per_core = num_sticks_per_core_group_1; + } else if (core_group_2.core_coord_in_core_ranges(core)) { + num_sticks_per_core = num_sticks_per_core_group_2; + } else { + // no-op + num_sticks_per_core = 0; + } + + // issue more reads before calling barrier + uint32_t num_sticks_per_core_read = 0, num_read_per_barrier = 0; + if (num_sticks_per_core != 0) { + auto num_sticks_per_core_pad32 = num_sticks_per_core + (32 - num_sticks_per_core % 32) % 32; + num_sticks_per_core_read = merge_num_sticks_to_read(num_sticks_per_core_pad32, unpadded_row_size_bytes_offset, max_read_size); + num_read_per_barrier = num_sticks_per_core_pad32 / num_sticks_per_core_read; + } + + id_per_dim[0] = num_sticks_written % num_unpadded_sticks_per_dim[0]; + uint32_t unpadded_written = num_sticks_written / num_unpadded_sticks_per_dim[0]; + uint32_t start_id = id_per_dim[0] + start_offset; + + for (uint32_t j = 1; j < num_dims; j++) { + id_per_dim[j] = unpadded_written % num_unpadded_sticks_per_dim[j]; + unpadded_written = unpadded_written / num_unpadded_sticks_per_dim[j]; + start_id += id_per_dim[j] * accumulated_total_per_dim[j - 1]; + } + vector reader_kernel_args = common_reader_kernel_args; + // + uint32_t addr_offset = 5; // input buffer addr, padded_row_size_bytes, unpadded_row_size_bytes, num_dims + reader_kernel_args[addr_offset++] = start_id; + reader_kernel_args[addr_offset++] = num_sticks_per_core; + reader_kernel_args[addr_offset++] = num_sticks_per_core_read; + reader_kernel_args[addr_offset] = num_read_per_barrier; + reader_kernel_args.insert(reader_kernel_args.end(), id_per_dim.begin(), id_per_dim.end()); + + vector writer_kernel_args = { + output_buffer->address(), unpadded_row_size_bytes, unpadded_row_size_bytes_offset, num_sticks_per_core, num_sticks_per_core_read, num_read_per_barrier, num_sticks_written, 0}; + num_sticks_written += num_sticks_per_core; + ret_val[i] = {reader_kernel_args, writer_kernel_args}; + } + + return ret_val; +} + +operation::ProgramWithCallbacks slice_rm_multi_core( + const Tensor& a, Tensor& output, const tt::tt_metal::Shape& output_tensor_start, const tt::tt_metal::Shape& output_tensor_end) { + const tt::tt_metal::Shape output_shape = output.get_legacy_shape(); + + tt::tt_metal::Program program = tt::tt_metal::CreateProgram(); + + // This should allocate a DRAM buffer on the device + tt::tt_metal::Device* device = a.device(); + + uint32_t num_unpadded_sticks = output.volume() / output.get_legacy_shape()[-1]; + + auto compute_with_storage_grid_size = device->compute_with_storage_grid_size(); + uint32_t num_cores_x = compute_with_storage_grid_size.x; + uint32_t num_cores_y = compute_with_storage_grid_size.y; + + CoreRange total_cores({0, 0}, {num_cores_x - 1, num_cores_y - 1}); + uint32_t num_cores_total = num_cores_x * num_cores_y; + auto [num_cores, all_cores, core_group_1, core_group_2, num_sticks_per_core_group_1, num_sticks_per_core_group_2] = + split_work_to_cores(compute_with_storage_grid_size, num_unpadded_sticks); + + tt::tt_metal::Buffer* src0_buffer = a.buffer(); + + tt::DataFormat cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(a.get_dtype()); + + uint32_t padded_row_size_bytes = a.get_legacy_shape()[-1] * a.element_size(); + uint32_t unpadded_row_size_bytes = output_shape[-1] * a.element_size(); + + tt::tt_metal::Buffer* dst_buffer = output.buffer(); + TT_ASSERT(dst_buffer != nullptr, "Output buffer should be allocated on device!"); + + bool src0_is_dram = src0_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; + std::vector reader_compile_time_args_vec = {(std::uint32_t)src0_is_dram}; + bool dst_is_dram = dst_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; + + uint32_t src_stick_size = padded_row_size_bytes; + uint32_t dst_stick_size = unpadded_row_size_bytes; + + uint32_t src0_cb_index = 0; + uint32_t max_read_size = 4096; + uint32_t cb_page_size = dst_is_dram ? tt::round_up(unpadded_row_size_bytes, TILE_WIDTH) : tt::round_up(unpadded_row_size_bytes, TILE_WIDTH / 2); + uint32_t num_input_pages = num_sticks_per_core_group_1 > num_sticks_per_core_group_2 ? num_sticks_per_core_group_1 : num_sticks_per_core_group_2; + uint32_t num_sticks_per_core_read = 0, num_read_per_barrier = 0; + if (num_input_pages != 0) { + auto num_sticks_per_core_pad32 = num_input_pages + (32 - num_input_pages % 32) % 32; + num_sticks_per_core_read = merge_num_sticks_to_read(num_sticks_per_core_pad32, cb_page_size, max_read_size); + num_read_per_barrier = num_sticks_per_core_pad32 / num_sticks_per_core_read; + } + tt::tt_metal::CircularBufferConfig cb_src0_config = + tt::tt_metal::CircularBufferConfig(num_read_per_barrier * 2 * cb_page_size, {{src0_cb_index, cb_data_format}}) + .set_page_size(src0_cb_index, cb_page_size); + auto cb_src0 = tt::tt_metal::CreateCircularBuffer(program, total_cores, cb_src0_config); + + + std::vector writer_compile_time_args_vec = {(std::uint32_t)src0_cb_index, (std::uint32_t)dst_is_dram}; + + tt::tt_metal::KernelHandle unary_reader_kernel_id = tt::tt_metal::CreateKernel( + program, + "ttnn/cpp/ttnn/operations/data_movement/slice/device/kernels/dataflow/slice_reader_unary_unpad_dims_rm_interleaved_start_id.cpp", + total_cores, + tt::tt_metal::ReaderDataMovementConfig(reader_compile_time_args_vec)); + + tt::tt_metal::KernelHandle unary_writer_kernel_id = tt::tt_metal::CreateKernel( + program, + "ttnn/cpp/ttnn/operations/data_movement/slice/device/kernels/dataflow/slice_writer_unary_stick_layout_interleaved_start_id.cpp", + total_cores, + tt::tt_metal::WriterDataMovementConfig(writer_compile_time_args_vec)); + + auto all_runtime_args = get_slice_runtime_args_rm( + a, + output, + output_tensor_start, + num_cores_total, + num_cores, + num_cores_y, + core_group_1, + core_group_2, + num_sticks_per_core_group_1, + num_sticks_per_core_group_2, + max_read_size); + + for (uint32_t i = 0, num_sticks_written = 0; i < num_cores_total; i++) { + CoreCoord core = {i / num_cores_y, i % num_cores_y}; + tt::tt_metal::SetRuntimeArgs(program, unary_reader_kernel_id, core, all_runtime_args[i].first); + + tt::tt_metal::SetRuntimeArgs(program, unary_writer_kernel_id, core, all_runtime_args[i].second); + } + + auto override_runtime_args_callback = + [unary_reader_kernel_id, unary_writer_kernel_id, compute_with_storage_grid_size, max_read_size]( + const void* operation, + const Program& program, + const std::vector& input_tensors, + const std::vector>&, + const std::vector& output_tensors) { + auto src_tensor = input_tensors.at(0); + auto dst_tensor = output_tensors.at(0); + uint32_t num_cores_x = compute_with_storage_grid_size.x; + uint32_t num_cores_y = compute_with_storage_grid_size.y; + uint32_t num_cores_total = num_cores_x * num_cores_y; + uint32_t num_unpadded_sticks = dst_tensor.volume() / dst_tensor.get_legacy_shape()[-1]; + auto + [num_cores, + all_cores, + core_group_1, + core_group_2, + num_sticks_per_core_group_1, + num_sticks_per_core_group_2] = + split_work_to_cores(compute_with_storage_grid_size, num_unpadded_sticks); + + const auto tensor_start = static_cast(operation)->slice_start; + auto all_runtime_args = get_slice_runtime_args_rm( + src_tensor, + dst_tensor, + tensor_start, + num_cores_total, + num_cores, + num_cores_y, + core_group_1, + core_group_2, + num_sticks_per_core_group_1, + num_sticks_per_core_group_2, + max_read_size); + + for (uint32_t i = 0, num_tiles_written = 0; i < num_cores_total; i++) { + CoreCoord core = {i / num_cores_y, i % num_cores_y}; + + { SetRuntimeArgs(program, unary_reader_kernel_id, core, all_runtime_args[i].first); } + + { SetRuntimeArgs(program, unary_writer_kernel_id, core, all_runtime_args[i].second); } + } + }; + + return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_args_callback}; +} + +template +inline __attribute__((always_inline)) void set_slice_runtime_args_tile( + const Tensor& input_tensor, + const Tensor& output_tensor, + const tt::tt_metal::Shape& output_tensor_start, + const uint32_t& num_cores_total, + const uint32_t& num_cores, + const std::vector& cores, + const uint32_t& num_cores_group_1, + const uint32_t& num_cores_group_2, + const uint32_t& num_tiles_per_core_group_1, + const uint32_t& num_tiles_per_core_group_2, + const Program& program, + const tt::tt_metal::KernelHandle& unary_reader_kernel_id, + const tt::tt_metal::KernelHandle& unary_writer_kernel_id, + std::vector& accumulated_total_per_dim) { + const auto input_buffer = input_tensor.buffer(); + const auto output_buffer = output_tensor.buffer(); + const auto& input_shape = input_tensor.get_legacy_shape(); + const auto& output_shape = output_tensor.get_legacy_shape(); + + std::uint32_t num_dims = static_cast(input_shape.rank()); + + uint32_t num_unpadded_Xt = output_shape[-1] / TILE_WIDTH; + uint32_t num_total_Xt = input_shape[-1] / TILE_WIDTH; + uint32_t num_padded_Xt = num_total_Xt - num_unpadded_Xt; + uint32_t num_unpadded_Yt = output_shape[-2] / TILE_HEIGHT; + uint32_t num_total_Yt = input_shape[-2] / TILE_HEIGHT; + uint32_t num_padded_Yt = (num_total_Yt - num_unpadded_Yt) * num_total_Xt; + + const auto set_common_reader_args = [&]( + uint32_t* reader_common_args, + uint32_t* num_unpadded_tiles_per_dim, + uint32_t* num_padded_tiles_per_dim) __attribute__((always_inline)) { + reader_common_args[0] = input_buffer->address(); + num_unpadded_tiles_per_dim[0] = num_unpadded_Xt; + num_unpadded_tiles_per_dim[1] = num_unpadded_Yt; + num_padded_tiles_per_dim[0] = num_padded_Xt; + num_padded_tiles_per_dim[1] = num_padded_Yt; + accumulated_total_per_dim[0] = num_total_Xt; + accumulated_total_per_dim[1] = num_total_Yt * num_total_Xt; + for (int32_t i = 2; i < num_dims; ++i) { + uint32_t num_unpadded_dim = output_shape[-(i + 1)]; + uint32_t num_total_dim = input_shape[-(i + 1)]; + uint32_t num_padded_dim = (num_total_dim - num_unpadded_dim) * accumulated_total_per_dim[i - 1]; + num_unpadded_tiles_per_dim[i] = num_unpadded_dim; + num_padded_tiles_per_dim[i] = num_padded_dim; + accumulated_total_per_dim[i] = num_total_dim * accumulated_total_per_dim[i - 1]; + } + }; + + const auto set_reader_rt_args = [&]( + uint32_t* reader_rt_args, + const uint32_t* num_unpadded_tiles_per_dim, + const uint32_t* num_padded_tiles_per_dim, + const uint32_t& num_tiles_per_core, + const uint32_t& start_offset, + const uint32_t& num_tiles_written) __attribute__((always_inline)) { + reader_rt_args[2] = num_tiles_written % num_unpadded_tiles_per_dim[0]; + uint32_t unpadded_written = num_tiles_written / num_unpadded_tiles_per_dim[0]; + uint32_t start_id = reader_rt_args[2] + start_offset; + for (uint32_t j = 1; j < num_dims; ++j) { + reader_rt_args[2 + j] = unpadded_written % num_unpadded_tiles_per_dim[j]; + unpadded_written = unpadded_written / num_unpadded_tiles_per_dim[j]; + start_id += reader_rt_args[2 + j] * accumulated_total_per_dim[j - 1]; + } + reader_rt_args[0] = start_id; + reader_rt_args[1] = num_tiles_per_core; + }; + + if constexpr (initialize_args) { + std::vector reader_common_args(1 + num_dims * 2); + uint32_t* num_unpadded_tiles_per_dim = reader_common_args.data() + 1; + uint32_t* num_padded_tiles_per_dim = num_unpadded_tiles_per_dim + num_dims; + set_common_reader_args(reader_common_args.data(), num_unpadded_tiles_per_dim, num_padded_tiles_per_dim); + SetCommonRuntimeArgs(program, unary_reader_kernel_id, reader_common_args); + } + auto& reader_common_args = GetCommonRuntimeArgs(program, unary_reader_kernel_id); + uint32_t* num_unpadded_tiles_per_dim = reader_common_args.data() + 1; + uint32_t* num_padded_tiles_per_dim = num_unpadded_tiles_per_dim + num_dims; + if constexpr (!initialize_args) { + set_common_reader_args(reader_common_args.data(), num_unpadded_tiles_per_dim, num_padded_tiles_per_dim); + } + + uint32_t start_offset = ttnn::operations::data_movement::get_tiled_start_offset(input_tensor, ttnn::Shape(output_tensor_start)); + + auto& reader_kernel_args_by_core = GetRuntimeArgs(program, unary_reader_kernel_id); + auto& writer_kernel_args_by_core = GetRuntimeArgs(program, unary_writer_kernel_id); + const uint32_t num_used_cores = num_cores_group_1 + num_cores_group_2; + for (uint32_t i = 0, num_tiles_written = 0; i < num_cores_total; ++i) { + const CoreCoord& core = cores[i]; + uint32_t num_tiles_per_core; + if (i < num_cores_group_1) { + num_tiles_per_core = num_tiles_per_core_group_1; + } else if (i < num_used_cores) { + num_tiles_per_core = num_tiles_per_core_group_2; + } else { + // no-op + if constexpr (initialize_args) { + std::vector reader_kernel_args(2 + num_dims, 0); + std::vector writer_kernel_args(3, 0); + tt::tt_metal::SetRuntimeArgs(program, unary_reader_kernel_id, core, reader_kernel_args); + tt::tt_metal::SetRuntimeArgs(program, unary_writer_kernel_id, core, writer_kernel_args); + } else { + auto& reader_kernel_args = reader_kernel_args_by_core[core.x][core.y]; + reader_kernel_args[1] = 0; + auto& writer_kernel_args = writer_kernel_args_by_core[core.x][core.y]; + writer_kernel_args[1] = 0; + } + continue; + } + + if constexpr (initialize_args) { + std::vector reader_kernel_args(2 + num_dims); + set_reader_rt_args( + reader_kernel_args.data(), + num_unpadded_tiles_per_dim, + num_padded_tiles_per_dim, + num_tiles_per_core, + start_offset, + num_tiles_written); + SetRuntimeArgs(program, unary_reader_kernel_id, core, reader_kernel_args); + } else { + auto& reader_kernel_args = reader_kernel_args_by_core[core.x][core.y]; + set_reader_rt_args( + reader_kernel_args.data(), + num_unpadded_tiles_per_dim, + num_padded_tiles_per_dim, + num_tiles_per_core, + start_offset, + num_tiles_written); + } + + if constexpr (initialize_args) { + vector writer_kernel_args = {output_buffer->address(), num_tiles_per_core, num_tiles_written}; + tt::tt_metal::SetRuntimeArgs(program, unary_writer_kernel_id, core, writer_kernel_args); + } else { + auto& writer_kernel_args = writer_kernel_args_by_core[core.x][core.y]; + writer_kernel_args[0] = output_buffer->address(); + writer_kernel_args[1] = num_tiles_per_core; + writer_kernel_args[2] = num_tiles_written; + } + num_tiles_written += num_tiles_per_core; + } +} + +operation::ProgramWithCallbacks slice_tile_multi_core( + const Tensor& a, Tensor& output, const tt::tt_metal::Shape& output_tensor_start, const tt::tt_metal::Shape& output_tensor_end) { + const tt::tt_metal::Shape output_shape = output.get_legacy_shape(); + + tt::tt_metal::Program program = tt::tt_metal::CreateProgram(); + + // This should allocate a DRAM buffer on the device + tt::tt_metal::Device* device = a.device(); + + uint32_t num_unpadded_tiles = output.volume() / TILE_HW; + + auto compute_with_storage_grid_size = device->compute_with_storage_grid_size(); + uint32_t num_cores_x = compute_with_storage_grid_size.x; + uint32_t num_cores_y = compute_with_storage_grid_size.y; + auto num_cores_total = num_cores_x * num_cores_y; + CoreRange total_cores({0, 0}, {num_cores_x - 1, num_cores_y - 1}); + + auto [num_cores, all_cores, core_group_1, core_group_2, num_tiles_per_core_group_1, num_tiles_per_core_group_2] = + split_work_to_cores(compute_with_storage_grid_size, num_unpadded_tiles); + + tt::tt_metal::Buffer* src0_buffer = a.buffer(); + + tt::tt_metal::Buffer* dst_buffer = output.buffer(); + TT_ASSERT(dst_buffer != nullptr, "Output buffer should be allocated on device!"); + + tt::DataFormat cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(a.get_dtype()); + uint32_t single_tile_size = tt::tt_metal::detail::TileSize(cb_data_format); + + uint32_t src0_cb_index = 0; + uint32_t num_input_tiles = 2; + tt::tt_metal::CircularBufferConfig cb_src0_config = + tt::tt_metal::CircularBufferConfig(num_input_tiles * single_tile_size, {{src0_cb_index, cb_data_format}}) + .set_page_size(src0_cb_index, single_tile_size); + auto cb_src0 = tt::tt_metal::CreateCircularBuffer(program, total_cores, cb_src0_config); + + std::uint32_t num_dims = static_cast(a.get_legacy_shape().rank()); + + // Reader compile-time args + // Data is 32 byte aligned + bool src0_is_dram = src0_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; + bool dst_is_dram = dst_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; + std::vector reader_compile_time_args = { + static_cast(src0_cb_index), + static_cast(num_dims), + static_cast(src0_is_dram), + }; + std::vector writer_compile_time_args = { + static_cast(src0_cb_index), static_cast(dst_is_dram)}; + + // Tilized reader + tt::tt_metal::KernelHandle unary_reader_kernel_id = tt::tt_metal::CreateKernel( + program, + "ttnn/cpp/ttnn/operations/data_movement/slice/device/kernels/dataflow/reader_unary_unpad_dims_interleaved_start_id.cpp", + total_cores, + tt::tt_metal::ReaderDataMovementConfig(reader_compile_time_args)); + + tt::tt_metal::KernelHandle unary_writer_kernel_id = tt::tt_metal::CreateKernel( + program, + "ttnn/cpp/ttnn/operations/eltwise/unary/device/kernels/dataflow/writer_unary_interleaved_start_id.cpp", + total_cores, + tt::tt_metal::WriterDataMovementConfig(writer_compile_time_args)); + + const auto cores = grid_to_cores(num_cores_total, num_cores_x, num_cores_y, false); + + std::vector accumulated_total_per_dim(num_dims); + set_slice_runtime_args_tile( + a, + output, + output_tensor_start, + num_cores_total, + num_cores, + cores, + core_group_1.num_cores(), + core_group_2.num_cores(), + num_tiles_per_core_group_1, + num_tiles_per_core_group_2, + program, + unary_reader_kernel_id, + unary_writer_kernel_id, + accumulated_total_per_dim); + + auto override_runtime_args_callback = [unary_reader_kernel_id, + unary_writer_kernel_id, + compute_with_storage_grid_size, + cores, + accumulated_total_per_dim]( + const void* operation, + const Program& program, + const std::vector& input_tensors, + const std::vector>&, + const std::vector& output_tensors) mutable { + const Tensor& src_tensor = input_tensors[0]; + const Tensor& dst_tensor = output_tensors[0]; + uint32_t num_unpadded_tiles = dst_tensor.volume() / TILE_HW; + + uint32_t num_cores_x = compute_with_storage_grid_size.x; + uint32_t num_cores_y = compute_with_storage_grid_size.y; + uint32_t num_cores_total = cores.size(); + + auto + [num_cores, all_cores, core_group_1, core_group_2, num_tiles_per_core_group_1, num_tiles_per_core_group_2] = + split_work_to_cores(compute_with_storage_grid_size, num_unpadded_tiles); + + const auto& tensor_start = static_cast(operation)->slice_start; + set_slice_runtime_args_tile( + src_tensor, + dst_tensor, + tensor_start, + num_cores_total, + num_cores, + cores, + core_group_1.num_cores(), + core_group_2.num_cores(), + num_tiles_per_core_group_1, + num_tiles_per_core_group_2, + program, + unary_reader_kernel_id, + unary_writer_kernel_id, + accumulated_total_per_dim); + }; + + return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_args_callback}; +} + +operation::ProgramWithCallbacks slice_multi_core( + const Tensor& a, Tensor& output, const tt::tt_metal::Shape& output_tensor_start, const tt::tt_metal::Shape& output_tensor_end) { + switch (a.get_layout()) { + case Layout::ROW_MAJOR: return slice_rm_multi_core(a, output, output_tensor_start, output_tensor_end); + case Layout::TILE: return slice_tile_multi_core(a, output, output_tensor_start, output_tensor_end); + default: TT_ASSERT(false, "Unsupported Layout"); + } + return {}; +} + +} // namespace ttnn::operations::data_movement::detail + diff --git a/ttnn/cpp/ttnn/operations/data_movement/slice/device/slice_program_factory.hpp b/ttnn/cpp/ttnn/operations/data_movement/slice/device/slice_program_factory.hpp index 10821e644a1..e532e4e73f7 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/slice/device/slice_program_factory.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/slice/device/slice_program_factory.hpp @@ -3,546 +3,11 @@ // SPDX-License-Identifier: Apache-2.0 #pragma once -#include "optional" -#include "ttnn/deprecated/tt_dnn/op_library/math.hpp" -#include "ttnn/deprecated/tt_dnn/op_library/work_split.hpp" -#include "tt_metal/common/constants.hpp" -#include "tt_metal/detail/util.hpp" #include "tt_metal/host_api.hpp" -#include "slice_op.hpp" -using namespace tt::constants; +namespace ttnn::operations::data_movement::detail { -namespace tt { +operation::ProgramWithCallbacks slice_multi_core(const Tensor& a, Tensor& output, const tt::tt_metal::Shape& output_tensor_start, const tt::tt_metal::Shape& output_tensor_end); -namespace tt_metal { +} // namespace ttnn::operations::data_movement::detail -inline std::vector, std::vector>> get_slice_runtime_args_rm( - const Tensor& input_tensor, - Tensor& output_tensor, - const Shape& output_tensor_start, - uint32_t num_cores_total, - uint32_t num_cores, - uint32_t num_cores_y, - CoreRangeSet core_group_1, - CoreRangeSet core_group_2, - uint32_t num_sticks_per_core_group_1, - uint32_t num_sticks_per_core_group_2, - uint32_t max_read_size) { - tt_metal::Device* device = input_tensor.device(); - - auto input_buffer = input_tensor.buffer(); - auto output_buffer = output_tensor.buffer(); - auto input_shape = input_tensor.get_legacy_shape(); - auto output_shape = output_tensor.get_legacy_shape(); - - uint32_t padded_row_size_bytes = input_shape[-1] * input_tensor.element_size(); - uint32_t unpadded_row_size_bytes = output_shape[-1] * input_tensor.element_size(); - - std::uint32_t num_dims = static_cast(input_shape.rank()); - std::vector num_unpadded_sticks_per_dim(num_dims); - std::vector num_padded_sticks_per_dim(num_dims); - std::vector id_per_dim(num_dims); - - std::vector accumulated_total_per_dim(num_dims); - - // TODO: Remove first element of these arrays and update kernel accordingly - // This currently just matches tile version where we iterate over the row as well - num_unpadded_sticks_per_dim[0] = 1; - num_padded_sticks_per_dim[0] = 0; - accumulated_total_per_dim[0] = 1; - - for (int32_t i = 1; i < num_dims; i++) { - uint32_t num_unpadded_dim = output_shape[-(i + 1)]; - uint32_t num_total_dim = input_shape[-(i + 1)]; - uint32_t num_padded_dim = (num_total_dim - num_unpadded_dim) * accumulated_total_per_dim[i - 1]; - num_unpadded_sticks_per_dim[i] = num_unpadded_dim; - num_padded_sticks_per_dim[i] = num_padded_dim; - accumulated_total_per_dim[i] = num_total_dim * accumulated_total_per_dim[i - 1]; - } - - uint32_t unpadded_row_size_bytes_offset = round_up(unpadded_row_size_bytes, TILE_WIDTH / 2); - - vector common_reader_kernel_args = { - input_tensor.buffer()->address() + output_tensor_start[-1] * output_tensor.element_size(), - padded_row_size_bytes, - unpadded_row_size_bytes, - unpadded_row_size_bytes_offset, - num_dims, - 0, - 0, - 0, - 0}; - common_reader_kernel_args.insert( - common_reader_kernel_args.end(), num_unpadded_sticks_per_dim.begin(), num_unpadded_sticks_per_dim.end()); - common_reader_kernel_args.insert( - common_reader_kernel_args.end(), num_padded_sticks_per_dim.begin(), num_padded_sticks_per_dim.end()); - - std::vector, std::vector>> ret_val(num_cores_total); - - uint32_t start_offset = ttnn::operations::data_movement::get_rm_start_offset(input_tensor, ttnn::Shape(output_tensor_start)); - for (uint32_t i = 0, num_sticks_written = 0; i < num_cores_total; i++) { - CoreCoord core = {i / num_cores_y, i % num_cores_y}; - uint32_t num_sticks_per_core; - if (core_group_1.core_coord_in_core_ranges(core)) { - num_sticks_per_core = num_sticks_per_core_group_1; - } else if (core_group_2.core_coord_in_core_ranges(core)) { - num_sticks_per_core = num_sticks_per_core_group_2; - } else { - // no-op - num_sticks_per_core = 0; - } - - // issue more reads before calling barrier - uint32_t num_sticks_per_core_read = 0, num_read_per_barrier = 0; - if (num_sticks_per_core != 0) { - auto num_sticks_per_core_pad32 = num_sticks_per_core + (32 - num_sticks_per_core % 32) % 32; - num_sticks_per_core_read = merge_num_sticks_to_read(num_sticks_per_core_pad32, unpadded_row_size_bytes_offset, max_read_size); - num_read_per_barrier = num_sticks_per_core_pad32 / num_sticks_per_core_read; - } - - id_per_dim[0] = num_sticks_written % num_unpadded_sticks_per_dim[0]; - uint32_t unpadded_written = num_sticks_written / num_unpadded_sticks_per_dim[0]; - uint32_t start_id = id_per_dim[0] + start_offset; - - for (uint32_t j = 1; j < num_dims; j++) { - id_per_dim[j] = unpadded_written % num_unpadded_sticks_per_dim[j]; - unpadded_written = unpadded_written / num_unpadded_sticks_per_dim[j]; - start_id += id_per_dim[j] * accumulated_total_per_dim[j - 1]; - } - vector reader_kernel_args = common_reader_kernel_args; - // - uint32_t addr_offset = 5; // input buffer addr, padded_row_size_bytes, unpadded_row_size_bytes, num_dims - reader_kernel_args[addr_offset++] = start_id; - reader_kernel_args[addr_offset++] = num_sticks_per_core; - reader_kernel_args[addr_offset++] = num_sticks_per_core_read; - reader_kernel_args[addr_offset] = num_read_per_barrier; - reader_kernel_args.insert(reader_kernel_args.end(), id_per_dim.begin(), id_per_dim.end()); - - vector writer_kernel_args = { - output_buffer->address(), unpadded_row_size_bytes, unpadded_row_size_bytes_offset, num_sticks_per_core, num_sticks_per_core_read, num_read_per_barrier, num_sticks_written, 0}; - num_sticks_written += num_sticks_per_core; - ret_val[i] = {reader_kernel_args, writer_kernel_args}; - } - - return ret_val; -} - -operation::ProgramWithCallbacks slice_rm_multi_core( - const Tensor& a, Tensor& output, const Shape& output_tensor_start, const Shape& output_tensor_end) { - const Shape output_shape = output.get_legacy_shape(); - - tt_metal::Program program = tt_metal::CreateProgram(); - - // This should allocate a DRAM buffer on the device - tt_metal::Device* device = a.device(); - - uint32_t num_unpadded_sticks = output.volume() / output.get_legacy_shape()[-1]; - - auto compute_with_storage_grid_size = device->compute_with_storage_grid_size(); - uint32_t num_cores_x = compute_with_storage_grid_size.x; - uint32_t num_cores_y = compute_with_storage_grid_size.y; - - CoreRange total_cores({0, 0}, {num_cores_x - 1, num_cores_y - 1}); - uint32_t num_cores_total = num_cores_x * num_cores_y; - auto [num_cores, all_cores, core_group_1, core_group_2, num_sticks_per_core_group_1, num_sticks_per_core_group_2] = - split_work_to_cores(compute_with_storage_grid_size, num_unpadded_sticks); - - tt_metal::Buffer* src0_buffer = a.buffer(); - - tt::DataFormat cb_data_format = tt_metal::datatype_to_dataformat_converter(a.get_dtype()); - - uint32_t padded_row_size_bytes = a.get_legacy_shape()[-1] * a.element_size(); - uint32_t unpadded_row_size_bytes = output_shape[-1] * a.element_size(); - - tt_metal::Buffer* dst_buffer = output.buffer(); - TT_ASSERT(dst_buffer != nullptr, "Output buffer should be allocated on device!"); - - bool src0_is_dram = src0_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; - std::vector reader_compile_time_args_vec = {(std::uint32_t)src0_is_dram}; - bool dst_is_dram = dst_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; - - uint32_t src_stick_size = padded_row_size_bytes; - uint32_t dst_stick_size = unpadded_row_size_bytes; - - uint32_t src0_cb_index = 0; - uint32_t max_read_size = 4096; - uint32_t cb_page_size = dst_is_dram ? round_up(unpadded_row_size_bytes, TILE_WIDTH) : round_up(unpadded_row_size_bytes, TILE_WIDTH / 2); - uint32_t num_input_pages = num_sticks_per_core_group_1 > num_sticks_per_core_group_2 ? num_sticks_per_core_group_1 : num_sticks_per_core_group_2; - uint32_t num_sticks_per_core_read = 0, num_read_per_barrier = 0; - if (num_input_pages != 0) { - auto num_sticks_per_core_pad32 = num_input_pages + (32 - num_input_pages % 32) % 32; - num_sticks_per_core_read = merge_num_sticks_to_read(num_sticks_per_core_pad32, cb_page_size, max_read_size); - num_read_per_barrier = num_sticks_per_core_pad32 / num_sticks_per_core_read; - } - tt_metal::CircularBufferConfig cb_src0_config = - tt_metal::CircularBufferConfig(num_read_per_barrier * 2 * cb_page_size, {{src0_cb_index, cb_data_format}}) - .set_page_size(src0_cb_index, cb_page_size); - auto cb_src0 = tt_metal::CreateCircularBuffer(program, total_cores, cb_src0_config); - - - std::vector writer_compile_time_args_vec = {(std::uint32_t)src0_cb_index, (std::uint32_t)dst_is_dram}; - - tt_metal::KernelHandle unary_reader_kernel_id = tt_metal::CreateKernel( - program, - "ttnn/cpp/ttnn/operations/data_movement/slice/device/kernels/dataflow/slice_reader_unary_unpad_dims_rm_interleaved_start_id.cpp", - total_cores, - tt_metal::ReaderDataMovementConfig(reader_compile_time_args_vec)); - - tt_metal::KernelHandle unary_writer_kernel_id = tt_metal::CreateKernel( - program, - "ttnn/cpp/ttnn/operations/data_movement/slice/device/kernels/dataflow/slice_writer_unary_stick_layout_interleaved_start_id.cpp", - total_cores, - tt_metal::WriterDataMovementConfig(writer_compile_time_args_vec)); - - auto all_runtime_args = get_slice_runtime_args_rm( - a, - output, - output_tensor_start, - num_cores_total, - num_cores, - num_cores_y, - core_group_1, - core_group_2, - num_sticks_per_core_group_1, - num_sticks_per_core_group_2, - max_read_size); - - for (uint32_t i = 0, num_sticks_written = 0; i < num_cores_total; i++) { - CoreCoord core = {i / num_cores_y, i % num_cores_y}; - tt_metal::SetRuntimeArgs(program, unary_reader_kernel_id, core, all_runtime_args[i].first); - - tt_metal::SetRuntimeArgs(program, unary_writer_kernel_id, core, all_runtime_args[i].second); - } - - auto override_runtime_args_callback = - [unary_reader_kernel_id, unary_writer_kernel_id, compute_with_storage_grid_size, max_read_size]( - const void* operation, - const Program& program, - const std::vector& input_tensors, - const std::vector>&, - const std::vector& output_tensors) { - auto src_tensor = input_tensors.at(0); - auto dst_tensor = output_tensors.at(0); - uint32_t num_cores_x = compute_with_storage_grid_size.x; - uint32_t num_cores_y = compute_with_storage_grid_size.y; - uint32_t num_cores_total = num_cores_x * num_cores_y; - uint32_t num_unpadded_sticks = dst_tensor.volume() / dst_tensor.get_legacy_shape()[-1]; - auto - [num_cores, - all_cores, - core_group_1, - core_group_2, - num_sticks_per_core_group_1, - num_sticks_per_core_group_2] = - split_work_to_cores(compute_with_storage_grid_size, num_unpadded_sticks); - - const auto tensor_start = static_cast(operation)->slice_start; - auto all_runtime_args = get_slice_runtime_args_rm( - src_tensor, - dst_tensor, - tensor_start, - num_cores_total, - num_cores, - num_cores_y, - core_group_1, - core_group_2, - num_sticks_per_core_group_1, - num_sticks_per_core_group_2, - max_read_size); - - for (uint32_t i = 0, num_tiles_written = 0; i < num_cores_total; i++) { - CoreCoord core = {i / num_cores_y, i % num_cores_y}; - - { SetRuntimeArgs(program, unary_reader_kernel_id, core, all_runtime_args[i].first); } - - { SetRuntimeArgs(program, unary_writer_kernel_id, core, all_runtime_args[i].second); } - } - }; - - return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_args_callback}; -} - -template -inline __attribute__((always_inline)) void set_slice_runtime_args_tile( - const Tensor& input_tensor, - const Tensor& output_tensor, - const Shape& output_tensor_start, - const uint32_t& num_cores_total, - const uint32_t& num_cores, - const std::vector& cores, - const uint32_t& num_cores_group_1, - const uint32_t& num_cores_group_2, - const uint32_t& num_tiles_per_core_group_1, - const uint32_t& num_tiles_per_core_group_2, - const Program& program, - const tt_metal::KernelHandle& unary_reader_kernel_id, - const tt_metal::KernelHandle& unary_writer_kernel_id, - std::vector& accumulated_total_per_dim) { - const auto input_buffer = input_tensor.buffer(); - const auto output_buffer = output_tensor.buffer(); - const auto& input_shape = input_tensor.get_legacy_shape(); - const auto& output_shape = output_tensor.get_legacy_shape(); - - std::uint32_t num_dims = static_cast(input_shape.rank()); - - uint32_t num_unpadded_Xt = output_shape[-1] / TILE_WIDTH; - uint32_t num_total_Xt = input_shape[-1] / TILE_WIDTH; - uint32_t num_padded_Xt = num_total_Xt - num_unpadded_Xt; - uint32_t num_unpadded_Yt = output_shape[-2] / TILE_HEIGHT; - uint32_t num_total_Yt = input_shape[-2] / TILE_HEIGHT; - uint32_t num_padded_Yt = (num_total_Yt - num_unpadded_Yt) * num_total_Xt; - - const auto set_common_reader_args = [&]( - uint32_t* reader_common_args, - uint32_t* num_unpadded_tiles_per_dim, - uint32_t* num_padded_tiles_per_dim) __attribute__((always_inline)) { - reader_common_args[0] = input_buffer->address(); - num_unpadded_tiles_per_dim[0] = num_unpadded_Xt; - num_unpadded_tiles_per_dim[1] = num_unpadded_Yt; - num_padded_tiles_per_dim[0] = num_padded_Xt; - num_padded_tiles_per_dim[1] = num_padded_Yt; - accumulated_total_per_dim[0] = num_total_Xt; - accumulated_total_per_dim[1] = num_total_Yt * num_total_Xt; - for (int32_t i = 2; i < num_dims; ++i) { - uint32_t num_unpadded_dim = output_shape[-(i + 1)]; - uint32_t num_total_dim = input_shape[-(i + 1)]; - uint32_t num_padded_dim = (num_total_dim - num_unpadded_dim) * accumulated_total_per_dim[i - 1]; - num_unpadded_tiles_per_dim[i] = num_unpadded_dim; - num_padded_tiles_per_dim[i] = num_padded_dim; - accumulated_total_per_dim[i] = num_total_dim * accumulated_total_per_dim[i - 1]; - } - }; - - const auto set_reader_rt_args = [&]( - uint32_t* reader_rt_args, - const uint32_t* num_unpadded_tiles_per_dim, - const uint32_t* num_padded_tiles_per_dim, - const uint32_t& num_tiles_per_core, - const uint32_t& start_offset, - const uint32_t& num_tiles_written) __attribute__((always_inline)) { - reader_rt_args[2] = num_tiles_written % num_unpadded_tiles_per_dim[0]; - uint32_t unpadded_written = num_tiles_written / num_unpadded_tiles_per_dim[0]; - uint32_t start_id = reader_rt_args[2] + start_offset; - for (uint32_t j = 1; j < num_dims; ++j) { - reader_rt_args[2 + j] = unpadded_written % num_unpadded_tiles_per_dim[j]; - unpadded_written = unpadded_written / num_unpadded_tiles_per_dim[j]; - start_id += reader_rt_args[2 + j] * accumulated_total_per_dim[j - 1]; - } - reader_rt_args[0] = start_id; - reader_rt_args[1] = num_tiles_per_core; - }; - - if constexpr (initialize_args) { - std::vector reader_common_args(1 + num_dims * 2); - uint32_t* num_unpadded_tiles_per_dim = reader_common_args.data() + 1; - uint32_t* num_padded_tiles_per_dim = num_unpadded_tiles_per_dim + num_dims; - set_common_reader_args(reader_common_args.data(), num_unpadded_tiles_per_dim, num_padded_tiles_per_dim); - SetCommonRuntimeArgs(program, unary_reader_kernel_id, reader_common_args); - } - auto& reader_common_args = GetCommonRuntimeArgs(program, unary_reader_kernel_id); - uint32_t* num_unpadded_tiles_per_dim = reader_common_args.data() + 1; - uint32_t* num_padded_tiles_per_dim = num_unpadded_tiles_per_dim + num_dims; - if constexpr (!initialize_args) { - set_common_reader_args(reader_common_args.data(), num_unpadded_tiles_per_dim, num_padded_tiles_per_dim); - } - - uint32_t start_offset = ttnn::operations::data_movement::get_tiled_start_offset(input_tensor, ttnn::Shape(output_tensor_start)); - - auto& reader_kernel_args_by_core = GetRuntimeArgs(program, unary_reader_kernel_id); - auto& writer_kernel_args_by_core = GetRuntimeArgs(program, unary_writer_kernel_id); - const uint32_t num_used_cores = num_cores_group_1 + num_cores_group_2; - for (uint32_t i = 0, num_tiles_written = 0; i < num_cores_total; ++i) { - const CoreCoord& core = cores[i]; - uint32_t num_tiles_per_core; - if (i < num_cores_group_1) { - num_tiles_per_core = num_tiles_per_core_group_1; - } else if (i < num_used_cores) { - num_tiles_per_core = num_tiles_per_core_group_2; - } else { - // no-op - if constexpr (initialize_args) { - std::vector reader_kernel_args(2 + num_dims, 0); - std::vector writer_kernel_args(3, 0); - tt_metal::SetRuntimeArgs(program, unary_reader_kernel_id, core, reader_kernel_args); - tt_metal::SetRuntimeArgs(program, unary_writer_kernel_id, core, writer_kernel_args); - } else { - auto& reader_kernel_args = reader_kernel_args_by_core[core.x][core.y]; - reader_kernel_args[1] = 0; - auto& writer_kernel_args = writer_kernel_args_by_core[core.x][core.y]; - writer_kernel_args[1] = 0; - } - continue; - } - - if constexpr (initialize_args) { - std::vector reader_kernel_args(2 + num_dims); - set_reader_rt_args( - reader_kernel_args.data(), - num_unpadded_tiles_per_dim, - num_padded_tiles_per_dim, - num_tiles_per_core, - start_offset, - num_tiles_written); - SetRuntimeArgs(program, unary_reader_kernel_id, core, reader_kernel_args); - } else { - auto& reader_kernel_args = reader_kernel_args_by_core[core.x][core.y]; - set_reader_rt_args( - reader_kernel_args.data(), - num_unpadded_tiles_per_dim, - num_padded_tiles_per_dim, - num_tiles_per_core, - start_offset, - num_tiles_written); - } - - if constexpr (initialize_args) { - vector writer_kernel_args = {output_buffer->address(), num_tiles_per_core, num_tiles_written}; - tt_metal::SetRuntimeArgs(program, unary_writer_kernel_id, core, writer_kernel_args); - } else { - auto& writer_kernel_args = writer_kernel_args_by_core[core.x][core.y]; - writer_kernel_args[0] = output_buffer->address(); - writer_kernel_args[1] = num_tiles_per_core; - writer_kernel_args[2] = num_tiles_written; - } - num_tiles_written += num_tiles_per_core; - } -} - -operation::ProgramWithCallbacks slice_tile_multi_core( - const Tensor& a, Tensor& output, const Shape& output_tensor_start, const Shape& output_tensor_end) { - const Shape output_shape = output.get_legacy_shape(); - - tt_metal::Program program = tt_metal::CreateProgram(); - - // This should allocate a DRAM buffer on the device - tt_metal::Device* device = a.device(); - - uint32_t num_unpadded_tiles = output.volume() / TILE_HW; - - auto compute_with_storage_grid_size = device->compute_with_storage_grid_size(); - uint32_t num_cores_x = compute_with_storage_grid_size.x; - uint32_t num_cores_y = compute_with_storage_grid_size.y; - auto num_cores_total = num_cores_x * num_cores_y; - CoreRange total_cores({0, 0}, {num_cores_x - 1, num_cores_y - 1}); - - auto [num_cores, all_cores, core_group_1, core_group_2, num_tiles_per_core_group_1, num_tiles_per_core_group_2] = - split_work_to_cores(compute_with_storage_grid_size, num_unpadded_tiles); - - tt_metal::Buffer* src0_buffer = a.buffer(); - - tt_metal::Buffer* dst_buffer = output.buffer(); - TT_ASSERT(dst_buffer != nullptr, "Output buffer should be allocated on device!"); - - tt::DataFormat cb_data_format = tt_metal::datatype_to_dataformat_converter(a.get_dtype()); - uint32_t single_tile_size = tt_metal::detail::TileSize(cb_data_format); - - uint32_t src0_cb_index = 0; - uint32_t num_input_tiles = 2; - tt_metal::CircularBufferConfig cb_src0_config = - tt_metal::CircularBufferConfig(num_input_tiles * single_tile_size, {{src0_cb_index, cb_data_format}}) - .set_page_size(src0_cb_index, single_tile_size); - auto cb_src0 = tt_metal::CreateCircularBuffer(program, total_cores, cb_src0_config); - - std::uint32_t num_dims = static_cast(a.get_legacy_shape().rank()); - - // Reader compile-time args - // Data is 32 byte aligned - bool src0_is_dram = src0_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; - bool dst_is_dram = dst_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; - std::vector reader_compile_time_args = { - static_cast(src0_cb_index), - static_cast(num_dims), - static_cast(src0_is_dram), - }; - std::vector writer_compile_time_args = { - static_cast(src0_cb_index), static_cast(dst_is_dram)}; - - // Tilized reader - tt_metal::KernelHandle unary_reader_kernel_id = tt_metal::CreateKernel( - program, - "ttnn/cpp/ttnn/operations/data_movement/slice/device/kernels/dataflow/reader_unary_unpad_dims_interleaved_start_id.cpp", - total_cores, - tt_metal::ReaderDataMovementConfig(reader_compile_time_args)); - - tt_metal::KernelHandle unary_writer_kernel_id = tt_metal::CreateKernel( - program, - "ttnn/cpp/ttnn/operations/eltwise/unary/device/kernels/dataflow/writer_unary_interleaved_start_id.cpp", - total_cores, - tt_metal::WriterDataMovementConfig(writer_compile_time_args)); - - const auto cores = grid_to_cores(num_cores_total, num_cores_x, num_cores_y, false); - - std::vector accumulated_total_per_dim(num_dims); - set_slice_runtime_args_tile( - a, - output, - output_tensor_start, - num_cores_total, - num_cores, - cores, - core_group_1.num_cores(), - core_group_2.num_cores(), - num_tiles_per_core_group_1, - num_tiles_per_core_group_2, - program, - unary_reader_kernel_id, - unary_writer_kernel_id, - accumulated_total_per_dim); - - auto override_runtime_args_callback = [unary_reader_kernel_id, - unary_writer_kernel_id, - compute_with_storage_grid_size, - cores, - accumulated_total_per_dim]( - const void* operation, - const Program& program, - const std::vector& input_tensors, - const std::vector>&, - const std::vector& output_tensors) mutable { - const Tensor& src_tensor = input_tensors[0]; - const Tensor& dst_tensor = output_tensors[0]; - uint32_t num_unpadded_tiles = dst_tensor.volume() / TILE_HW; - - uint32_t num_cores_x = compute_with_storage_grid_size.x; - uint32_t num_cores_y = compute_with_storage_grid_size.y; - uint32_t num_cores_total = cores.size(); - - auto - [num_cores, all_cores, core_group_1, core_group_2, num_tiles_per_core_group_1, num_tiles_per_core_group_2] = - split_work_to_cores(compute_with_storage_grid_size, num_unpadded_tiles); - - const auto& tensor_start = static_cast(operation)->slice_start; - set_slice_runtime_args_tile( - src_tensor, - dst_tensor, - tensor_start, - num_cores_total, - num_cores, - cores, - core_group_1.num_cores(), - core_group_2.num_cores(), - num_tiles_per_core_group_1, - num_tiles_per_core_group_2, - program, - unary_reader_kernel_id, - unary_writer_kernel_id, - accumulated_total_per_dim); - }; - - return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_args_callback}; -} - -operation::ProgramWithCallbacks slice_multi_core( - const Tensor& a, Tensor& output, const Shape& output_tensor_start, const Shape& output_tensor_end) { - switch (a.get_layout()) { - case Layout::ROW_MAJOR: return slice_rm_multi_core(a, output, output_tensor_start, output_tensor_end); - case Layout::TILE: return slice_tile_multi_core(a, output, output_tensor_start, output_tensor_end); - default: TT_ASSERT(false, "Unsupported Layout"); - } - return {}; -} - -} // namespace tt_metal - -} // namespace tt diff --git a/ttnn/cpp/ttnn/operations/data_movement/tilize/device/tilize_program_factory.cpp b/ttnn/cpp/ttnn/operations/data_movement/tilize/device/tilize_program_factory.cpp new file mode 100644 index 00000000000..4beb6402a2c --- /dev/null +++ b/ttnn/cpp/ttnn/operations/data_movement/tilize/device/tilize_program_factory.cpp @@ -0,0 +1,411 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + + +#include + +#include "ttnn/deprecated/tt_dnn/op_library/cb_utils.hpp" +#include "ttnn/deprecated/tt_dnn/op_library/math.hpp" +#include "ttnn/operation.hpp" +#include "ttnn/deprecated/tt_dnn/op_library/work_split_tilize.hpp" +#include "tt_metal/common/constants.hpp" +#include "tt_metal/detail/util.hpp" +#include "tt_metal/host_api.hpp" + +using namespace tt::constants; + +namespace ttnn::operations::data_movement::detail { + +operation::ProgramWithCallbacks tilize_single_core(const Tensor& a, Tensor& output) { + tt::tt_metal::Program program{}; + + CoreRange core({0, 0}, {0, 0}); + + tt::tt_metal::Buffer* src0_buffer = a.buffer(); + + // This should allocate a DRAM buffer on the device + tt::tt_metal::Device* device = a.device(); + auto output_shape = output.get_legacy_shape(); + + tt::tt_metal::Buffer* dst_buffer = output.buffer(); + TT_ASSERT(dst_buffer != nullptr, "Output buffer should be allocated on device!"); + + tt::DataFormat input_cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(a.get_dtype()); + uint32_t input_single_tile_size = tt::tt_metal::detail::TileSize(input_cb_data_format); + + tt::DataFormat output_cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(output.get_dtype()); + uint32_t output_single_tile_size = tt::tt_metal::detail::TileSize(output_cb_data_format); + + uint32_t num_tiles = a.volume() / TILE_HW; + + auto width = a.get_legacy_shape()[-1]; + uint32_t stick_s = width; + uint32_t num_sticks = a.volume() / width; + uint32_t stick_size = stick_s * a.element_size(); // Assuming bfloat16 dataformat + + uint32_t num_tiles_in_row = stick_s / TILE_WIDTH; + // Ensure we don't intrude into storage space + uint32_t max_l1_size = a.device()->l1_size_per_core() / 2 - L1_UNRESERVED_BASE; + uint32_t max_tiles = max_l1_size / (input_single_tile_size + output_single_tile_size); // 2 CBs + // Currently need the number of tiles in a row to be divisible by tiles in a block + uint32_t num_tiles_per_block = 1; + if (num_tiles_in_row <= max_tiles) { + num_tiles_per_block = num_tiles_in_row; + } else { + for (uint32_t n_t = max_tiles; n_t > 0; n_t--) { + if (num_tiles_in_row % n_t == 0) { + num_tiles_per_block = n_t; + break; + } + } + } + uint32_t block_width_size = num_tiles_per_block * TILE_WIDTH * a.element_size(); + uint32_t num_full_blocks_in_row = num_tiles_in_row / num_tiles_per_block; + uint32_t num_leftover_tiles = num_tiles_in_row % num_tiles_per_block; + uint32_t leftover_width_in_row = num_leftover_tiles * a.element_size(); + + uint32_t src0_cb_index = 0; + uint32_t num_input_tiles = num_tiles_per_block; + + auto src0_cb_config = tt::tt_metal::CircularBufferConfig( + num_input_tiles * input_single_tile_size, {{src0_cb_index, input_cb_data_format}}) + .set_page_size(src0_cb_index, input_single_tile_size); + auto cb_src0 = tt::tt_metal::CreateCircularBuffer(program, core, src0_cb_config); + + uint32_t output_cb_index = 16; // output operands start at index 16 + uint32_t num_output_tiles = num_tiles_per_block; + auto cb_output_config = tt::tt_metal::CircularBufferConfig( + num_output_tiles * output_single_tile_size, {{output_cb_index, output_cb_data_format}}) + .set_page_size(output_cb_index, output_single_tile_size); + auto cb_output = tt::tt_metal::CreateCircularBuffer(program, core, cb_output_config); + + vector reader_kernel_args = { + src0_buffer->address(), + num_sticks, + stick_size, + num_tiles_per_block, + block_width_size, + num_full_blocks_in_row, + num_leftover_tiles, + leftover_width_in_row, + 0 // row_start_id + }; + + // Reader compile-time args + uint32_t src0_is_dram = src0_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; + uint32_t stick_size_is_power_of_two = is_power_of_two_at_least_32(stick_size); + uint32_t log2_stick_size = stick_size_is_power_of_two ? (uint32_t)log2(stick_size) : 0; + std::vector reader_compile_time_args = {src0_is_dram, stick_size_is_power_of_two, log2_stick_size}; + + uint32_t out_is_dram = dst_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; + std::vector writer_compile_time_args = {output_cb_index, out_is_dram}; + + // Tilized reader + tt::tt_metal::KernelHandle unary_reader_kernel_id = tt::tt_metal::CreateKernel( + program, + "ttnn/cpp/ttnn/operations/data_movement/tilize/device/kernels/dataflow/reader_unary_stick_layout_split_rows_interleaved.cpp", + core, + tt::tt_metal::ReaderDataMovementConfig(reader_compile_time_args)); + + // Tilized writer + tt::tt_metal::KernelHandle unary_writer_kernel_id = tt::tt_metal::CreateKernel( + program, + "ttnn/cpp/ttnn/operations/eltwise/unary/device/kernels/dataflow/writer_unary_interleaved_start_id.cpp", + core, + tt::tt_metal::WriterDataMovementConfig(writer_compile_time_args)); + + vector compute_args = { + num_tiles / num_tiles_per_block, // per_core_block_cnt + num_tiles_per_block // per_core_block_tile_cnt + }; + + auto tilize_kernel_id = tt::tt_metal::CreateKernel( + program, + "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/compute/tilize.cpp", + core, + tt::tt_metal::ComputeConfig{.compile_args = compute_args}); + + tt::tt_metal::SetRuntimeArgs(program, unary_reader_kernel_id, core, reader_kernel_args); + + tt::tt_metal::SetRuntimeArgs(program, unary_writer_kernel_id, core, {dst_buffer->address(), num_tiles, 0}); + + auto override_runtime_args_callback = [reader_kernel_id = unary_reader_kernel_id, + writer_kernel_id = unary_writer_kernel_id]( + const Program& program, + const std::vector& input_buffers, + const std::vector& output_buffers) { + auto src_buffer = input_buffers.at(0); + + auto dst_buffer = output_buffers.at(0); + + CoreCoord core = {0, 0}; + + { + auto& runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); + runtime_args[0] = src_buffer->address(); + } + + { + auto& runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + runtime_args[0] = dst_buffer->address(); + } + }; + + return {std::move(program), override_runtime_args_callback}; +} + +operation::ProgramWithCallbacks tilize_multi_core_interleaved(const Tensor& a, Tensor& output) { + tt::tt_metal::Program program = tt::tt_metal::CreateProgram(); + + tt::DataFormat input_cb_data_format = datatype_to_dataformat_converter(a.get_dtype()); + uint32_t input_single_tile_size = tt::tt_metal::detail::TileSize(input_cb_data_format); + tt::DataFormat output_cb_data_format = datatype_to_dataformat_converter(output.get_dtype()); + uint32_t output_single_tile_size = tt::tt_metal::detail::TileSize(output_cb_data_format); + + int32_t ntiles = a.volume() / TILE_HW; + uint32_t ntiles_per_block = a.get_legacy_shape()[-1] / TILE_WIDTH; + uint32_t nblocks = std::ceil((float)ntiles / ntiles_per_block); + uint32_t block_size_nbytes = a.get_legacy_shape()[-1] * a.element_size(); + + Device* device = a.device(); + auto grid_size = device->compute_with_storage_grid_size(); + auto [ncores, all_cores, core_range, core_range_cliff, nblocks_per_core, nblocks_per_core_cliff] = + split_blocks_for_tilize(grid_size, nblocks); + + create_cb(tt::CB::c_in0, program, all_cores, input_single_tile_size, ntiles_per_block, input_cb_data_format); + + auto [output_cb_index, _] = + create_cb(tt::CB::c_out0, program, all_cores, output_single_tile_size, ntiles_per_block, output_cb_data_format); + + Buffer* src0_buffer = a.buffer(); + Buffer* dst_buffer = output.buffer(); + TT_ASSERT(dst_buffer != nullptr, "Output buffer should be allocated on device!"); + + /** reader + */ + uint32_t src0_is_dram = src0_buffer->buffer_type() == BufferType::DRAM ? 1 : 0; + uint32_t stick_size_is_power_of_two = is_power_of_two_at_least_32(block_size_nbytes); + uint32_t log2_stick_size = stick_size_is_power_of_two ? (uint32_t)std::log2(block_size_nbytes) : 0; + std::vector reader_ct_args = {src0_is_dram, stick_size_is_power_of_two, log2_stick_size}; + KernelHandle unary_reader_kernel_id = CreateKernel( + program, + "ttnn/cpp/ttnn/operations/data_movement/tilize/device/kernels/dataflow/reader_unary_stick_layout_split_rows_interleaved.cpp", + all_cores, + ReaderDataMovementConfig(reader_ct_args)); + + /** writer + */ + uint32_t out_is_dram = dst_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; + std::vector writer_ct_args = {output_cb_index, out_is_dram}; + KernelHandle unary_writer_kernel_id = CreateKernel( + program, + "ttnn/cpp/ttnn/operations/eltwise/unary/device/kernels/dataflow/writer_unary_interleaved_start_id.cpp", + all_cores, + WriterDataMovementConfig(writer_ct_args)); + + /** compute + */ + vector compute_args = {nblocks_per_core, ntiles_per_block}; + vector compute_args_cliff = {nblocks_per_core_cliff, ntiles_per_block}; + + if (core_range.ranges().size() > 0) { + auto tilize_kernel_id = CreateKernel( + program, + "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/compute/tilize.cpp", + core_range, + ComputeConfig{.compile_args = compute_args}); + } + if (core_range_cliff.size() > 0) { + auto tilize_cliff_kernel_id = CreateKernel( + program, + "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/compute/tilize.cpp", + core_range_cliff, + ComputeConfig{.compile_args = compute_args_cliff}); + } + + // 1D distribution of blocks across cores + bool has_cliff = core_range_cliff.size() > 0; + + uint32_t ncores_full = ncores - has_cliff; + uint32_t ncores_x = grid_size.x; + uint32_t tile_start_id = 0; + uint32_t row_start_id = 0; + const auto& cores = grid_to_cores(ncores, grid_size.x, grid_size.y, true); + for (uint32_t i = 0; i < ncores_full; ++i) { + const CoreCoord& core = cores[i]; + + // reader runtime args + vector reader_rt_args = { + src0_buffer->address(), + nblocks_per_core * TILE_HEIGHT, + block_size_nbytes, + ntiles_per_block, + block_size_nbytes, + 1, // full blocks in row + 0, // num leftover tiles + 0, // leftover width in row + row_start_id}; + + // writer runtime args + vector writer_rt_args = { + dst_buffer->address(), + ntiles_per_block * nblocks_per_core, // ntiles per core + tile_start_id // start id + }; + + SetRuntimeArgs(program, unary_reader_kernel_id, core, reader_rt_args); + SetRuntimeArgs(program, unary_writer_kernel_id, core, writer_rt_args); + + tile_start_id += ntiles_per_block * nblocks_per_core; + row_start_id += TILE_HEIGHT * nblocks_per_core; + } + if (has_cliff) { + // the last core is a cliff core with nblocks_per_core_cliff blocks + const CoreCoord& core = cores.back(); + + // reader runtime args + vector reader_rt_args = { + src0_buffer->address(), + nblocks_per_core_cliff * TILE_HEIGHT, + block_size_nbytes, + ntiles_per_block, + block_size_nbytes, + 1, // full blocks in row + 0, // num leftover tiles + 0, // leftover width in row + row_start_id}; + + // writer runtime args + vector writer_rt_args = { + dst_buffer->address(), + ntiles_per_block * nblocks_per_core_cliff, // ntiles per core + tile_start_id // start id + }; + + SetRuntimeArgs(program, unary_reader_kernel_id, core, reader_rt_args); + SetRuntimeArgs(program, unary_writer_kernel_id, core, writer_rt_args); + } + + auto override_runtime_args_callback = + [reader_kernel_id = unary_reader_kernel_id, writer_kernel_id = unary_writer_kernel_id, cores = cores]( + const Program& program, + const std::vector& input_buffers, + const std::vector& output_buffers) { + auto src_buffer = input_buffers.at(0); + auto dst_buffer = output_buffers.at(0); + + auto& reader_runtime_args_by_core = GetRuntimeArgs(program, reader_kernel_id); + auto& writer_runtime_args_by_core = GetRuntimeArgs(program, writer_kernel_id); + for (const auto& core : cores) { + { + auto& runtime_args = reader_runtime_args_by_core[core.x][core.y]; + runtime_args[0] = src_buffer->address(); + } + { + auto& runtime_args = writer_runtime_args_by_core[core.x][core.y]; + runtime_args[0] = dst_buffer->address(); + } + } + }; + + return {std::move(program), override_runtime_args_callback}; +} + +operation::ProgramWithCallbacks tilize_multi_core_sharded(const Tensor& input, Tensor& output) { + tt::tt_metal::Program program{}; + + tt::DataFormat input_cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(input.get_dtype()); + uint32_t input_single_tile_size = tt::tt_metal::detail::TileSize(input_cb_data_format); + tt::DataFormat output_cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(output.get_dtype()); + uint32_t output_single_tile_size = tt::tt_metal::detail::TileSize(output_cb_data_format); + + uint32_t num_tiles = input.volume() / TILE_HW; + + tt::tt_metal::Device* device = input.device(); + + auto shard_spec = input.shard_spec().value(); + uint32_t num_tiles_per_shard = shard_spec.shape[0] * shard_spec.shape[1] / TILE_HW; + uint32_t num_tiles_per_row = shard_spec.shape[1] / TILE_WIDTH; + auto all_cores = shard_spec.grid; + uint32_t num_cores_x = device->compute_with_storage_grid_size().x; + uint32_t num_cores = all_cores.num_cores(); + + auto [src0_cb_index, cb_src0] = create_cb( + tt::CB::c_in0, + program, + all_cores, + input_single_tile_size, + num_tiles_per_shard, + input_cb_data_format, + input.buffer()); + + auto [output_cb_index, cb_output] = create_cb( + tt::CB::c_out0, + program, + all_cores, + output_single_tile_size, + num_tiles_per_shard, + output_cb_data_format, + output.buffer()); + + auto src_buffer = input.buffer(); + + auto dst_buffer = output.buffer(); + + std::vector reader_compile_time_args = {(std::uint32_t)src0_cb_index}; + + bool dst_is_dram = dst_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; + std::vector writer_compile_time_args = {(std::uint32_t)output_cb_index}; + + tt::tt_metal::KernelHandle unary_reader_kernel_id = tt::tt_metal::CreateKernel( + program, + "ttnn/cpp/ttnn/operations/eltwise/unary/device/kernels/dataflow/reader_unary_sharded.cpp", + all_cores, + tt::tt_metal::ReaderDataMovementConfig(reader_compile_time_args)); + + tt::tt_metal::KernelHandle unary_writer_kernel_id = tt::tt_metal::CreateKernel( + program, + "ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/sharded/kernels/dataflow/writer_unary_sharded.cpp", + all_cores, + tt::tt_metal::WriterDataMovementConfig(writer_compile_time_args)); + + vector compute_args = {uint32_t(num_tiles_per_shard / num_tiles_per_row), uint32_t(num_tiles_per_row)}; + + auto untilize_kernel_id = tt::tt_metal::CreateKernel( + program, + "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/compute/tilize.cpp", + all_cores, + tt::tt_metal::ComputeConfig{.compile_args = compute_args}); + + tt::tt_metal::SetRuntimeArgs(program, unary_reader_kernel_id, all_cores, {num_tiles_per_shard}); + + tt::tt_metal::SetRuntimeArgs(program, unary_writer_kernel_id, all_cores, {num_tiles_per_shard}); + + auto override_runtime_arguments_callback = + [unary_reader_kernel_id, unary_writer_kernel_id, cb_src0, cb_output]( + const void* operation, + Program& program, + const std::vector& input_tensors, + const std::vector>& optional_input_tensors, + const std::vector& output_tensors) { + auto src_buffer = input_tensors.at(0).buffer(); + + auto dst_buffer = output_tensors.at(0).buffer(); + + UpdateDynamicCircularBufferAddress(program, cb_src0, *src_buffer); + + UpdateDynamicCircularBufferAddress(program, cb_output, *dst_buffer); + }; + return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_arguments_callback}; +} + +operation::ProgramWithCallbacks tilize_multi_core(const Tensor& a, Tensor& output) { + if (a.memory_config().is_sharded()) { + return tilize_multi_core_sharded(a, output); + } else { + return tilize_multi_core_interleaved(a, output); + } +} + +} // namespace ttnn::operations::data_movement::detail diff --git a/ttnn/cpp/ttnn/operations/data_movement/tilize/device/tilize_program_factory.hpp b/ttnn/cpp/ttnn/operations/data_movement/tilize/device/tilize_program_factory.hpp index 3388de10a99..1e061f399a0 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/tilize/device/tilize_program_factory.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/tilize/device/tilize_program_factory.hpp @@ -4,409 +4,13 @@ #pragma once -#include - -#include "ttnn/deprecated/tt_dnn/op_library/cb_utils.hpp" -#include "ttnn/deprecated/tt_dnn/op_library/math.hpp" -#include "ttnn/operation.hpp" -#include "ttnn/deprecated/tt_dnn/op_library/work_split_tilize.hpp" -#include "tt_metal/common/constants.hpp" -#include "tt_metal/detail/util.hpp" #include "tt_metal/host_api.hpp" -using namespace tt::constants; namespace ttnn::operations::data_movement::detail { -operation::ProgramWithCallbacks tilize_single_core(const Tensor& a, Tensor& output) { - tt::tt_metal::Program program{}; - - CoreRange core({0, 0}, {0, 0}); - - tt::tt_metal::Buffer* src0_buffer = a.buffer(); - - // This should allocate a DRAM buffer on the device - tt::tt_metal::Device* device = a.device(); - auto output_shape = output.get_legacy_shape(); - - tt::tt_metal::Buffer* dst_buffer = output.buffer(); - TT_ASSERT(dst_buffer != nullptr, "Output buffer should be allocated on device!"); - - tt::DataFormat input_cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(a.get_dtype()); - uint32_t input_single_tile_size = tt::tt_metal::detail::TileSize(input_cb_data_format); - - tt::DataFormat output_cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(output.get_dtype()); - uint32_t output_single_tile_size = tt::tt_metal::detail::TileSize(output_cb_data_format); - - uint32_t num_tiles = a.volume() / TILE_HW; - - auto width = a.get_legacy_shape()[-1]; - uint32_t stick_s = width; - uint32_t num_sticks = a.volume() / width; - uint32_t stick_size = stick_s * a.element_size(); // Assuming bfloat16 dataformat - - uint32_t num_tiles_in_row = stick_s / TILE_WIDTH; - // Ensure we don't intrude into storage space - uint32_t max_l1_size = a.device()->l1_size_per_core() / 2 - L1_UNRESERVED_BASE; - uint32_t max_tiles = max_l1_size / (input_single_tile_size + output_single_tile_size); // 2 CBs - // Currently need the number of tiles in a row to be divisible by tiles in a block - uint32_t num_tiles_per_block = 1; - if (num_tiles_in_row <= max_tiles) { - num_tiles_per_block = num_tiles_in_row; - } else { - for (uint32_t n_t = max_tiles; n_t > 0; n_t--) { - if (num_tiles_in_row % n_t == 0) { - num_tiles_per_block = n_t; - break; - } - } - } - uint32_t block_width_size = num_tiles_per_block * TILE_WIDTH * a.element_size(); - uint32_t num_full_blocks_in_row = num_tiles_in_row / num_tiles_per_block; - uint32_t num_leftover_tiles = num_tiles_in_row % num_tiles_per_block; - uint32_t leftover_width_in_row = num_leftover_tiles * a.element_size(); - - uint32_t src0_cb_index = 0; - uint32_t num_input_tiles = num_tiles_per_block; - - auto src0_cb_config = tt::tt_metal::CircularBufferConfig( - num_input_tiles * input_single_tile_size, {{src0_cb_index, input_cb_data_format}}) - .set_page_size(src0_cb_index, input_single_tile_size); - auto cb_src0 = tt::tt_metal::CreateCircularBuffer(program, core, src0_cb_config); - - uint32_t output_cb_index = 16; // output operands start at index 16 - uint32_t num_output_tiles = num_tiles_per_block; - auto cb_output_config = tt::tt_metal::CircularBufferConfig( - num_output_tiles * output_single_tile_size, {{output_cb_index, output_cb_data_format}}) - .set_page_size(output_cb_index, output_single_tile_size); - auto cb_output = tt::tt_metal::CreateCircularBuffer(program, core, cb_output_config); - - vector reader_kernel_args = { - src0_buffer->address(), - num_sticks, - stick_size, - num_tiles_per_block, - block_width_size, - num_full_blocks_in_row, - num_leftover_tiles, - leftover_width_in_row, - 0 // row_start_id - }; - - // Reader compile-time args - uint32_t src0_is_dram = src0_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; - uint32_t stick_size_is_power_of_two = is_power_of_two_at_least_32(stick_size); - uint32_t log2_stick_size = stick_size_is_power_of_two ? (uint32_t)log2(stick_size) : 0; - std::vector reader_compile_time_args = {src0_is_dram, stick_size_is_power_of_two, log2_stick_size}; - - uint32_t out_is_dram = dst_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; - std::vector writer_compile_time_args = {output_cb_index, out_is_dram}; - - // Tilized reader - tt::tt_metal::KernelHandle unary_reader_kernel_id = tt::tt_metal::CreateKernel( - program, - "ttnn/cpp/ttnn/operations/data_movement/tilize/device/kernels/dataflow/reader_unary_stick_layout_split_rows_interleaved.cpp", - core, - tt::tt_metal::ReaderDataMovementConfig(reader_compile_time_args)); - - // Tilized writer - tt::tt_metal::KernelHandle unary_writer_kernel_id = tt::tt_metal::CreateKernel( - program, - "ttnn/cpp/ttnn/operations/eltwise/unary/device/kernels/dataflow/writer_unary_interleaved_start_id.cpp", - core, - tt::tt_metal::WriterDataMovementConfig(writer_compile_time_args)); - - vector compute_args = { - num_tiles / num_tiles_per_block, // per_core_block_cnt - num_tiles_per_block // per_core_block_tile_cnt - }; - - auto tilize_kernel_id = tt::tt_metal::CreateKernel( - program, - "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/compute/tilize.cpp", - core, - tt::tt_metal::ComputeConfig{.compile_args = compute_args}); - - tt::tt_metal::SetRuntimeArgs(program, unary_reader_kernel_id, core, reader_kernel_args); - - tt::tt_metal::SetRuntimeArgs(program, unary_writer_kernel_id, core, {dst_buffer->address(), num_tiles, 0}); - - auto override_runtime_args_callback = [reader_kernel_id = unary_reader_kernel_id, - writer_kernel_id = unary_writer_kernel_id]( - const Program& program, - const std::vector& input_buffers, - const std::vector& output_buffers) { - auto src_buffer = input_buffers.at(0); - - auto dst_buffer = output_buffers.at(0); - - CoreCoord core = {0, 0}; - - { - auto& runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); - runtime_args[0] = src_buffer->address(); - } - - { - auto& runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); - runtime_args[0] = dst_buffer->address(); - } - }; - - return {std::move(program), override_runtime_args_callback}; -} - -operation::ProgramWithCallbacks tilize_multi_core_interleaved(const Tensor& a, Tensor& output) { - tt::tt_metal::Program program = tt::tt_metal::CreateProgram(); - - tt::DataFormat input_cb_data_format = datatype_to_dataformat_converter(a.get_dtype()); - uint32_t input_single_tile_size = tt::tt_metal::detail::TileSize(input_cb_data_format); - tt::DataFormat output_cb_data_format = datatype_to_dataformat_converter(output.get_dtype()); - uint32_t output_single_tile_size = tt::tt_metal::detail::TileSize(output_cb_data_format); - - int32_t ntiles = a.volume() / TILE_HW; - uint32_t ntiles_per_block = a.get_legacy_shape()[-1] / TILE_WIDTH; - uint32_t nblocks = std::ceil((float)ntiles / ntiles_per_block); - uint32_t block_size_nbytes = a.get_legacy_shape()[-1] * a.element_size(); - - Device* device = a.device(); - auto grid_size = device->compute_with_storage_grid_size(); - auto [ncores, all_cores, core_range, core_range_cliff, nblocks_per_core, nblocks_per_core_cliff] = - split_blocks_for_tilize(grid_size, nblocks); - - create_cb(tt::CB::c_in0, program, all_cores, input_single_tile_size, ntiles_per_block, input_cb_data_format); - - auto [output_cb_index, _] = - create_cb(tt::CB::c_out0, program, all_cores, output_single_tile_size, ntiles_per_block, output_cb_data_format); - - Buffer* src0_buffer = a.buffer(); - Buffer* dst_buffer = output.buffer(); - TT_ASSERT(dst_buffer != nullptr, "Output buffer should be allocated on device!"); - - /** reader - */ - uint32_t src0_is_dram = src0_buffer->buffer_type() == BufferType::DRAM ? 1 : 0; - uint32_t stick_size_is_power_of_two = is_power_of_two_at_least_32(block_size_nbytes); - uint32_t log2_stick_size = stick_size_is_power_of_two ? (uint32_t)std::log2(block_size_nbytes) : 0; - std::vector reader_ct_args = {src0_is_dram, stick_size_is_power_of_two, log2_stick_size}; - KernelHandle unary_reader_kernel_id = CreateKernel( - program, - "ttnn/cpp/ttnn/operations/data_movement/tilize/device/kernels/dataflow/reader_unary_stick_layout_split_rows_interleaved.cpp", - all_cores, - ReaderDataMovementConfig(reader_ct_args)); - - /** writer - */ - uint32_t out_is_dram = dst_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; - std::vector writer_ct_args = {output_cb_index, out_is_dram}; - KernelHandle unary_writer_kernel_id = CreateKernel( - program, - "ttnn/cpp/ttnn/operations/eltwise/unary/device/kernels/dataflow/writer_unary_interleaved_start_id.cpp", - all_cores, - WriterDataMovementConfig(writer_ct_args)); - - /** compute - */ - vector compute_args = {nblocks_per_core, ntiles_per_block}; - vector compute_args_cliff = {nblocks_per_core_cliff, ntiles_per_block}; - - if (core_range.ranges().size() > 0) { - auto tilize_kernel_id = CreateKernel( - program, - "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/compute/tilize.cpp", - core_range, - ComputeConfig{.compile_args = compute_args}); - } - if (core_range_cliff.size() > 0) { - auto tilize_cliff_kernel_id = CreateKernel( - program, - "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/compute/tilize.cpp", - core_range_cliff, - ComputeConfig{.compile_args = compute_args_cliff}); - } - - // 1D distribution of blocks across cores - bool has_cliff = core_range_cliff.size() > 0; - - uint32_t ncores_full = ncores - has_cliff; - uint32_t ncores_x = grid_size.x; - uint32_t tile_start_id = 0; - uint32_t row_start_id = 0; - const auto& cores = grid_to_cores(ncores, grid_size.x, grid_size.y, true); - for (uint32_t i = 0; i < ncores_full; ++i) { - const CoreCoord& core = cores[i]; - - // reader runtime args - vector reader_rt_args = { - src0_buffer->address(), - nblocks_per_core * TILE_HEIGHT, - block_size_nbytes, - ntiles_per_block, - block_size_nbytes, - 1, // full blocks in row - 0, // num leftover tiles - 0, // leftover width in row - row_start_id}; - - // writer runtime args - vector writer_rt_args = { - dst_buffer->address(), - ntiles_per_block * nblocks_per_core, // ntiles per core - tile_start_id // start id - }; - - SetRuntimeArgs(program, unary_reader_kernel_id, core, reader_rt_args); - SetRuntimeArgs(program, unary_writer_kernel_id, core, writer_rt_args); - - tile_start_id += ntiles_per_block * nblocks_per_core; - row_start_id += TILE_HEIGHT * nblocks_per_core; - } - if (has_cliff) { - // the last core is a cliff core with nblocks_per_core_cliff blocks - const CoreCoord& core = cores.back(); - - // reader runtime args - vector reader_rt_args = { - src0_buffer->address(), - nblocks_per_core_cliff * TILE_HEIGHT, - block_size_nbytes, - ntiles_per_block, - block_size_nbytes, - 1, // full blocks in row - 0, // num leftover tiles - 0, // leftover width in row - row_start_id}; - - // writer runtime args - vector writer_rt_args = { - dst_buffer->address(), - ntiles_per_block * nblocks_per_core_cliff, // ntiles per core - tile_start_id // start id - }; - - SetRuntimeArgs(program, unary_reader_kernel_id, core, reader_rt_args); - SetRuntimeArgs(program, unary_writer_kernel_id, core, writer_rt_args); - } - - auto override_runtime_args_callback = - [reader_kernel_id = unary_reader_kernel_id, writer_kernel_id = unary_writer_kernel_id, cores = cores]( - const Program& program, - const std::vector& input_buffers, - const std::vector& output_buffers) { - auto src_buffer = input_buffers.at(0); - auto dst_buffer = output_buffers.at(0); - - auto& reader_runtime_args_by_core = GetRuntimeArgs(program, reader_kernel_id); - auto& writer_runtime_args_by_core = GetRuntimeArgs(program, writer_kernel_id); - for (const auto& core : cores) { - { - auto& runtime_args = reader_runtime_args_by_core[core.x][core.y]; - runtime_args[0] = src_buffer->address(); - } - { - auto& runtime_args = writer_runtime_args_by_core[core.x][core.y]; - runtime_args[0] = dst_buffer->address(); - } - } - }; - - return {std::move(program), override_runtime_args_callback}; -} - -operation::ProgramWithCallbacks tilize_multi_core_sharded(const Tensor& input, Tensor& output) { - tt::tt_metal::Program program{}; - - tt::DataFormat input_cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(input.get_dtype()); - uint32_t input_single_tile_size = tt::tt_metal::detail::TileSize(input_cb_data_format); - tt::DataFormat output_cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(output.get_dtype()); - uint32_t output_single_tile_size = tt::tt_metal::detail::TileSize(output_cb_data_format); - - uint32_t num_tiles = input.volume() / TILE_HW; - - tt::tt_metal::Device* device = input.device(); - - auto shard_spec = input.shard_spec().value(); - uint32_t num_tiles_per_shard = shard_spec.shape[0] * shard_spec.shape[1] / TILE_HW; - uint32_t num_tiles_per_row = shard_spec.shape[1] / TILE_WIDTH; - auto all_cores = shard_spec.grid; - uint32_t num_cores_x = device->compute_with_storage_grid_size().x; - uint32_t num_cores = all_cores.num_cores(); - - auto [src0_cb_index, cb_src0] = create_cb( - tt::CB::c_in0, - program, - all_cores, - input_single_tile_size, - num_tiles_per_shard, - input_cb_data_format, - input.buffer()); - - auto [output_cb_index, cb_output] = create_cb( - tt::CB::c_out0, - program, - all_cores, - output_single_tile_size, - num_tiles_per_shard, - output_cb_data_format, - output.buffer()); - - auto src_buffer = input.buffer(); - - auto dst_buffer = output.buffer(); - - std::vector reader_compile_time_args = {(std::uint32_t)src0_cb_index}; - - bool dst_is_dram = dst_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; - std::vector writer_compile_time_args = {(std::uint32_t)output_cb_index}; - - tt::tt_metal::KernelHandle unary_reader_kernel_id = tt::tt_metal::CreateKernel( - program, - "ttnn/cpp/ttnn/operations/eltwise/unary/device/kernels/dataflow/reader_unary_sharded.cpp", - all_cores, - tt::tt_metal::ReaderDataMovementConfig(reader_compile_time_args)); - - tt::tt_metal::KernelHandle unary_writer_kernel_id = tt::tt_metal::CreateKernel( - program, - "ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/sharded/kernels/dataflow/writer_unary_sharded.cpp", - all_cores, - tt::tt_metal::WriterDataMovementConfig(writer_compile_time_args)); - - vector compute_args = {uint32_t(num_tiles_per_shard / num_tiles_per_row), uint32_t(num_tiles_per_row)}; - - auto untilize_kernel_id = tt::tt_metal::CreateKernel( - program, - "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/compute/tilize.cpp", - all_cores, - tt::tt_metal::ComputeConfig{.compile_args = compute_args}); - - tt::tt_metal::SetRuntimeArgs(program, unary_reader_kernel_id, all_cores, {num_tiles_per_shard}); - - tt::tt_metal::SetRuntimeArgs(program, unary_writer_kernel_id, all_cores, {num_tiles_per_shard}); - - auto override_runtime_arguments_callback = - [unary_reader_kernel_id, unary_writer_kernel_id, cb_src0, cb_output]( - const void* operation, - Program& program, - const std::vector& input_tensors, - const std::vector>& optional_input_tensors, - const std::vector& output_tensors) { - auto src_buffer = input_tensors.at(0).buffer(); - - auto dst_buffer = output_tensors.at(0).buffer(); - - UpdateDynamicCircularBufferAddress(program, cb_src0, *src_buffer); - - UpdateDynamicCircularBufferAddress(program, cb_output, *dst_buffer); - }; - return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_arguments_callback}; -} +operation::ProgramWithCallbacks tilize_single_core(const Tensor& a, Tensor& output); +operation::ProgramWithCallbacks tilize_multi_core(const Tensor& a, Tensor& output); -operation::ProgramWithCallbacks tilize_multi_core(const Tensor& a, Tensor& output) { - if (a.memory_config().is_sharded()) { - return tilize_multi_core_sharded(a, output); - } else { - return tilize_multi_core_interleaved(a, output); - } -} } // namespace ttnn::operations::data_movement::detail diff --git a/ttnn/cpp/ttnn/operations/data_movement/tilize_with_val_padding/device/tilize_with_val_padding_program_factory.cpp b/ttnn/cpp/ttnn/operations/data_movement/tilize_with_val_padding/device/tilize_with_val_padding_program_factory.cpp new file mode 100644 index 00000000000..c4535254a86 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/data_movement/tilize_with_val_padding/device/tilize_with_val_padding_program_factory.cpp @@ -0,0 +1,491 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + + +#include + +#include "ttnn/deprecated/tt_dnn/op_library/cb_utils.hpp" +#include "ttnn/deprecated/tt_dnn/op_library/math.hpp" +#include "ttnn/operation.hpp" +#include "ttnn/deprecated/tt_dnn/op_library/work_split_tilize.hpp" +#include "tt_metal/common/constants.hpp" +#include "tt_metal/detail/util.hpp" +#include "tt_metal/host_api.hpp" + +using namespace tt::constants; + +namespace ttnn::operations::data_movement::detail { + +operation::ProgramWithCallbacks tilize_with_val_padding_single_core( + const Tensor& a, Tensor& output, const float pad_value) { + auto output_shape = output.get_legacy_shape(); + + tt::tt_metal::Program program = tt::tt_metal::CreateProgram(); + + CoreRange core({0, 0}, {0, 0}); + + // This should allocate a DRAM buffer on the device + tt::tt_metal::Device* device = a.device(); + + tt::tt_metal::Buffer* src0_buffer = a.buffer(); + + tt::DataFormat input_cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(a.get_dtype()); + uint32_t input_single_tile_size = tt::tt_metal::detail::TileSize(input_cb_data_format); + + tt::DataFormat output_cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(output.get_dtype()); + uint32_t output_single_tile_size = tt::tt_metal::detail::TileSize(output_cb_data_format); + + int32_t num_tiles = output.volume() / TILE_HW; + + auto true_input_shape = a.get_legacy_shape(); + auto true_output_shape = output.get_legacy_shape(); + + auto input_w = true_input_shape.rank() >= 4 ? true_input_shape[-4] : 1; + auto input_z = true_input_shape.rank() >= 3 ? true_input_shape[-3] : 1; + auto input_y = true_input_shape.rank() >= 2 ? true_input_shape[-2] : 1; + auto input_x = true_input_shape[-1]; + + auto output_w = true_output_shape.rank() >= 4 ? true_output_shape[-4] : 1; + auto output_z = true_output_shape.rank() >= 3 ? true_output_shape[-3] : 1; + auto output_y = true_output_shape.rank() >= 2 ? true_output_shape[-2] : 1; + auto output_x = true_output_shape[-1]; + + uint32_t unpadded_row_size_bytes = input_x * a.element_size(); // Assuming bfloat16 dataformat + uint32_t padded_row_size_bytes = output_x * a.element_size(); // Assuming bfloat16 dataformat + + constexpr uint32_t alignment = 32; + + uint32_t num_tiles_in_row = output_x / TILE_WIDTH; + // Ensure we don't intrude into storage space + uint32_t max_l1_size = a.device()->l1_size_per_core() / 2 - L1_UNRESERVED_BASE; + // Memory usage is 2 CBs of width W, plus buffer of size alignment + (W * datum size) + uint32_t max_X = (max_l1_size - alignment) / (a.element_size() * TILE_HEIGHT * 2 + a.element_size()); + uint32_t max_tiles = max_X / TILE_WIDTH; + + // Currently need the number of tiles in a row to be divisible by tiles in a block + uint32_t num_tiles_per_block = 1; + if (num_tiles_in_row <= max_tiles) { + num_tiles_per_block = num_tiles_in_row; + } else { + for (uint32_t n_t = max_tiles; n_t > 0; n_t--) { + if (num_tiles_in_row % n_t == 0) { + num_tiles_per_block = n_t; + break; + } + } + } + + uint32_t block_width = num_tiles_per_block * TILE_WIDTH; + uint32_t block_row_size = block_width * a.element_size(); + uint32_t num_blocks_w_output = padded_row_size_bytes / block_row_size; + uint32_t num_blocks_w_input = unpadded_row_size_bytes / block_row_size; + + // Leftover size if input is not divisible by block size + uint32_t block_row_leftover_size = unpadded_row_size_bytes - num_blocks_w_input * block_row_size; + + // Number of blocks that differ between input and output + const uint32_t num_blocks_w_diff = num_blocks_w_output - num_blocks_w_input - (block_row_leftover_size > 0 ? 1 : 0); + + const uint32_t padded_Y_diff_blocks = (output_y - input_y) / TILE_HEIGHT * num_blocks_w_output; + const uint32_t padded_Z_diff_blocks = (output_z - input_z) * output_y / TILE_HEIGHT * num_blocks_w_output; + const uint32_t padded_W_diff_blocks = + (output_w - input_w) * output_z * output_y / TILE_HEIGHT * num_blocks_w_output; + const uint32_t num_leftover_Y = input_y - input_y / TILE_HEIGHT * TILE_HEIGHT; + + tt::tt_metal::Buffer* dst_buffer = output.buffer(); + TT_ASSERT(dst_buffer != nullptr, "Output buffer should be allocated on device!"); + + uint32_t src0_cb_index = 0; + uint32_t num_input_tiles = num_tiles_per_block; + assert(num_input_tiles > 0); + tt::tt_metal::CircularBufferConfig src0_cb_config = + tt::tt_metal::CircularBufferConfig( + num_input_tiles * input_single_tile_size, {{src0_cb_index, input_cb_data_format}}) + .set_page_size(src0_cb_index, input_single_tile_size); + auto cb_src0 = tt::tt_metal::CreateCircularBuffer(program, core, src0_cb_config); + + uint32_t output_cb_index = 16; // output operands start at index 16 + uint32_t num_output_tiles = num_tiles_per_block; + tt::tt_metal::CircularBufferConfig cb_output_config = + tt::tt_metal::CircularBufferConfig( + num_output_tiles * output_single_tile_size, {{output_cb_index, output_cb_data_format}}) + .set_page_size(output_cb_index, output_single_tile_size); + auto cb_output = tt::tt_metal::CreateCircularBuffer(program, core, cb_output_config); + + bfloat16 bfloat_pad_value = bfloat16(pad_value); + uint32_t packed_pad_value = pack_two_bfloat16_into_uint32({bfloat_pad_value, bfloat_pad_value}); + + vector reader_kernel_args = { + src0_buffer->address(), + input_w, + padded_W_diff_blocks, + input_z, + padded_Z_diff_blocks, + input_y, + padded_Y_diff_blocks, + num_leftover_Y, + input_x, + unpadded_row_size_bytes, + padded_row_size_bytes, + packed_pad_value, + num_blocks_w_input, + num_blocks_w_output, + num_blocks_w_diff, + block_row_size, + block_row_leftover_size}; + + // Reader compile-time args + uint32_t src0_is_dram = src0_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; + uint32_t stick_size = unpadded_row_size_bytes; + uint32_t stick_size_is_power_of_two = is_power_of_two_at_least_32(stick_size); + uint32_t log2_stick_size = stick_size_is_power_of_two ? (uint32_t)log2(stick_size) : 0; + std::vector reader_compile_time_args = {src0_is_dram, stick_size_is_power_of_two, log2_stick_size}; + + // Tilized reader + tt::tt_metal::KernelHandle unary_reader_kernel_id = tt::tt_metal::CreateKernel( + program, + "ttnn/cpp/ttnn/operations/data_movement/tilize_with_val_padding/device/kernels/dataflow/reader_unary_pad_dims_split_rows.cpp", + core, + tt::tt_metal::ReaderDataMovementConfig(reader_compile_time_args)); + + // Tilized writer + uint32_t out_is_dram = dst_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; + tt::tt_metal::KernelHandle unary_writer_kernel_id = tt::tt_metal::CreateKernel( + program, + "ttnn/cpp/ttnn/operations/eltwise/unary/device/kernels/dataflow/writer_unary_interleaved_start_id.cpp", + core, + tt::tt_metal::WriterDataMovementConfig({output_cb_index, out_is_dram})); + + vector compute_kernel_args = {uint32_t(num_tiles / num_tiles_per_block), uint32_t(num_tiles_per_block)}; + + auto tilize_kernel_id = tt::tt_metal::CreateKernel( + program, + "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/compute/tilize.cpp", + core, + tt::tt_metal::ComputeConfig{.compile_args = compute_kernel_args}); + + tt::tt_metal::SetRuntimeArgs(program, unary_reader_kernel_id, core, reader_kernel_args); + + tt::tt_metal::SetRuntimeArgs( + program, unary_writer_kernel_id, core, {dst_buffer->address(), (uint32_t)num_tiles, 0}); + + auto override_runtime_args_callback = [reader_kernel_id = unary_reader_kernel_id, + writer_kernel_id = unary_writer_kernel_id]( + const Program& program, + const std::vector& input_buffers, + const std::vector& output_buffers) { + auto src_buffer = input_buffers.at(0); + + auto dst_buffer = output_buffers.at(0); + + CoreCoord core = {0, 0}; + + { + auto& runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); + runtime_args[0] = src_buffer->address(); + } + + { + auto& runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + runtime_args[0] = dst_buffer->address(); + } + }; + + return {std::move(program), override_runtime_args_callback}; +} + +operation::ProgramWithCallbacks tilize_with_val_padding_multi_core_interleaved( + const Tensor& a, Tensor& output, const float pad_value) { + tt::tt_metal::Program program = tt::tt_metal::CreateProgram(); + + tt::DataFormat input_cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(a.get_dtype()); + uint32_t input_single_tile_size = tt::tt_metal::detail::TileSize(input_cb_data_format); + tt::DataFormat output_cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(output.get_dtype()); + uint32_t output_single_tile_size = tt::tt_metal::detail::TileSize(output_cb_data_format); + + Device* device = a.device(); + CoreCoord grid_size = device->compute_with_storage_grid_size(); + + uint32_t num_blocks = output.volume() / output.get_legacy_shape()[-1] / TILE_HEIGHT; + uint32_t num_tiles_per_row = output.get_legacy_shape()[-1] / TILE_WIDTH; + + auto [ncores, all_cores, core_range, core_range_cliff, nblocks_per_core, nblocks_per_core_cliff] = + split_blocks_for_tilize(grid_size, num_blocks); + + bool has_cliff = core_range_cliff.size() > 0; + + uint32_t unpadded_row_size_bytes = a.get_legacy_shape()[-1] * a.element_size(); // Assuming bfloat16 dataformat + uint32_t padded_row_size_bytes = output.get_legacy_shape()[-1] * a.element_size(); // Assuming bfloat16 dataformat + + auto [src0_cb_index, cb_src0] = + create_cb(tt::CB::c_in0, program, all_cores, input_single_tile_size, num_tiles_per_row, input_cb_data_format); + + auto [output_cb_index, cb_output] = create_cb( + tt::CB::c_out0, program, all_cores, output_single_tile_size, num_tiles_per_row, output_cb_data_format); + + Buffer* src0_buffer = a.buffer(); + Buffer* dst_buffer = output.buffer(); + TT_ASSERT(dst_buffer != nullptr, "Output buffer should be allocated on device!"); + + /** reader + */ + uint32_t src0_is_dram = src0_buffer->buffer_type() == BufferType::DRAM ? 1 : 0; + uint32_t stick_size = unpadded_row_size_bytes; + uint32_t stick_size_is_power_of_two = is_power_of_two_at_least_32(stick_size); + uint32_t log2_stick_size = stick_size_is_power_of_two ? (std::uint32_t)std::log2(stick_size) : 0; + + KernelHandle unary_reader_kernel_id = CreateKernel( + program, + "ttnn/cpp/ttnn/operations/data_movement/tilize_with_val_padding/device/kernels/dataflow/reader_unary_pad_dims_split_rows_multicore.cpp", + all_cores, + ReaderDataMovementConfig({src0_is_dram, stick_size_is_power_of_two, log2_stick_size})); + + /** writer + */ + uint32_t out_is_dram = dst_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; + + KernelHandle unary_writer_kernel_id = CreateKernel( + program, + "ttnn/cpp/ttnn/operations/eltwise/unary/device/kernels/dataflow/writer_unary_interleaved_start_id.cpp", + all_cores, + WriterDataMovementConfig({output_cb_index, out_is_dram})); + + /** compute + */ + if (core_range.size() > 0) { + auto tilize_kernel_id = CreateKernel( + program, + "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/compute/tilize.cpp", + core_range, + ComputeConfig{.compile_args = {nblocks_per_core, num_tiles_per_row}}); + } + if (has_cliff) { + auto tilize_cliff_kernel_id = CreateKernel( + program, + "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/compute/tilize.cpp", + core_range_cliff, + ComputeConfig{.compile_args = {nblocks_per_core_cliff, num_tiles_per_row}}); + } + + /* RUNTIME ARGS */ + + bfloat16 bfloat_pad_value = bfloat16(pad_value); + uint32_t packed_pad_value = pack_two_bfloat16_into_uint32({bfloat_pad_value, bfloat_pad_value}); + + // 1D distribution of blocks across cores + auto core_assignments = distribute_work( + output.get_legacy_shape().without_padding(), + output.get_legacy_shape().padding(), + ncores, + nblocks_per_core, + has_cliff, + nblocks_per_core_cliff); + + uint32_t tile_start_id = 0; + uint32_t row_start_id = 0; + uint32_t ncores_x = grid_size.x; + + const auto& cores = grid_to_cores(ncores, grid_size.x, grid_size.y, true); + for (uint32_t i = 0; i < ncores; ++i) { + const auto& core = cores[i]; + const std::vector& assignment = core_assignments.at(i); + + // reader runtime args + vector reader_rt_args = { + src0_buffer->address(), + unpadded_row_size_bytes, + padded_row_size_bytes, + packed_pad_value, + row_start_id, + static_cast(assignment.size()), + }; + + uint32_t nblocks_per_core = 0; + + for (const auto& el : assignment) { + nblocks_per_core += el.block_count(); + row_start_id += el.data_row_count(); + reader_rt_args.push_back(el.n_data); + reader_rt_args.push_back(el.n_mixed); + reader_rt_args.push_back(el.n_pads); + reader_rt_args.push_back(el.times); + } + + uint32_t num_tiles_per_core = num_tiles_per_row * nblocks_per_core; + + // writer runtime args + vector writer_rt_args = {dst_buffer->address(), num_tiles_per_core, tile_start_id}; + + SetRuntimeArgs(program, unary_reader_kernel_id, core, reader_rt_args); + SetRuntimeArgs(program, unary_writer_kernel_id, core, writer_rt_args); + + tile_start_id += num_tiles_per_core; + } + + auto override_runtime_args_callback = + [reader_kernel_id = unary_reader_kernel_id, writer_kernel_id = unary_writer_kernel_id, cores = cores]( + const Program& program, + const std::vector& input_buffers, + const std::vector& output_buffers) { + auto src_buffer = input_buffers.at(0); + auto dst_buffer = output_buffers.at(0); + + auto& reader_runtime_args_by_core = GetRuntimeArgs(program, reader_kernel_id); + auto& writer_runtime_args_by_core = GetRuntimeArgs(program, writer_kernel_id); + for (const auto& core : cores) { + { + auto& runtime_args = reader_runtime_args_by_core[core.x][core.y]; + runtime_args[0] = src_buffer->address(); + } + { + auto& runtime_args = writer_runtime_args_by_core[core.x][core.y]; + runtime_args[0] = dst_buffer->address(); + } + } + }; + + return {std::move(program), override_runtime_args_callback}; +} + +// This purely supports input width shard -> output width shard for now +operation::ProgramWithCallbacks tilize_with_val_padding_multi_core_sharded( + const Tensor& a, Tensor& output, const float pad_value) { + tt::tt_metal::Program program = tt::tt_metal::CreateProgram(); + + bool src_sharded = a.memory_config().is_sharded(); + bool out_sharded = output.memory_config().is_sharded(); + + tt::DataFormat input_cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(a.get_dtype()); + uint32_t input_single_tile_size = tt::tt_metal::detail::TileSize(input_cb_data_format); + tt::DataFormat output_cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(output.get_dtype()); + uint32_t output_single_tile_size = tt::tt_metal::detail::TileSize(output_cb_data_format); + + Device* device = a.device(); + + auto input_shard_spec = a.shard_spec().value(); + auto output_shard_spec = output.shard_spec().value(); + + auto all_cores = output_shard_spec.grid; + + uint32_t num_batches = output.volume() / (output.get_legacy_shape()[-2] * output.get_legacy_shape()[-1]); + + uint32_t num_input_rows = input_shard_spec.shape[0]; + uint32_t input_shard_width_bytes = input_shard_spec.shape[1] * a.element_size(); + uint32_t ntiles_per_core = output_shard_spec.shape[0] * output_shard_spec.shape[1] / TILE_HW; + uint32_t ntiles_per_batch = ntiles_per_core / num_batches; + uint32_t ntiles_per_block = output_shard_spec.shape[1] / TILE_WIDTH; + uint32_t nblocks_per_core = output_shard_spec.shape[0] / TILE_HEIGHT; + uint32_t num_padded_rows = output.get_legacy_shape()[-2] - a.get_legacy_shape()[-2]; + + auto [src0_cb_index, cb_src0] = create_cb( + tt::CB::c_in1, + program, + all_cores, + input_shard_width_bytes, + num_input_rows, + input_cb_data_format, + src_sharded ? a.buffer() : nullptr); + + auto [src1_cb_index, cb_src1] = create_cb( + tt::CB::c_in0, program, all_cores, input_single_tile_size, ntiles_per_batch * 2, input_cb_data_format); + + auto [src2_cb_index, cb_src2] = + create_cb(tt::CB::c_in2, program, all_cores, input_shard_width_bytes, 1, input_cb_data_format); + + auto [output_cb_index, cb_output] = create_cb( + tt::CB::c_out0, + program, + all_cores, + output_single_tile_size, + ntiles_per_core, + output_cb_data_format, + out_sharded ? output.buffer() : nullptr); + + Buffer* src0_buffer = a.buffer(); + Buffer* dst_buffer = output.buffer(); + TT_ASSERT(dst_buffer != nullptr, "Output buffer should be allocated on device!"); + + /** reader + */ + KernelHandle unary_reader_kernel_id; + std::vector reader_ct_args = { + (std::uint32_t)src0_cb_index, + (std::uint32_t)src1_cb_index, + (std::uint32_t)src2_cb_index, + }; + + unary_reader_kernel_id = tt::tt_metal::CreateKernel( + program, + "ttnn/cpp/ttnn/operations/data_movement/tilize_with_val_padding/device/kernels/dataflow/reader_unary_pad_height_width_sharded.cpp", + all_cores, + tt::tt_metal::ReaderDataMovementConfig(reader_ct_args)); + + /** writer + */ + KernelHandle unary_writer_kernel_id; + bool out_is_dram = dst_buffer->buffer_type() == BufferType::DRAM ? 1 : 0; + vector writer_ct_args = { + output_cb_index, + }; + unary_writer_kernel_id = CreateKernel( + program, + "ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/sharded/kernels/dataflow/writer_unary_sharded.cpp", + all_cores, + WriterDataMovementConfig(writer_ct_args)); + + /** compute + */ + vector compute_args = { + (uint32_t)nblocks_per_core, // per_core_block_cnt + (uint32_t)ntiles_per_block, // per_block_ntiles + }; + + auto tilize_kernel_id = CreateKernel( + program, "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/compute/tilize.cpp", all_cores, ComputeConfig{.compile_args = compute_args}); + + bfloat16 bfloat_pad_value = bfloat16(pad_value); + uint32_t packed_pad_value = pack_two_bfloat16_into_uint32({bfloat_pad_value, bfloat_pad_value}); + + vector reader_rt_args = { + num_input_rows, + input_shard_width_bytes, + (num_input_rows / num_batches) * input_shard_width_bytes, + ntiles_per_batch, + num_padded_rows, + num_batches, + packed_pad_value}; + tt::tt_metal::SetRuntimeArgs(program, unary_reader_kernel_id, all_cores, reader_rt_args); + + vector writer_rt_args = {ntiles_per_core}; + tt::tt_metal::SetRuntimeArgs(program, unary_writer_kernel_id, all_cores, writer_rt_args); + + auto override_runtime_arguments_callback = [reader_kernel_id = unary_reader_kernel_id, + writer_kernel_id = unary_writer_kernel_id, + cb_src0 = cb_src0, + cb_output = cb_output]( + const void* operation, + Program& program, + const std::vector& input_tensors, + const std::vector>&, + const std::vector& output_tensors) { + auto src_buffer = input_tensors.at(0).buffer(); + auto dst_buffer = output_tensors.at(0).buffer(); + + UpdateDynamicCircularBufferAddress(program, cb_src0, *src_buffer); + UpdateDynamicCircularBufferAddress(program, cb_output, *dst_buffer); + }; + + return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_arguments_callback}; +} + +operation::ProgramWithCallbacks tilize_with_val_padding_multi_core( + const Tensor& a, Tensor& output, const float pad_value) { + if (a.memory_config().is_sharded()) { + return tilize_with_val_padding_multi_core_sharded(a, output, pad_value); + } else { + return tilize_with_val_padding_multi_core_interleaved(a, output, pad_value); + } +} + +} // namespace ttnn::operations::data_movement::detail diff --git a/ttnn/cpp/ttnn/operations/data_movement/tilize_with_val_padding/device/tilize_with_val_padding_program_factory.hpp b/ttnn/cpp/ttnn/operations/data_movement/tilize_with_val_padding/device/tilize_with_val_padding_program_factory.hpp index 7db439aa9df..be4cb3beccf 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/tilize_with_val_padding/device/tilize_with_val_padding_program_factory.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/tilize_with_val_padding/device/tilize_with_val_padding_program_factory.hpp @@ -4,14 +4,6 @@ #pragma once -#include - -#include "ttnn/deprecated/tt_dnn/op_library/cb_utils.hpp" -#include "ttnn/deprecated/tt_dnn/op_library/math.hpp" -#include "ttnn/operation.hpp" -#include "ttnn/deprecated/tt_dnn/op_library/work_split_tilize.hpp" -#include "tt_metal/common/constants.hpp" -#include "tt_metal/detail/util.hpp" #include "tt_metal/host_api.hpp" using namespace tt::constants; @@ -19,474 +11,11 @@ using namespace tt::constants; namespace ttnn::operations::data_movement::detail { operation::ProgramWithCallbacks tilize_with_val_padding_single_core( - const Tensor& a, Tensor& output, const float pad_value) { - auto output_shape = output.get_legacy_shape(); - - tt::tt_metal::Program program = tt::tt_metal::CreateProgram(); - - CoreRange core({0, 0}, {0, 0}); - - // This should allocate a DRAM buffer on the device - tt::tt_metal::Device* device = a.device(); - - tt::tt_metal::Buffer* src0_buffer = a.buffer(); - - tt::DataFormat input_cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(a.get_dtype()); - uint32_t input_single_tile_size = tt::tt_metal::detail::TileSize(input_cb_data_format); - - tt::DataFormat output_cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(output.get_dtype()); - uint32_t output_single_tile_size = tt::tt_metal::detail::TileSize(output_cb_data_format); - - int32_t num_tiles = output.volume() / TILE_HW; - - auto true_input_shape = a.get_legacy_shape(); - auto true_output_shape = output.get_legacy_shape(); - - auto input_w = true_input_shape.rank() >= 4 ? true_input_shape[-4] : 1; - auto input_z = true_input_shape.rank() >= 3 ? true_input_shape[-3] : 1; - auto input_y = true_input_shape.rank() >= 2 ? true_input_shape[-2] : 1; - auto input_x = true_input_shape[-1]; - - auto output_w = true_output_shape.rank() >= 4 ? true_output_shape[-4] : 1; - auto output_z = true_output_shape.rank() >= 3 ? true_output_shape[-3] : 1; - auto output_y = true_output_shape.rank() >= 2 ? true_output_shape[-2] : 1; - auto output_x = true_output_shape[-1]; - - uint32_t unpadded_row_size_bytes = input_x * a.element_size(); // Assuming bfloat16 dataformat - uint32_t padded_row_size_bytes = output_x * a.element_size(); // Assuming bfloat16 dataformat - - constexpr uint32_t alignment = 32; - - uint32_t num_tiles_in_row = output_x / TILE_WIDTH; - // Ensure we don't intrude into storage space - uint32_t max_l1_size = a.device()->l1_size_per_core() / 2 - L1_UNRESERVED_BASE; - // Memory usage is 2 CBs of width W, plus buffer of size alignment + (W * datum size) - uint32_t max_X = (max_l1_size - alignment) / (a.element_size() * TILE_HEIGHT * 2 + a.element_size()); - uint32_t max_tiles = max_X / TILE_WIDTH; - - // Currently need the number of tiles in a row to be divisible by tiles in a block - uint32_t num_tiles_per_block = 1; - if (num_tiles_in_row <= max_tiles) { - num_tiles_per_block = num_tiles_in_row; - } else { - for (uint32_t n_t = max_tiles; n_t > 0; n_t--) { - if (num_tiles_in_row % n_t == 0) { - num_tiles_per_block = n_t; - break; - } - } - } - - uint32_t block_width = num_tiles_per_block * TILE_WIDTH; - uint32_t block_row_size = block_width * a.element_size(); - uint32_t num_blocks_w_output = padded_row_size_bytes / block_row_size; - uint32_t num_blocks_w_input = unpadded_row_size_bytes / block_row_size; - - // Leftover size if input is not divisible by block size - uint32_t block_row_leftover_size = unpadded_row_size_bytes - num_blocks_w_input * block_row_size; - - // Number of blocks that differ between input and output - const uint32_t num_blocks_w_diff = num_blocks_w_output - num_blocks_w_input - (block_row_leftover_size > 0 ? 1 : 0); - - const uint32_t padded_Y_diff_blocks = (output_y - input_y) / TILE_HEIGHT * num_blocks_w_output; - const uint32_t padded_Z_diff_blocks = (output_z - input_z) * output_y / TILE_HEIGHT * num_blocks_w_output; - const uint32_t padded_W_diff_blocks = - (output_w - input_w) * output_z * output_y / TILE_HEIGHT * num_blocks_w_output; - const uint32_t num_leftover_Y = input_y - input_y / TILE_HEIGHT * TILE_HEIGHT; - - tt::tt_metal::Buffer* dst_buffer = output.buffer(); - TT_ASSERT(dst_buffer != nullptr, "Output buffer should be allocated on device!"); - - uint32_t src0_cb_index = 0; - uint32_t num_input_tiles = num_tiles_per_block; - assert(num_input_tiles > 0); - tt::tt_metal::CircularBufferConfig src0_cb_config = - tt::tt_metal::CircularBufferConfig( - num_input_tiles * input_single_tile_size, {{src0_cb_index, input_cb_data_format}}) - .set_page_size(src0_cb_index, input_single_tile_size); - auto cb_src0 = tt::tt_metal::CreateCircularBuffer(program, core, src0_cb_config); - - uint32_t output_cb_index = 16; // output operands start at index 16 - uint32_t num_output_tiles = num_tiles_per_block; - tt::tt_metal::CircularBufferConfig cb_output_config = - tt::tt_metal::CircularBufferConfig( - num_output_tiles * output_single_tile_size, {{output_cb_index, output_cb_data_format}}) - .set_page_size(output_cb_index, output_single_tile_size); - auto cb_output = tt::tt_metal::CreateCircularBuffer(program, core, cb_output_config); - - bfloat16 bfloat_pad_value = bfloat16(pad_value); - uint32_t packed_pad_value = pack_two_bfloat16_into_uint32({bfloat_pad_value, bfloat_pad_value}); - - vector reader_kernel_args = { - src0_buffer->address(), - input_w, - padded_W_diff_blocks, - input_z, - padded_Z_diff_blocks, - input_y, - padded_Y_diff_blocks, - num_leftover_Y, - input_x, - unpadded_row_size_bytes, - padded_row_size_bytes, - packed_pad_value, - num_blocks_w_input, - num_blocks_w_output, - num_blocks_w_diff, - block_row_size, - block_row_leftover_size}; - - // Reader compile-time args - uint32_t src0_is_dram = src0_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; - uint32_t stick_size = unpadded_row_size_bytes; - uint32_t stick_size_is_power_of_two = is_power_of_two_at_least_32(stick_size); - uint32_t log2_stick_size = stick_size_is_power_of_two ? (uint32_t)log2(stick_size) : 0; - std::vector reader_compile_time_args = {src0_is_dram, stick_size_is_power_of_two, log2_stick_size}; - - // Tilized reader - tt::tt_metal::KernelHandle unary_reader_kernel_id = tt::tt_metal::CreateKernel( - program, - "ttnn/cpp/ttnn/operations/data_movement/tilize_with_val_padding/device/kernels/dataflow/reader_unary_pad_dims_split_rows.cpp", - core, - tt::tt_metal::ReaderDataMovementConfig(reader_compile_time_args)); - - // Tilized writer - uint32_t out_is_dram = dst_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; - tt::tt_metal::KernelHandle unary_writer_kernel_id = tt::tt_metal::CreateKernel( - program, - "ttnn/cpp/ttnn/operations/eltwise/unary/device/kernels/dataflow/writer_unary_interleaved_start_id.cpp", - core, - tt::tt_metal::WriterDataMovementConfig({output_cb_index, out_is_dram})); - - vector compute_kernel_args = {uint32_t(num_tiles / num_tiles_per_block), uint32_t(num_tiles_per_block)}; - - auto tilize_kernel_id = tt::tt_metal::CreateKernel( - program, - "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/compute/tilize.cpp", - core, - tt::tt_metal::ComputeConfig{.compile_args = compute_kernel_args}); - - tt::tt_metal::SetRuntimeArgs(program, unary_reader_kernel_id, core, reader_kernel_args); - - tt::tt_metal::SetRuntimeArgs( - program, unary_writer_kernel_id, core, {dst_buffer->address(), (uint32_t)num_tiles, 0}); - - auto override_runtime_args_callback = [reader_kernel_id = unary_reader_kernel_id, - writer_kernel_id = unary_writer_kernel_id]( - const Program& program, - const std::vector& input_buffers, - const std::vector& output_buffers) { - auto src_buffer = input_buffers.at(0); - - auto dst_buffer = output_buffers.at(0); - - CoreCoord core = {0, 0}; - - { - auto& runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); - runtime_args[0] = src_buffer->address(); - } - - { - auto& runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); - runtime_args[0] = dst_buffer->address(); - } - }; - - return {std::move(program), override_runtime_args_callback}; -} - -operation::ProgramWithCallbacks tilize_with_val_padding_multi_core_interleaved( - const Tensor& a, Tensor& output, const float pad_value) { - tt::tt_metal::Program program = tt::tt_metal::CreateProgram(); - - tt::DataFormat input_cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(a.get_dtype()); - uint32_t input_single_tile_size = tt::tt_metal::detail::TileSize(input_cb_data_format); - tt::DataFormat output_cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(output.get_dtype()); - uint32_t output_single_tile_size = tt::tt_metal::detail::TileSize(output_cb_data_format); - - Device* device = a.device(); - CoreCoord grid_size = device->compute_with_storage_grid_size(); - - uint32_t num_blocks = output.volume() / output.get_legacy_shape()[-1] / TILE_HEIGHT; - uint32_t num_tiles_per_row = output.get_legacy_shape()[-1] / TILE_WIDTH; - - auto [ncores, all_cores, core_range, core_range_cliff, nblocks_per_core, nblocks_per_core_cliff] = - split_blocks_for_tilize(grid_size, num_blocks); - - bool has_cliff = core_range_cliff.size() > 0; - - uint32_t unpadded_row_size_bytes = a.get_legacy_shape()[-1] * a.element_size(); // Assuming bfloat16 dataformat - uint32_t padded_row_size_bytes = output.get_legacy_shape()[-1] * a.element_size(); // Assuming bfloat16 dataformat - - auto [src0_cb_index, cb_src0] = - create_cb(tt::CB::c_in0, program, all_cores, input_single_tile_size, num_tiles_per_row, input_cb_data_format); - - auto [output_cb_index, cb_output] = create_cb( - tt::CB::c_out0, program, all_cores, output_single_tile_size, num_tiles_per_row, output_cb_data_format); - - Buffer* src0_buffer = a.buffer(); - Buffer* dst_buffer = output.buffer(); - TT_ASSERT(dst_buffer != nullptr, "Output buffer should be allocated on device!"); - - /** reader - */ - uint32_t src0_is_dram = src0_buffer->buffer_type() == BufferType::DRAM ? 1 : 0; - uint32_t stick_size = unpadded_row_size_bytes; - uint32_t stick_size_is_power_of_two = is_power_of_two_at_least_32(stick_size); - uint32_t log2_stick_size = stick_size_is_power_of_two ? (std::uint32_t)std::log2(stick_size) : 0; - - KernelHandle unary_reader_kernel_id = CreateKernel( - program, - "ttnn/cpp/ttnn/operations/data_movement/tilize_with_val_padding/device/kernels/dataflow/reader_unary_pad_dims_split_rows_multicore.cpp", - all_cores, - ReaderDataMovementConfig({src0_is_dram, stick_size_is_power_of_two, log2_stick_size})); - - /** writer - */ - uint32_t out_is_dram = dst_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; - - KernelHandle unary_writer_kernel_id = CreateKernel( - program, - "ttnn/cpp/ttnn/operations/eltwise/unary/device/kernels/dataflow/writer_unary_interleaved_start_id.cpp", - all_cores, - WriterDataMovementConfig({output_cb_index, out_is_dram})); - - /** compute - */ - if (core_range.size() > 0) { - auto tilize_kernel_id = CreateKernel( - program, - "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/compute/tilize.cpp", - core_range, - ComputeConfig{.compile_args = {nblocks_per_core, num_tiles_per_row}}); - } - if (has_cliff) { - auto tilize_cliff_kernel_id = CreateKernel( - program, - "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/compute/tilize.cpp", - core_range_cliff, - ComputeConfig{.compile_args = {nblocks_per_core_cliff, num_tiles_per_row}}); - } - - /* RUNTIME ARGS */ - - bfloat16 bfloat_pad_value = bfloat16(pad_value); - uint32_t packed_pad_value = pack_two_bfloat16_into_uint32({bfloat_pad_value, bfloat_pad_value}); - - // 1D distribution of blocks across cores - auto core_assignments = distribute_work( - output.get_legacy_shape().without_padding(), - output.get_legacy_shape().padding(), - ncores, - nblocks_per_core, - has_cliff, - nblocks_per_core_cliff); - - uint32_t tile_start_id = 0; - uint32_t row_start_id = 0; - uint32_t ncores_x = grid_size.x; - - const auto& cores = grid_to_cores(ncores, grid_size.x, grid_size.y, true); - for (uint32_t i = 0; i < ncores; ++i) { - const auto& core = cores[i]; - const std::vector& assignment = core_assignments.at(i); - - // reader runtime args - vector reader_rt_args = { - src0_buffer->address(), - unpadded_row_size_bytes, - padded_row_size_bytes, - packed_pad_value, - row_start_id, - static_cast(assignment.size()), - }; - - uint32_t nblocks_per_core = 0; - - for (const auto& el : assignment) { - nblocks_per_core += el.block_count(); - row_start_id += el.data_row_count(); - reader_rt_args.push_back(el.n_data); - reader_rt_args.push_back(el.n_mixed); - reader_rt_args.push_back(el.n_pads); - reader_rt_args.push_back(el.times); - } - - uint32_t num_tiles_per_core = num_tiles_per_row * nblocks_per_core; - - // writer runtime args - vector writer_rt_args = {dst_buffer->address(), num_tiles_per_core, tile_start_id}; - - SetRuntimeArgs(program, unary_reader_kernel_id, core, reader_rt_args); - SetRuntimeArgs(program, unary_writer_kernel_id, core, writer_rt_args); - - tile_start_id += num_tiles_per_core; - } - - auto override_runtime_args_callback = - [reader_kernel_id = unary_reader_kernel_id, writer_kernel_id = unary_writer_kernel_id, cores = cores]( - const Program& program, - const std::vector& input_buffers, - const std::vector& output_buffers) { - auto src_buffer = input_buffers.at(0); - auto dst_buffer = output_buffers.at(0); - - auto& reader_runtime_args_by_core = GetRuntimeArgs(program, reader_kernel_id); - auto& writer_runtime_args_by_core = GetRuntimeArgs(program, writer_kernel_id); - for (const auto& core : cores) { - { - auto& runtime_args = reader_runtime_args_by_core[core.x][core.y]; - runtime_args[0] = src_buffer->address(); - } - { - auto& runtime_args = writer_runtime_args_by_core[core.x][core.y]; - runtime_args[0] = dst_buffer->address(); - } - } - }; - - return {std::move(program), override_runtime_args_callback}; -} - -// This purely supports input width shard -> output width shard for now -operation::ProgramWithCallbacks tilize_with_val_padding_multi_core_sharded( - const Tensor& a, Tensor& output, const float pad_value) { - tt::tt_metal::Program program = tt::tt_metal::CreateProgram(); - - bool src_sharded = a.memory_config().is_sharded(); - bool out_sharded = output.memory_config().is_sharded(); - - tt::DataFormat input_cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(a.get_dtype()); - uint32_t input_single_tile_size = tt::tt_metal::detail::TileSize(input_cb_data_format); - tt::DataFormat output_cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(output.get_dtype()); - uint32_t output_single_tile_size = tt::tt_metal::detail::TileSize(output_cb_data_format); - - Device* device = a.device(); - - auto input_shard_spec = a.shard_spec().value(); - auto output_shard_spec = output.shard_spec().value(); - - auto all_cores = output_shard_spec.grid; - - uint32_t num_batches = output.volume() / (output.get_legacy_shape()[-2] * output.get_legacy_shape()[-1]); - - uint32_t num_input_rows = input_shard_spec.shape[0]; - uint32_t input_shard_width_bytes = input_shard_spec.shape[1] * a.element_size(); - uint32_t ntiles_per_core = output_shard_spec.shape[0] * output_shard_spec.shape[1] / TILE_HW; - uint32_t ntiles_per_batch = ntiles_per_core / num_batches; - uint32_t ntiles_per_block = output_shard_spec.shape[1] / TILE_WIDTH; - uint32_t nblocks_per_core = output_shard_spec.shape[0] / TILE_HEIGHT; - uint32_t num_padded_rows = output.get_legacy_shape()[-2] - a.get_legacy_shape()[-2]; - - auto [src0_cb_index, cb_src0] = create_cb( - tt::CB::c_in1, - program, - all_cores, - input_shard_width_bytes, - num_input_rows, - input_cb_data_format, - src_sharded ? a.buffer() : nullptr); - - auto [src1_cb_index, cb_src1] = create_cb( - tt::CB::c_in0, program, all_cores, input_single_tile_size, ntiles_per_batch * 2, input_cb_data_format); - - auto [src2_cb_index, cb_src2] = - create_cb(tt::CB::c_in2, program, all_cores, input_shard_width_bytes, 1, input_cb_data_format); - - auto [output_cb_index, cb_output] = create_cb( - tt::CB::c_out0, - program, - all_cores, - output_single_tile_size, - ntiles_per_core, - output_cb_data_format, - out_sharded ? output.buffer() : nullptr); - - Buffer* src0_buffer = a.buffer(); - Buffer* dst_buffer = output.buffer(); - TT_ASSERT(dst_buffer != nullptr, "Output buffer should be allocated on device!"); - - /** reader - */ - KernelHandle unary_reader_kernel_id; - std::vector reader_ct_args = { - (std::uint32_t)src0_cb_index, - (std::uint32_t)src1_cb_index, - (std::uint32_t)src2_cb_index, - }; - - unary_reader_kernel_id = tt::tt_metal::CreateKernel( - program, - "ttnn/cpp/ttnn/operations/data_movement/tilize_with_val_padding/device/kernels/dataflow/reader_unary_pad_height_width_sharded.cpp", - all_cores, - tt::tt_metal::ReaderDataMovementConfig(reader_ct_args)); - - /** writer - */ - KernelHandle unary_writer_kernel_id; - bool out_is_dram = dst_buffer->buffer_type() == BufferType::DRAM ? 1 : 0; - vector writer_ct_args = { - output_cb_index, - }; - unary_writer_kernel_id = CreateKernel( - program, - "ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/sharded/kernels/dataflow/writer_unary_sharded.cpp", - all_cores, - WriterDataMovementConfig(writer_ct_args)); - - /** compute - */ - vector compute_args = { - (uint32_t)nblocks_per_core, // per_core_block_cnt - (uint32_t)ntiles_per_block, // per_block_ntiles - }; - - auto tilize_kernel_id = CreateKernel( - program, "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/compute/tilize.cpp", all_cores, ComputeConfig{.compile_args = compute_args}); - - bfloat16 bfloat_pad_value = bfloat16(pad_value); - uint32_t packed_pad_value = pack_two_bfloat16_into_uint32({bfloat_pad_value, bfloat_pad_value}); - - vector reader_rt_args = { - num_input_rows, - input_shard_width_bytes, - (num_input_rows / num_batches) * input_shard_width_bytes, - ntiles_per_batch, - num_padded_rows, - num_batches, - packed_pad_value}; - tt::tt_metal::SetRuntimeArgs(program, unary_reader_kernel_id, all_cores, reader_rt_args); - - vector writer_rt_args = {ntiles_per_core}; - tt::tt_metal::SetRuntimeArgs(program, unary_writer_kernel_id, all_cores, writer_rt_args); - - auto override_runtime_arguments_callback = [reader_kernel_id = unary_reader_kernel_id, - writer_kernel_id = unary_writer_kernel_id, - cb_src0 = cb_src0, - cb_output = cb_output]( - const void* operation, - Program& program, - const std::vector& input_tensors, - const std::vector>&, - const std::vector& output_tensors) { - auto src_buffer = input_tensors.at(0).buffer(); - auto dst_buffer = output_tensors.at(0).buffer(); + const Tensor& a, Tensor& output, const float pad_value); - UpdateDynamicCircularBufferAddress(program, cb_src0, *src_buffer); - UpdateDynamicCircularBufferAddress(program, cb_output, *dst_buffer); - }; - return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_arguments_callback}; -} operation::ProgramWithCallbacks tilize_with_val_padding_multi_core( - const Tensor& a, Tensor& output, const float pad_value) { - if (a.memory_config().is_sharded()) { - return tilize_with_val_padding_multi_core_sharded(a, output, pad_value); - } else { - return tilize_with_val_padding_multi_core_interleaved(a, output, pad_value); - } -} + const Tensor& a, Tensor& output, const float pad_value); } // namespace ttnn::operations::data_movement::detail