diff --git a/ttnn/CMakeLists.txt b/ttnn/CMakeLists.txt index 1ee19c5be45..3f397cf3f55 100644 --- a/ttnn/CMakeLists.txt +++ b/ttnn/CMakeLists.txt @@ -30,9 +30,6 @@ set(ALL_TTNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/core/compute_kernel/compute_kernel_config.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/conv/conv2d/conv2d.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/conv/conv2d/device/conv_op_program_factory.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/conv/conv2d/device/multi_core_optimized_conv/optimized_conv_op.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/conv/conv2d/device/multi_core_optimized_conv_sharded/optimized_conv_op_sharded.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/conv/conv2d/device/multi_core_optimized_conv_sharded/optimized_conv_op_sharded_v2.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/conv/conv2d/device/multi_core_optimized_conv_sharded/optimized_conv_op_width_sharded_v2.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/conv/conv2d/device/optimized_conv_op_program_factory.cpp diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.hpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.hpp index c28fece4d00..fda2dba4297 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.hpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.hpp @@ -16,7 +16,6 @@ #include "tt_metal/common/math.hpp" #include "ttnn/operations/data_movement/pad/pad.hpp" #include "ttnn/operations/conv/conv2d/device/optimized_conv_op.hpp" -#include "ttnn/operations/conv/conv2d/device/conv_op.hpp" #include "ttnn/tensor/tensor.hpp" #include "ttnn/operations/sliding_window/sliding_window.hpp" #include "ttnn/operations/sliding_window/halo/halo.hpp" diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv_op.hpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv_op.hpp deleted file mode 100644 index 2eb8557687c..00000000000 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv_op.hpp +++ /dev/null @@ -1,103 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#pragma once - -#include "ttnn/tensor/tensor.hpp" -#include "ttnn/run_operation.hpp" - -namespace ttnn::operations::conv { -namespace conv2d { - -// TODO: Accept parallelization -enum class ConvOpParallelizationStrategy { - MULTI_CORE, MULTI_CORE_REUSE, MULTI_CORE_REUSE_MCAST, SINGLE_CORE -}; - -struct Conv { - // additional parameters - const std::vector conv_params; - const uint32_t act_block_h_ntiles, act_block_w_ntiles, weight_block_w_ntiles, out_subblock_h_ntiles, out_subblock_w_ntiles, output_channels; - bool use_address_map, use_fast_reader, untilize_out, has_bias, fuse_relu; - MathFidelity math_fidelity; - Conv(uint32_t act_bh, uint32_t act_bw, uint32_t weight_bw, uint32_t out_sh, uint32_t out_sw, const std::vector&c_params, uint32_t output_channels, bool address_map, bool fast_reader, bool untile_out, bool has_bias, bool fuse_relu, MathFidelity mfidelity) - : act_block_h_ntiles(act_bh), - act_block_w_ntiles(act_bw), - weight_block_w_ntiles(weight_bw), - out_subblock_h_ntiles(out_sh), - out_subblock_w_ntiles(out_sw), - output_channels(output_channels), - conv_params(c_params), - use_address_map(address_map), - use_fast_reader(fast_reader), - untilize_out(untile_out), - has_bias(has_bias), - fuse_relu(fuse_relu), - math_fidelity(mfidelity) {} - - void validate(const std::vector& input_tensors, const std::vector>& optional_input_tensors) const; - std::vector compute_output_shapes(const std::vector& input_tensors) const; - std::vector create_output_tensors(const std::vector& input_tensors) const; - operation::ProgramWithCallbacks create_program(const std::vector& input_tensors, const std::vector>& optional_input_tensors, std::vector &output_tensors) const; - - static constexpr auto attribute_names = std::make_tuple( - "conv_params", - "act_block_h_ntiles", - "act_block_w_ntiles", - "weight_block_w_ntiles", - "out_subblock_h_ntiles", - "out_subblock_w_ntiles", - "output_channels", - "use_address_map", - "use_fast_reader", - "untilize_out", - "has_bias", - "fuse_relu", - "math_fidelity"); - const auto attribute_values() const { - return std::make_tuple( - std::cref(this->conv_params), - std::cref(this->act_block_h_ntiles), - std::cref(this->act_block_w_ntiles), - std::cref(this->weight_block_w_ntiles), - std::cref(this->out_subblock_h_ntiles), - std::cref(this->out_subblock_w_ntiles), - std::cref(this->output_channels), - std::cref(this->use_address_map), - std::cref(this->use_fast_reader), - std::cref(this->untilize_out), - std::cref(this->has_bias), - std::cref(this->fuse_relu), - std::cref(this->math_fidelity)); - } -}; - -Tensor conv(const Tensor& a, const Tensor &b, std::optional bias, const vector conv_params, uint32_t act_block_h_ntiles, uint32_t act_block_w_ntiles, uint32_t weight_block_w_ntiles, - uint32_t out_subblock_h_ntiles, uint32_t out_subblock_w_ntiles, uint32_t output_channels, bool has_bias); - -Tensor conv_with_fast_reader(const Tensor& a, const Tensor &b, std::optional bias, const vector conv_params, uint32_t act_block_h_ntiles, uint32_t act_block_w_ntiles, uint32_t weight_block_w_ntiles, - uint32_t out_subblock_h_ntiles, uint32_t out_subblock_w_ntiles, uint32_t output_channels, bool untilize_out, bool has_bias, bool fuse_relu, MathFidelity math_fidelity = MathFidelity::HiFi4); - -operation::ProgramWithCallbacks conv_single_core(const Tensor& A, const Tensor& B, std::optional bias, vector conv_params, uint32_t act_block_h_ntiles, uint32_t act_block_w_ntiles, uint32_t weight_block_w_ntiles, - uint32_t out_subblock_h_ntiles, uint32_t out_subblock_w_ntiles, uint32_t output_channels, bool has_bias, MathFidelity math_fidelity, Tensor& output); // Tilizes a, untilizes b - -Tensor conv_with_address_map(const Tensor& a, const Tensor &b, std::optional bias, const vector conv_params, uint32_t act_block_h_ntiles, uint32_t act_block_w_ntiles, uint32_t weight_block_w_ntiles, - uint32_t out_subblock_h_ntiles, uint32_t out_subblock_w_ntiles, uint32_t output_channels); -operation::ProgramWithCallbacks conv_with_address_map_single_core(const Tensor& A, const Tensor& B, vector conv_params, uint32_t act_block_h_ntiles, uint32_t act_block_w_ntiles, uint32_t weight_block_w_ntiles, - uint32_t out_subblock_h_ntiles, uint32_t out_subblock_w_ntiles, uint32_t output_channels, Tensor& output); // Tilizes a, untilizes b - - -} // namespace tt_metal - -} // namespace tt - -// TODO: Merge with optimized_conv_op_utils? -namespace conv_op_utils { -using namespace tt; -using namespace tt::tt_metal; - -pair compute_conv_output_face_shape(uint32_t conv_activation_h, uint32_t conv_activation_w, uint32_t filter_h, uint32_t filter_w, uint32_t stride_h, uint32_t stride_w, uint32_t pad_h, uint32_t pad_w); -pair, vector> compute_conv_activation_as_mm_shape(Shape conv_activation_shape, vector conv_params, uint32_t act_block_h_ntiles, uint32_t act_block_w_ntiles, bool use_fast_reader); - -} diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv_op_program_factory.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv_op_program_factory.cpp deleted file mode 100644 index 320f717c6cf..00000000000 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv_op_program_factory.cpp +++ /dev/null @@ -1,1475 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include "conv_op.hpp" - -#include "tt_metal/host_api.hpp" -#include "tt_metal/detail/tt_metal.hpp" -#include "tt_metal/common/constants.hpp" - -#include "ttnn/operations/experimental/auto_format/auto_format.hpp" -#include "ttnn/operations/eltwise/unary/common/unary_op_utils.hpp" - -using namespace tt::constants; - -namespace conv_op_utils { -using namespace tt; -using namespace tt::tt_metal; - -pair compute_conv_output_face_shape(uint32_t conv_activation_h, uint32_t conv_activation_w, uint32_t filter_h, uint32_t filter_w, uint32_t stride_h, uint32_t stride_w, uint32_t pad_h, uint32_t pad_w) { - uint32_t conv_output_h = ((conv_activation_h - filter_h + (2 * pad_h)) / stride_h) + 1; - uint32_t conv_output_w = ((conv_activation_w - filter_w + (2 * pad_w)) / stride_w) + 1; - return {conv_output_h, conv_output_w}; -} -pair, vector> compute_conv_activation_as_mm_shape(Shape conv_activation_shape, vector conv_params, uint32_t act_block_h_ntiles, uint32_t act_block_w_ntiles, bool use_fast_reader) { - uint32_t filter_h = (uint32_t) conv_params[0]; - uint32_t filter_w = (uint32_t) conv_params[1]; - uint32_t stride_h = (uint32_t) conv_params[2]; - uint32_t stride_w = (uint32_t) conv_params[3]; - uint32_t pad_h = (uint32_t) conv_params[4]; - uint32_t pad_w = (uint32_t) conv_params[5]; - auto [conv_output_h, conv_output_w] = compute_conv_output_face_shape(conv_activation_shape[1], conv_activation_shape[2], filter_h, filter_w, stride_h, stride_w, pad_h, pad_w); - // pad height - uint32_t num_rows = (uint32_t) conv_output_h*conv_output_w; - uint32_t act_block_h_datums = act_block_h_ntiles * TILE_HEIGHT; - uint32_t num_rows_padded = (uint32_t) (std::ceil((double) num_rows / (double) act_block_h_datums ) * act_block_h_datums); - uint32_t num_cols = conv_activation_shape[3] * filter_h * filter_w; - uint32_t act_block_w_datums = act_block_w_ntiles * TILE_WIDTH; - uint32_t num_cols_padded = (uint32_t) (std::ceil((double) num_cols / (double) act_block_w_datums ) * act_block_w_datums); - if(use_fast_reader) { - assert(act_block_w_datums >= conv_activation_shape[3] * filter_w); - num_cols_padded = act_block_w_datums * filter_h; - } - return {{1, num_rows_padded, num_cols_padded}, {1, num_rows, num_cols}}; -} - -} - -namespace ttnn::operations::conv { -namespace conv2d { - -using namespace tt; - -const uint32_t act_cb = CB::c_in0; -const uint32_t weight_cb = CB::c_in1; -const uint32_t bias_cb = CB::c_in2; -const uint32_t matmul_partials_cb = CB::c_intermed0; -const uint32_t tilize_mode_tilized_act_cb = CB::c_intermed1; -const uint32_t untilize_mode_final_matmul_partials_cb = CB::c_intermed2; -const uint32_t untilize_mode_reblock_cb = CB::c_intermed3; -const uint32_t out0_cb = CB::c_out0; - - -void create_CBs_for_fused_matmul_new_alloc(tt_metal::Program &program, - tt_metal::Device* device, - CoreRange core, - uint32_t num_cb0_tiles, - uint32_t num_cb1_tiles, - uint32_t num_cb0_tilized_tiles, - uint32_t num_output_tiles, - uint32_t num_reblock_cb_tiles, - uint32_t num_writer_output_tiles, - uint32_t num_bytes_for_df, - bool untilize_out, - uint32_t bias_ntiles = 0, - bool with_bias = false) { - - uint32_t single_tile_size = num_bytes_for_df * 1024; - - // Invariants - CircularBufferConfig cb_act_config = CircularBufferConfig(num_cb0_tiles * single_tile_size, {{act_cb, tt::DataFormat::Float16_b}}) - .set_page_size(act_cb, single_tile_size); - auto cb_act = tt_metal::CreateCircularBuffer(program, core, cb_act_config); - - CircularBufferConfig cb_weight_config = CircularBufferConfig(num_cb1_tiles * single_tile_size, {{weight_cb, tt::DataFormat::Float16_b}}) - .set_page_size(weight_cb, single_tile_size); - auto cb_weight = tt_metal::CreateCircularBuffer(program, core, cb_weight_config); - - // Used for placing tilized activations - CircularBufferConfig cb_src0_tilized_config = CircularBufferConfig(num_cb0_tilized_tiles * single_tile_size, {{tilize_mode_tilized_act_cb, tt::DataFormat::Float16_b}}) - .set_page_size(tilize_mode_tilized_act_cb, single_tile_size); - auto cb_src0_tilized = CreateCircularBuffer(program, core, cb_src0_tilized_config); - - if (untilize_out) { - CircularBufferConfig cb_matmul_partials_config = CircularBufferConfig(num_output_tiles * single_tile_size, {{matmul_partials_cb, tt::DataFormat::Float16_b}}) - .set_page_size(matmul_partials_cb, single_tile_size); - auto cb_matmul_partials = CreateCircularBuffer(program, core, cb_matmul_partials_config); - - // Shares same address space as matmul partials - CircularBufferConfig cb_final_matmul_partials_config = CircularBufferConfig(num_output_tiles * single_tile_size, {{untilize_mode_final_matmul_partials_cb, tt::DataFormat::Float16_b}}) - .set_page_size(untilize_mode_final_matmul_partials_cb, single_tile_size); - auto cb_final_matmul_partials = CreateCircularBuffer(program, core, cb_final_matmul_partials_config); - - // Supposed to be a small CB only responsible for reorganizing - // the output blocks to fill the whole "per core output block width" - CircularBufferConfig cb_reblock_config = CircularBufferConfig(num_reblock_cb_tiles * single_tile_size, {{untilize_mode_reblock_cb, tt::DataFormat::Float16_b}}) - .set_page_size(untilize_mode_reblock_cb, single_tile_size); - auto cb_reblock = CreateCircularBuffer(program, core, cb_reblock_config); - - CircularBufferConfig cb_output_config = CircularBufferConfig(num_writer_output_tiles * single_tile_size, {{out0_cb, tt::DataFormat::Float16_b}}) - .set_page_size(out0_cb, single_tile_size); - auto cb_output = CreateCircularBuffer(program, core, cb_output_config); - } else { - CoreRangeSet cores(std::set({core})); - std::map cb_output_data_format_spec = { - {out0_cb, tt::DataFormat::Float16_b}, - {matmul_partials_cb, tt::DataFormat::Float16_b} - }; - CircularBufferConfig cb_matmul_partials_config = CircularBufferConfig(num_output_tiles * single_tile_size, cb_output_data_format_spec) - .set_page_size(out0_cb, single_tile_size) - .set_page_size(matmul_partials_cb, single_tile_size); - auto cb_output = CreateCircularBuffer(program, core, cb_matmul_partials_config); - } - - if (with_bias) { - // bias input - uint32_t bias_pagesize = single_tile_size; - CircularBufferConfig cb_bias_config = CircularBufferConfig(bias_ntiles * bias_pagesize, {{bias_cb, tt::DataFormat::Float16_b}}) - .set_page_size(bias_cb, bias_pagesize); - auto cb_bias = CreateCircularBuffer(program, core, cb_bias_config); - - log_debug("BIAS CBs: {} {} {}", bias_cb, bias_ntiles, bias_pagesize); - } -} - -operation::ProgramWithCallbacks conv_as_large_bmm_single_core_(const Tensor& a, const Tensor &b, std::optional bias, vector conv_params, - uint32_t act_block_h_ntiles, uint32_t act_block_w_ntiles, uint32_t weight_block_w_ntiles, - uint32_t out_subblock_h_ntiles, uint32_t out_subblock_w_ntiles, uint32_t output_channels, bool use_fast_reader, bool untilize_out, bool has_bias, bool fuse_relu, const MathFidelity math_fidelity, Tensor &output) { - bool pass = true; - tt_metal::Device *device = a.device(); - TT_ASSERT(a.get_layout() == Layout::ROW_MAJOR, "Conv activation should be in row major layout"); - uint32_t act_batch_size = a.get_legacy_shape()[0]; - TT_ASSERT(act_batch_size == 1, "Only batch size 1 supported."); - TT_ASSERT(output_channels <= b.get_legacy_shape()[3], "Invalid weight shape. Incorrect weight tensor."); - uint32_t num_bytes_of_df = 2; // 2 bytes for bfloat16 - // Compute the 2d matrix shape - auto [act_matrix_shape, act_matrix_shape_unpadded] = conv_op_utils::compute_conv_activation_as_mm_shape(a.get_legacy_shape(), conv_params, act_block_h_ntiles, act_block_w_ntiles, use_fast_reader); - assert(act_matrix_shape.size() == 3); - assert(act_matrix_shape[0] == 1); - uint32_t act_matrix_height = (uint32_t) act_matrix_shape[1]; - uint32_t act_matrix_width = (uint32_t) act_matrix_shape[2]; - - // Tensor b has weights and it should be tiled layout after converting conv weights into weight matrix - TT_ASSERT(b.get_layout() == Layout::TILE, "Conv weights should be in tiled layout"); - TT_ASSERT(b.get_legacy_shape()[0] == 1, "Conv weight matrix shape is invalid"); - TT_ASSERT(b.get_legacy_shape()[1] == 1, "Conv weight matrix shape is invalid"); - uint32_t weight_matrix_height = b.get_legacy_shape()[2]; - uint32_t weight_matrix_width = b.get_legacy_shape()[3]; - - if (has_bias) { - // Tensor bias is of shape {output_channels} - TT_ASSERT(bias.has_value()); - TT_ASSERT(bias.value().buffer() != nullptr); - auto bias_shape_without_padding = bias.value().get_legacy_shape().without_padding(); - TT_ASSERT(bias_shape_without_padding[0] == 1, "Bias should have batch == 1"); - TT_ASSERT(bias_shape_without_padding[1] == 1 && bias_shape_without_padding[2] == 1, "Bias should have H == W == 1"); - TT_ASSERT(bias_shape_without_padding[3] == output_channels, "Bias should have output_channels"); - } - - // Normal matrix shape check - TT_ASSERT(act_matrix_width == weight_matrix_height, "The width of tensor a needs to match the height of tensor b"); - - // Tile size divisibility checks - TT_ASSERT(act_matrix_height % TILE_HEIGHT == 0, "Height of activation matrix needs to be divisible by 32"); - TT_ASSERT(act_matrix_width % TILE_WIDTH == 0, "Width of activation matrix needs to be divisible by 32"); - TT_ASSERT(weight_matrix_height % TILE_HEIGHT == 0, "Height of weight matrix needs to be divisible by 32"); - TT_ASSERT(weight_matrix_width % TILE_WIDTH == 0, "Width of weight matrix needs to be divisible by 32"); - - // Device compatibility checks - TT_ASSERT(a.storage_type() == StorageType::DEVICE && - b.storage_type() == StorageType::DEVICE && - "Operands to large matmul need to be on device!"); - TT_ASSERT(a.device() == b.device(), "Operands to conv need to be on the same device!"); - TT_ASSERT(a.buffer() != nullptr && b.buffer() != nullptr, "Operands to conv need to be allocated in buffers on device!"); - if (has_bias) { - TT_ASSERT(bias.value().storage_type() == StorageType::DEVICE, "Bias should be on device"); - TT_ASSERT(bias.value().device() == a.device(), "Bias should be on the same device as act tensor"); - } - - // Convert tensor dims to tile dims - uint32_t act_matrix_height_ntiles = act_matrix_height / TILE_HEIGHT; - uint32_t act_matrix_width_ntiles = act_matrix_width / TILE_WIDTH; - uint32_t weight_matrix_height_ntiles = weight_matrix_height / TILE_HEIGHT; - uint32_t weight_matrix_width_ntiles = weight_matrix_width / TILE_WIDTH; - - assert(act_matrix_height_ntiles % act_block_h_ntiles == 0); - assert(act_matrix_width_ntiles % act_block_w_ntiles == 0); - assert(weight_matrix_width_ntiles % weight_block_w_ntiles == 0); - - uint32_t num_blocks_act_h = act_matrix_height_ntiles / act_block_h_ntiles; - uint32_t num_blocks_act_w = act_matrix_width_ntiles / act_block_w_ntiles; - uint32_t num_blocks_weight_w = weight_matrix_width_ntiles / weight_block_w_ntiles; - - // act block info - uint32_t act_block_w_datums = act_matrix_width / num_blocks_act_w; - uint32_t act_block_h_datums = act_matrix_height / num_blocks_act_h; - - // weight block info - uint32_t weight_block_w_datums = weight_matrix_width / num_blocks_weight_w; - assert(weight_block_w_ntiles % out_subblock_w_ntiles == 0); - uint32_t weight_num_subblocks = weight_block_w_ntiles / out_subblock_w_ntiles; - uint32_t weight_block_h_ntiles = act_block_w_ntiles; - uint32_t weight_block_num_tiles = weight_block_w_ntiles * weight_block_h_ntiles; - - uint32_t num_groups = num_blocks_act_h * num_blocks_act_w * num_blocks_weight_w; - // writer of conv op partially removes padding on the width - // it removes the padding done for block width but it doesn't remove padding done for tiled width - uint32_t output_channels_padded_to_tile_width = round_up(output_channels, TILE_WIDTH); - assert(output_channels_padded_to_tile_width <= weight_matrix_width); - uint32_t output_width_num_tiles = output_channels_padded_to_tile_width / TILE_WIDTH; - uint32_t num_blocks_output_w = (uint32_t) std::ceil((double) output_channels_padded_to_tile_width / (double) weight_block_w_datums); - uint32_t last_block_width_datums = (output_channels_padded_to_tile_width % weight_block_w_datums == 0) ? weight_block_w_datums : (output_channels_padded_to_tile_width % weight_block_w_datums); - assert(last_block_width_datums % TILE_WIDTH == 0); - uint32_t output_row_size_bytes = output_channels_padded_to_tile_width * num_bytes_of_df; - uint32_t last_block_row_size_bytes = last_block_width_datums * num_bytes_of_df; - // sanity check - assert(num_blocks_output_w == num_blocks_weight_w); - - tt_metal::Program program = tt_metal::CreateProgram(); - CoreCoord core_coord = {0, 0}; // TODO: avoid another var here. Find a way to use core range instead. - CoreRange core({0, 0}, {0, 0}); - - uint32_t single_tile_size = num_bytes_of_df * TILE_HEIGHT * TILE_WIDTH; - tt_metal::Buffer *src0_dram_buffer = a.buffer(); - tt_metal::Buffer *src1_dram_buffer = b.buffer(); - TT_ASSERT(src1_dram_buffer->size() % single_tile_size == 0, "Buffer size of tensor b must be divisible by single_tile_size (aka divisible by sizeof(df) * 1024)"); - - tt_metal::Buffer *dst_dram_buffer = output.buffer(); - TT_ASSERT(dst_dram_buffer != nullptr, "Output buffer should be allocated on device!"); - - // out - uint32_t out_dram_addr = dst_dram_buffer->address(); - uint32_t out_row_size = weight_matrix_width * num_bytes_of_df; - uint32_t out_subblock_num_tiles = out_subblock_h_ntiles * out_subblock_w_ntiles; - TT_ASSERT(out_subblock_num_tiles <= 8, "Need to ensure that matmul partials fit in dst"); - - // act - uint32_t act_dram_addr = src0_dram_buffer->address(); - auto act_dram_noc_xy = src0_dram_buffer->noc_coordinates(); - uint32_t act_noc_x = act_dram_noc_xy.x; - uint32_t act_noc_y = act_dram_noc_xy.y; - - assert(act_matrix_width_ntiles % act_block_w_ntiles == 0); - assert(act_block_h_ntiles % out_subblock_h_ntiles == 0); - uint32_t act_num_subblocks = act_block_h_ntiles / out_subblock_h_ntiles; - uint32_t act_block_num_tiles = act_block_h_ntiles * act_block_w_ntiles; - uint32_t act_subblock_h_ntiles = out_subblock_h_ntiles; - uint32_t act_subblock_num_tiles = act_subblock_h_ntiles * act_block_w_ntiles; - - // weight - uint32_t weight_dram_addr = src1_dram_buffer->address(); - auto weight_dram_noc_xy = src1_dram_buffer->noc_coordinates(); - uint32_t weight_noc_x = weight_dram_noc_xy.x; - uint32_t weight_noc_y = weight_dram_noc_xy.y; - - // bias - Buffer *bias_buffer = nullptr; - uint32_t bias_dram_addr = 0; - uint32_t bias_ntiles = 0, bias_tile_nbytes = 0, bias_log2_of_pagesize = 0; - if (has_bias) { - bias_buffer = bias.value().buffer(); - bias_dram_addr = bias_buffer->address(); - bias_ntiles = bias.value().get_legacy_shape()[3] / constants::TILE_WIDTH; // TODO: support non tile multiple sizes - bias_tile_nbytes = single_tile_size; - bias_log2_of_pagesize = (uint32_t) std::log2((float) bias_tile_nbytes); - } - - // more args for reader - uint32_t conv_act_size_h = a.get_legacy_shape()[1]; - uint32_t conv_act_size_w = a.get_legacy_shape()[2]; - uint32_t conv_act_size_c = a.get_legacy_shape()[3]; - uint32_t weight_size_h = (uint32_t) conv_params[0]; - uint32_t weight_size_w = (uint32_t) conv_params[1]; - uint32_t stride_h = (uint32_t) conv_params[2]; - uint32_t stride_w = (uint32_t) conv_params[3]; - uint32_t pad_h = (uint32_t) conv_params[4]; - uint32_t pad_w = (uint32_t) conv_params[5]; - uint32_t conv_output_size_h = ((conv_act_size_h - weight_size_h + (2 * pad_h)) / stride_h) + 1; - uint32_t conv_output_size_w = ((conv_act_size_w - weight_size_w + (2 * pad_w)) / stride_w) + 1; - std::map reader_defines; - if (use_fast_reader) { - if(conv_act_size_c * weight_size_w != act_block_w_datums) { - assert(act_block_w_datums > conv_act_size_c * weight_size_w); - uint32_t conv_act_block_width_padding_bytes = (act_block_w_datums - (conv_act_size_c * weight_size_w)) * num_bytes_of_df; - reader_defines["ACT_BLOCK_WIDTH_PADDING_BYTES"] = std::to_string(conv_act_block_width_padding_bytes); - } - if (conv_output_size_h * conv_output_size_w < act_block_h_datums * num_blocks_act_h) { - reader_defines["ACT_BLOCK_HEIGHT_PADDING"] = "1"; - } - } - uint32_t output_height_padded_to_tile_height = round_up(conv_output_size_h*conv_output_size_w, TILE_HEIGHT); - uint32_t output_height_num_tiles = output_height_padded_to_tile_height / TILE_HEIGHT; - assert(output_height_num_tiles <= act_matrix_height_ntiles); - - uint32_t act_matrix_height_unpadded = conv_output_size_h * conv_output_size_w; - uint32_t act_matrix_width_unpadded = conv_act_size_c * weight_size_h * weight_size_w; - uint32_t src_dram_act_buffer_size_bytes = src0_dram_buffer->size(); - uint32_t src_dram_weight_buffer_size_bytes = src1_dram_buffer->size(); - uint32_t dst_l1_act_buffer_size_bytes = act_block_h_ntiles * act_block_w_ntiles * single_tile_size; - uint32_t dst_l1_weight_buffer_size_bytes = weight_block_h_ntiles * weight_block_w_ntiles * single_tile_size; - - // more args for writer - uint32_t out_block_row_size_bytes = weight_block_w_ntiles*TILE_WIDTH*num_bytes_of_df; - uint32_t out_row_size_bytes = output_channels_padded_to_tile_width*num_bytes_of_df; - uint32_t batch_size = 1; - // output data format - const auto out_df = datatype_to_dataformat_converter(a.get_dtype()); - // For debug - { - log_debug(tt::LogOp, "act_matrix_height_ntiles: {}", act_matrix_height_ntiles); - log_debug(tt::LogOp, "act_matrix_width_ntiles: {}", act_matrix_width_ntiles); - log_debug(tt::LogOp, "weight_matrix_width_ntiles: {}", weight_matrix_width_ntiles); - log_debug(tt::LogOp, "num_blocks_act_h: {}", num_blocks_act_h); - log_debug(tt::LogOp, "num_blocks_act_w: {}", num_blocks_act_w); - log_debug(tt::LogOp, "num_blocks_weight_w: {}", num_blocks_weight_w); - log_debug(tt::LogOp, "act_dram_addr: {}", act_dram_addr); - log_debug(tt::LogOp, "act_block_h_ntiles: {}", act_block_h_ntiles); - log_debug(tt::LogOp, "act_block_h_datums: {}", act_block_h_datums); - log_debug(tt::LogOp, "act_block_w_ntiles: {}", act_block_w_ntiles); - log_debug(tt::LogOp, "act_block_w_datums: {}", act_block_w_datums); - log_debug(tt::LogOp, "act_num_subblocks: {}", act_num_subblocks); - log_debug(tt::LogOp, "act_block_num_tiles: {}", act_block_num_tiles); - log_debug(tt::LogOp, "act_subblock_h_ntiles: {}", act_subblock_h_ntiles); - log_debug(tt::LogOp, "act_subblock_num_tiles: {}", act_subblock_num_tiles); - log_debug(tt::LogOp, "weight_dram_addr: {}", weight_dram_addr); - log_debug(tt::LogOp, "weight_num_subblocks: {}", weight_num_subblocks); - log_debug(tt::LogOp, "weight_block_num_tiles: {}", weight_block_num_tiles); - log_debug(tt::LogOp, "weight_block_w_ntiles: {}", weight_block_w_ntiles); - log_debug(tt::LogOp, "weight_block_h_ntiles: {}", weight_block_h_ntiles); - log_debug(tt::LogOp, "has_bias: {}", has_bias); - log_debug(tt::LogOp, "bias_dram_addr: {}", bias_dram_addr); - log_debug(tt::LogOp, "bias_ntiles: {}", bias_ntiles); - log_debug(tt::LogOp, "out_dram_addr: {}", out_dram_addr); - log_debug(tt::LogOp, "out_row_size: {}", out_row_size); - log_debug(tt::LogOp, "out_subblock_h_ntiles: {}", out_subblock_h_ntiles); - log_debug(tt::LogOp, "out_subblock_w_ntiles: {}", out_subblock_w_ntiles); - log_debug(tt::LogOp, "out_subblock_num_tiles: {}", out_subblock_num_tiles); - log_debug(tt::LogOp, "num_groups: {}", num_groups); - } - - bool rn50_first_conv = (conv_act_size_h == 230 && conv_act_size_w == 231 && - conv_output_size_h == 112 && conv_output_size_w == 112 && - weight_size_h == 7 && weight_size_w == 8 && - stride_h == 2 && stride_w == 2 && - num_blocks_weight_w == 1); - - uint32_t num_weight_tiles_in_cb = weight_block_h_ntiles * weight_block_w_ntiles; - if (rn50_first_conv) { - num_weight_tiles_in_cb = weight_block_h_ntiles * weight_block_w_ntiles * num_blocks_weight_w * num_blocks_act_w; - } - create_CBs_for_fused_matmul_new_alloc( - program, - a.device(), - core, - act_block_h_ntiles * act_block_w_ntiles * 2, // row major act cb, double bufferred - num_weight_tiles_in_cb, // tiled weight cb - act_block_h_ntiles * act_block_w_ntiles, // tiled act cb - act_block_h_ntiles * weight_block_w_ntiles, // math output cb - weight_block_w_ntiles, // reblock cb - act_block_h_ntiles * weight_block_w_ntiles * 2, // writer output cb, double bufferred - num_bytes_of_df, - untilize_out, - bias_ntiles, - has_bias); - - // define for bias - std::map all_defines; - std::map compute_defines; - if (has_bias) { - all_defines["FUSE_BIAS"] = "1"; - compute_defines["FUSE_BIAS"] = "1"; - } - - if (fuse_relu) { - using ttnn::operations::unary::UnaryOpType; - using ttnn::operations::unary::utils::get_defines; - compute_defines.merge(get_defines(UnaryOpType::RELU, std::nullopt, "ACTIVATION", "i")); - if (has_bias) { - compute_defines["FUSE_BIAS"] = "1"; - } - } - - string reader_kernel; - vector reader_rt_args; - std::vector reader_compile_time_args; - string writer_kernel; - string compute_kernel; - if (use_fast_reader) { - TT_ASSERT(!(conv_act_size_c & (conv_act_size_c - 1))); // channel depth power of 2 is supported only - TT_ASSERT(!(out_row_size_bytes & (out_row_size_bytes - 1))); // output channels power of 2 is supported only - if (pad_h == 0 && pad_w == 0) { - if(rn50_first_conv) { - reader_kernel = "ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_conv_activations_fast_resnet50_first_conv.cpp"; - compute_kernel = "ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/bmm_tilize_untilize_all_weights_in_l1_single_output_block_width_dim.cpp"; - } else { - reader_kernel = "ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_conv_activations_fast_without_conv_padding.cpp"; - compute_kernel = "tt_eager/tt_dnn/kernels/compute/bmm_tilize_untilize.cpp"; - } - } else { - reader_kernel = "ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_conv_activations_fast.cpp"; - compute_kernel = "tt_eager/tt_dnn/kernels/compute/bmm_tilize_untilize.cpp"; - } - reader_compile_time_args = {(uint32_t) (src0_dram_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0), - (uint32_t) stride_h, (uint32_t) stride_w, (uint32_t) conv_act_size_w, (uint32_t) conv_output_size_w, - (uint32_t) conv_act_size_c * num_bytes_of_df, (uint32_t) std::log2(conv_act_size_c * num_bytes_of_df)}; - } else { - reader_kernel = "ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_conv_activations.cpp"; - reader_compile_time_args = {(uint32_t) (src0_dram_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0)}; - compute_kernel = "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/compute/bmm_tilize_untilize.cpp"; - } - if (use_fast_reader && rn50_first_conv) { - assert(pad_h == 0 && pad_w == 0); - reader_rt_args = { - act_dram_addr, - conv_act_size_c, - conv_output_size_w, - weight_size_w, - num_blocks_act_h, - num_blocks_act_w, - act_block_h_datums, - act_block_num_tiles - }; - } else { - reader_rt_args = { - // arguments for act - act_dram_addr, - act_noc_x, - act_noc_y, - - conv_act_size_w, - conv_act_size_h, - conv_act_size_c, - weight_size_h, - weight_size_w, - stride_h, - stride_w, - pad_h, - pad_w, - conv_output_size_h, - conv_output_size_w, - num_blocks_act_h, - num_blocks_act_w, - num_blocks_weight_w, - num_groups, - - act_matrix_height_unpadded, - act_matrix_width_unpadded, - act_matrix_height, - act_matrix_width, - act_matrix_height_ntiles, - act_matrix_width_ntiles, - act_block_h_datums, - act_block_w_datums, - act_block_h_ntiles, - act_block_w_ntiles, - act_block_num_tiles, - - src_dram_act_buffer_size_bytes, - dst_l1_act_buffer_size_bytes, - }; - } - - vector writer_rt_args; - std::vector writer_compile_time_args; - if (untilize_out) { - if (rn50_first_conv) { - writer_kernel = "ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_and_reader_weights_resnet50_first_conv_untilize_out.cpp"; - } else if (use_fast_reader) { - writer_kernel = "ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_unary_stick_8bank_blocks_reader_weight_tile_with_pow2_addr_gen_fast.cpp"; - } else { - writer_kernel = "ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_unary_stick_layout_8bank_blocks_reader_weight_tile_layout.cpp"; - } - writer_compile_time_args = {(uint32_t) (src0_dram_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0), out0_cb, weight_cb, (uint32_t) std::log2(out_row_size_bytes)}; - writer_rt_args = { - out_dram_addr, - weight_dram_addr, - - act_block_h_datums, - out_block_row_size_bytes, - 1, - num_blocks_act_h, - num_blocks_weight_w, - out_row_size_bytes, - last_block_row_size_bytes, - act_matrix_height_unpadded, - - num_blocks_act_w, // = number of blocks of weight in height dim - weight_block_num_tiles, - weight_block_h_ntiles, - weight_block_w_ntiles, - weight_matrix_width_ntiles, // weight_stride_h - weight_matrix_width_ntiles * weight_block_h_ntiles, // weight_next_block_stride_h, - weight_block_w_ntiles, // weight_next_block_stride_w - - }; - } else { - assert(use_fast_reader); // tiled out not tested for generic conv - if (rn50_first_conv) { - writer_kernel = "ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_and_reader_weights_resnet50_first_conv_tiled_out.cpp"; - } else { - writer_kernel = "ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_reader_conv_weights_tiled.cpp"; - } - writer_compile_time_args = { - (uint32_t) (src0_dram_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0), - out0_cb, - weight_cb, - bias_cb, - bias_log2_of_pagesize, - bias_tile_nbytes, - (uint32_t) (bias_buffer == nullptr ? 0 : (bias_buffer->buffer_type() == BufferType::DRAM ? 1 : 0))}; - writer_rt_args = { - out_dram_addr, - weight_dram_addr, - - output_width_num_tiles, // out_next_tile_stride_h - 1, // out_next_tile_stride_w - out_subblock_h_ntiles * output_width_num_tiles, // out_next_subblock_stride_h - out_subblock_w_ntiles, // out_next_subblock_stride_w - act_block_h_ntiles * output_width_num_tiles, // out_next_block_stride_h - weight_block_w_ntiles, // out_next_block_stride_w - out_subblock_h_ntiles, - out_subblock_w_ntiles, - out_subblock_num_tiles, - act_block_h_ntiles / out_subblock_h_ntiles, // out_num_subblocks_h - weight_block_w_ntiles / out_subblock_w_ntiles, // out_num_subblocks_w - num_blocks_act_h, // out_num_blocks_h - num_blocks_weight_w, // out_num_blocks_w - act_block_h_ntiles, // out_block_height_num_tiles - output_height_num_tiles, // out_height_num_tiles without block shape padding - output_width_num_tiles, // out_width_num_tiles withoug block shape padding - - num_blocks_act_w, // = number of blocks of weight in height dim - weight_block_num_tiles, - weight_block_h_ntiles, - weight_block_w_ntiles, - weight_matrix_width_ntiles, // weight_stride_h - weight_matrix_width_ntiles * weight_block_h_ntiles, // weight_next_block_stride_h, - weight_block_w_ntiles, // weight_next_block_stride_w - - // bias - bias_dram_addr, - bias_ntiles - }; - } - tt::DataFormat cb_data_format = datatype_to_dataformat_converter(a.get_dtype()); - auto reader_id = CreateKernel( - program, - reader_kernel, - core, - ReaderDataMovementConfig( - reader_compile_time_args, - reader_defines)); - - auto writer_id = CreateKernel( - program, - writer_kernel, - core, - WriterDataMovementConfig( - writer_compile_time_args, - all_defines)); - - vector compute_kernel_args = { - act_block_w_ntiles, - act_num_subblocks, - act_block_num_tiles, - act_subblock_num_tiles, - act_subblock_h_ntiles, - - weight_num_subblocks, - weight_block_num_tiles, - weight_block_w_ntiles, - - num_blocks_act_h, - num_blocks_act_w, - num_blocks_weight_w, - - out_subblock_h_ntiles, - out_subblock_w_ntiles, - out_subblock_num_tiles, - - true, - untilize_out, - - bias_ntiles - }; - - auto compute = CreateKernel( - program, - compute_kernel, - core, - ComputeConfig{ - .math_fidelity = math_fidelity, - .compile_args = compute_kernel_args, - .defines = compute_defines}); - - SetRuntimeArgs( - program, reader_id, core, - reader_rt_args - ); - - SetRuntimeArgs( - program, writer_id, core, - writer_rt_args - ); - - auto override_runtime_args_callback = [ - reader_kernel_id=reader_id, - writer_kernel_id=writer_id, - has_bias=has_bias - ] - ( - const Program &program, - const std::vector& input_buffers, - const std::vector& output_buffers - ) { - - TT_ASSERT(input_buffers.size() == 3); - TT_ASSERT(output_buffers.size() == 1); - - auto src_dram_buffer_a = input_buffers.at(0); - auto src_dram_buffer_b = input_buffers.at(1); - - auto dst_dram_buffer = output_buffers.at(0); - - CoreCoord core = {0, 0}; - - { - auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); - runtime_args[0] = src_dram_buffer_a->address(); - } - - { - auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); - runtime_args[0] = dst_dram_buffer->address(); - runtime_args[1] = src_dram_buffer_b->address(); - if (has_bias) { - auto src_dram_buffer_c = input_buffers.at(2); - TT_ASSERT(src_dram_buffer_c != nullptr); - runtime_args[25] = src_dram_buffer_c->address(); - } - } - }; - - return {std::move(program), override_runtime_args_callback}; - -} - -// generates address map for reader kernel which reads from dram buffer (tiled layout) into l1 buffer -std::pair, vector> generate_conv_weight_address_map( - const Shape& weight_shape, - uint32_t weight_block_h_datums, - uint32_t weight_block_w_datums, - uint32_t num_blocks_act_h, - uint32_t num_blocks_weight_h, - uint32_t num_blocks_weight_w, - uint32_t num_bytes_df) { - vector address_map; - vector address_map_metadata; - assert(weight_shape[0] == 1 && weight_shape[1] == 1); - uint32_t matrix_height = weight_shape[2]; - uint32_t matrix_width = weight_shape[3]; - assert(matrix_height % weight_block_h_datums == 0); - assert(matrix_width % weight_block_w_datums == 0); - uint32_t src_dram_buffer_size_bytes = matrix_height * matrix_width * num_bytes_df; - uint32_t dst_l1_buffer_size_bytes = weight_block_h_datums * weight_block_w_datums * num_bytes_df; - uint32_t num_groups = num_blocks_act_h * num_blocks_weight_h * num_blocks_weight_w; - assert(matrix_height % TILE_HEIGHT == 0); - uint32_t matrix_height_ntiles = matrix_height / TILE_HEIGHT; - assert(matrix_width % TILE_WIDTH == 0); - uint32_t matrix_width_ntiles = matrix_width / TILE_WIDTH; - assert(matrix_height_ntiles % num_blocks_weight_h == 0); - uint32_t block_height_ntiles = matrix_height_ntiles / num_blocks_weight_h; - assert(matrix_width_ntiles % num_blocks_weight_w == 0); - uint32_t block_width_ntiles = matrix_width_ntiles / num_blocks_weight_w; - uint32_t matrix_size_ntiles = matrix_height_ntiles * matrix_width_ntiles; - assert(weight_block_h_datums % TILE_HEIGHT == 0); - assert(weight_block_w_datums % TILE_WIDTH == 0); - assert(block_height_ntiles == weight_block_h_datums / TILE_HEIGHT); - assert(block_width_ntiles == weight_block_w_datums / TILE_WIDTH); - address_map_metadata.push_back(num_groups); - uint32_t address_map_current_group_dram_address_offset = 0; - for(uint32_t group_idx = 0; group_idx < num_groups; group_idx++) { - // Weight blocks are col major - uint32_t block_idx_h = (uint32_t) (group_idx % num_blocks_weight_h); - uint32_t block_idx_w = (uint32_t) (group_idx / num_blocks_weight_h) % (num_blocks_weight_w); - uint32_t block_idx = (block_idx_w * num_blocks_weight_h) + block_idx_h; - uint32_t start_block_tile_h_index = block_idx_h * block_height_ntiles; - uint32_t start_block_tile_w_index = block_idx_w * block_width_ntiles; - uint32_t single_tile_size_bytes = TILE_HEIGHT * TILE_WIDTH * num_bytes_df; - uint32_t address_map_current_group_size = 0; - // Weight tiles are in row major order within block - for(uint32_t tile_h_index_in_block = 0; tile_h_index_in_block < block_height_ntiles; tile_h_index_in_block++) { - for(uint32_t tile_w_index_in_block = 0; tile_w_index_in_block < block_width_ntiles; tile_w_index_in_block++) { - uint32_t tile_index_h_in_matrix = tile_h_index_in_block + start_block_tile_h_index; - uint32_t tile_index_w_in_matrix = tile_w_index_in_block + start_block_tile_w_index; - // Weight tiles are in row major order in weight matrix in dram - uint32_t tile_index_in_matrix = (tile_index_h_in_matrix * block_width_ntiles * num_blocks_weight_w) + tile_index_w_in_matrix; - assert(tile_index_in_matrix < matrix_size_ntiles); - // Weight tiles are in row major order in weight block in l1 - uint32_t tile_index_in_block = tile_h_index_in_block * block_width_ntiles + tile_w_index_in_block; - uint32_t src_address_offset_dram = tile_index_in_matrix * single_tile_size_bytes; - uint32_t read_size_bytes = single_tile_size_bytes; - uint32_t dst_address_offset_l1 = tile_index_in_block * single_tile_size_bytes; - uint32_t pad = 0; - assert(read_size_bytes > 0); - assert(pad == 0 || pad == 1); - assert(src_address_offset_dram < src_dram_buffer_size_bytes); - assert(dst_address_offset_l1 < dst_l1_buffer_size_bytes); - address_map.push_back(src_address_offset_dram); - address_map.push_back(dst_address_offset_l1); - address_map.push_back(read_size_bytes); - address_map.push_back(pad); - address_map_current_group_size += 4; - } - } - // DRAM reads should be 32B aligned - assert(address_map_current_group_dram_address_offset%32 == 0); - address_map_metadata.push_back(address_map_current_group_dram_address_offset); - address_map_metadata.push_back(address_map_current_group_size); - // Pad 0s in address map buffer to ensure each read address is 32B aligned (32/sizeof(uint32_t) == 8 elements) - uint32_t address_map_current_group_size_padded = (uint32_t) (std::ceil((double) address_map_current_group_size / (double) 8) * 8); - if(address_map_current_group_size_padded != address_map_current_group_size) { - assert(address_map_current_group_size_padded > address_map_current_group_size); - address_map.insert(address_map.end(), address_map_current_group_size_padded - address_map_current_group_size, 0); - } - // update next group's dram read address offset (in bytes) - address_map_current_group_dram_address_offset += (address_map_current_group_size_padded*sizeof(uint32_t)); - } - return make_pair(std::move(address_map), std::move(address_map_metadata)); -} - -std::pair, vector> generate_conv_activation_address_map( - const Shape& activation_shape, - const vector& conv_params, - uint32_t act_block_h_datums, - uint32_t act_block_w_datums, - uint32_t weight_block_w_datums, - uint32_t num_blocks_act_h, - uint32_t num_blocks_act_w, - uint32_t num_blocks_weight_w, - uint32_t num_bytes_df) { - vector address_map; - vector address_map_metadata; - uint32_t conv_input_y = activation_shape[1]; - uint32_t conv_input_x = activation_shape[2]; - uint32_t conv_input_z = activation_shape[3]; - uint32_t R = conv_params[0]; - uint32_t S = conv_params[1]; - uint32_t U = conv_params[2]; - uint32_t V = conv_params[3]; - uint32_t Pad_H = conv_params[4]; - uint32_t Pad_W = conv_params[5]; - uint32_t src_dram_buffer_size_bytes = conv_input_x * conv_input_y * conv_input_z * num_bytes_df; - uint32_t dst_l1_buffer_size_bytes = act_block_h_datums * act_block_w_datums * num_bytes_df; - int conv_output_h = ((conv_input_x - R + (2 * Pad_H)) / U) + 1; - int conv_output_w = ((conv_input_y - S + (2 * Pad_W)) / V) + 1; - uint32_t matrix_height_unpadded = conv_output_h * conv_output_w; - uint32_t matrix_width_unpadded = conv_input_z * R * S; - uint32_t matrix_height = (uint32_t) (std::ceil((double) matrix_height_unpadded / (double) act_block_h_datums ) * act_block_h_datums); - uint32_t matrix_width = (uint32_t) (std::ceil((double) matrix_width_unpadded / (double) act_block_w_datums ) * act_block_w_datums); - - uint32_t num_groups = num_blocks_act_h * num_blocks_act_w * num_blocks_weight_w; - uint32_t channel_stick_size = conv_input_z; - uint32_t address_map_current_group_dram_address_offset = 0; - address_map_metadata.push_back(num_groups); - for(uint32_t group_idx = 0; group_idx < num_groups; group_idx++) { - uint32_t block_idx_h = (uint32_t) (group_idx / num_blocks_act_w) / (num_blocks_weight_w); - uint32_t block_idx_w = (uint32_t) (group_idx % num_blocks_act_w); - uint32_t block_idx = (block_idx_h * num_blocks_act_w) + block_idx_w; - uint32_t start_block_2d_index_h = block_idx_h * act_block_h_datums; - uint32_t start_block_2d_index_w = block_idx_w * act_block_w_datums; - uint32_t start_block_2d_index = (start_block_2d_index_h * act_block_w_datums * num_blocks_act_w) + start_block_2d_index_w; - assert(start_block_2d_index_w < matrix_width_unpadded); - uint32_t address_map_current_group_size = 0; - for(uint32_t h_b = 0; h_b < act_block_h_datums; h_b++) { - uint32_t h = start_block_2d_index_h + h_b; - uint32_t dst_address_offset_l1 = h_b * act_block_w_datums * num_bytes_df; - if (h >= matrix_height_unpadded) { - // pad (block shape padding for height dim) - uint32_t pad_size_bytes = act_block_w_datums * num_bytes_df; - assert(dst_address_offset_l1 < dst_l1_buffer_size_bytes); - address_map.push_back(0); // src address not used - address_map.push_back(dst_address_offset_l1); - address_map.push_back(pad_size_bytes); - address_map.push_back(1); // pad = 1 - address_map_current_group_size += 4; - } - else { - uint32_t w = start_block_2d_index_w; - uint32_t end_block_2d_index_w = start_block_2d_index_w + act_block_w_datums - 1; - assert(end_block_2d_index_w < matrix_width); - while (w <= end_block_2d_index_w) { - uint32_t src_address_offset_dram = 0; - uint32_t read_size_bytes = 0; - uint32_t pad = 0; - if (w >= matrix_width_unpadded) { - // pad (block shape padding for width dim) - assert(end_block_2d_index_w == matrix_width-1); - read_size_bytes = (end_block_2d_index_w - w + 1) * num_bytes_df; - pad = 1; - } - else { - uint32_t channel_stick_offset = w % channel_stick_size; - uint32_t channel_stick_col_id = w / channel_stick_size; - uint32_t channel_stick_row_id = h; - assert(channel_stick_offset % (32/num_bytes_df) == 0); // DRAM read address must be aligned to 32 bytes - uint32_t channel_stick_row_id_x = channel_stick_row_id % conv_output_w; - uint32_t channel_stick_row_id_y = channel_stick_row_id / conv_output_w; - uint32_t act_tensor_start_x = channel_stick_row_id_x * V; - uint32_t act_tensor_start_y = channel_stick_row_id_y * U; - uint32_t act_tensor_padded_x = act_tensor_start_x + (channel_stick_col_id % S); - uint32_t act_tensor_padded_y = act_tensor_start_y + (channel_stick_col_id / S); - assert(w <= end_block_2d_index_w); - uint32_t read_size = std::min(channel_stick_size - channel_stick_offset, (end_block_2d_index_w+1)-w); - read_size_bytes = read_size * num_bytes_df; - if(act_tensor_padded_x < Pad_W || act_tensor_padded_x >= (Pad_W + conv_input_x) || act_tensor_padded_y < Pad_H || act_tensor_padded_y >= (Pad_H + conv_input_y)) { - // pad (conv padding) - pad = 1; - } - else { - uint32_t act_tensor_x = act_tensor_padded_x - Pad_W; - uint32_t act_tensor_y = act_tensor_padded_y - Pad_H; - assert(act_tensor_x < conv_input_x && act_tensor_x >= 0 && act_tensor_y < conv_input_y && act_tensor_y >= 0); - uint32_t act_tensor_channel_id = act_tensor_y * conv_input_x + act_tensor_x; - src_address_offset_dram = ((act_tensor_channel_id * channel_stick_size) + channel_stick_offset) * num_bytes_df; - assert(src_address_offset_dram % 32 == 0); // DRAM read address must be aligned to 32 bytes - } - } - assert(read_size_bytes > 0); - assert(pad == 0 || pad == 1); - assert(src_address_offset_dram < src_dram_buffer_size_bytes); - assert(dst_address_offset_l1 < dst_l1_buffer_size_bytes); - address_map.push_back(src_address_offset_dram); - address_map.push_back(dst_address_offset_l1); - address_map.push_back(read_size_bytes); - address_map.push_back(pad); - address_map_current_group_size += 4; - dst_address_offset_l1 += read_size_bytes; - w += (read_size_bytes/num_bytes_df); - assert(w <= end_block_2d_index_w+1); - } - } - } - // DRAM reads should be 32B aligned - assert(address_map_current_group_dram_address_offset%32 == 0); - address_map_metadata.push_back(address_map_current_group_dram_address_offset); - address_map_metadata.push_back(address_map_current_group_size); - // Pad 0s in address map buffer to ensure each read address is 32B aligned (32/sizeof(uint32_t) == 8 elements) - uint32_t address_map_current_group_size_padded = (uint32_t) (std::ceil((double) address_map_current_group_size / (double) 8) * 8); - if(address_map_current_group_size_padded != address_map_current_group_size) { - assert(address_map_current_group_size_padded > address_map_current_group_size); - address_map.insert(address_map.end(), address_map_current_group_size_padded - address_map_current_group_size, 0); - } - // update next group's dram read address offset (in bytes) - address_map_current_group_dram_address_offset += (address_map_current_group_size_padded*sizeof(uint32_t)); - } - return make_pair(std::move(address_map), std::move(address_map_metadata)); -} - -std::pair, vector> populate_address_map_vectors_for_reader_kernel(vector address_map_raw) { - // This function is called twice i.e., for activation and weight address maps - // "address_map_raw" is the DTX address map vector returned from DTX "conv_transform" function. - // "address_map_raw" contains metadata along with the address map data for all groups - // To keep the reader kernel simple, the metadata is separated into a different buffer - // So two buffers are created - - // First buffer is in DRAM containing the address map for all groups - // This DRAM buffer is big and is streamed into L1 scratchpad - // Second buffer contains the metadata and is copied to L1 from host - // It contains number of groups in its first index, followed by group info for each group - - // 1. dram read address offset of address map group in dram buffer (in bytes) - // 2. size of address map group in dram buffer (in datums, not bytes) - // TODO (nshanker), support for streaming the second buffer from dram if it does not fit in L1 - vector address_map; // will be in dram - vector address_map_metadata; // will be in l1 - - uint32_t num_address_map_fields_per_transfer = 4; // TODO (nshanker): remove hardcoded 4 and get this value from output of DTX - uint32_t num_dtx_groups = address_map_raw[0]; - address_map_metadata.push_back(address_map_raw[0]); - uint32_t address_map_raw_index = 1; - uint32_t current_group_dram_address_offset = 0; - for(uint32_t g = 0; g < num_dtx_groups; g++) { - // insert group's dram read address (in bytes) in metadata buffer - // Separate reads are issued for each "address map group" - // DRAM reads should be 32B aligned - assert(current_group_dram_address_offset%32 == 0); - address_map_metadata.push_back(current_group_dram_address_offset); - // insert group size (datums, not in bytes) into metadata buffer - uint32_t current_group_size = address_map_raw[address_map_raw_index]; - address_map_metadata.push_back(current_group_size); - address_map_raw_index += 1; - // insert address map for this group into the address map buffer - auto address_map_raw_current_group_start = address_map_raw.begin() + address_map_raw_index; - address_map.insert(address_map.end(), - address_map_raw_current_group_start, - address_map_raw_current_group_start + current_group_size); - address_map_raw_index += current_group_size; - // Pad 0s in address map buffer to ensure each read address is 32B aligned (32/sizeof(uint32_t) == 8 elements) - uint32_t current_group_size_padded = (uint32_t) (std::ceil((double) current_group_size / (double) 8) * 8); - if(current_group_size_padded != current_group_size) { - assert(current_group_size_padded > current_group_size); - address_map.insert(address_map.end(), current_group_size_padded - current_group_size, 0); - } - // update next group's dram read address offset (in bytes) - current_group_dram_address_offset += (current_group_size_padded*sizeof(uint32_t)); - } - return make_pair(std::move(address_map), std::move(address_map_metadata)); -} - -operation::ProgramWithCallbacks conv_as_large_bmm_with_address_map_single_core_(const Tensor& a, const Tensor &b, vector conv_params, - uint32_t act_block_h_ntiles, uint32_t act_block_w_ntiles, uint32_t weight_block_w_ntiles, - uint32_t out_subblock_h_ntiles, uint32_t out_subblock_w_ntiles, uint32_t output_channels, bool untilize_out, Tensor &output) { - bool pass = true; - assert(untilize_out == true); - tt_metal::Device *device = a.device(); - TT_ASSERT(a.get_layout() == Layout::ROW_MAJOR, "Conv activation should be in row major layout"); - TT_ASSERT(a.get_legacy_shape()[0] == 1, "Only batch size 1 supported."); - TT_ASSERT(output_channels <= b.get_legacy_shape()[3], "Invalid weight shape. Incorrect weight tensor."); - - uint32_t num_bytes_of_df = 2; // 2 bytes for bfloat16 - // Compute the 2d matrix shape - auto [matrix_shape, matrix_shape_unpadded] = conv_op_utils::compute_conv_activation_as_mm_shape(a.get_legacy_shape(), conv_params, act_block_h_ntiles, act_block_w_ntiles, false); - assert(matrix_shape.size() == 3); - assert(matrix_shape[0] == 1); - uint32_t num_rows = (uint32_t) matrix_shape[1]; - uint32_t num_cols = (uint32_t) matrix_shape[2]; - - // More Checks - uint32_t Ba = 1; - uint32_t Ca = 1; - auto Ha = num_rows; - auto Wa = num_cols; - uint32_t Bb = b.get_legacy_shape()[0]; - uint32_t Cb = b.get_legacy_shape()[1]; - uint32_t Hb = b.get_legacy_shape()[2]; - uint32_t Wb = b.get_legacy_shape()[3]; - // Normal matrix shape checks - TT_ASSERT(Ba == 1, "So far, large matmul op has only been tested for batch one."); - TT_ASSERT(Ba == Bb, "Batch dimension needs to match"); - TT_ASSERT(Ca == Cb, "Channel dimension needs to match"); - TT_ASSERT(Wa == Hb, "The width of tensor a needs to match the height of tensor b"); - - // Tile size divisibility checks - TT_ASSERT(Ha % TILE_HEIGHT == 0, "Height of tensor a needs to be divisible by 32"); - TT_ASSERT(Wa % TILE_WIDTH == 0, "Width of tensor a needs to be divisible by 32"); - TT_ASSERT(Hb % TILE_HEIGHT == 0, "Height of tensor b needs to be divisible by 32"); - TT_ASSERT(Wb % TILE_WIDTH == 0, "Width of tensor b needs to be divisible by 32"); - - // Device compatibility checks - TT_ASSERT(a.storage_type() == StorageType::DEVICE and b.storage_type() == StorageType::DEVICE, "Operands to large matmul need to be on device!"); - TT_ASSERT(a.device() == b.device(), "Operands to large matmul need to be on the same device!"); - TT_ASSERT(a.buffer() != nullptr and b.buffer() != nullptr, "Operands to large matmul need to be allocated in buffers on device!"); - // Convert tensor dims to tile dims - uint32_t B = Ba; - uint32_t Hat = Ha / TILE_HEIGHT; - uint32_t Wat = Wa / TILE_WIDTH; - uint32_t Wbt = Wb / TILE_WIDTH; - log_debug(tt::LogOp, "Hat(MM Activation H in tiles): {}", Hat); - log_debug(tt::LogOp, "Wat(MM Activation W (MM Weight H) in tiles): {}", Wat); - log_debug(tt::LogOp, "Wbt(MM Weight W in tiles): {}", Wbt); - - assert(Hat % act_block_h_ntiles == 0); - assert(Wat % act_block_w_ntiles == 0); - assert(Wbt % weight_block_w_ntiles == 0); - - uint32_t num_blocks_act_h = Hat / act_block_h_ntiles; - uint32_t num_blocks_act_w = Wat / act_block_w_ntiles; - uint32_t num_blocks_weight_w = Wbt / weight_block_w_ntiles; - - // act block info - uint32_t act_block_w_datums = Wa / num_blocks_act_w; - uint32_t act_block_h_datums = Ha / num_blocks_act_h; - - // weight block info - uint32_t weight_block_w_datums = Wb / num_blocks_weight_w; - assert(weight_block_w_ntiles % out_subblock_w_ntiles == 0); - uint32_t weight_num_subblocks = weight_block_w_ntiles / out_subblock_w_ntiles; - uint32_t weight_block_h_ntiles = act_block_w_ntiles; - uint32_t weight_block_num_tiles = weight_block_w_ntiles * weight_block_h_ntiles; - // writer of conv op partially removes padding on the width - // it removes the padding done for block width but it doesn't remove padding done for tiled width - uint32_t output_channels_padded_to_tile_width = round_up(output_channels, TILE_WIDTH); - assert(output_channels_padded_to_tile_width <= Wb); - uint32_t num_blocks_output_w = (uint32_t) std::ceil((double) output_channels_padded_to_tile_width / (double) weight_block_w_datums); - uint32_t last_block_width_datums = (output_channels_padded_to_tile_width % weight_block_w_datums == 0) ? weight_block_w_datums : (output_channels_padded_to_tile_width % weight_block_w_datums); - assert(last_block_width_datums % TILE_WIDTH == 0); - uint32_t output_row_size_bytes = output_channels_padded_to_tile_width * num_bytes_of_df; - uint32_t last_block_row_size_bytes = last_block_width_datums * num_bytes_of_df; - // sanity check - assert(num_blocks_output_w == num_blocks_weight_w); - - // DTX conv activation transform data access pattern - auto [act_address_map, act_address_map_metadata] = generate_conv_activation_address_map(ttnn::Shape(a.get_legacy_shape()), conv_params, act_block_h_datums, act_block_w_datums, weight_block_w_datums, - num_blocks_act_h, num_blocks_act_w, num_blocks_weight_w, num_bytes_of_df); - - auto [weight_address_map, weight_address_map_metadata] = generate_conv_weight_address_map(ttnn::Shape(b.get_legacy_shape()), act_block_w_datums, weight_block_w_datums, - num_blocks_act_h, num_blocks_act_w, num_blocks_weight_w, num_bytes_of_df); - - // sanity check - uint32_t num_dtx_groups = act_address_map_metadata[0]; - assert(weight_address_map_metadata[0] == num_dtx_groups); - - // debug prints - int detailed_debug = 1; - if(detailed_debug > 0) { - log_debug(tt::LogOp, "Printing activation and weight address maps."); - log_debug(tt::LogOp, "DTX groups: {}", num_dtx_groups); - uint32_t act_metadata_index = 1; - uint32_t weight_metadata_index = 1; - uint32_t act_addr_map_index = 0; - uint32_t weight_addr_map_index = 0; - for(uint32_t g = 0; g < num_dtx_groups; g++) { - log_debug(tt::LogOp, " DTX group: {}", g); - uint32_t act_current_group_address = act_address_map_metadata[act_metadata_index]; - act_metadata_index += 1; - uint32_t act_current_group_size = act_address_map_metadata[act_metadata_index]; - act_metadata_index += 1; - log_debug(tt::LogOp, " act_current_group_address: {}", act_current_group_address); - log_debug(tt::LogOp, " act_current_group_size: {}", act_current_group_size); - if(detailed_debug > 1) { - uint32_t act_current_group_index = act_current_group_address/sizeof(uint32_t); - for(uint32_t i = act_current_group_index; i < act_current_group_index + act_current_group_size; i+=4) { - log_debug(tt::LogOp, " act_addr_map[0]: {}", act_address_map[i]); - log_debug(tt::LogOp, " act_addr_map[1]: {}", act_address_map[i+1]); - log_debug(tt::LogOp, " act_addr_map[2]: {}", act_address_map[i+2]); - log_debug(tt::LogOp, " act_addr_map[3]: {}", act_address_map[i+3]); - } - } - uint32_t weight_current_group_address = weight_address_map_metadata[weight_metadata_index]; - weight_metadata_index += 1; - uint32_t weight_current_group_size = weight_address_map_metadata[weight_metadata_index]; - weight_metadata_index += 1; - log_debug(tt::LogOp, " weight_current_group_address: {}", weight_current_group_address); - log_debug(tt::LogOp, " weight_current_group_size: {}", weight_current_group_size); - if(detailed_debug > 1) { - uint32_t weight_current_group_index = weight_current_group_address/sizeof(uint32_t); - for(uint32_t i = weight_current_group_index; i < weight_current_group_index + weight_current_group_size; i+=4) { - log_debug(tt::LogOp, " weight_addr_map[0]: {}", weight_address_map[i]); - log_debug(tt::LogOp, " weight_addr_map[1]: {}", weight_address_map[i+1]); - log_debug(tt::LogOp, " weight_addr_map[2]: {}", weight_address_map[i+2]); - log_debug(tt::LogOp, " weight_addr_map[3]: {}", weight_address_map[i+3]); - } - } - } - } - - uint32_t dram_bank_id = 0; - auto act_address_map_buffer_size_in_dram = act_address_map.size() * sizeof(uint32_t); - tt_metal::InterleavedBufferConfig act_config{ - .device= device, - .size = act_address_map_buffer_size_in_dram, - .page_size = act_address_map_buffer_size_in_dram, - .buffer_type = tt_metal::BufferType::DRAM - }; - - auto weight_address_map_buffer_size_in_dram = weight_address_map.size() * sizeof(uint32_t); - tt_metal::InterleavedBufferConfig weight_config{ - .device= device, - .size = weight_address_map_buffer_size_in_dram, - .page_size = weight_address_map_buffer_size_in_dram, - .buffer_type = tt_metal::BufferType::DRAM - }; - - - auto act_address_map_dram_buffer = CreateBuffer(act_config); - auto weight_address_map_dram_buffer = CreateBuffer(weight_config); - uint32_t act_address_map_dram_addr = act_address_map_dram_buffer->address(); - // DRAM to L1 writes should 32B aligned - assert(act_address_map_dram_addr%32 == 0); - auto act_address_map_dram_noc_xy = act_address_map_dram_buffer->noc_coordinates(); - uint32_t act_address_map_dram_noc_x = act_address_map_dram_noc_xy.x; - uint32_t act_address_map_dram_noc_y = act_address_map_dram_noc_xy.y; - uint32_t weight_address_map_dram_addr = weight_address_map_dram_buffer->address(); - // DRAM to L1 writes should 32B aligned - assert(weight_address_map_dram_addr%32 == 0); - auto weight_address_map_dram_noc_xy = weight_address_map_dram_buffer->noc_coordinates(); - uint32_t weight_address_map_dram_noc_x = weight_address_map_dram_noc_xy.x; - uint32_t weight_address_map_dram_noc_y = weight_address_map_dram_noc_xy.y; - - // Write address maps to DRAM - detail::WriteToDeviceDRAMChannel(device, dram_bank_id, act_address_map_dram_addr, act_address_map); - detail::WriteToDeviceDRAMChannel(device, dram_bank_id, weight_address_map_dram_addr, weight_address_map); - - tt_metal::Program program = tt_metal::CreateProgram(); - CoreCoord core_coord = {0, 0}; // TODO: avoid another var here. Find a way to use core range instead. - CoreRange core({0, 0}, {0, 0}); - - uint32_t single_tile_size = num_bytes_of_df * TILE_HEIGHT * TILE_WIDTH; - tt_metal::Buffer *src0_dram_buffer = a.buffer(); - tt_metal::Buffer *src1_dram_buffer = b.buffer(); - TT_ASSERT(src1_dram_buffer->size() % single_tile_size == 0, "Buffer size of tensor b must be divisible by single_tile_size (aka divisible by sizeof(df) * 1024)"); - - tt_metal::Buffer *dst_dram_buffer = output.buffer(); - TT_ASSERT(dst_dram_buffer != nullptr, "Output buffer should be allocated on device!"); - - // L1 buffers - // Create scratchpad buffer in L1 to stream in dtx address map from dram - // One scratchpad buffer is used for both activation and weight address maps - uint32_t num_address_map_fields_per_transfer = 4; // TODO: (nshanker): remove hardcoded 4 and get this value from output of DTX - // Scratchpad buffer size must be a multiple of 32B to ensure DRAM->L1 addresses align 32B - auto scratch_pad_for_address_map_in_l1_b0_size_bytes = 32; - // Scratchpad buffer size must also be a multiple of address map fields per transfer. We need all address map fields for a transfer in scratchpad. - assert(scratch_pad_for_address_map_in_l1_b0_size_bytes % (num_address_map_fields_per_transfer*sizeof(uint32_t)) == 0); - - tt_metal::InterleavedBufferConfig scratchpad_l1_config{ - .device= device, - .size = (uint64_t)scratch_pad_for_address_map_in_l1_b0_size_bytes, - .page_size = (uint64_t)scratch_pad_for_address_map_in_l1_b0_size_bytes, - .buffer_type = tt_metal::BufferType::L1 - }; - - auto scratch_pad_for_address_map_l1_buffer = CreateBuffer(scratchpad_l1_config); - uint32_t scratch_pad_for_address_map_l1_address = scratch_pad_for_address_map_l1_buffer->address(); - // DRAM to L1 writes should 32B aligned - assert(scratch_pad_for_address_map_l1_address%32 == 0); - // Create address map metadata buffers in L1 - // Metadata vectors are copied to L1 buffers from host before calling detail::LaunchProgram - auto act_address_map_metadata_l1_b0_size = act_address_map_metadata.size() * sizeof(uint32_t); - - tt_metal::InterleavedBufferConfig act_l1_config{ - .device= device, - .size = (uint64_t)act_address_map_metadata_l1_b0_size, - .page_size = (uint64_t)act_address_map_metadata_l1_b0_size, - .buffer_type = tt_metal::BufferType::L1 - }; - auto act_address_map_metadata_l1_buffer = CreateBuffer(act_l1_config); - uint32_t act_address_map_metadata_l1_address = act_address_map_metadata_l1_buffer->address(); - auto weight_address_map_metadata_l1_b0_size = weight_address_map_metadata.size() * sizeof(uint32_t); - - tt_metal::InterleavedBufferConfig weight_l1_config{ - .device= device, - .size = (uint64_t)weight_address_map_metadata_l1_b0_size, - .page_size = (uint64_t)weight_address_map_metadata_l1_b0_size, - .buffer_type = tt_metal::BufferType::L1 - }; - - - auto weight_address_map_metadata_l1_buffer = CreateBuffer(weight_l1_config); - uint32_t weight_address_map_metadata_l1_address = weight_address_map_metadata_l1_buffer->address(); - - // out - uint32_t out_dram_addr = dst_dram_buffer->address(); - uint32_t out_row_size = Wb * num_bytes_of_df; - uint32_t out_subblock_num_tiles = out_subblock_h_ntiles * out_subblock_w_ntiles; - - TT_ASSERT(out_subblock_num_tiles <= 8, "Need to ensure that matmul partials fit in dst"); - - // act - uint32_t act_dram_addr = src0_dram_buffer->address(); - auto act_dram_noc_xy = src0_dram_buffer->noc_coordinates(); - uint32_t act_noc_x = act_dram_noc_xy.x; - uint32_t act_noc_y = act_dram_noc_xy.y; - - assert(Wat % act_block_w_ntiles == 0); - assert(act_block_h_ntiles % out_subblock_h_ntiles == 0); - uint32_t act_num_subblocks = act_block_h_ntiles / out_subblock_h_ntiles; - uint32_t act_block_num_tiles = act_block_h_ntiles * act_block_w_ntiles; - uint32_t act_subblock_h_ntiles = out_subblock_h_ntiles; - uint32_t act_subblock_num_tiles = act_subblock_h_ntiles * act_block_w_ntiles; - - // weight - uint32_t weight_dram_addr = src1_dram_buffer->address(); - auto weight_dram_noc_xy = src1_dram_buffer->noc_coordinates(); - uint32_t weight_noc_x = weight_dram_noc_xy.x; - uint32_t weight_noc_y = weight_dram_noc_xy.y; - - // output data format - const auto out_df = datatype_to_dataformat_converter(a.get_dtype()); - // For debug - { - log_debug(tt::LogOp, "Hat (activation height in tiles): {}", Hat); - log_debug(tt::LogOp, "Wat (activation width in tiles): {}", Wat); - log_debug(tt::LogOp, "Wbt (weight width in tiles): {}", Wbt); - log_debug(tt::LogOp, "num_blocks_act_h: {}", num_blocks_act_h); - log_debug(tt::LogOp, "num_blocks_act_w: {}", num_blocks_act_w); - log_debug(tt::LogOp, "num_blocks_weight_w: {}", num_blocks_weight_w); - log_debug(tt::LogOp, "act_dram_addr: {}", act_dram_addr); - log_debug(tt::LogOp, "act_block_h_ntiles: {}", act_block_h_ntiles); - log_debug(tt::LogOp, "act_block_h_datums: {}", act_block_h_datums); - log_debug(tt::LogOp, "act_block_w_ntiles: {}", act_block_w_ntiles); - log_debug(tt::LogOp, "act_block_w_datums: {}", act_block_w_datums); - log_debug(tt::LogOp, "act_num_subblocks: {}", act_num_subblocks); - log_debug(tt::LogOp, "act_block_num_tiles: {}", act_block_num_tiles); - log_debug(tt::LogOp, "act_address_map_dram_addr: {}", act_address_map_dram_addr); - log_debug(tt::LogOp, "act_address_map_metadata_l1_address: {}", act_address_map_metadata_l1_address); - log_debug(tt::LogOp, "act_subblock_h_ntiles: {}", act_subblock_h_ntiles); - log_debug(tt::LogOp, "act_subblock_num_tiles: {}", act_subblock_num_tiles); - log_debug(tt::LogOp, "weight_dram_addr: {}", weight_dram_addr); - log_debug(tt::LogOp, "weight_num_subblocks: {}", weight_num_subblocks); - log_debug(tt::LogOp, "weight_block_num_tiles: {}", weight_block_num_tiles); - log_debug(tt::LogOp, "weight_address_map_dram_addr: {}", weight_address_map_dram_addr); - log_debug(tt::LogOp, "weight_address_map_metadata_l1_address: {}", weight_address_map_metadata_l1_address); - log_debug(tt::LogOp, "weight_block_w_ntiles: {}", weight_block_w_ntiles); - log_debug(tt::LogOp, "weight_block_h_ntiles: {}", weight_block_h_ntiles); - log_debug(tt::LogOp, "out_dram_addr: {}", out_dram_addr); - log_debug(tt::LogOp, "out_row_size: {}", out_row_size); - log_debug(tt::LogOp, "out_subblock_h_ntiles: {}", out_subblock_h_ntiles); - log_debug(tt::LogOp, "out_subblock_w_ntiles: {}", out_subblock_w_ntiles); - log_debug(tt::LogOp, "out_subblock_num_tiles: {}", out_subblock_num_tiles); - log_debug(tt::LogOp, "num_dtx_groups: {}", num_dtx_groups); - log_debug(tt::LogOp, "scratch_pad_for_address_map_l1_address: {}", scratch_pad_for_address_map_l1_address); - } - - create_CBs_for_fused_matmul_new_alloc( - program, - a.device(), - core, - act_block_h_ntiles * act_block_w_ntiles, // row major act cb - weight_block_h_ntiles * weight_block_w_ntiles, // tiled weight cb - act_block_h_ntiles * act_block_w_ntiles, // tiled act cb - act_block_h_ntiles * weight_block_w_ntiles, // math output cb - weight_block_w_ntiles, // reblock cb - act_block_h_ntiles * weight_block_w_ntiles, // writer output cb - num_bytes_of_df, - untilize_out); - - string reader_kernel; - vector reader_rt_args; - reader_kernel = "ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_binary_dtx.cpp"; - reader_rt_args = { - // arguments for act - act_dram_addr, - act_noc_x, - act_noc_y, - act_address_map_dram_addr, - act_address_map_dram_noc_x, - act_address_map_dram_noc_y, - act_address_map_metadata_l1_address, - act_block_num_tiles, - - // arguments for weight - weight_dram_addr, - weight_noc_x, - weight_noc_y, - weight_address_map_dram_addr, - weight_address_map_dram_noc_x, - weight_address_map_dram_noc_y, - weight_address_map_metadata_l1_address, - weight_block_num_tiles, - - scratch_pad_for_address_map_l1_address, - }; - - string writer_kernel; - vector writer_rt_args; - if (untilize_out) { - writer_kernel = "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/dataflow/writer_unary_stick_layout_interleaved_blocks.cpp"; - writer_rt_args = { - out_dram_addr, - act_block_h_datums, - weight_block_w_ntiles*TILE_WIDTH*num_bytes_of_df, - 1, - num_blocks_act_h, - num_blocks_weight_w, - output_channels_padded_to_tile_width*num_bytes_of_df, - last_block_row_size_bytes, - matrix_shape_unpadded[1], - 0, - 0 - }; - } else { - assert(false && "Tiled output unsupported"); - writer_kernel = "ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_matmul_tile_layout.cpp"; - writer_rt_args = { - out_dram_addr, - 0, - 1, - Wbt, - out_subblock_w_ntiles, - out_subblock_h_ntiles * Wbt, - - out_subblock_w_ntiles, - out_subblock_h_ntiles, - out_subblock_w_ntiles * out_subblock_h_ntiles, - Wbt / out_subblock_w_ntiles, - Hat / out_subblock_h_ntiles - }; - } - auto reader_id = tt_metal::CreateKernel( - program, - reader_kernel, - core, - tt_metal::ReaderDataMovementConfig{}); - std::vector writer_compile_time_args = {(uint32_t) (src0_dram_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0)}; - auto writer_id = tt_metal::CreateKernel( - program, - writer_kernel, - core, - tt_metal::WriterDataMovementConfig(writer_compile_time_args)); - - vector compute_kernel_args = { - act_block_w_ntiles, - act_num_subblocks, - act_block_num_tiles, - act_subblock_num_tiles, - act_subblock_h_ntiles, - - weight_num_subblocks, - weight_block_num_tiles, - weight_block_w_ntiles, - - num_blocks_act_h, - num_blocks_act_w, - num_blocks_weight_w, - - out_subblock_h_ntiles, - out_subblock_w_ntiles, - out_subblock_num_tiles, - - true, - untilize_out - }; - - auto eltwise_binary_kernel = tt_metal::CreateKernel( - program, - "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/compute/bmm_tilize_untilize.cpp", - core, - tt_metal::ComputeConfig{.compile_args = compute_kernel_args} - ); - - tt_metal::SetRuntimeArgs( - program, reader_id, core, - reader_rt_args - ); - - tt_metal::SetRuntimeArgs( - program, writer_id, core, - writer_rt_args - ); - - tt_metal::detail::WriteToDeviceL1(device, core_coord, act_address_map_metadata_l1_address, act_address_map_metadata); - tt_metal::detail::WriteToDeviceL1(device, core_coord, weight_address_map_metadata_l1_address, weight_address_map_metadata); - - auto override_runtime_args_callback = [ - reader_kernel_id=reader_id, - writer_kernel_id=writer_id - ] - ( - const Program &program, - const std::vector& input_buffers, - const std::vector& output_buffers - ) { - - auto src_dram_buffer_a = input_buffers.at(0); - auto src_dram_buffer_b = input_buffers.at(1); - - auto dst_dram_buffer = output_buffers.at(0); - - CoreCoord core = {0, 0}; - - { - auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); - runtime_args[0] = src_dram_buffer_a->address(); - runtime_args[1] = src_dram_buffer_a->noc_coordinates().x; - runtime_args[2] = src_dram_buffer_a->noc_coordinates().y; - runtime_args[8] = src_dram_buffer_b->address(); - runtime_args[9] = src_dram_buffer_b->noc_coordinates().x; - runtime_args[10] = src_dram_buffer_b->noc_coordinates().y; - } - - { - auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); - runtime_args[0] = dst_dram_buffer->address(); - } - }; - - return {std::move(program), override_runtime_args_callback}; -} - -inline Tensor conv_(const Tensor& a, const Tensor &b, std::optional bias, const vector conv_params, - uint32_t act_block_h_ntiles, uint32_t act_block_w_ntiles, uint32_t weight_block_w_ntiles, - uint32_t out_subblock_h_ntiles, uint32_t out_subblock_w_ntiles, uint32_t output_channels, - bool use_address_map, bool use_fast_reader, bool untilize_out, bool has_bias = false, bool fuse_relu = false, MathFidelity math_fidelity = MathFidelity::HiFi4) { - TT_ASSERT(b.get_layout() == Layout::TILE); // Weights should already be formatted - auto padded_a_shape = Shape(std::vector{a.get_legacy_shape()[0], a.get_legacy_shape()[1], a.get_legacy_shape()[2], round_up(a.get_legacy_shape()[3], 16)}); - ttnn::operations::experimental::auto_format::FormatParams input_a_format_params = {.pad_shape=padded_a_shape.value, .pad_value=0.0, .target_layout=Layout::ROW_MAJOR}; - ttnn::operations::experimental::auto_format::FormatParams input_b_format_params = {.pad_shape=b.get_legacy_shape(), .pad_value=0.0, .target_layout=Layout::TILE}; - ttnn::operations::experimental::auto_format::FormatParams input_bias_format_params = {}; - if (has_bias) { - input_bias_format_params = {.pad_shape=bias.value().get_legacy_shape(), .pad_value=0, .target_layout=Layout::TILE}; - } - auto output_layout = untilize_out ? Layout::ROW_MAJOR : Layout::TILE; - return operation::run_without_autoformat( - Conv(act_block_h_ntiles, act_block_w_ntiles, weight_block_w_ntiles, out_subblock_h_ntiles, out_subblock_w_ntiles, conv_params, output_channels, use_address_map, use_fast_reader, untilize_out, has_bias, fuse_relu, math_fidelity), - {a, b}, - {bias}).at(0); -} - -Tensor conv(const Tensor& a, const Tensor &b, std::optional bias, const vector conv_params, uint32_t act_block_h_ntiles, uint32_t act_block_w_ntiles, uint32_t weight_block_w_ntiles, - uint32_t out_subblock_h_ntiles, uint32_t out_subblock_w_ntiles, uint32_t output_channels, bool has_bias) { - return conv_(a, b, bias, conv_params, act_block_h_ntiles, act_block_w_ntiles, weight_block_w_ntiles, out_subblock_h_ntiles, out_subblock_w_ntiles, output_channels, false, false, true, has_bias); -} - - -operation::ProgramWithCallbacks conv_single_core(const Tensor& a, const Tensor &b, std::optional bias, const vector conv_params, uint32_t act_block_h_ntiles, uint32_t act_block_w_ntiles, uint32_t weight_block_w_ntiles, - uint32_t out_subblock_h_ntiles, uint32_t out_subblock_w_ntiles, uint32_t output_channels, bool use_fast_reader, bool untilize_out, bool has_bias, bool fuse_relu, const MathFidelity math_fidelity, Tensor &output) { - return conv_as_large_bmm_single_core_(a, b, bias, conv_params, act_block_h_ntiles, act_block_w_ntiles, weight_block_w_ntiles, out_subblock_h_ntiles, out_subblock_w_ntiles, output_channels, use_fast_reader, untilize_out, has_bias, fuse_relu, math_fidelity, output); -} - -operation::ProgramWithCallbacks conv_with_address_map_single_core(const Tensor& a, const Tensor &b, const vector conv_params, uint32_t act_block_h_ntiles, uint32_t act_block_w_ntiles, uint32_t weight_block_w_ntiles, - uint32_t out_subblock_h_ntiles, uint32_t out_subblock_w_ntiles, uint32_t output_channels, bool untilize_out, Tensor &output) { - return conv_as_large_bmm_with_address_map_single_core_(a, b, conv_params, act_block_h_ntiles, act_block_w_ntiles, weight_block_w_ntiles, out_subblock_h_ntiles, out_subblock_w_ntiles, output_channels, untilize_out, output); -} - -void Conv::validate(const std::vector& input_tensors, const std::vector>& optional_input_tensors) const { - const auto& input_tensor_a = input_tensors.at(0); - const auto& input_tensor_b = input_tensors.at(1); - // TODO: ... -} - -std::vector Conv::compute_output_shapes(const std::vector& input_tensors) const { - const auto& input_tensor_a = input_tensors.at(0); - uint32_t conv_activation_h = input_tensor_a.get_legacy_shape()[1]; - uint32_t conv_activation_w = input_tensor_a.get_legacy_shape()[2]; - // TODO: clean up here - uint32_t filter_h = (uint32_t) conv_params[0]; - uint32_t filter_w = (uint32_t) conv_params[1]; - uint32_t stride_h = (uint32_t) conv_params[2]; - uint32_t stride_w = (uint32_t) conv_params[3]; - uint32_t pad_h = (uint32_t) conv_params[4]; - uint32_t pad_w = (uint32_t) conv_params[5]; - auto [conv_output_h, conv_output_w] = conv_op_utils::compute_conv_output_face_shape(conv_activation_h, conv_activation_w, filter_h, filter_w, stride_h, stride_w, pad_h, pad_w); - - if (untilize_out) { - // TODO: Update batch size below - // RM output has unpadded output height and padded output width to 32. - // pad the output channels to TILE_WIDTH as conv writer kernel does not remove padding for tile - // TODO (nshanker): specify padding explicitly here with "Padding" object and add unit test - auto output_channels = round_up(this->output_channels, TILE_WIDTH); - Shape output_tensor_shape = Shape(std::vector{1, conv_output_h, conv_output_w, output_channels}); - return {output_tensor_shape.value}; - } else { - // Tiled output shape is padded shape. Padded to tile shape. - auto shape_w = conv_output_h*conv_output_w; - auto shape_c = output_channels; - auto padded_shape_w = round_up(shape_w, TILE_HEIGHT); - auto padded_shape_c = round_up(this->output_channels, TILE_WIDTH); - auto output_padding = Padding({{0, 0}, {0, 0}, {0, (padded_shape_w - shape_w)}, {0, (padded_shape_c - shape_c)}}, Padding::PadValue::Any); - auto output_tensor_shape = Shape(tt::tt_metal::Shape({1, 1, padded_shape_w, padded_shape_c}, output_padding)); - return {output_tensor_shape.value}; - } -} - -std::vector Conv::create_output_tensors(const std::vector& input_tensors) const { - const auto& input_tensor = input_tensors.at(0); - auto output_layout = this->untilize_out ? Layout::ROW_MAJOR : Layout::TILE; - return operation::generic_create_output_tensors(*this, input_tensors, input_tensor.get_dtype(), output_layout, input_tensor.memory_config()); -} - -operation::ProgramWithCallbacks Conv::create_program(const std::vector& input_tensors, - const std::vector>& optional_input_tensors, - std::vector& output_tensors) const { - const auto& input_tensor_a = input_tensors.at(0); - const auto& input_tensor_b = input_tensors.at(1); - const auto& input_tensor_bias = optional_input_tensors.at(0); - auto& output_tensor = output_tensors.at(0); - if(use_address_map) { - return {conv_with_address_map_single_core(input_tensor_a, input_tensor_b, conv_params, act_block_h_ntiles, act_block_w_ntiles, weight_block_w_ntiles, out_subblock_h_ntiles, out_subblock_w_ntiles, output_channels, untilize_out, output_tensor)}; - } else { - return {conv_single_core(input_tensor_a, input_tensor_b, input_tensor_bias, conv_params, act_block_h_ntiles, act_block_w_ntiles, weight_block_w_ntiles, out_subblock_h_ntiles, out_subblock_w_ntiles, output_channels, use_fast_reader, untilize_out, has_bias, fuse_relu, math_fidelity, output_tensor)}; - } -} - -} // namespace tt_metal - -} // namespace tt diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/bmm_tilize_untilize_all_weights_in_l1_single_output_block_width_dim.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/bmm_tilize_untilize_all_weights_in_l1_single_output_block_width_dim.cpp deleted file mode 100644 index 1c3b2db8171..00000000000 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/bmm_tilize_untilize_all_weights_in_l1_single_output_block_width_dim.cpp +++ /dev/null @@ -1,326 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include - -#include "mod_div_lib.h" -#include "compute_kernel_api/tilize.h" -#include "compute_kernel_api/untilize.h" -#include "compute_kernel_api/tile_move_copy.h" -#include "compute_kernel_api/matmul.h" -#ifdef FUSE_BIAS -#include "compute_kernel_api/bcast.h" -#endif -#include "compute_kernel_api/eltwise_unary/sfpu_split_includes.h" - -inline void tilize_in( - uint32_t in_cb_id, - uint32_t in_subblock_h, - uint32_t in_block_w, - uint32_t in_num_subblocks, - uint32_t out_cb_id) { - - tilize_init_short(in_cb_id, in_block_w); - for (uint32_t in_subblock = 0; in_subblock < in_num_subblocks; ++in_subblock) { - for (uint32_t h = 0; h < in_subblock_h; ++h) { - cb_wait_front(in_cb_id, in_block_w); - cb_reserve_back(out_cb_id, in_block_w); - tilize_block(in_cb_id, in_block_w, out_cb_id); - cb_push_back(out_cb_id, in_block_w); - cb_pop_front(in_cb_id, in_block_w); - } - } - tilize_uninit(in_cb_id); -} // tilize_in() - -// NOTE: Bias is not supported with the untilize option -#ifndef FUSE_BIAS - inline void reblock_and_untilize( - uint32_t num_out_subblocks_in_col, - uint32_t out_subblock_num_tiles, - uint32_t out_subblock_h, - uint32_t out_subblock_w, - uint32_t out_block_w, - uint32_t interm_cb_id, - uint32_t reblock_cb_id, - uint32_t out_cb_id) { - - uint32_t num_tiles_in_row_of_subblocks = mulsi3(out_subblock_num_tiles, num_out_subblocks_in_col); - cb_wait_front(interm_cb_id, num_tiles_in_row_of_subblocks); - - uint32_t within_block_index = 0; - for (uint32_t h = 0; h < out_subblock_h; h++) { - uint32_t block_offset = 0; - - // Reblock - copy_tile_to_dst_init_short(); - cb_reserve_back(reblock_cb_id, out_block_w); - for (uint32_t n = 0; n < num_out_subblocks_in_col; n++) { - for (uint32_t w = 0; w < out_subblock_w; w++) { - uint32_t tile_index = block_offset + within_block_index + w; - tile_regs_acquire(); - copy_tile(interm_cb_id, tile_index, 0); - tile_regs_commit(); - - tile_regs_wait(); - pack_tile(0, reblock_cb_id); - tile_regs_release(); - } - block_offset += out_subblock_num_tiles; - } - cb_push_back(reblock_cb_id, out_block_w); - - // Untilize - untilize_init_short(reblock_cb_id); - cb_wait_front(reblock_cb_id, out_block_w); - cb_reserve_back(out_cb_id, out_block_w); - untilize_block(reblock_cb_id, out_block_w, out_cb_id); - cb_pop_front(reblock_cb_id, out_block_w); - cb_push_back(out_cb_id, out_block_w); - untilize_uninit(reblock_cb_id); - - within_block_index += out_subblock_w; - } - cb_pop_front(interm_cb_id, num_tiles_in_row_of_subblocks); - } // reblock_and_untilize() -#endif - -inline void pack_matmul_subblock(uint32_t cb_id, uint32_t out_subblock_num_tiles) { - cb_reserve_back(cb_id, out_subblock_num_tiles); - tile_regs_wait(); - for (uint32_t i = 0; i < out_subblock_num_tiles; ++i) { - pack_tile(i, cb_id); - } - tile_regs_release(); - cb_push_back(cb_id, out_subblock_num_tiles); -} - -namespace NAMESPACE { -void MAIN { - - constexpr uint32_t in0_block_w = get_compile_time_arg_val(0); // inner block size in tiles - constexpr uint32_t in0_num_subblocks = get_compile_time_arg_val(1); // outer row block size (in inner row blocks) - constexpr uint32_t in0_block_num_tiles = get_compile_time_arg_val(2); // out_subblock_h*in0_block_w*in0_num_subblocks; - constexpr uint32_t in0_subblock_num_tiles = get_compile_time_arg_val(3); // out_subblock_h*in0_block_w - constexpr uint32_t in0_subblock_h = get_compile_time_arg_val(4); - constexpr uint32_t in1_num_subblocks = get_compile_time_arg_val(5); // outer column block size (in inner column blocks) - constexpr uint32_t in1_block_num_tiles = get_compile_time_arg_val(6); //out_subblock_w*in0_block_w* in1_num_subblocks; - constexpr uint32_t in1_per_core_w = get_compile_time_arg_val(7); // out_subblock_w*in1_num_subblocks - // if these are not defined as volatile, it causes code size for TRISC2 to be too large if num_blocks > 1 - constexpr uint32_t in0_num_blocks_h = get_compile_time_arg_val(8); - constexpr uint32_t in0_num_blocks_w = get_compile_time_arg_val(9); - constexpr uint32_t in1_num_blocks_w = get_compile_time_arg_val(10); - constexpr uint32_t out_subblock_h = get_compile_time_arg_val(11); // inner row block size in tiles - constexpr uint32_t out_subblock_w = get_compile_time_arg_val(12); // inner column block size in tiles - constexpr uint32_t out_subblock_num_tiles = get_compile_time_arg_val(13); // out_subblock_h * out_subblock_w; - constexpr bool tilize_in0 = get_compile_time_arg_val(14); - constexpr bool untilize_out = get_compile_time_arg_val(15); - - constexpr uint32_t out_block_num_tiles = in0_num_subblocks * in1_num_subblocks * out_subblock_num_tiles; - - constexpr uint32_t out_block_w = in1_per_core_w; - constexpr bool spill = in0_num_blocks_w > 1; - - // CB indices - constexpr uint32_t in0_cb_id = tt::CB::c_in0; - constexpr uint32_t in1_cb_id = tt::CB::c_in1; - constexpr uint32_t matmul_partials_cb = tt::CB::c_intermed0; - constexpr uint32_t tilized_in0_cb_id = tt::CB::c_intermed1; - constexpr uint32_t untilize_mode_reblock_cb = tt::CB::c_intermed2; - constexpr uint32_t out_cb_id = tt::CB::c_out0; - - constexpr uint32_t untilize_mode_out_cb_id = untilize_out ? matmul_partials_cb : out_cb_id; - - #ifdef FUSE_BIAS - constexpr uint32_t bias_ntiles_w = get_compile_time_arg_val(16); - constexpr uint32_t bias_cb_id = tt::CB::c_in2; - constexpr uint32_t mm_out_cb_id = matmul_partials_cb; - #else - constexpr uint32_t mm_out_cb_id = untilize_mode_out_cb_id; - #endif - - constexpr uint32_t mm_in0_cb_id = tilize_in0 ? tilized_in0_cb_id : in0_cb_id; - - mm_init(mm_in0_cb_id, in1_cb_id, out_cb_id); - - #ifdef SFPU_OP_INIT_ACTIVATION - SFPU_OP_INIT_ACTIVATION - #endif - - cb_wait_front(in1_cb_id, in1_block_num_tiles * in0_num_blocks_w * in1_num_blocks_w); // wait for all weights, in_num_blocks_w == 1 - - for(uint32_t in0_block_h_i = 0; in0_block_h_i < in0_num_blocks_h; ++in0_block_h_i) { - bool enable_reload = false; - uint32_t in1_index_inner_dim_h_offset = 0; - - #ifdef PACK_RELU - PACK(( llk_pack_relu_config(ReluType::NO_RELU) )); - #endif - - uint32_t curr_matmul_out_cb = matmul_partials_cb; - for(uint32_t in0_block_w_i = 0; in0_block_w_i < in0_num_blocks_w; ++in0_block_w_i) { // inner dim of act (w) - bool last_out = (in0_block_w_i == in0_num_blocks_w - 1); - if constexpr (tilize_in0) { - #if defined PACK_RELU and not defined FUSE_BIAS - if (last_out) { - // if last block we pack the final result with relu enabled - PACK(( llk_pack_relu_config(ReluType::NO_RELU) )); - } - #endif - unpack_reconfig_data_format_srca(in1_cb_id, in0_cb_id); - tilize_in(in0_cb_id, in0_subblock_h, in0_block_w, in0_num_subblocks, tilized_in0_cb_id); - mm_init_short(); - unpack_reconfig_data_format_srca(in0_cb_id, in1_cb_id); - } - cb_wait_front(mm_in0_cb_id, in0_block_num_tiles); - - if (last_out) { - #if defined PACK_RELU and not defined FUSE_BIAS - // if last block we pack the final result with relu enabled - PACK(( llk_pack_relu_config(ReluType::ZERO_RELU) )); - #endif - curr_matmul_out_cb = mm_out_cb_id; - } - - uint32_t in0_index_subblock_offset = 0; - for (uint32_t in0_subblock_i = 0; in0_subblock_i < in0_num_subblocks; ++in0_subblock_i) { - uint32_t in1_index_subblock_offset = 0; - for (uint32_t in1_subblock_i = 0; in1_subblock_i < in1_num_subblocks; ++in1_subblock_i) { - if (enable_reload) { - // Reconfigure input - copy_tile_to_dst_init_short(); - unpack_reconfig_data_format_srca(in1_cb_id, matmul_partials_cb); - cb_wait_front(matmul_partials_cb, out_subblock_num_tiles); - tile_regs_acquire(); - for (uint32_t i = 0; i < out_subblock_num_tiles; ++i) { - copy_tile(matmul_partials_cb, i, i); - } - cb_pop_front(matmul_partials_cb, out_subblock_num_tiles); - // Reconfigure srcA back - mm_init_short(); - unpack_reconfig_data_format_srca(matmul_partials_cb, in1_cb_id); - } else { - // just acquire - tile_regs_acquire(); - } - - // Compute output sub-block from in0_subblock x in1_subblock - uint32_t dst_index = 0; - uint32_t in0_index_h_offset = 0; - uint32_t in1_index_offset = in1_index_inner_dim_h_offset + in1_index_subblock_offset; - for (uint32_t h = 0; h < out_subblock_h; ++h) { - uint32_t in0_index_offset = in0_index_subblock_offset + in0_index_h_offset; - for (uint32_t w = 0; w < out_subblock_w; ++w) { - uint32_t in1_index_inner_dim_subblock_offset = 0; - uint32_t in1_index_offset_w = in1_index_offset + w; - for (uint32_t inner_dim = 0; inner_dim < in0_block_w; ++inner_dim) { - matmul_tiles(mm_in0_cb_id, // in0_cb - in1_cb_id, // in1_cb - in0_index_offset + inner_dim, // in0 tile - in1_index_offset_w + in1_index_inner_dim_subblock_offset, // in1 tile - dst_index, // dst - false); - in1_index_inner_dim_subblock_offset += in1_per_core_w; - } // for in0_block_w - ++dst_index; - } // for out_subblock_w - in0_index_h_offset += in0_block_w; - } // for out_subblock_h - - #if not defined FUSE_BIAS and defined SFPU_OP_INIT_ACTIVATION - if (last_out) { - for (uint32_t i = 0; i < out_subblock_num_tiles; ++ i) { - SFPU_OP_FUNC_ACTIVATION - } - } - #endif - tile_regs_commit(); - pack_matmul_subblock(curr_matmul_out_cb, out_subblock_num_tiles); - in1_index_subblock_offset += out_subblock_w; - } // for in1_num_subblocks - in0_index_subblock_offset += in0_subblock_num_tiles; - } - - if constexpr (spill) enable_reload = true; - - cb_pop_front(mm_in0_cb_id, in0_block_num_tiles); - in1_index_inner_dim_h_offset += in1_block_num_tiles; - } // for in0_num_blocks_w - #ifdef FUSE_BIAS - #ifdef PACK_RELU - PACK(( llk_pack_relu_config(ReluType::ZERO_RELU) )); - #endif - add_bcast_rows_init_short(); - unpack_reconfig_data_format(in1_cb_id, matmul_partials_cb, mm_in0_cb_id, bias_cb_id); - cb_wait_front(bias_cb_id, bias_ntiles_w); - cb_wait_front(matmul_partials_cb, out_block_num_tiles); - for (uint32_t in0_subblock_i = 0; in0_subblock_i < in0_num_subblocks; ++in0_subblock_i) { - uint32_t in1_index_subblock_offset = 0; - for (uint32_t in1_subblock_i = 0; in1_subblock_i < in1_num_subblocks; ++in1_subblock_i) { - // reconfig packer df for out - // pack_reconfig_data_format(out_cb_id); - tile_regs_acquire(); - uint32_t i = 0; - for (uint32_t h = 0; h < out_subblock_h; ++ h) { - uint32_t bcast_tile_i = in1_index_subblock_offset; - for (uint32_t w = 0; w < out_subblock_w; ++ w) { - add_tiles_bcast_rows(matmul_partials_cb, bias_cb_id, i, bcast_tile_i, i); - ++ bcast_tile_i; - ++ i; - } - } - // reconfig unpacker df for srcB - // unpack_reconfig_data_format(in1_cb_id, in0_cb_id); - - #ifdef SFPU_OP_INIT_ACTIVATION - for (uint32_t i = 0; i < out_subblock_num_tiles; ++ i) { - SFPU_OP_FUNC_ACTIVATION - } - #endif - tile_regs_commit(); - // do not pop front bias as it may be used again for subsequent blocks - cb_pop_front(matmul_partials_cb, out_subblock_num_tiles); - - pack_matmul_subblock(untilize_mode_out_cb_id, out_subblock_num_tiles); - in1_index_subblock_offset += out_subblock_w; - } // for in1_num_subblocks - } - if constexpr(in0_num_blocks_h > 1) { - if constexpr (!tilize_in0) { - mm_init_short(); - } - unpack_reconfig_data_format(matmul_partials_cb, in1_cb_id, bias_cb_id, mm_in0_cb_id); - } - #else - if constexpr(untilize_out) { - #ifdef PACK_RELU - PACK(( llk_pack_relu_config(ReluType::NO_RELU) )); - #endif - unpack_reconfig_data_format(in1_cb_id, matmul_partials_cb, mm_in0_cb_id, untilize_mode_reblock_cb); - for (uint32_t in0_subblock_i = 0; in0_subblock_i < in0_num_subblocks; ++in0_subblock_i) { - uint32_t in1_index_subblock_offset = 0; - for (uint32_t in1_subblock_i = 0; in1_subblock_i < in1_num_subblocks; ++in1_subblock_i) { - reblock_and_untilize( - in1_num_subblocks, - out_subblock_num_tiles, - out_subblock_h, - out_subblock_w, - out_block_w, - matmul_partials_cb, - untilize_mode_reblock_cb, - out_cb_id); - } - } - if constexpr(in0_num_blocks_h > 1) { - if constexpr (!tilize_in) { - mm_init_short(); - } - unpack_reconfig_data_format(matmul_partials_cb, in1_cb_id, untilize_mode_reblock_cb, mm_in0_cb_id); - } - } - #endif - } // for in0_num_blocks_h - cb_pop_front(in1_cb_id, in1_block_num_tiles * in0_num_blocks_w * in1_num_blocks_w); -} // MAIN -} // NAMESPACE diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_binary_dtx.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_binary_dtx.cpp deleted file mode 100644 index aeab6c49252..00000000000 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_binary_dtx.cpp +++ /dev/null @@ -1,177 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include -#include "dataflow_api.h" - -inline void noc_async_read_from_dram_to_l1(uint32_t dram_addr, uint32_t dram_noc_x, uint32_t dram_noc_y, uint32_t l1_dest_addr, uint32_t read_size) { - uint64_t src_noc_addr = get_noc_addr(dram_noc_x, dram_noc_y, dram_addr); - noc_async_read(src_noc_addr, l1_dest_addr, read_size); -} -inline void async_read_from_dram_using_address_map(uint32_t dram_start_addr, - uint32_t dram_noc_x, - uint32_t dram_noc_y, - uint32_t l1_write_addr, - uint32_t address_map_scratch_pad_l1_addr, - uint32_t address_map_group_size, - uint32_t address_map_group_dram_addr, - uint32_t address_map_dram_noc_x, - uint32_t address_map_dram_noc_y) { - volatile tt_l1_ptr uint32_t * address_map_scratch_pad_buffer = (volatile tt_l1_ptr uint32_t*)(address_map_scratch_pad_l1_addr); - uint32_t address_map_scratch_pad_buffer_size_bytes = 32; // TODO (nshanker): make this a compile time kernel arg - uint32_t address_map_scratch_pad_buffer_size = address_map_scratch_pad_buffer_size_bytes >> 2; - uint32_t address_map_scratch_pad_index = 0; - - for(uint32_t i = 0; i < address_map_group_size; i+=4) { - if (address_map_scratch_pad_index == 0) { - // Issue a read from DRAM to fill up the entire scratchpad buffer - // Scratch pad buffer size must be a multiple of 32B because DRAM read needs to be 32B aligned - // We want to always do DRAM to L1 write at the start of scratchpad beause l1 write address has to be 32B aligned - // Kernel assumptions - - // Host must ensure that "address_map_scratch_pad_l1_addr" % 32 == 0 - // Host must ensure that "address_map_scratch_pad_buffer_size_bytes" % 32 == 0 - noc_async_read_from_dram_to_l1(address_map_group_dram_addr, - address_map_dram_noc_x, address_map_dram_noc_y, - address_map_scratch_pad_l1_addr, address_map_scratch_pad_buffer_size_bytes); - noc_async_read_barrier(); - address_map_group_dram_addr += address_map_scratch_pad_buffer_size_bytes; - } - // There are 4 entries in the address map vector for one transfer - uint32_t src_address_offset = address_map_scratch_pad_buffer[address_map_scratch_pad_index]; - uint32_t dst_address_offset = address_map_scratch_pad_buffer[address_map_scratch_pad_index+1]; - uint32_t read_size = address_map_scratch_pad_buffer[address_map_scratch_pad_index+2]; - uint32_t pad = address_map_scratch_pad_buffer[address_map_scratch_pad_index+3]; - // DPRINT << "src_address_offset=" << src_address_offset << ENDL(); - // DPRINT << "dst_address_offset=" << dst_address_offset << ENDL(); - // DPRINT << "read_size=" << read_size << ENDL(); - // DPRINT << "pad=" << pad << ENDL(); - - if(pad == 1) { - // Insert zeroes in l1 - uint32_t dst_addr = l1_write_addr + dst_address_offset; - uint32_t pad_size = read_size; - volatile std::uint8_t* start_dst= (volatile uint8_t*)(dst_addr); - for (uint32_t offset = 0; offset < pad_size; offset++) { - *(start_dst + offset) = 0; - } - // TODO (nshanker): More performant version below but switched off because it fails non deterministically - // // source address is set to max. This refers to padding location. - // // read zeroes from zero buffer - // uint32_t dst_addr = l1_write_addr + dst_address_offset; - // uint32_t pad_size = read_size; - // if (pad_size <= MEM_ZEROS_SIZE) { - // noc_async_read(zeros_base_noc_addr, dst_addr, pad_size); - // } - // else { - // // padding size is bigger than the zero buffer size - // // read from zero buffer multiple times - // uint32_t zeros_to_read = pad_size; - // uint32_t zeros_read_size = MEM_ZEROS_SIZE; - // while(zeros_to_read != 0) { - // noc_async_read(zeros_base_noc_addr, dst_addr, zeros_read_size); - // zeros_to_read -= zeros_read_size; - // if (zeros_to_read < zeros_read_size) { - // zeros_read_size = zeros_to_read; - // } - // } - // } - } - else { - uint32_t src_addr = dram_start_addr + src_address_offset; - uint32_t dst_addr = l1_write_addr + dst_address_offset; - noc_async_read_from_dram_to_l1(src_addr, dram_noc_x, dram_noc_y, dst_addr, read_size); - } - address_map_scratch_pad_index += 4; - if(address_map_scratch_pad_index == address_map_scratch_pad_buffer_size) { - // Reached the end of scratchpad buffer - // Reset the index to 0 for the next iteration - address_map_scratch_pad_index = 0; - } - } -} -void kernel_main() { - // Arguments for in0 - uint32_t in0_addr_base = get_arg_val(0); - uint32_t in0_noc_x = get_arg_val(1); - uint32_t in0_noc_y = get_arg_val(2); - uint32_t in0_address_map_dram_addr = get_arg_val(3); - uint32_t in0_address_map_dram_noc_x = get_arg_val(4); - uint32_t in0_address_map_dram_noc_y = get_arg_val(5); - uint32_t in0_address_map_metadata_l1_addr = get_arg_val(6); - uint32_t in0_block_num_tiles = get_arg_val(7); - - // Arguments for in1 - uint32_t in1_addr_base = get_arg_val(8); - uint32_t in1_noc_x = get_arg_val(9); - uint32_t in1_noc_y = get_arg_val(10); - uint32_t in1_address_map_dram_addr = get_arg_val(11); - uint32_t in1_address_map_dram_noc_x = get_arg_val(12); - uint32_t in1_address_map_dram_noc_y = get_arg_val(13); - uint32_t in1_address_map_metadata_l1_addr = get_arg_val(14); - uint32_t in1_block_num_tiles = get_arg_val(15); - - uint32_t scratch_pad_for_address_map_in_l1_addr = get_arg_val(16); - - constexpr uint32_t cb_id_in0 = 0; - constexpr uint32_t cb_id_in1 = 1; - // Scratchpad buffer in l1 to stream address map from DRAM into L1 - volatile tt_l1_ptr std::uint32_t* scratch_pad_for_address_map_l1_buffer = (volatile tt_l1_ptr uint32_t*)(scratch_pad_for_address_map_in_l1_addr); - // Address map metadata buffers in l1. Metadata is copied into L1 buffers by the host before kernel is launched - volatile tt_l1_ptr std::uint32_t* in0_address_map_metdata_l1_buffer = (volatile tt_l1_ptr uint32_t*)(in0_address_map_metadata_l1_addr); - volatile tt_l1_ptr std::uint32_t* in1_address_map_metdata_l1_buffer = (volatile tt_l1_ptr uint32_t*)(in1_address_map_metadata_l1_addr); - - // TODO (nshanker): For a more performant padding implementation which is switched off because it fails non deterministically - // // Put zeroes in the zero buffer for padding - // constexpr uint32_t num_elements_in_zeros_buffer = MEM_ZEROS_SIZE / sizeof(uint32_t); - // volatile tt_l1_ptr uint32_t* zero_base_ptr = reinterpret_cast(MEM_ZEROS_BASE); - // for (uint32_t zero_base_offset = 0; zero_base_offset < num_elements_in_zeros_buffer; zero_base_offset++) { - // *(zero_base_ptr + zero_base_offset) = 0; - // } - // uint64_t zeros_base_noc_addr = get_noc_addr(MEM_ZEROS_BASE); - - uint32_t in0_address_map_metadata_index = 0; - // address map metdata buffer contains number of groups in the first element - uint32_t num_groups = in0_address_map_metdata_l1_buffer[in0_address_map_metadata_index]; - in0_address_map_metadata_index += 1; - // in0 and in1 address maps should have same number of groups - // no need to get the in1 num of groups - uint32_t in1_address_map_metadata_index = 1; - //DPRINT << "num_groups=" << num_groups << ENDL(); - - for(uint32_t g = 0; g < num_groups; g++) { - - // Read in0 block from DRAM - - // Read in0 block - cb_reserve_back(cb_id_in0, in0_block_num_tiles); - uint32_t l1_write_addr_in0 = get_write_ptr(cb_id_in0); - uint32_t in0_address_map_current_group_dram_addr_offset = in0_address_map_metdata_l1_buffer[in0_address_map_metadata_index]; - uint32_t in0_address_map_current_group_dram_addr = in0_address_map_dram_addr + in0_address_map_current_group_dram_addr_offset; - in0_address_map_metadata_index += 1; - uint32_t in0_address_map_current_group_size = in0_address_map_metdata_l1_buffer[in0_address_map_metadata_index]; - in0_address_map_metadata_index += 1; - - async_read_from_dram_using_address_map(in0_addr_base, in0_noc_x, in0_noc_y, - l1_write_addr_in0, scratch_pad_for_address_map_in_l1_addr, in0_address_map_current_group_size, - in0_address_map_current_group_dram_addr, in0_address_map_dram_noc_x, in0_address_map_dram_noc_y); - noc_async_read_barrier(); - - - // Read in1 block from DRAM - // Read in1 block - cb_reserve_back(cb_id_in1, in1_block_num_tiles); - uint32_t l1_write_addr_in1 = get_write_ptr(cb_id_in1); - uint32_t in1_address_map_current_group_dram_addr_offset = in1_address_map_metdata_l1_buffer[in1_address_map_metadata_index]; - uint32_t in1_address_map_current_group_dram_addr = in1_address_map_dram_addr + in1_address_map_current_group_dram_addr_offset; - in1_address_map_metadata_index += 1; - uint32_t in1_address_map_current_group_size = in1_address_map_metdata_l1_buffer[in1_address_map_metadata_index]; - in1_address_map_metadata_index += 1; - async_read_from_dram_using_address_map(in1_addr_base, in1_noc_x, in1_noc_y, - l1_write_addr_in1, scratch_pad_for_address_map_in_l1_addr, in1_address_map_current_group_size, - in1_address_map_current_group_dram_addr, in1_address_map_dram_noc_x, in1_address_map_dram_noc_y); - noc_async_read_barrier(); - cb_push_back(cb_id_in0, in0_block_num_tiles); - cb_push_back(cb_id_in1, in1_block_num_tiles); - } -} diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_conv1x1_activations_fast_for_col_major_conv_out_blocks.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_conv1x1_activations_fast_for_col_major_conv_out_blocks.cpp deleted file mode 100644 index 8d639f6af0a..00000000000 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_conv1x1_activations_fast_for_col_major_conv_out_blocks.cpp +++ /dev/null @@ -1,178 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include -#include "dataflow_api.h" -// #include "debug/dprint.h" - -inline void pad_l1_buffer_with_zeroes(uint32_t l1_addr, uint32_t pad_size_bytes) { - volatile std::uint32_t* dst = reinterpret_cast(l1_addr); - volatile std::uint32_t* end_dst = dst + (pad_size_bytes >> 2); // Divide by 4 using right shift - - while (dst < end_dst) { - *dst++ = 0; - } - - uint32_t remainder = pad_size_bytes & 0x3; // Get the remainder using bitwise AND - if (remainder != 0) { - volatile std::uint8_t* byte_dst = reinterpret_cast(dst); - for (uint32_t i = 0; i < remainder; ++i) { - *byte_dst++ = 0; - } - } -} - -void kernel_main() { - uint32_t i = 0; - uint32_t act_addr_dram_base = get_arg_val(i); i+=1; - uint32_t act_dram_noc_x = get_arg_val(i); i+=1; - uint32_t act_dram_noc_y = get_arg_val(i); i+=1; - - uint32_t conv_act_size_w_ = get_arg_val(i); i+=1; - uint32_t conv_act_size_h = get_arg_val(i); i+=1; - uint32_t conv_act_size_c_ = get_arg_val(i); i+=1; - uint32_t weight_size_h = get_arg_val(i); i+=1; - uint32_t weight_size_w = get_arg_val(i); i+=1; - uint32_t stride_h_ = get_arg_val(i); i+=1; - uint32_t stride_w_ = get_arg_val(i); i+=1; - uint32_t pad_h = get_arg_val(i); i+=1; - uint32_t pad_w = get_arg_val(i); i+=1; - uint32_t conv_output_size_h = get_arg_val(i); i+=1; - uint32_t conv_output_size_w = get_arg_val(i); i+=1; - uint32_t num_blocks_act_h = get_arg_val(i); i+=1; - uint32_t num_blocks_act_w = get_arg_val(i); i+=1; - uint32_t num_blocks_weight_w = get_arg_val(i); i+=1; - uint32_t num_groups = get_arg_val(i); i+=1; - - uint32_t act_matrix_height_unpadded = get_arg_val(i); i+=1; - uint32_t act_matrix_width_unpadded = get_arg_val(i); i+=1; - uint32_t act_matrix_height = get_arg_val(i); i+=1; - uint32_t act_matrix_width = get_arg_val(i); i+=1; - uint32_t act_matrix_height_ntiles = get_arg_val(i); i+=1; - uint32_t act_matrix_width_ntiles = get_arg_val(i); i+=1; - uint32_t act_block_h_datums = get_arg_val(i); i+=1; - uint32_t act_block_w_datums = get_arg_val(i); i+=1; - uint32_t act_block_h_ntiles = get_arg_val(i); i+=1; - uint32_t act_block_w_ntiles = get_arg_val(i); i+=1; - uint32_t act_block_num_tiles = get_arg_val(i); i+=1; - uint32_t act_w_num_outer = get_arg_val(i); i+=1; - uint32_t src_dram_act_buffer_size_bytes = get_arg_val(i); i+=1; - uint32_t dst_l1_act_buffer_size_bytes = get_arg_val(i); i+=1; - uint32_t n_start = get_arg_val(i); i+=1; - uint32_t out_h_start = get_arg_val(i); i+=1; - uint32_t out_w_start = get_arg_val(i); i+=1; - uint32_t total_h_start = get_arg_val(i); i+=1; - - uint32_t noop = get_arg_val(i); i+=1; - if(noop) { - return; - } - - constexpr bool act_in_dram = get_compile_time_arg_val(0) == 1; - constexpr uint32_t stride_h = get_compile_time_arg_val(1); - constexpr uint32_t stride_w = get_compile_time_arg_val(2); - constexpr uint32_t conv_act_size_w = get_compile_time_arg_val(3); - constexpr uint32_t conv_output_w_last_index = get_compile_time_arg_val(4) - 1; - constexpr uint32_t conv_act_c_read_bytes = get_compile_time_arg_val(5); - constexpr uint32_t log_base_2_of_conv_act_size_c_bytes = get_compile_time_arg_val(6); - constexpr uint32_t num_channel_slices = get_compile_time_arg_val(8); // arg index=7 unused - constexpr uint32_t channel_slice_size_bytes = get_compile_time_arg_val(9); - - constexpr uint32_t cb_id_act = 0; - constexpr uint32_t tile_size_pow2_exponent = 11; - const DataFormat data_format = get_dataformat(cb_id_act); - const InterleavedPow2AddrGenFast s_act = { - .bank_base_address = act_addr_dram_base, - .log_base_2_of_page_size = log_base_2_of_conv_act_size_c_bytes - }; - - // Assumptions. Must be true. Validate on host. - // assert(act_block_w_datums == C * weight_size_w) - // assert(num_blocks_act_w == weight_size_h) - // assert(act_block_w_datums % C == 0) - // assert(act_block_w_datums % 32 == 0) - // assert(act_block_h_datums % 32 == 0) - // assert(act_block_h_ntiles == act_block_h_datums/32) - // assert(act_block_w_ntiles == act_block_w_datums/32) - // assert(act_block_num_tiles == (act_block_h_datums * act_block_w_datums)/1024) - - // DPRINT << "Running new conv reader" << ENDL(); - // DPRINT << "act matrix h unpadded " << act_matrix_height_unpadded << ENDL(); - // DPRINT << "num_blocks_act_h " << num_blocks_act_h << ENDL(); - // DPRINT << "act_block_h_datums " << act_block_h_datums << ENDL(); - // DPRINT << "num_blocks_weight_w " << num_blocks_weight_w << ENDL(); - // DPRINT << "num_blocks_act_w " << num_blocks_act_w << ENDL(); - // Outer loop is number of blocks in weight width dim - // Conv output blocks are computed in col major order - for(uint32_t nbr = 0; nbr < num_blocks_weight_w; nbr++) { - uint32_t out_h = out_h_start; - uint32_t out_w = out_w_start; - uint32_t out_h_reset = out_h_start; - uint32_t out_w_reset = out_w_start; - uint32_t total_h = total_h_start; - uint32_t total_h_reset = total_h_start; - uint32_t n = n_start; - uint32_t n_reset = n_start; - for(uint32_t nbh = 0; nbh < num_blocks_act_h; nbh++) { - uint32_t channel_slice_offset = 0; - for (uint32_t chs = 0; chs < num_channel_slices; chs++) { - out_h = out_h_reset; - out_w = out_w_reset; - total_h = total_h_reset; - n = n_reset; - cb_reserve_back(cb_id_act, act_block_num_tiles); - uint32_t l1_write_addr_act = get_write_ptr(cb_id_act); - uint32_t l1_addr_offset = 0; - for(uint32_t bh = 0; bh < act_block_h_datums; bh++) { - uint32_t in_h_offset = out_h * stride_h; - uint32_t in_w_offset = out_w * stride_w; // expect stride 1 or 2.. make this compile time args - also conv input width - uint32_t read_size_bytes = channel_slice_size_bytes; - - if (total_h < act_matrix_height_unpadded) { - uint32_t in_h = in_h_offset; - uint32_t in_w = in_w_offset; - - if(in_h < pad_h || in_w < pad_w || in_h >= (conv_act_size_h + pad_h) || in_w >= (conv_act_size_w_ + pad_w)) { - // pad 0s in l1 - uint32_t dst_addr = l1_write_addr_act + l1_addr_offset; - uint32_t pad_size_bytes = read_size_bytes; - pad_l1_buffer_with_zeroes(dst_addr, pad_size_bytes); - } else { - // read one channel from dram multi bank - row_id = channel_id - uint32_t in_h_raw = in_h - pad_h; - uint32_t in_w_raw = in_w - pad_w; - uint32_t channel_id = (n * conv_act_size_h * conv_act_size_w) + (in_h_raw * conv_act_size_w) + in_w_raw; - //DPRINT << "n=" << n << " h=" << in_h_raw << " w=" << in_w_raw << " conv_act_size_h=" << conv_act_size_h << " conv_act_size_w=" << conv_act_size_w << ENDL(); - uint32_t dst_addr = l1_write_addr_act + l1_addr_offset; - s_act.noc_async_read_partial_page(channel_id, dst_addr, channel_slice_size_bytes, channel_slice_offset); - } - } //else { DPRINT << "total_h here =" << total_h << ENDL(); } //do nothing. let garbage rows be in l1 - l1_addr_offset += read_size_bytes; - if(out_w < conv_output_size_w - 1) { - out_w += 1; - } else { - out_w = 0; - if (out_h < conv_output_size_h - 1) { - out_h += 1; - } else if (total_h < act_matrix_height_unpadded){ - // next image in batch - out_h = 0; - n += 1; - } - } - total_h += 1; - } // for block height - //DPRINT << "waiting on read barrier" << ENDL(); - noc_async_read_barrier(); - //DPRINT << "done on read barrier" << ENDL(); - cb_push_back(cb_id_act, act_block_num_tiles); - channel_slice_offset += channel_slice_size_bytes; - } // for num channel slices - out_h_reset = out_h; - out_w_reset = out_w; - total_h_reset = total_h; - n_reset = n; - } // for num of act blocks in height dim - } // for num of weight blocks in width dim -} diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_conv_activations.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_conv_activations.cpp deleted file mode 100644 index 914b0473689..00000000000 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_conv_activations.cpp +++ /dev/null @@ -1,195 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include -#include "dataflow_api.h" -// #include "debug/dprint.h" - -inline void pad_l1_buffer_with_zeroes(uint32_t l1_addr, uint32_t pad_size_bytes) { - volatile std::uint32_t* dst = reinterpret_cast(l1_addr); - volatile std::uint32_t* end_dst = dst + (pad_size_bytes >> 2); // Divide by 4 using right shift - - while (dst < end_dst) { - *dst++ = 0; - } - - uint32_t remainder = pad_size_bytes & 0x3; // Get the remainder using bitwise AND - if (remainder != 0) { - volatile std::uint8_t* byte_dst = reinterpret_cast(dst); - for (uint32_t i = 0; i < remainder; ++i) { - *byte_dst++ = 0; - } - } -} - -void kernel_main() { - uint32_t i = 0; - uint32_t act_addr_dram_base = get_arg_val(i); i+=1; - uint32_t act_dram_noc_x = get_arg_val(i); i+=1; - uint32_t act_dram_noc_y = get_arg_val(i); i+=1; - - uint32_t conv_act_size_w = get_arg_val(i); i+=1; - uint32_t conv_act_size_h = get_arg_val(i); i+=1; - uint32_t conv_act_size_c = get_arg_val(i); i+=1; - uint32_t weight_size_h = get_arg_val(i); i+=1; - uint32_t weight_size_w = get_arg_val(i); i+=1; - uint32_t stride_h = get_arg_val(i); i+=1; - uint32_t stride_w = get_arg_val(i); i+=1; - uint32_t pad_h = get_arg_val(i); i+=1; - uint32_t pad_w = get_arg_val(i); i+=1; - uint32_t conv_output_size_h = get_arg_val(i); i+=1; - uint32_t conv_output_size_w = get_arg_val(i); i+=1; - uint32_t num_blocks_act_h = get_arg_val(i); i+=1; - uint32_t num_blocks_act_w = get_arg_val(i); i+=1; - uint32_t num_blocks_weight_w = get_arg_val(i); i+=1; - uint32_t num_groups = get_arg_val(i); i+=1; - - uint32_t act_matrix_height_unpadded = get_arg_val(i); i+=1; - uint32_t act_matrix_width_unpadded = get_arg_val(i); i+=1; - uint32_t act_matrix_height = get_arg_val(i); i+=1; - uint32_t act_matrix_width = get_arg_val(i); i+=1; - uint32_t act_matrix_height_ntiles = get_arg_val(i); i+=1; - uint32_t act_matrix_width_ntiles = get_arg_val(i); i+=1; - uint32_t act_block_h_datums = get_arg_val(i); i+=1; - uint32_t act_block_w_datums = get_arg_val(i); i+=1; - uint32_t act_block_h_ntiles = get_arg_val(i); i+=1; - uint32_t act_block_w_ntiles = get_arg_val(i); i+=1; - uint32_t act_block_num_tiles = get_arg_val(i); i+=1; - uint32_t src_dram_act_buffer_size_bytes = get_arg_val(i); i+=1; - uint32_t dst_l1_act_buffer_size_bytes = get_arg_val(i); i+=1; - - constexpr bool act_in_dram = get_compile_time_arg_val(0) == 1; - - constexpr uint32_t cb_id_act = 0; - constexpr uint32_t tile_size_pow2_exponent = 11; - const DataFormat data_format = get_dataformat(cb_id_act); - uint32_t channel_stick_size = conv_act_size_c; - uint32_t channel_stick_size_bytes = channel_stick_size << 1; - const InterleavedAddrGen s_act = { - .bank_base_address = act_addr_dram_base, - .page_size = channel_stick_size_bytes - }; - - for(uint32_t group_idx = 0; group_idx < num_groups; group_idx++) { - - // Read activations for this group - // Activations are in channels last layout in dram - { - cb_reserve_back(cb_id_act, act_block_num_tiles); - uint32_t block_idx_h = (uint32_t) (group_idx / num_blocks_act_w) / (num_blocks_weight_w); - uint32_t block_idx_w = (uint32_t) (group_idx % num_blocks_act_w); - uint32_t block_idx = (block_idx_h * num_blocks_act_w) + block_idx_w; - uint32_t start_block_2d_index_h = block_idx_h * act_block_h_datums; - uint32_t start_block_2d_index_w = block_idx_w * act_block_w_datums; - uint32_t start_block_2d_index = (start_block_2d_index_h * act_block_w_datums * num_blocks_act_w) + start_block_2d_index_w; - uint32_t l1_write_addr_act = get_write_ptr(cb_id_act); - // TODO (nshanker): add macro to disable checks - if(start_block_2d_index_w >= act_matrix_width_unpadded) { - //DPRINT << "Problem" << ENDL(); - } - for(uint32_t h_b = 0; h_b < act_block_h_datums; h_b++) { - uint32_t h = start_block_2d_index_h + h_b; - uint32_t dst_address_offset_l1 = (h_b * act_block_w_datums)<<1; - if (h >= act_matrix_height_unpadded) { - // pad (block shape padding for height dim) - uint32_t pad_size_bytes = act_block_w_datums<<1; - // TODO (nshanker): add macro to disable checks - if(dst_address_offset_l1 + (pad_size_bytes-1) >= dst_l1_act_buffer_size_bytes) { - //DPRINT << "Problem" << ENDL(); - } - uint32_t dst_addr = l1_write_addr_act + dst_address_offset_l1; - pad_l1_buffer_with_zeroes(dst_addr, pad_size_bytes); - } - else { - uint32_t w = start_block_2d_index_w; - uint32_t end_block_2d_index_w = start_block_2d_index_w + act_block_w_datums - 1; - // TODO (nshanker): add macro to disable checks - if(end_block_2d_index_w >= act_matrix_width) { - //DPRINT << "Problem" << ENDL(); - } - while (w <= end_block_2d_index_w) { - uint32_t src_address_offset_dram = 0; - uint32_t read_size_bytes = 0; - uint32_t pad = 0; - if (w >= act_matrix_width_unpadded) { - // pad (block shape padding for width dim) - // TODO (nshanker): add macro to disable checks - if(end_block_2d_index_w != act_matrix_width-1) { - //DPRINT << "Problem" << ENDL(); - } - uint32_t pad_size_bytes = (end_block_2d_index_w - w + 1)<<1; - if(dst_address_offset_l1 + (pad_size_bytes-1) >= dst_l1_act_buffer_size_bytes) { - //DPRINT << "Problem" << ENDL(); - } - uint32_t dst_addr = l1_write_addr_act + dst_address_offset_l1; - pad_l1_buffer_with_zeroes(dst_addr, pad_size_bytes); - read_size_bytes = pad_size_bytes; - } - else { - uint32_t channel_stick_offset = w % channel_stick_size; - uint32_t channel_stick_col_id = w / channel_stick_size; - uint32_t channel_stick_row_id = h; - if(channel_stick_offset % 16 != 0) { // DRAM read address must be aligned to 32 bytes - //DPRINT << "Problem" << ENDL(); - } - uint32_t channel_stick_row_id_x = channel_stick_row_id % conv_output_size_w; - uint32_t channel_stick_row_id_y = channel_stick_row_id / conv_output_size_w; - uint32_t act_tensor_start_x = channel_stick_row_id_x * stride_w; - uint32_t act_tensor_start_y = channel_stick_row_id_y * stride_h; - uint32_t act_tensor_padded_x = act_tensor_start_x + (channel_stick_col_id % weight_size_w); - uint32_t act_tensor_padded_y = act_tensor_start_y + (channel_stick_col_id / weight_size_w); - if(w > end_block_2d_index_w) { - //DPRINT << "Problem" << ENDL(); - } - uint32_t a = channel_stick_size - channel_stick_offset; - uint32_t b = (end_block_2d_index_w+1)-w; - uint32_t read_size = a < b ? a : b; - read_size_bytes = read_size << 1; - if(act_tensor_padded_x < pad_w || act_tensor_padded_x >= (pad_w + conv_act_size_w) || act_tensor_padded_y < pad_h || act_tensor_padded_y >= (pad_h + conv_act_size_h)) { - // pad (conv padding) - uint32_t dst_addr = l1_write_addr_act + dst_address_offset_l1; - uint32_t pad_size_bytes = read_size_bytes; - if(dst_address_offset_l1 + (pad_size_bytes-1) >= dst_l1_act_buffer_size_bytes) { - //DPRINT << "Problem" << ENDL(); - } - pad_l1_buffer_with_zeroes(dst_addr, pad_size_bytes); - } - else { - uint32_t act_tensor_x = act_tensor_padded_x - pad_w; - uint32_t act_tensor_y = act_tensor_padded_y - pad_h; - if(act_tensor_x >= conv_act_size_w || act_tensor_y >= conv_act_size_h) { - //DPRINT << "Problem" << ENDL(); - } - uint32_t act_tensor_channel_id = act_tensor_y * conv_act_size_w + act_tensor_x; - src_address_offset_dram = ((act_tensor_channel_id * channel_stick_size) + channel_stick_offset)<<1; - if(src_address_offset_dram % 32 != 0) { // DRAM read address must be aligned to 32 bytes - //DPRINT << "Problem1" << ENDL(); - } - if(src_address_offset_dram >= src_dram_act_buffer_size_bytes) { - //DPRINT << "Problem2" << ENDL(); - } - if(dst_address_offset_l1 + (read_size_bytes-1) >= dst_l1_act_buffer_size_bytes) { - //DPRINT << "Problem3" << ENDL(); - } - uint32_t src_addr = act_addr_dram_base + src_address_offset_dram; - uint32_t dst_addr = l1_write_addr_act + dst_address_offset_l1; - uint64_t act_noc_addr = get_noc_addr(act_tensor_channel_id, s_act, (channel_stick_offset<<1)); - noc_async_read(act_noc_addr, dst_addr, read_size_bytes); - } - } - dst_address_offset_l1 += read_size_bytes; - w += (read_size_bytes>>1); - if(w > end_block_2d_index_w+1) { - //DPRINT << "Problem" << ENDL(); - } - } - } - } - } - - noc_async_read_barrier(); - cb_push_back(cb_id_act, act_block_num_tiles); - } - -} diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_conv_activations_2d_mcast_padded_with_halo_3x3_weights.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_conv_activations_2d_mcast_padded_with_halo_3x3_weights.cpp deleted file mode 100644 index 54839ff5ccb..00000000000 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_conv_activations_2d_mcast_padded_with_halo_3x3_weights.cpp +++ /dev/null @@ -1,238 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include -#include "dataflow_api.h" -// #include "debug/dprint.h" - -FORCE_INLINE -void read_channels(uint32_t& l1_write_addr_act, const uint32_t act_l1_read_addr, const uint32_t reader_channel_idx, - const uint32_t log_base_2_of_conv_act_size_c_bytes, const uint32_t coalesced_read_bytes, const uint32_t stride_h_bytes) { - - constexpr uint32_t unroll_factor = WINDOW_INNER; - uint32_t act_l1_read_addr_plus_offset = act_l1_read_addr + (reader_channel_idx << log_base_2_of_conv_act_size_c_bytes); - #pragma GCC unroll unroll_factor - for (uint32_t inner = 0; inner < WINDOW_INNER; inner++) { - noc_async_read_one_packet_with_state(act_l1_read_addr_plus_offset, l1_write_addr_act); - l1_write_addr_act += coalesced_read_bytes; - // +2 is hard-coded, TODO: generalize - act_l1_read_addr_plus_offset += stride_h_bytes; - } -} - -void kernel_main() { - uint32_t i = 0; - uint32_t conv_act_size_w_ = get_arg_val(i); i+=1; - uint32_t conv_act_size_h = get_arg_val(i); i+=1; - uint32_t weight_size_h = get_arg_val(i); i+=1; - uint32_t weight_size_w = get_arg_val(i); i+=1; - // uint32_t act_block_h_datums = get_arg_val(i); i+=1; - i+=1; // skip an arg - uint32_t act_block_num_tiles = get_arg_val(i); i+=1; - uint32_t act_w_num_outer = get_arg_val(i); i+=1; - - uint32_t first_partial_right_aligned_row_width = get_arg_val(i); i+=1; - uint32_t skip_after_partial_right_aligned_row = get_arg_val(i); i+=1; - uint32_t first_partial_image_num_rows = get_arg_val(i); i+=1; - uint32_t skip_after_first_partial_image_row = get_arg_val(i); i+=1; - uint32_t num_full_images = get_arg_val(i); i+=1; - uint32_t skip_after_full_image = get_arg_val(i); i+=1; - uint32_t last_partial_image_num_rows = get_arg_val(i); i+=1; - uint32_t last_partial_left_aligned_row_width = get_arg_val(i); i+=1; - - // moved these to compile-time args - // uint32_t window_outer = get_arg_val(i); i+=1; - // uint32_t window_inner = get_arg_val(i); i+=1; - i+=2; // skip 2 rt args - - uint32_t noop = get_arg_val(i); i+=1; - if(noop) { - return; - } - - uint32_t act_mcast_dest_noc_start_x = get_arg_val(i); i+=1; - uint32_t act_mcast_dest_noc_start_y = get_arg_val(i); i+=1; - uint32_t act_mcast_dest_noc_end_x = get_arg_val(i); i+=1; - uint32_t act_mcast_dest_noc_end_y = get_arg_val(i); i+=1; - uint32_t act_mcast_num_dests = get_arg_val(i); i+=1; - uint32_t act_mcast_num_cores = get_arg_val(i); i+=1; - uint32_t act_mcast_sender_semaphore_addr = get_semaphore(get_arg_val(i)); i+=1; - uint32_t act_mcast_receiver_semaphore_addr = get_semaphore(get_arg_val(i)); i+=1; - - uint32_t act_mcast_sender_size_bytes = get_arg_val(i); i+=1; - uint32_t act_mcast_sender_id = get_arg_val(i); i+=1; - uint32_t act_mcast_sender_noc_x = get_arg_val(i); i+=1; - tt_l1_ptr uint32_t *act_mcast_sender_noc_y = (tt_l1_ptr uint32_t*)(get_arg_addr(i)); - - constexpr bool act_in_dram = get_compile_time_arg_val(0) == 1; - constexpr uint32_t stride_h = get_compile_time_arg_val(1); - constexpr uint32_t stride_w = get_compile_time_arg_val(2); - constexpr uint32_t conv_act_size_w = get_compile_time_arg_val(3); - constexpr uint32_t conv_output_w_last_index = get_compile_time_arg_val(4) - 1; - constexpr uint32_t conv_act_c_read_bytes = get_compile_time_arg_val(5); - constexpr uint32_t log_base_2_of_conv_act_size_c_bytes = get_compile_time_arg_val(6); - // TODO delete unused: get_compile_time_arg_val(7); (8), (9) - // need to have these as compile-time since we unroll loops based on them - constexpr uint32_t window_outer = get_compile_time_arg_val(10); - constexpr uint32_t window_inner = get_compile_time_arg_val(11); - constexpr uint32_t act_block_h_datums = get_compile_time_arg_val(12); - - constexpr uint32_t cb_id_act = tt::CB::c_in0; - constexpr uint32_t tilized_in0_cb_id = tt::CB::c_intermed1; - constexpr uint32_t cb_id_sharded_act = tt::CB::c_in3; - constexpr uint32_t cb_id_act_row_major_bfloat16 = tt::CB::c_in6; - - // Assumptions. Must be true. Validate on host. - // assert(act_block_w_datums == C * weight_size_w) - // assert(num_blocks_act_w == weight_size_h) - // assert(act_block_w_datums % C == 0) - // assert(act_block_w_datums % 32 == 0) - // assert( % 32 == 0) - // assert(act_block_h_ntiles == act_block_h_datums/32) - // assert(act_block_w_ntiles == act_block_w_datums/32) - // assert(act_block_num_tiles == (act_block_h_datums * act_block_w_datums)/1024) - - // LOOP TO FILL READER INDICES - constexpr uint32_t cb_reader_indices = tt::CB::c_in4; - volatile tt_l1_ptr uint16_t* reader_indices_ptr = reinterpret_cast(get_write_ptr(cb_reader_indices)); - - uint32_t weights_top_left_corner_idx = 0; - uint32_t reader_idx = 0; - - // First partial right-aligned row - for (uint32_t k = 0; k < first_partial_right_aligned_row_width; k++) { - reader_indices_ptr[reader_idx++] = weights_top_left_corner_idx++; - } - weights_top_left_corner_idx += skip_after_partial_right_aligned_row; // Skip padded width - - // First partial image - for (uint32_t j = 0; j < first_partial_image_num_rows; j++) { - for (uint32_t k = 0; k < conv_act_size_w_; k++) { - reader_indices_ptr[reader_idx++] = weights_top_left_corner_idx++; - } - weights_top_left_corner_idx += weight_size_w - 1; - } - weights_top_left_corner_idx += skip_after_first_partial_image_row; // Skip padded rows - - // Full images - for (uint32_t i = 0; i < num_full_images; i++) { - for (uint32_t j = 0; j < conv_act_size_h; j++) { - for (uint32_t k = 0; k < conv_act_size_w; k++) { - reader_indices_ptr[reader_idx++] = weights_top_left_corner_idx++; - } - weights_top_left_corner_idx += weight_size_w - 1; - } - weights_top_left_corner_idx += skip_after_full_image; // Skip padded rows - } - - // Last partial image - for (uint32_t j = 0; j < last_partial_image_num_rows; j++) { - for (uint32_t k = 0; k < conv_act_size_w; k++) { - reader_indices_ptr[reader_idx++] = weights_top_left_corner_idx++; - } - weights_top_left_corner_idx += weight_size_w - 1; - } - - // Last partial left-alighted row - for (uint32_t k = 0; k < last_partial_left_aligned_row_width; k++) { - reader_indices_ptr[reader_idx++] = weights_top_left_corner_idx++; - } - - // Set ur local VALID value, to be mcasted to destinations flag address after the data has been mcasted - volatile tt_l1_ptr uint32_t* act_mcast_receiver_semaphore_addr_ptr = reinterpret_cast(act_mcast_receiver_semaphore_addr); - noc_semaphore_set(act_mcast_receiver_semaphore_addr_ptr, VALID); - // local address that will be atomically incremented by mcast receivers, to know when all receivers are ready - // to receive the mcast - volatile tt_l1_ptr uint32_t* act_mcast_sender_semaphore_addr_ptr = reinterpret_cast(act_mcast_sender_semaphore_addr); - - uint64_t act_multicast_noc_addr = get_noc_multicast_addr( - act_mcast_dest_noc_start_x, - act_mcast_dest_noc_start_y, - act_mcast_dest_noc_end_x, - act_mcast_dest_noc_end_y, - 0 - ); - - uint64_t act_mcast_receiver_semaphore_noc_addr = act_multicast_noc_addr | act_mcast_receiver_semaphore_addr; - constexpr uint32_t num_issued_reads_per_block = act_block_h_datums * window_inner; - - // TODO: need to make the read coalescing optimization cleaner - // currently works for the case of num_coalesced_reads == weight_size_w since these reads are contiguous on both src/dst side - constexpr uint32_t num_coalesced_reads = 3; - constexpr uint32_t coalesced_read_bytes = num_coalesced_reads * conv_act_c_read_bytes; - - volatile tt_l1_ptr uint32_t* packed_reader_indices_ptr = reinterpret_cast(get_write_ptr(cb_reader_indices)); - - - // Fully create act matrix and tilize it before mcast - // set_state uses just x/y from the get_noc_addr, addr is ignored - uint32_t act_l1_read_addr = get_read_ptr(cb_id_sharded_act); - noc_async_read_one_packet_set_state(get_noc_addr(act_l1_read_addr), coalesced_read_bytes); - - // Reset reader_idx to finish act_block_h_datums - reader_idx = 0; - cb_reserve_back(cb_id_act_row_major_bfloat16, act_block_num_tiles); - uint32_t l1_write_addr_act = get_write_ptr(cb_id_act_row_major_bfloat16); - - constexpr uint32_t stride_h_bytes = (conv_act_size_w+2) << log_base_2_of_conv_act_size_c_bytes; - static_assert(act_block_h_datums % 2 == 0); // need to be even to read 2 in the body, due to packing of 2 indices in 1 uint32_t word - // #pragma GCC unroll 4 // didn't seem to help (neutral), manual unroll 2x perf drop - for (uint32_t bh = 0; bh < act_block_h_datums/2; bh++) { - uint32_t two_reader_indices = packed_reader_indices_ptr[reader_idx]; - read_channels(l1_write_addr_act, act_l1_read_addr, two_reader_indices & 0xffff, log_base_2_of_conv_act_size_c_bytes, coalesced_read_bytes, stride_h_bytes); - read_channels(l1_write_addr_act, act_l1_read_addr, two_reader_indices >> 16 , log_base_2_of_conv_act_size_c_bytes, coalesced_read_bytes, stride_h_bytes); - - reader_idx++; - } - // incrementing num issued in one shot is actually slower - // noc_async_read_inc_num_issued(num_issued_reads_per_block); // "false" on read - noc_async_read_barrier(); - cb_push_back(cb_id_act_row_major_bfloat16, act_block_num_tiles); - - // compute tilizes and pops cb_id_act and pushes to tilized_in0_cb_id - cb_wait_front(tilized_in0_cb_id, act_block_num_tiles); - - - // Round robin self-mcast and receive tilized act matrix in cb_id_act - // Compute should function like regular mm - for (uint32_t act_w_outer_i = 0; act_w_outer_i < act_w_num_outer; act_w_outer_i++) { - if (act_w_outer_i == act_mcast_sender_id) { - // MCAST SENDER: send entire tilized input to other cores in column - cb_reserve_back(cb_id_act, act_block_num_tiles); - - // wait until all act mcast destinations have atomically incremented the act semaphore_addr (i.e. its value should be act_mcast_num_dests), then reset - // the semaphore_addr value back to zero for the next block - noc_semaphore_wait(act_mcast_sender_semaphore_addr_ptr, act_mcast_num_dests); - noc_semaphore_set(act_mcast_sender_semaphore_addr_ptr, 0); - - // Now we have the block in the CB address, we can mcast to dests! - uint32_t tilized_act_start_address = get_read_ptr(tilized_in0_cb_id); - uint64_t act_multicast_data_addr = act_multicast_noc_addr | get_write_ptr(cb_id_act); - // num_dests will source, since we are copying to a different local CB as well - noc_async_write_multicast_loopback_src(tilized_act_start_address, act_multicast_data_addr, act_mcast_sender_size_bytes, act_mcast_num_cores + 1, true, true); - - // Note: no need for write barrier, since these two multicasts are done on the same noc id, same vc, same cmd_buf - // Also, this only works because we are setting VCs statically (using NOC_CMD_STATIC_VC). - - // We should also multicast VALID flag to destinations for receiver semaphore - noc_semaphore_set_multicast(act_mcast_receiver_semaphore_addr, act_mcast_receiver_semaphore_noc_addr, act_mcast_num_cores); - - noc_async_write_barrier(); - } else { - // MCAST RECEIVER: receive entire tilized input from sender core - cb_reserve_back(cb_id_act, act_block_num_tiles); - - // Set act semaphore value to INVALID - noc_semaphore_set(act_mcast_receiver_semaphore_addr_ptr, INVALID); - - // Atomic increment source core counter - uint64_t act_mcast_sender_semaphore_noc_addr = get_noc_addr(act_mcast_sender_noc_x, act_mcast_sender_noc_y[act_w_outer_i], act_mcast_sender_semaphore_addr); - noc_semaphore_inc(act_mcast_sender_semaphore_noc_addr, 1); - - // wait on act semaphore value to become VALID (set by mcast sender after it multicasts data) - noc_semaphore_wait(act_mcast_receiver_semaphore_addr_ptr, VALID); - } - cb_push_back(cb_id_act, act_block_num_tiles); - } -} diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_conv_activations_act_block_w_equals_channels_X_filter_width.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_conv_activations_act_block_w_equals_channels_X_filter_width.cpp deleted file mode 100644 index 431fca5c2af..00000000000 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_conv_activations_act_block_w_equals_channels_X_filter_width.cpp +++ /dev/null @@ -1,191 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include -#include "dataflow_api.h" -// #include "debug/dprint.h" - -inline void pad_l1_buffer_with_zeroes(uint32_t l1_addr, uint32_t pad_size_bytes) { - volatile std::uint32_t* dst = reinterpret_cast(l1_addr); - volatile std::uint32_t* end_dst = dst + (pad_size_bytes >> 2); // Divide by 4 using right shift - - while (dst < end_dst) { - *dst++ = 0; - } - - uint32_t remainder = pad_size_bytes & 0x3; // Get the remainder using bitwise AND - if (remainder != 0) { - volatile std::uint8_t* byte_dst = reinterpret_cast(dst); - for (uint32_t i = 0; i < remainder; ++i) { - *byte_dst++ = 0; - } - } -} - -void kernel_main() { - uint32_t i = 0; - uint32_t act_addr_dram_base = get_arg_val(i); i+=1; - uint32_t act_dram_noc_x = get_arg_val(i); i+=1; - uint32_t act_dram_noc_y = get_arg_val(i); i+=1; - - uint32_t conv_act_size_w_ = get_arg_val(i); i+=1; - uint32_t conv_act_size_h = get_arg_val(i); i+=1; - uint32_t conv_act_size_c_ = get_arg_val(i); i+=1; - uint32_t weight_size_h = get_arg_val(i); i+=1; - uint32_t weight_size_w = get_arg_val(i); i+=1; - uint32_t stride_h_ = get_arg_val(i); i+=1; - uint32_t stride_w_ = get_arg_val(i); i+=1; - uint32_t pad_h = get_arg_val(i); i+=1; - uint32_t pad_w = get_arg_val(i); i+=1; - uint32_t conv_output_size_h = get_arg_val(i); i+=1; - uint32_t conv_output_size_w = get_arg_val(i); i+=1; - uint32_t num_blocks_act_h = get_arg_val(i); i+=1; - uint32_t num_blocks_act_w = get_arg_val(i); i+=1; - uint32_t num_blocks_weight_w = get_arg_val(i); i+=1; - uint32_t num_groups = get_arg_val(i); i+=1; - - uint32_t act_matrix_height_unpadded = get_arg_val(i); i+=1; - uint32_t act_matrix_width_unpadded = get_arg_val(i); i+=1; - uint32_t act_matrix_height = get_arg_val(i); i+=1; - uint32_t act_matrix_width = get_arg_val(i); i+=1; - uint32_t act_matrix_height_ntiles = get_arg_val(i); i+=1; - uint32_t act_matrix_width_ntiles = get_arg_val(i); i+=1; - uint32_t act_block_h_datums = get_arg_val(i); i+=1; - uint32_t act_block_w_datums = get_arg_val(i); i+=1; - uint32_t act_block_h_ntiles = get_arg_val(i); i+=1; - uint32_t act_block_w_ntiles = get_arg_val(i); i+=1; - uint32_t act_block_num_tiles = get_arg_val(i); i+=1; - uint32_t act_w_num_outer = get_arg_val(i); i+=1; - uint32_t src_dram_act_buffer_size_bytes = get_arg_val(i); i+=1; - uint32_t dst_l1_act_buffer_size_bytes = get_arg_val(i); i+=1; - uint32_t n_start = get_arg_val(i); i+=1; - uint32_t out_h_start = get_arg_val(i); i+=1; - uint32_t out_w_start = get_arg_val(i); i+=1; - uint32_t total_h_start = get_arg_val(i); i+=1; - - uint32_t noop = get_arg_val(i); i+=1; - if(noop) { - return; - } - - constexpr bool act_in_dram = get_compile_time_arg_val(0) == 1; - constexpr uint32_t stride_h = get_compile_time_arg_val(1); - constexpr uint32_t stride_w = get_compile_time_arg_val(2); - constexpr uint32_t conv_act_size_w = get_compile_time_arg_val(3); - constexpr uint32_t conv_output_w_last_index = get_compile_time_arg_val(4) - 1; - constexpr uint32_t conv_act_c_read_bytes = get_compile_time_arg_val(5); - constexpr uint32_t log_base_2_of_conv_act_size_c_bytes = get_compile_time_arg_val(6); - - constexpr uint32_t cb_id_act = 0; - constexpr uint32_t tile_size_pow2_exponent = 11; - const DataFormat data_format = get_dataformat(cb_id_act); - const InterleavedPow2AddrGenFast s_act = { - .bank_base_address = act_addr_dram_base, - .log_base_2_of_page_size = log_base_2_of_conv_act_size_c_bytes - }; - - - - // Assumptions. Must be true. Validate on host. - // assert(act_block_w_datums == C * weight_size_w) - // assert(num_blocks_act_w == weight_size_h) - // assert(act_block_w_datums % C == 0) - // assert(act_block_w_datums % 32 == 0) - // assert(act_block_h_datums % 32 == 0) - // assert(act_block_h_ntiles == act_block_h_datums/32) - // assert(act_block_w_ntiles == act_block_w_datums/32) - // assert(act_block_num_tiles == (act_block_h_datums * act_block_w_datums)/1024) - - // DPRINT << "Running new conv reader" << ENDL(); - // DPRINT << "act matrix h unpadded " << act_matrix_height_unpadded << ENDL(); - // DPRINT << "num_blocks_act_h " << num_blocks_act_h << ENDL(); - // DPRINT << "act_block_h_datums " << act_block_h_datums << ENDL(); - // DPRINT << "num_blocks_weight_w " << num_blocks_weight_w << ENDL(); - // DPRINT << "num_blocks_act_w " << num_blocks_act_w << ENDL(); - // Outer loop is number of blocks in weight width dim - // Conv output blocks are computed in col major order - for(uint32_t nbr = 0; nbr < num_blocks_weight_w; nbr++) { - uint32_t out_h = out_h_start; - uint32_t out_w = out_w_start; - uint32_t out_h_reset = out_h_start; - uint32_t out_w_reset = out_w_start; - uint32_t total_h = total_h_start; - uint32_t total_h_reset = total_h_start; - uint32_t n = n_start; - uint32_t n_reset = n_start; - for(uint32_t nbh = 0; nbh < num_blocks_act_h; nbh++) { - uint32_t in_h_offset_within_kernel_window = 0; - for (uint32_t nbw = 0; nbw < num_blocks_act_w; nbw++) { - out_h = out_h_reset; - out_w = out_w_reset; - total_h = total_h_reset; - n = n_reset; - cb_reserve_back(cb_id_act, act_block_num_tiles); - uint32_t l1_write_addr_act = get_write_ptr(cb_id_act); - uint32_t l1_addr_offset = 0; - for(uint32_t bh = 0; bh < act_block_h_datums; bh++) { - uint32_t in_h_offset = out_h * stride_h; - uint32_t in_w_offset = out_w * stride_w; // expect stride 1 or 2.. make this compile time args - also conv input width - uint32_t in_w_offset_within_kernel_window = 0; - for(uint32_t bw = 0; bw < weight_size_w; bw++) { - uint32_t read_size_bytes = conv_act_c_read_bytes; - - if (total_h < act_matrix_height_unpadded) { - uint32_t in_h = in_h_offset + in_h_offset_within_kernel_window; - uint32_t in_w = in_w_offset + in_w_offset_within_kernel_window; - - if(in_h < pad_h || in_w < pad_w || in_h >= (conv_act_size_h + pad_h) || in_w >= (conv_act_size_w_ + pad_w)) { - // pad 0s in l1 - uint32_t dst_addr = l1_write_addr_act + l1_addr_offset; - uint32_t pad_size_bytes = read_size_bytes; - pad_l1_buffer_with_zeroes(dst_addr, pad_size_bytes); - } else { - // read one channel from dram multi bank - row_id = channel_id - uint32_t in_h_raw = in_h - pad_h; - uint32_t in_w_raw = in_w - pad_w; - uint32_t channel_id = (n * conv_act_size_h * conv_act_size_w) + (in_h_raw * conv_act_size_w) + in_w_raw; - //DPRINT << "n=" << n << " h=" << in_h_raw << " w=" << in_w_raw << " conv_act_size_h=" << conv_act_size_h << " conv_act_size_w=" << conv_act_size_w << ENDL(); - uint32_t dst_addr = l1_write_addr_act + l1_addr_offset; - s_act.noc_async_read_page(channel_id, dst_addr); - } - } //else { DPRINT << "total_h here =" << total_h << ENDL(); } //do nothing. let garbage rows be in l1 - l1_addr_offset += read_size_bytes; - in_w_offset_within_kernel_window += 1; - } // for block width - // pad 0s for block padding on the right side of block.. only first conv since C%32 != 0.. ifdef with compile time define - #ifdef ACT_BLOCK_WIDTH_PADDING_BYTES - // pad 0s in l1 - uint32_t dst_addr = l1_write_addr_act + l1_addr_offset; - pad_l1_buffer_with_zeroes(dst_addr, (uint32_t) ACT_BLOCK_WIDTH_PADDING_BYTES); - l1_addr_offset += (uint32_t) ACT_BLOCK_WIDTH_PADDING_BYTES; - #endif - if(out_w < conv_output_size_w - 1) { - out_w += 1; - } else { - out_w = 0; - //DPRINT << "total_h=" << total_h << ENDL(); - if (out_h < conv_output_size_h - 1) { - out_h += 1; - } else if (total_h < act_matrix_height_unpadded){ - // next image in batch - out_h = 0; - n += 1; - //DPRINT << "next image, n=" << n << ENDL(); - } - } - total_h += 1; - } // for block height - in_h_offset_within_kernel_window += 1; - //DPRINT << "waiting on read barrier" << ENDL(); - noc_async_read_barrier(); - //DPRINT << "done on read barrier" << ENDL(); - cb_push_back(cb_id_act, act_block_num_tiles); - } // for num of act blocks in inner width dim - out_h_reset = out_h; - out_w_reset = out_w; - total_h_reset = total_h; - n_reset = n; - } // for num of act blocks in height dim - } // for num of weight blocks in width dim -} diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_conv_activations_fast.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_conv_activations_fast.cpp deleted file mode 100644 index 2e6187dd584..00000000000 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_conv_activations_fast.cpp +++ /dev/null @@ -1,175 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include -#include "dataflow_api.h" -// #include "debug/dprint.h" - -inline void pad_l1_buffer_with_zeroes(uint32_t l1_addr, uint32_t pad_size_bytes) { - volatile std::uint32_t* dst = reinterpret_cast(l1_addr); - volatile std::uint32_t* end_dst = dst + (pad_size_bytes >> 2); // Divide by 4 using right shift - - while (dst < end_dst) { - *dst++ = 0; - } - - uint32_t remainder = pad_size_bytes & 0x3; // Get the remainder using bitwise AND - if (remainder != 0) { - volatile std::uint8_t* byte_dst = reinterpret_cast(dst); - for (uint32_t i = 0; i < remainder; ++i) { - *byte_dst++ = 0; - } - } -} - -void kernel_main() { - uint32_t i = 0; - uint32_t act_addr_dram_base = get_arg_val(i); i+=1; - uint32_t act_dram_noc_x = get_arg_val(i); i+=1; - uint32_t act_dram_noc_y = get_arg_val(i); i+=1; - - uint32_t conv_act_size_w_ = get_arg_val(i); i+=1; - uint32_t conv_act_size_h = get_arg_val(i); i+=1; - uint32_t conv_act_size_c_ = get_arg_val(i); i+=1; - uint32_t weight_size_h = get_arg_val(i); i+=1; - uint32_t weight_size_w = get_arg_val(i); i+=1; - uint32_t stride_h_ = get_arg_val(i); i+=1; - uint32_t stride_w_ = get_arg_val(i); i+=1; - uint32_t pad_h = get_arg_val(i); i+=1; - uint32_t pad_w = get_arg_val(i); i+=1; - uint32_t conv_output_size_h = get_arg_val(i); i+=1; - uint32_t conv_output_size_w = get_arg_val(i); i+=1; - uint32_t num_blocks_act_h = get_arg_val(i); i+=1; - uint32_t num_blocks_act_w = get_arg_val(i); i+=1; - uint32_t num_blocks_weight_w = get_arg_val(i); i+=1; - uint32_t num_groups = get_arg_val(i); i+=1; - - uint32_t act_matrix_height_unpadded = get_arg_val(i); i+=1; - uint32_t act_matrix_width_unpadded = get_arg_val(i); i+=1; - uint32_t act_matrix_height = get_arg_val(i); i+=1; - uint32_t act_matrix_width = get_arg_val(i); i+=1; - uint32_t act_matrix_height_ntiles = get_arg_val(i); i+=1; - uint32_t act_matrix_width_ntiles = get_arg_val(i); i+=1; - uint32_t act_block_h_datums = get_arg_val(i); i+=1; - uint32_t act_block_w_datums = get_arg_val(i); i+=1; - uint32_t act_block_h_ntiles = get_arg_val(i); i+=1; - uint32_t act_block_w_ntiles = get_arg_val(i); i+=1; - uint32_t act_block_num_tiles = get_arg_val(i); i+=1; - uint32_t src_dram_act_buffer_size_bytes = get_arg_val(i); i+=1; - uint32_t dst_l1_act_buffer_size_bytes = get_arg_val(i); i+=1; - - constexpr bool act_in_dram = get_compile_time_arg_val(0) == 1; - constexpr uint32_t stride_h = get_compile_time_arg_val(1); - constexpr uint32_t stride_w = get_compile_time_arg_val(2); - constexpr uint32_t conv_act_size_w = get_compile_time_arg_val(3); - constexpr uint32_t conv_output_w_last_index = get_compile_time_arg_val(4) - 1; - constexpr uint32_t conv_act_size_c_bytes = get_compile_time_arg_val(5); - constexpr uint32_t log_base_2_of_conv_act_size_c_bytes = get_compile_time_arg_val(6); - - constexpr uint32_t cb_id_act = 0; - constexpr uint32_t tile_size_pow2_exponent = 11; - const DataFormat data_format = get_dataformat(cb_id_act); - const InterleavedPow2AddrGenFast s_act = { - .bank_base_address = act_addr_dram_base, - .log_base_2_of_page_size = log_base_2_of_conv_act_size_c_bytes - }; - - // Assumptions. Must be true. Validate on host. - // assert(act_block_w_datums == C * weight_size_w) - // assert(num_blocks_act_w == weight_size_h) - // assert(act_block_w_datums % C == 0) - // assert(act_block_w_datums % 32 == 0) - // assert(act_block_h_datums % 32 == 0) - // assert(act_block_h_ntiles == act_block_h_datums/32) - // assert(act_block_w_ntiles == act_block_w_datums/32) - // assert(act_block_num_tiles == (act_block_h_datums * act_block_w_datums)/1024) - - uint32_t out_h = 0; - uint32_t out_w = 0; - uint32_t out_h_start = 0; - uint32_t out_w_start = 0; - uint32_t total_h = 0; - uint32_t total_h_start = 0; - uint32_t n = 0; - uint32_t n_start = 0; - // DPRINT << "Running new conv reader" << ENDL(); - // DPRINT << "act matrix h unpadded " << act_matrix_height_unpadded << ENDL(); - // DPRINT << "num_blocks_act_h " << num_blocks_act_h << ENDL(); - // DPRINT << "act_block_h_datums " << act_block_h_datums << ENDL(); - // DPRINT << "num_blocks_weight_w " << num_blocks_weight_w << ENDL(); - // DPRINT << "num_blocks_act_w " << num_blocks_act_w << ENDL(); - for(uint32_t nbh = 0; nbh < num_blocks_act_h; nbh++) { - for(uint32_t nbr = 0; nbr < num_blocks_weight_w; nbr++) { - uint32_t in_h_offset_within_kernel_window = 0; - for (uint32_t nbw = 0; nbw < num_blocks_act_w; nbw++) { - out_h = out_h_start; - out_w = out_w_start; - total_h = total_h_start; - n = n_start; - cb_reserve_back(cb_id_act, act_block_num_tiles); - uint32_t l1_write_addr_act = get_write_ptr(cb_id_act); - uint32_t l1_addr_offset = 0; - for(uint32_t bh = 0; bh < act_block_h_datums; bh++) { - uint32_t in_h_offset = out_h * stride_h; - uint32_t in_w_offset = out_w * stride_w; // expect stride 1 or 2.. make this compile time args - also conv input width - uint32_t in_w_offset_within_kernel_window = 0; - for(uint32_t bw = 0; bw < weight_size_w; bw++) { - uint32_t read_size_bytes = conv_act_size_c_bytes; - - if (total_h < act_matrix_height_unpadded) { - uint32_t in_h = in_h_offset + in_h_offset_within_kernel_window; - uint32_t in_w = in_w_offset + in_w_offset_within_kernel_window; - - if(in_h < pad_h || in_w < pad_w || in_h >= (conv_act_size_h + pad_h) || in_w >= (conv_act_size_w_ + pad_w)) { - // pad 0s in l1 - uint32_t dst_addr = l1_write_addr_act + l1_addr_offset; - uint32_t pad_size_bytes = read_size_bytes; - pad_l1_buffer_with_zeroes(dst_addr, pad_size_bytes); - } else { - // read one channel from dram multi bank - row_id = channel_id - uint32_t in_h_raw = in_h - pad_h; - uint32_t in_w_raw = in_w - pad_w; - uint32_t channel_id = (n * conv_act_size_h * conv_act_size_w) + (in_h_raw * conv_act_size_w) + in_w_raw; - //DPRINT << "n=" << n << " h=" << in_h_raw << " w=" << in_w_raw << " conv_act_size_h=" << conv_act_size_h << " conv_act_size_w=" << conv_act_size_w << ENDL(); - uint32_t dst_addr = l1_write_addr_act + l1_addr_offset; - s_act.noc_async_read_page(channel_id, dst_addr); - } - } //else { DPRINT << "total_h here =" << total_h << ENDL(); } //do nothing. let garbage rows be in l1 - l1_addr_offset += read_size_bytes; - in_w_offset_within_kernel_window += 1; - } // for block width - // pad 0s for block padding on the right side of block.. only first conv since C%32 != 0.. ifdef with compile time define - #ifdef ACT_BLOCK_WIDTH_PADDING_BYTES - // pad 0s in l1 - uint32_t dst_addr = l1_write_addr_act + l1_addr_offset; - pad_l1_buffer_with_zeroes(dst_addr, (uint32_t) ACT_BLOCK_WIDTH_PADDING_BYTES); - l1_addr_offset += (uint32_t) ACT_BLOCK_WIDTH_PADDING_BYTES; - #endif - if(out_w < conv_output_size_w - 1) { - out_w += 1; - } else { - out_w = 0; - //DPRINT << "total_h=" << total_h << ENDL(); - if (out_h < conv_output_size_h - 1) { - out_h += 1; - } else if (total_h < act_matrix_height_unpadded){ - // next image in batch - out_h = 0; - n += 1; - //DPRINT << "next image, n=" << n << ENDL(); - } - } - total_h += 1; - } // for block height - in_h_offset_within_kernel_window += 1; - noc_async_read_barrier(); - cb_push_back(cb_id_act, act_block_num_tiles); - } // for num of act blocks in inner width dim - } // for num of weight blocks in width dim - out_h_start = out_h; - out_w_start = out_w; - total_h_start = total_h; - n_start = n; - } // for num of act blocks in height dim -} diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_conv_activations_fast_for_col_major_conv_out_blocks.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_conv_activations_fast_for_col_major_conv_out_blocks.cpp deleted file mode 100644 index 6374673b9c0..00000000000 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_conv_activations_fast_for_col_major_conv_out_blocks.cpp +++ /dev/null @@ -1,198 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include -#include "dataflow_api.h" -// #include "debug/dprint.h" - -inline void pad_l1_buffer_with_zeroes(uint32_t l1_addr, uint32_t pad_size_bytes) { - volatile std::uint32_t* dst = reinterpret_cast(l1_addr); - volatile std::uint32_t* end_dst = dst + (pad_size_bytes >> 2); // Divide by 4 using right shift - - while (dst < end_dst) { - *dst++ = 0; - } - - uint32_t remainder = pad_size_bytes & 0x3; // Get the remainder using bitwise AND - if (remainder != 0) { - volatile std::uint8_t* byte_dst = reinterpret_cast(dst); - for (uint32_t i = 0; i < remainder; ++i) { - *byte_dst++ = 0; - } - } -} - -void kernel_main() { - uint32_t i = 0; - uint32_t act_addr_dram_base = get_arg_val(i); i+=1; - uint32_t act_dram_noc_x = get_arg_val(i); i+=1; - uint32_t act_dram_noc_y = get_arg_val(i); i+=1; - - uint32_t conv_act_size_w_ = get_arg_val(i); i+=1; - uint32_t conv_act_size_h = get_arg_val(i); i+=1; - uint32_t conv_act_size_c_ = get_arg_val(i); i+=1; - uint32_t weight_size_h = get_arg_val(i); i+=1; - uint32_t weight_size_w = get_arg_val(i); i+=1; - uint32_t stride_h_ = get_arg_val(i); i+=1; - uint32_t stride_w_ = get_arg_val(i); i+=1; - uint32_t pad_h = get_arg_val(i); i+=1; - uint32_t pad_w = get_arg_val(i); i+=1; - uint32_t conv_output_size_h = get_arg_val(i); i+=1; - uint32_t conv_output_size_w = get_arg_val(i); i+=1; - uint32_t num_blocks_act_h = get_arg_val(i); i+=1; - uint32_t num_blocks_act_w = get_arg_val(i); i+=1; - uint32_t num_blocks_weight_w = get_arg_val(i); i+=1; - uint32_t num_groups = get_arg_val(i); i+=1; - - uint32_t act_matrix_height_unpadded = get_arg_val(i); i+=1; - uint32_t act_matrix_width_unpadded = get_arg_val(i); i+=1; - uint32_t act_matrix_height = get_arg_val(i); i+=1; - uint32_t act_matrix_width = get_arg_val(i); i+=1; - uint32_t act_matrix_height_ntiles = get_arg_val(i); i+=1; - uint32_t act_matrix_width_ntiles = get_arg_val(i); i+=1; - uint32_t act_block_h_datums = get_arg_val(i); i+=1; - uint32_t act_block_w_datums = get_arg_val(i); i+=1; - uint32_t act_block_h_ntiles = get_arg_val(i); i+=1; - uint32_t act_block_w_ntiles = get_arg_val(i); i+=1; - uint32_t act_block_num_tiles = get_arg_val(i); i+=1; - uint32_t act_w_num_outer = get_arg_val(i); i+=1; - uint32_t src_dram_act_buffer_size_bytes = get_arg_val(i); i+=1; - uint32_t dst_l1_act_buffer_size_bytes = get_arg_val(i); i+=1; - uint32_t n_start = get_arg_val(i); i+=1; - uint32_t out_h_start = get_arg_val(i); i+=1; - uint32_t out_w_start = get_arg_val(i); i+=1; - uint32_t total_h_start = get_arg_val(i); i+=1; - - uint32_t noop = get_arg_val(i); i+=1; - if(noop) { - return; - } - - constexpr bool act_in_dram = get_compile_time_arg_val(0) == 1; - constexpr uint32_t stride_h = get_compile_time_arg_val(1); - constexpr uint32_t stride_w = get_compile_time_arg_val(2); - constexpr uint32_t conv_act_size_w = get_compile_time_arg_val(3); - constexpr uint32_t conv_output_w_last_index = get_compile_time_arg_val(4) - 1; - constexpr uint32_t conv_act_c_read_bytes = get_compile_time_arg_val(5); - constexpr uint32_t log_base_2_of_conv_act_size_c_bytes = get_compile_time_arg_val(6); - - constexpr uint32_t cb_id_act = 0; - constexpr uint32_t tile_size_pow2_exponent = 11; - const DataFormat data_format = get_dataformat(cb_id_act); - const InterleavedPow2AddrGenFast s_act = { - .bank_base_address = act_addr_dram_base, - .log_base_2_of_page_size = log_base_2_of_conv_act_size_c_bytes - }; - - // Assumptions. Must be true. Validate on host. - // assert(act_block_w_datums == C * weight_size_w) - // assert(num_blocks_act_w == weight_size_h) - // assert(act_block_w_datums % C == 0) - // assert(act_block_w_datums % 32 == 0) - // assert(act_block_h_datums % 32 == 0) - // assert(act_block_h_ntiles == act_block_h_datums/32) - // assert(act_block_w_ntiles == act_block_w_datums/32) - // assert(act_block_num_tiles == (act_block_h_datums * act_block_w_datums)/1024) - - // DPRINT << "Running new conv reader" << ENDL(); - // DPRINT << "act matrix h unpadded " << act_matrix_height_unpadded << ENDL(); - // DPRINT << "num_blocks_act_h " << num_blocks_act_h << ENDL(); - // DPRINT << "act_block_h_datums " << act_block_h_datums << ENDL(); - // DPRINT << "num_blocks_weight_w " << num_blocks_weight_w << ENDL(); - // DPRINT << "num_blocks_act_w " << num_blocks_act_w << ENDL(); - // Outer loop is number of blocks in weight width dim - // Conv output blocks are computed in col major order - - uint32_t read_size_bytes = conv_act_c_read_bytes; - for(uint32_t nbr = 0; nbr < num_blocks_weight_w; nbr++) { - uint32_t out_h = out_h_start; - uint32_t out_w = out_w_start; - uint32_t out_h_reset = out_h_start; - uint32_t out_w_reset = out_w_start; - uint32_t total_h = total_h_start; - uint32_t total_h_reset = total_h_start; - uint32_t n = n_start; - uint32_t n_reset = n_start; - - uint32_t in_h_offset_within_kernel_window_start = 0; - for(uint32_t nbh = 0; nbh < num_blocks_act_h; nbh++) { - - uint32_t in_h_offset_within_kernel_window = in_h_offset_within_kernel_window_start; - uint32_t act_w_offset_bytes = 0; - #ifdef ACT_W_OUTER_BLOCKS // Adding an additional loop here when not needed seems to add about 10k ns - for(uint32_t act_w_outer_i = 0; act_w_outer_i < act_w_num_outer; act_w_outer_i++) { - #endif - for (uint32_t channel_stick_h = 0; channel_stick_h < weight_size_h; channel_stick_h++) { - uint32_t in_w_offset_within_kernel_window = 0; - for (uint32_t channel_stick_w = 0; channel_stick_w < weight_size_w; channel_stick_w++) { - out_h = out_h_reset; - out_w = out_w_reset; - total_h = total_h_reset; - n = n_reset; - cb_reserve_back(cb_id_act, act_block_num_tiles); - uint32_t l1_write_addr_act = get_write_ptr(cb_id_act); - uint32_t l1_addr_offset = 0; - for(uint32_t bh = 0; bh < act_block_h_datums; bh++) { - uint32_t in_h_offset = out_h * stride_h; - uint32_t in_w_offset = out_w * stride_w; // expect stride 1 or 2.. make this compile time args - also conv input width - - if (total_h < act_matrix_height_unpadded) { - uint32_t in_h = in_h_offset + in_h_offset_within_kernel_window; - uint32_t in_w = in_w_offset + in_w_offset_within_kernel_window; - - if(in_h < pad_h || in_w < pad_w || in_h >= (conv_act_size_h + pad_h) || in_w >= (conv_act_size_w_ + pad_w)) { - // pad 0s in l1 - uint32_t dst_addr = l1_write_addr_act + l1_addr_offset; - uint32_t pad_size_bytes = read_size_bytes; - pad_l1_buffer_with_zeroes(dst_addr, pad_size_bytes); - } else { - // read one channel from dram multi bank - row_id = channel_id - uint32_t in_h_raw = in_h - pad_h; - uint32_t in_w_raw = in_w - pad_w; - uint32_t channel_id = (n * conv_act_size_h * conv_act_size_w) + (in_h_raw * conv_act_size_w) + in_w_raw; - - //DPRINT << "n=" << n << " h=" << in_h_raw << " w=" << in_w_raw << " conv_act_size_h=" << conv_act_size_h << " conv_act_size_w=" << conv_act_size_w << ENDL(); - uint32_t dst_addr = l1_write_addr_act + l1_addr_offset; - s_act.noc_async_read_partial_page(channel_id, dst_addr, read_size_bytes, act_w_offset_bytes); - } - } //else { DPRINT << "total_h here =" << total_h << ENDL(); } //do nothing. let garbage rows be in l1 - l1_addr_offset += read_size_bytes; - if(out_w < conv_output_size_w - 1) { - out_w += 1; - } else { - out_w = 0; - if (out_h < conv_output_size_h - 1) { - out_h += 1; - } else if (total_h < act_matrix_height_unpadded){ - // next image in batch - out_h = 0; - n += 1; - } - } - total_h += 1; - } // for block height - in_w_offset_within_kernel_window += 1; - //DPRINT << "waiting on read barrier" << ENDL(); - noc_async_read_barrier(); - //DPRINT << "done on read barrier" << ENDL(); - - cb_push_back(cb_id_act, act_block_num_tiles); - } // for filter window width - in_h_offset_within_kernel_window += 1; - } // for filter window height - in_h_offset_within_kernel_window = in_h_offset_within_kernel_window_start; - act_w_offset_bytes += read_size_bytes; - #ifdef ACT_W_OUTER_BLOCKS - } // for act_w_outer_i - #endif - - out_h_reset = out_h; - out_w_reset = out_w; - total_h_reset = total_h; - n_reset = n; - - in_h_offset_within_kernel_window_start = in_h_offset_within_kernel_window; - } // for num of act blocks in height dim - } // for num of weight blocks in width dim -} diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_conv_activations_fast_resnet50_first_conv.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_conv_activations_fast_resnet50_first_conv.cpp deleted file mode 100644 index d96f7acb84f..00000000000 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_conv_activations_fast_resnet50_first_conv.cpp +++ /dev/null @@ -1,120 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include -#include "dataflow_api.h" -// #include "debug/dprint.h" - -void kernel_main() { - uint32_t i = 0; - uint32_t act_addr_dram_base = get_arg_val(i); i+=1; - - uint32_t conv_act_size_c = get_arg_val(i); i+=1; - uint32_t conv_output_size_w = get_arg_val(i); i+=1; - uint32_t weight_size_w = get_arg_val(i); i+=1; - uint32_t num_blocks_act_h = get_arg_val(i); i+=1; - uint32_t num_blocks_act_w = get_arg_val(i); i+=1; - - uint32_t act_block_h_datums = get_arg_val(i); i+=1; - uint32_t act_block_num_tiles = get_arg_val(i); i+=1; - uint32_t in_h_start = get_arg_val(i); i+=1; - uint32_t out_w_start = get_arg_val(i); i+=1; - uint32_t last_start_in_h_curr_image = get_arg_val(i); i+=1; - - uint32_t noop = get_arg_val(i); i+=1; - if(noop) { - return; - } - - constexpr bool act_in_dram = get_compile_time_arg_val(0) == 1; - constexpr uint32_t stride_h = get_compile_time_arg_val(1); - constexpr uint32_t stride_w = get_compile_time_arg_val(2); - constexpr uint32_t conv_act_size_w = get_compile_time_arg_val(3); - constexpr uint32_t conv_output_w_last_index = get_compile_time_arg_val(4) - 1; - // 5,6 not used - constexpr uint32_t extra_padding_for_32B_alignment = get_compile_time_arg_val(7); - //constexpr uint32_t act_block_width_padding_bytes = get_compile_time_arg_val(1); - - constexpr uint32_t cb_id_act = 0; - constexpr uint32_t tile_size_pow2_exponent = 11; - const DataFormat data_format = get_dataformat(cb_id_act); - uint32_t channel_stick_size = conv_act_size_c; - uint32_t channel_stick_size_bytes = channel_stick_size << 1; - - const InterleavedPow2AddrGenFast s_act = { - .bank_base_address = act_addr_dram_base, - //.log_base_2_of_page_size = 5 // TODO: send as a compile-time arg, currently C=16 in FP16_B (so 32 B) - //.log_base_2_of_page_size = 13 // TODO: send as a compile-time arg, currently C=16 x W=256 in FP16_B = 8192 - .log_base_2_of_page_size = 11 // TODO: send as a compile-time arg, currently C=4 x W=256 in FP16_B = 2048 - }; - uint32_t read_size_bytes = channel_stick_size_bytes << 3; // channel stick size * 8 - // Assumptions. Must be true. Validate on host. - // assert(act_block_w_datums == C * weight_size_w) - // assert(num_blocks_act_w == weight_size_h) - // assert(act_block_w_datums % C == 0) - // assert(act_block_w_datums % 32 == 0) - // assert(act_block_h_datums % 32 == 0) - // assert(act_block_h_ntiles == act_block_h_datums/32) - // assert(act_block_w_ntiles == act_block_w_datums/32) - // assert(act_block_num_tiles == (act_block_h_datums * act_block_w_datums)/1024) - - - constexpr uint32_t in_w_padded_for_32_alignment = 231 + extra_padding_for_32B_alignment; - - uint32_t in_h = in_h_start; - uint32_t in_h_reset = in_h; - uint32_t out_w = out_w_start; - uint32_t out_w_reset = out_w; - uint32_t page_offset_h_2d_matrix = out_w_start * (channel_stick_size_bytes << 1); - uint32_t page_offset_h_2d_matrix_reset = page_offset_h_2d_matrix; - uint32_t last_start_in_h_stride = 222; - uint32_t last_start_in_h_curr_image_reset = last_start_in_h_curr_image; - for(uint32_t nbh = 0; nbh < num_blocks_act_h; nbh++) { - uint32_t c_id_offset_inter_block_col = 0; - uint32_t page_id_offset_inter_block_w = 0; - for (uint32_t nbw = 0; nbw < num_blocks_act_w; nbw++) { - out_w = out_w_reset; - in_h = in_h_reset; - page_offset_h_2d_matrix = page_offset_h_2d_matrix_reset; - last_start_in_h_curr_image = last_start_in_h_curr_image_reset; - cb_reserve_back(cb_id_act, act_block_num_tiles); - uint32_t l1_write_addr_act = get_write_ptr(cb_id_act); - uint32_t l1_addr_offset = 0; - for(uint32_t bh = 0; bh < act_block_h_datums; bh++) { - //uint32_t c_id_offset_inra_block_col = 0; - - // channel_stick * filter window width is contiguous in page - uint32_t page_id = in_h + page_id_offset_inter_block_w; - uint32_t page_offset = page_offset_h_2d_matrix; - uint32_t dst_addr = l1_write_addr_act + l1_addr_offset; - s_act.noc_async_read_partial_page(page_id, dst_addr, read_size_bytes, page_offset); - l1_addr_offset += read_size_bytes; - if(out_w < conv_output_size_w - 1) { - out_w += 1; - //first_c_id_in_2d_row += 2; // channel id stride in the w dimension - page_offset_h_2d_matrix += (channel_stick_size_bytes << 1); // * 2 for conv stride in the w dimension - } else { - out_w = 0; - page_offset_h_2d_matrix = 0; - if (in_h < last_start_in_h_curr_image) { - in_h += 2; // stride_h - } else { - // next image in batch - // stride in_h for next image.. assume shape is 1, N*H, W, C, in_h represents h coordinate in this shape. - in_h += 8; - last_start_in_h_curr_image = in_h + last_start_in_h_stride; - } - } - } // for block height - c_id_offset_inter_block_col += in_w_padded_for_32_alignment; - page_id_offset_inter_block_w += 1; - noc_async_read_barrier(); - cb_push_back(cb_id_act, act_block_num_tiles); - } // for num of act blocks in inner width dim - out_w_reset = out_w; - in_h_reset = in_h; - page_offset_h_2d_matrix_reset = page_offset_h_2d_matrix; - last_start_in_h_curr_image_reset = last_start_in_h_curr_image; - } // for num of act blocks in height dim -} diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_conv_activations_fast_without_conv_padding.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_conv_activations_fast_without_conv_padding.cpp deleted file mode 100644 index 9cec0f43c7f..00000000000 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_conv_activations_fast_without_conv_padding.cpp +++ /dev/null @@ -1,146 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include -#include "dataflow_api.h" -// #include "debug/dprint.h" - -inline void pad_l1_buffer_with_zeroes(uint32_t l1_addr, uint32_t pad_size_bytes) { - volatile std::uint32_t* dst = reinterpret_cast(l1_addr); - volatile std::uint32_t* end_dst = dst + (pad_size_bytes >> 2); // Divide by 4 using right shift - - while (dst < end_dst) { - *dst++ = 0; - } - - uint32_t remainder = pad_size_bytes & 0x3; // Get the remainder using bitwise AND - if (remainder != 0) { - volatile std::uint8_t* byte_dst = reinterpret_cast(dst); - for (uint32_t i = 0; i < remainder; ++i) { - *byte_dst++ = 0; - } - } -} - -void kernel_main() { - uint32_t i = 0; - uint32_t act_addr_dram_base = get_arg_val(i); i+=1; - uint32_t act_dram_noc_x = get_arg_val(i); i+=1; - uint32_t act_dram_noc_y = get_arg_val(i); i+=1; - - uint32_t conv_act_size_w_ = get_arg_val(i); i+=1; - uint32_t conv_act_size_h = get_arg_val(i); i+=1; - uint32_t conv_act_size_c_ = get_arg_val(i); i+=1; - uint32_t weight_size_h = get_arg_val(i); i+=1; - uint32_t weight_size_w = get_arg_val(i); i+=1; - uint32_t stride_h_ = get_arg_val(i); i+=1; - uint32_t stride_w_ = get_arg_val(i); i+=1; - uint32_t pad_h = get_arg_val(i); i+=1; - uint32_t pad_w = get_arg_val(i); i+=1; - uint32_t conv_output_size_h = get_arg_val(i); i+=1; - uint32_t conv_output_size_w = get_arg_val(i); i+=1; - uint32_t num_blocks_act_h = get_arg_val(i); i+=1; - uint32_t num_blocks_act_w = get_arg_val(i); i+=1; - uint32_t num_blocks_weight_w = get_arg_val(i); i+=1; - uint32_t num_groups = get_arg_val(i); i+=1; - - uint32_t act_matrix_height_unpadded = get_arg_val(i); i+=1; - uint32_t act_matrix_width_unpadded = get_arg_val(i); i+=1; - uint32_t act_matrix_height = get_arg_val(i); i+=1; - uint32_t act_matrix_width = get_arg_val(i); i+=1; - uint32_t act_matrix_height_ntiles = get_arg_val(i); i+=1; - uint32_t act_matrix_width_ntiles = get_arg_val(i); i+=1; - uint32_t act_block_h_datums = get_arg_val(i); i+=1; - uint32_t act_block_w_datums = get_arg_val(i); i+=1; - uint32_t act_block_h_ntiles = get_arg_val(i); i+=1; - uint32_t act_block_w_ntiles = get_arg_val(i); i+=1; - uint32_t act_block_num_tiles = get_arg_val(i); i+=1; - uint32_t src_dram_act_buffer_size_bytes = get_arg_val(i); i+=1; - uint32_t dst_l1_act_buffer_size_bytes = get_arg_val(i); i+=1; - - constexpr bool act_in_dram = get_compile_time_arg_val(0) == 1; - constexpr uint32_t stride_h = get_compile_time_arg_val(1); - constexpr uint32_t stride_w = get_compile_time_arg_val(2); - constexpr uint32_t conv_act_size_w = get_compile_time_arg_val(3); - constexpr uint32_t conv_output_w_last_index = get_compile_time_arg_val(4) - 1; - constexpr uint32_t conv_act_size_c_bytes = get_compile_time_arg_val(5); - constexpr uint32_t log_base_2_of_conv_act_size_c_bytes = get_compile_time_arg_val(6); - - constexpr uint32_t cb_id_act = 0; - constexpr uint32_t tile_size_pow2_exponent = 11; - const DataFormat data_format = get_dataformat(cb_id_act); - const InterleavedPow2AddrGenFast s_act = { - .bank_base_address = act_addr_dram_base, - .log_base_2_of_page_size = log_base_2_of_conv_act_size_c_bytes - }; - - // Assumptions. Must be true. Validate on host. - // assert(act_block_w_datums == C * weight_size_w) - // assert(num_blocks_act_w == weight_size_h) - // assert(act_block_w_datums % C == 0) - // assert(act_block_w_datums % 32 == 0) - // assert(act_block_h_datums % 32 == 0) - // assert(act_block_h_ntiles == act_block_h_datums/32) - // assert(act_block_w_ntiles == act_block_w_datums/32) - // assert(act_block_num_tiles == (act_block_h_datums * act_block_w_datums)/1024) - - uint32_t out_h = 0; - uint32_t out_w = 0; - uint32_t out_h_start = 0; - uint32_t out_w_start = 0; - //DPRINT << "Running new conv reader" << ENDL(); - for(uint32_t nbh = 0; nbh < num_blocks_act_h; nbh++) { - for(uint32_t nbr = 0; nbr < num_blocks_weight_w; nbr++) { - uint32_t in_h_offset_within_kernel_window = 0; - for (uint32_t nbw = 0; nbw < num_blocks_act_w; nbw++) { - out_h = out_h_start; - out_w = out_w_start; - cb_reserve_back(cb_id_act, act_block_num_tiles); - uint32_t l1_write_addr_act = get_write_ptr(cb_id_act); - uint32_t l1_addr_offset = 0; - for(uint32_t bh = 0; bh < act_block_h_datums; bh++) { - uint32_t in_h_offset = out_h * stride_h; - uint32_t in_w_offset = out_w * stride_w; // expect stride 1 or 2.. make this compile time args - also conv input width - uint32_t in_w_offset_within_kernel_window = 0; - for(uint32_t bw = 0; bw < weight_size_w; bw++) { - uint32_t read_size_bytes = conv_act_size_c_bytes; - #ifdef ACT_BLOCK_HEIGHT_PADDING - if (out_h < conv_output_size_h) { - #endif - uint32_t in_h = in_h_offset + in_h_offset_within_kernel_window; - uint32_t in_w = in_w_offset + in_w_offset_within_kernel_window; - - // read one channel from dram multi bank - row_id = channel_id - uint32_t channel_id = (in_h * conv_act_size_w) + in_w; - uint32_t dst_addr = l1_write_addr_act + l1_addr_offset; - s_act.noc_async_read_page(channel_id, dst_addr); - #ifdef ACT_BLOCK_HEIGHT_PADDING - } // else { do nothing. let garbage rows be in l1 } - #endif - l1_addr_offset += read_size_bytes; - in_w_offset_within_kernel_window += 1; - } // for block width - // pad 0s for block padding on the right side of block.. only first conv since C%32 != 0.. ifdef with compile time define - #ifdef ACT_BLOCK_WIDTH_PADDING_BYTES - // pad 0s in l1 - uint32_t dst_addr = l1_write_addr_act + l1_addr_offset; - pad_l1_buffer_with_zeroes(dst_addr, (uint32_t) ACT_BLOCK_WIDTH_PADDING_BYTES); - l1_addr_offset += (uint32_t) ACT_BLOCK_WIDTH_PADDING_BYTES; - #endif - if(out_w < conv_output_w_last_index) { - out_w += 1; - } else { - out_h += 1; - out_w = 0; - } - } // for block height - in_h_offset_within_kernel_window += 1; - noc_async_read_barrier(); - cb_push_back(cb_id_act, act_block_num_tiles); - } // for num of act blocks in inner width dim - } // for num of weight blocks in width dim - out_h_start = out_h; - out_w_start = out_w; - } // for num of act blocks in height dim -} diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_conv_activations_padded_with_halo_3x3_weights.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_conv_activations_padded_with_halo_3x3_weights.cpp deleted file mode 100644 index 5b41a25d49a..00000000000 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_conv_activations_padded_with_halo_3x3_weights.cpp +++ /dev/null @@ -1,220 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include -#include "dataflow_api.h" -// #include "debug/dprint.h" - - -void kernel_main() { - uint32_t i = 0; - uint32_t conv_act_size_w = get_arg_val(i); i+=1; - uint32_t conv_act_size_h = get_arg_val(i); i+=1; - uint32_t weight_size_h = get_arg_val(i); i+=1; - uint32_t weight_size_w = get_arg_val(i); i+=1; - - uint32_t act_num_blocks_h = get_arg_val(i); i+=1; - // inner loop bounds as compile-time args improve pef - // uint32_t act_block_h_datums = get_arg_val(i); i+=1; - // i+=1; // skip an arg - - uint32_t act_block_num_tiles = get_arg_val(i); i+=1; - - uint32_t first_partial_right_aligned_row_width = get_arg_val(i); i+=1; - uint32_t skip_after_partial_right_aligned_row = get_arg_val(i); i+=1; - uint32_t first_partial_image_num_rows = get_arg_val(i); i+=1; - uint32_t skip_after_first_partial_image_row = get_arg_val(i); i+=1; - uint32_t num_full_images = get_arg_val(i); i+=1; - uint32_t skip_after_full_image = get_arg_val(i); i+=1; - uint32_t last_partial_image_num_rows = get_arg_val(i); i+=1; - uint32_t last_partial_left_aligned_row_width = get_arg_val(i); i+=1; - - // moved these to compile-time args - // uint32_t window_outer = get_arg_val(i); i+=1; - // uint32_t window_inner = get_arg_val(i); i+=1; - i+=2; // skip 2 rt args - - uint32_t noop = get_arg_val(i); i+=1; - if(noop) { - return; - } - - constexpr bool act_in_dram = get_compile_time_arg_val(0) == 1; - constexpr uint32_t stride_h = get_compile_time_arg_val(1); - constexpr uint32_t stride_w = get_compile_time_arg_val(2); - constexpr uint32_t conv_act_size_w_ = get_compile_time_arg_val(3); - constexpr uint32_t conv_output_w_last_index = get_compile_time_arg_val(4) - 1; - constexpr uint32_t conv_act_c_read_bytes = get_compile_time_arg_val(5); - constexpr uint32_t log_base_2_of_conv_act_size_c_bytes = get_compile_time_arg_val(6); - // TODO delete unused: get_compile_time_arg_val(7); (8), (9) - // need to have these as compile-time, they are inner loop bouds / unroll loops / constexpr conditionals based on them - constexpr uint32_t window_outer = get_compile_time_arg_val(10); - constexpr uint32_t window_inner = get_compile_time_arg_val(11); - constexpr uint32_t act_block_h_datums = get_compile_time_arg_val(12); - - constexpr uint32_t cb_id_act = 0; - constexpr uint32_t cb_id_sharded_act = 3; - - // Assumptions. Must be true. Validate on host. - // assert(act_block_w_datums == C * weight_size_w) - // assert(num_blocks_act_w == weight_size_h) - // assert(act_block_w_datums % C == 0) - // assert(act_block_w_datums % 32 == 0) - // assert(act_block_h_datums % 32 == 0) - // assert(act_block_h_ntiles == act_block_h_datums/32) - // assert(act_block_w_ntiles == act_block_w_datums/32) - // assert(act_block_num_tiles == (act_block_h_datums * act_block_w_datums)/1024) - - // LOOP TO FILL READER INDICES - constexpr uint32_t cb_reader_indices = tt::CB::c_in4; - volatile tt_l1_ptr uint32_t* reader_indices_ptr = reinterpret_cast(get_write_ptr(cb_reader_indices)); - - uint32_t weights_top_left_corner_idx = 0; - uint32_t reader_idx = 0; - - // First partial right-aligned row - for (uint32_t k = 0; k < first_partial_right_aligned_row_width; k++) { - reader_indices_ptr[reader_idx++] = weights_top_left_corner_idx++; - } - weights_top_left_corner_idx += skip_after_partial_right_aligned_row; // Skip padded width - - // First partial image - for (uint32_t j = 0; j < first_partial_image_num_rows; j++) { - for (uint32_t k = 0; k < conv_act_size_w_; k++) { - reader_indices_ptr[reader_idx++] = weights_top_left_corner_idx++; - } - weights_top_left_corner_idx += weight_size_w - 1; - } - weights_top_left_corner_idx += skip_after_first_partial_image_row; // Skip padded rows - - // Full images - for (uint32_t i = 0; i < num_full_images; i++) { - for (uint32_t j = 0; j < conv_act_size_h; j++) { - for (uint32_t k = 0; k < conv_act_size_w; k++) { - reader_indices_ptr[reader_idx++] = weights_top_left_corner_idx++; - } - weights_top_left_corner_idx += weight_size_w - 1; - } - weights_top_left_corner_idx += skip_after_full_image; // Skip padded rows - } - - // Last partial image - for (uint32_t j = 0; j < last_partial_image_num_rows; j++) { - for (uint32_t k = 0; k < conv_act_size_w; k++) { - reader_indices_ptr[reader_idx++] = weights_top_left_corner_idx++; - } - weights_top_left_corner_idx += weight_size_w - 1; - } - - // Last partial left-alighted row - for (uint32_t k = 0; k < last_partial_left_aligned_row_width; k++) { - reader_indices_ptr[reader_idx++] = weights_top_left_corner_idx++; - } - - - // LOOP TO FILL READER OFFSETS - /* We can add another loop to read chunks of a stick as well. - * - Duplicate reader_offset for same stick X times (window_inner must be 1) - * - New loop between outer and inner that loops X times reading from same stick - * - Read conv_act_c_read_bytes / X each time - * - Update l1_write_addr_act by conv_act_c_read_bytes - */ - constexpr uint32_t cb_reader_offsets = tt::CB::c_in5; - volatile tt_l1_ptr uint32_t* reader_offsets_ptr = reinterpret_cast(get_write_ptr(cb_reader_offsets)); - uint32_t reader_offset = 0; // Constant offset for each pixel within filter window - uint32_t reader_offset_idx = 0; - for (uint32_t channel_stick_h = 0; channel_stick_h < weight_size_h; channel_stick_h++) { - for (uint32_t channel_stick_w = 0; channel_stick_w < weight_size_w; channel_stick_w++) { - reader_offsets_ptr[reader_offset_idx++] = reader_offset++; - } - // -1 to go back to previous reader_offset - reader_offset += conv_act_size_w - 1; // Assuming (weight_size_w - 1) / 2 == pad_w - } - - - // TODO: need to make the read coalescing optimization cleaner - // pass coalesce_window_inner_reads as a compile time arg and num_coalesced_reads so we can constexpr the if - // currently works for the case of num_coalesced_reads == weight_size_w since these reads are contiguous on both src/dst side - // we check if window_inner == weight_size_w to make sure coalescing is legal along full window_inner so the loop can be removed - constexpr bool coalesce_window_inner_reads = true; - constexpr uint32_t num_coalesced_reads = 3; - constexpr uint32_t coalesced_read_bytes = num_coalesced_reads * conv_act_c_read_bytes; - // the conditional selecting between coalescing and no-colescing must be constexpr to that compiler can optimized the other path away - // this has shown to be a big perf win - if constexpr (coalesce_window_inner_reads and window_inner == num_coalesced_reads) { - // coalesce reads along weight_size_w - reader_offset_idx = 0; - uint32_t act_l1_offset = 0; - uint32_t act_l1_read_addr = get_read_ptr(cb_id_sharded_act); - - static_assert(coalesced_read_bytes <= NOC_MAX_BURST_SIZE); - // set_state uses just x/y from the get_noc_addr, addr is ignored - noc_async_read_one_packet_set_state(get_noc_addr(act_l1_read_addr), coalesced_read_bytes); - uint32_t start_reader_idx = 0; - for (uint32_t bh = 0; bh < act_num_blocks_h; bh++) { - for (uint32_t outer = 0; outer < window_outer; outer++) { - // Reset reader_idx to finish act_block_h_datums - reader_idx = start_reader_idx; - - cb_reserve_back(cb_id_act, act_block_num_tiles); - uint32_t l1_write_addr_act = get_write_ptr(cb_id_act); - uint32_t reader_offset = act_l1_read_addr + (reader_offsets_ptr[reader_offset_idx] << log_base_2_of_conv_act_size_c_bytes); - // #pragma GCC unroll 4 // unroll didn't help, but act_block_h_datums (loop bound) being const does help - for (uint32_t bhd = 0; bhd < act_block_h_datums; bhd++) { - // local read from reader_index + reader_offset; - act_l1_offset = reader_offset + (reader_indices_ptr[reader_idx] << log_base_2_of_conv_act_size_c_bytes); - noc_async_read_one_packet_with_state(act_l1_offset, l1_write_addr_act); - l1_write_addr_act += coalesced_read_bytes; - reader_idx++; - } - noc_async_read_barrier(); - cb_push_back(cb_id_act, act_block_num_tiles); - - reader_offset_idx += window_inner; - } - reader_offset_idx = 0; - start_reader_idx = reader_idx; - } - - } else { - // no coalescing of reads - reader_offset_idx = 0; - uint32_t act_l1_offset = 0; - uint32_t act_l1_read_addr = get_read_ptr(cb_id_sharded_act); - - static_assert(conv_act_c_read_bytes <= NOC_MAX_BURST_SIZE); - // set_state uses just x/y from the get_noc_addr, addr is ignored - noc_async_read_one_packet_set_state(get_noc_addr(act_l1_read_addr), conv_act_c_read_bytes); - - uint32_t start_reader_idx = 0; - for (uint32_t bh = 0; bh < act_num_blocks_h; bh++) { - for (uint32_t outer = 0; outer < window_outer; outer++) { - // Reset reader_idx to finish act_block_h_datums - reader_idx = start_reader_idx; - cb_reserve_back(cb_id_act, act_block_num_tiles); - uint32_t l1_write_addr_act = get_write_ptr(cb_id_act); - for (uint32_t bhd = 0; bhd < act_block_h_datums; bhd++) { - // when no read coalesing, main use case is window_inner == 1, - // and if window_inner is const this loop should be removed by the compiler - for (uint32_t inner = 0; inner < window_inner; inner++) { - // local read from reader_index + reader_offset; - act_l1_offset = act_l1_read_addr + ((reader_indices_ptr[reader_idx] + reader_offsets_ptr[reader_offset_idx + inner]) << log_base_2_of_conv_act_size_c_bytes); - noc_async_read_one_packet_with_state(act_l1_offset, l1_write_addr_act); - l1_write_addr_act += conv_act_c_read_bytes; - - } - reader_idx++; - } - noc_async_read_barrier(); - cb_push_back(cb_id_act, act_block_num_tiles); - - reader_offset_idx += window_inner; - reader_offset_idx += window_inner; - reader_offset_idx += window_inner; - } - reader_offset_idx = 0; - start_reader_idx = reader_idx; - } - } -} diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_and_mcast_receiver_weights_resnet50_first_conv_tiled_out.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_and_mcast_receiver_weights_resnet50_first_conv_tiled_out.cpp deleted file mode 100644 index 8e4c45c77b4..00000000000 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_and_mcast_receiver_weights_resnet50_first_conv_tiled_out.cpp +++ /dev/null @@ -1,169 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include -#include "dataflow_api.h" - - -void kernel_main() { - uint32_t i = 0; - uint32_t out_addr = get_arg_val(i); i+=1; - uint32_t weight_addr_dram_base = get_arg_val(i); i+=1; - // Bias args. Unused if bias fusion is not enabled. - const uint32_t bias_addr = get_arg_val(i); i += 1; - - uint32_t out_next_tile_stride_h = get_arg_val(i); i+=1; - uint32_t out_next_tile_stride_w = get_arg_val(i); i+=1; - uint32_t out_next_subblock_stride_h = get_arg_val(i); i+=1; - uint32_t out_next_subblock_stride_w = get_arg_val(i); i+=1; - uint32_t out_next_block_stride_h = get_arg_val(i); i+=1; - uint32_t out_next_block_stride_w = get_arg_val(i); i+=1; - uint32_t out_subblock_h = get_arg_val(i); i+=1; - uint32_t out_subblock_w = get_arg_val(i); i+=1; - uint32_t out_subblock_tile_count = get_arg_val(i); i+=1; - uint32_t out_num_subblocks_h = get_arg_val(i); i+=1; - uint32_t out_num_subblocks_w = get_arg_val(i); i+=1; - uint32_t out_num_blocks_h = get_arg_val(i); i+=1; - uint32_t out_num_blocks_w = get_arg_val(i); i+=1; - uint32_t out_block_height_num_tiles = get_arg_val(i); i+=1; - uint32_t out_height_num_tiles = get_arg_val(i); i+=1; - uint32_t out_width_num_tiles = get_arg_val(i); i+=1; - uint32_t out_start_tile_id = get_arg_val(i); i+=1; - uint32_t out_start_tile_id_h = get_arg_val(i); i+=1; - uint32_t out_start_tile_id_w = get_arg_val(i); i+=1; - - uint32_t num_blocks_weight_h = get_arg_val(i); i+=1; - uint32_t weight_block_num_tiles = get_arg_val(i); i+=1; - uint32_t weight_block_height_num_outer = get_arg_val(i); i+=1; - uint32_t weight_block_height_ntiles = get_arg_val(i); i+=1; - uint32_t weight_block_width_ntiles = get_arg_val(i); i+=1; - uint32_t weight_stride_h = get_arg_val(i); i+=1; - uint32_t weight_next_block_stride_h = get_arg_val(i); i+=1; - uint32_t weight_next_block_stride_w = get_arg_val(i); i+=1; - - // Bias args. Unused if bias fusion is not enabled. - const uint32_t bias_ntiles = get_arg_val(i); i += 1; - const uint32_t bias_tile_offset = get_arg_val(i); i += 1; - - uint32_t noop = get_arg_val(i); i+=1; - if(noop) { - return; - } - - // mcast args - uint32_t weights_mcast_sender_noc_x = get_arg_val(i); i+=1; - uint32_t weights_mcast_sender_noc_y = get_arg_val(i); i+=1; - uint32_t weights_mcast_sender_semaphore_addr = get_semaphore(get_arg_val(i)); i+=1; - uint32_t weights_mcast_receiver_semaphore_addr = get_semaphore(get_arg_val(i)); i+=1; - - - constexpr bool out_in_dram = get_compile_time_arg_val(0) == 1; - constexpr uint32_t cb_id_out0 = get_compile_time_arg_val(1); - constexpr uint32_t cb_id_weight = get_compile_time_arg_val(2); - - volatile tt_l1_ptr uint32_t* weights_mcast_receiver_semaphore_addr_ptr = reinterpret_cast(weights_mcast_receiver_semaphore_addr); - - const uint32_t tile_nbytes = get_tile_size(cb_id_out0); - const DataFormat out_df = get_dataformat(cb_id_out0); - - const InterleavedAddrGenFast s = { - .bank_base_address = out_addr, - .page_size = tile_nbytes, - .data_format = out_df - }; - - // MCAST RECEIVE WEIGHTS - // read weight blocks inner dim - // read weight slice - 1 block of weights in width dim and full weight matrix height - // read slice only once for all activation blocks - for(uint32_t block_weight_h = 0; block_weight_h < num_blocks_weight_h; block_weight_h++) { - cb_reserve_back(cb_id_weight, weight_block_num_tiles); - // Set weights semaphore value to INVALID - noc_semaphore_set(weights_mcast_receiver_semaphore_addr_ptr, INVALID); - - // Atomic increment source core counter - uint64_t weights_mcast_sender_semaphore_noc_addr = get_noc_addr(weights_mcast_sender_noc_x, weights_mcast_sender_noc_y, weights_mcast_sender_semaphore_addr); - noc_semaphore_inc(weights_mcast_sender_semaphore_noc_addr, 1); - - // wait on weights semaphore value to become VALID (set by mcast sender after it multicasts data) - noc_semaphore_wait(weights_mcast_receiver_semaphore_addr_ptr, VALID); - - cb_push_back(cb_id_weight, weight_block_num_tiles); - } // for num_blocks_weight_h - - // first read in bias if enabled (done only once for all blocks) - #ifdef FUSE_BIAS - constexpr uint32_t bias_cb_id = get_compile_time_arg_val(3); - cb_reserve_back(bias_cb_id, bias_ntiles); - - // Set weights semaphore value to INVALID - noc_semaphore_set(weights_mcast_receiver_semaphore_addr_ptr, INVALID); - - // Atomic increment source core counter - uint64_t weights_mcast_sender_semaphore_noc_addr = get_noc_addr(weights_mcast_sender_noc_x, weights_mcast_sender_noc_y, weights_mcast_sender_semaphore_addr); - noc_semaphore_inc(weights_mcast_sender_semaphore_noc_addr, 1); - - // wait on weights semaphore value to become VALID (set by mcast sender after it multicasts data) - noc_semaphore_wait(weights_mcast_receiver_semaphore_addr_ptr, VALID); - - cb_push_back(bias_cb_id, bias_ntiles); - #endif - - #ifndef SHARDED_OUT - uint32_t out_block_h_start_tile_id = out_start_tile_id; - uint32_t out_block_h_start_tile_id_h = out_start_tile_id_h; - for(uint32_t bh = 0; bh < out_num_blocks_h; bh++) { - uint32_t out_block_w_start_tile_id = out_block_h_start_tile_id; - uint32_t out_block_w_start_tile_id_w = 0; - for (uint32_t bw = 0; bw < out_num_blocks_w; bw++) { - - uint32_t out_sbh_start_tile_id = out_block_w_start_tile_id; - uint32_t out_sbh_start_tile_id_h = out_block_h_start_tile_id_h; - for(uint32_t sbh = 0; sbh < out_num_subblocks_h; sbh++) { - uint32_t out_sbw_start_tile_id = out_sbh_start_tile_id; - uint32_t out_sbw_start_tile_id_w = out_block_w_start_tile_id_w; - for(uint32_t sbw = 0; sbw < out_num_subblocks_w; sbw++) { - uint32_t out_sb_row_start_tile_id = out_sbw_start_tile_id; - // wait for one subblock worth tiles - cb_wait_front(cb_id_out0, out_subblock_tile_count); - uint32_t l1_read_addr = get_read_ptr(cb_id_out0); - for(uint32_t h = 0; h < out_subblock_h; h++) { - uint32_t out_tile_id = out_sb_row_start_tile_id; - uint32_t out_tile_id_h = out_sbh_start_tile_id_h + h; - if (out_tile_id_h >= out_height_num_tiles) { // block shape height padding - break; - } - for(uint32_t w = 0; w < out_subblock_w; w++) { - uint32_t out_tile_id_w = out_sbw_start_tile_id_w + w; - if (out_tile_id_w >= out_width_num_tiles) { // block shape width padding - l1_read_addr += tile_nbytes; - } else { - //DPRINT << "out_tile_id - " << out_tile_id << ENDL(); - s.noc_async_write_tile(out_tile_id, l1_read_addr); - l1_read_addr += tile_nbytes; - out_tile_id += out_next_tile_stride_w; - } - } // out_subblock_w (ntiles) - out_sb_row_start_tile_id += out_next_tile_stride_h; - } // out_subblock_h (ntiles) - noc_async_write_barrier(); - //DPRINT << "Done writing subblock." << ENDL(); - cb_pop_front(cb_id_out0, out_subblock_tile_count); - out_sbw_start_tile_id += out_next_subblock_stride_w; - out_sbw_start_tile_id_w += out_subblock_w; - } // out_num_subblocks_w - out_sbh_start_tile_id += out_next_subblock_stride_h; - out_sbh_start_tile_id_h += out_subblock_h; - } // out_num_subblocks_h - out_block_w_start_tile_id += out_next_block_stride_w; - out_block_w_start_tile_id_w += weight_block_width_ntiles; - } // out_num_blocks_w - out_block_h_start_tile_id += out_next_block_stride_h; - out_block_h_start_tile_id_h += out_block_height_num_tiles; - } // out_num_blocks_h - - #else - cb_wait_front(cb_id_out0, out_subblock_tile_count * out_num_subblocks_h * out_num_subblocks_w * out_num_blocks_w * out_num_blocks_h); - #endif -} diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_and_mcast_sender_weights_resnet50_first_conv_tiled_out.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_and_mcast_sender_weights_resnet50_first_conv_tiled_out.cpp deleted file mode 100644 index 116232d0655..00000000000 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_and_mcast_sender_weights_resnet50_first_conv_tiled_out.cpp +++ /dev/null @@ -1,270 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include -#include "dataflow_api.h" - - -void kernel_main() { - uint32_t i = 0; - uint32_t out_addr = get_arg_val(i); i+=1; - uint32_t weight_addr_dram_base = get_arg_val(i); i+=1; - // Bias args. Unused if bias fusion is not enabled. - const uint32_t bias_addr = get_arg_val(i); i += 1; - - uint32_t out_next_tile_stride_h = get_arg_val(i); i+=1; - uint32_t out_next_tile_stride_w = get_arg_val(i); i+=1; - uint32_t out_next_subblock_stride_h = get_arg_val(i); i+=1; - uint32_t out_next_subblock_stride_w = get_arg_val(i); i+=1; - uint32_t out_next_block_stride_h = get_arg_val(i); i+=1; - uint32_t out_next_block_stride_w = get_arg_val(i); i+=1; - uint32_t out_subblock_h = get_arg_val(i); i+=1; - uint32_t out_subblock_w = get_arg_val(i); i+=1; - uint32_t out_subblock_tile_count = get_arg_val(i); i+=1; - uint32_t out_num_subblocks_h = get_arg_val(i); i+=1; - uint32_t out_num_subblocks_w = get_arg_val(i); i+=1; - uint32_t out_num_blocks_h = get_arg_val(i); i+=1; - uint32_t out_num_blocks_w = get_arg_val(i); i+=1; - uint32_t out_block_height_num_tiles = get_arg_val(i); i+=1; - uint32_t out_height_num_tiles = get_arg_val(i); i+=1; - uint32_t out_width_num_tiles = get_arg_val(i); i+=1; - uint32_t out_start_tile_id = get_arg_val(i); i+=1; - uint32_t out_start_tile_id_h = get_arg_val(i); i+=1; - uint32_t out_start_tile_id_w = get_arg_val(i); i+=1; - - uint32_t num_blocks_weight_h = get_arg_val(i); i+=1; - uint32_t weight_block_num_tiles = get_arg_val(i); i+=1; - uint32_t weight_block_height_num_outer = get_arg_val(i); i+=1; - uint32_t weight_block_height_ntiles = get_arg_val(i); i+=1; - uint32_t weight_block_width_ntiles = get_arg_val(i); i+=1; - uint32_t weight_stride_h = get_arg_val(i); i+=1; - uint32_t weight_next_block_stride_h = get_arg_val(i); i+=1; - uint32_t weight_next_block_stride_w = get_arg_val(i); i+=1; - - // Bias args. Unused if bias fusion is not enabled. - const uint32_t bias_ntiles = get_arg_val(i); i += 1; - const uint32_t bias_tile_offset = get_arg_val(i); i += 1; - - uint32_t noop = get_arg_val(i); i+=1; - if(noop) { - return; - } - - // mcast args - uint32_t weights_mcast_dest_noc_start_x = get_arg_val(i); i+=1; - uint32_t weights_mcast_dest_noc_start_y = get_arg_val(i); i+=1; - uint32_t weights_mcast_dest_noc_end_x = get_arg_val(i); i+=1; - uint32_t weights_mcast_dest_noc_end_y = get_arg_val(i); i+=1; - uint32_t weights_mcast_num_dests = get_arg_val(i); i+=1; - uint32_t weights_mcast_num_cores = get_arg_val(i); i+=1; - uint32_t weights_mcast_sender_semaphore_addr = get_semaphore(get_arg_val(i)); i+=1; - uint32_t weights_mcast_receiver_semaphore_addr = get_semaphore(get_arg_val(i)); i+=1; - - - constexpr bool out_in_dram = get_compile_time_arg_val(0) == 1; - constexpr uint32_t cb_id_out0 = get_compile_time_arg_val(1); - constexpr uint32_t cb_id_weight = get_compile_time_arg_val(2); - - - #ifndef SKIP_MCAST - // Set ur local VALID value, to be mcasted to destinations flag address after the data has been mcasted - volatile tt_l1_ptr uint32_t* weights_mcast_receiver_semaphore_addr_ptr = reinterpret_cast(weights_mcast_receiver_semaphore_addr); - *(weights_mcast_receiver_semaphore_addr_ptr) = VALID; - // local address that will be atomically incremented by mcast receivers, to know when all receivers are ready - // to receive the mcast - volatile tt_l1_ptr uint32_t* weights_mcast_sender_semaphore_addr_ptr = reinterpret_cast(weights_mcast_sender_semaphore_addr); - - uint64_t weights_mcast_receiver_semaphore_noc_addr = get_noc_multicast_addr( - weights_mcast_dest_noc_start_x, - weights_mcast_dest_noc_start_y, - weights_mcast_dest_noc_end_x, - weights_mcast_dest_noc_end_y, - weights_mcast_receiver_semaphore_addr); - #endif - - const uint32_t tile_nbytes = get_tile_size(cb_id_out0); - const DataFormat out_df = get_dataformat(cb_id_out0); - - const InterleavedAddrGenFast s = { - .bank_base_address = out_addr, - .page_size = tile_nbytes, - .data_format = out_df - }; - const uint32_t weight_tile_nbytes = get_tile_size(cb_id_weight); - const DataFormat weight_df = get_dataformat(cb_id_weight); - const InterleavedAddrGenFast s_weight = { - .bank_base_address = weight_addr_dram_base, - .page_size = weight_tile_nbytes, - .data_format = weight_df - }; - - // TODO: Review to support weight_block_height_num_outer - // READ WEIGHTS + MCAST SEND WEIGHTS - // read weight blocks inner dim - // weight DRAM -> L1 (weights in tiled form) - uint32_t weight_start_tile_id = 0; - uint32_t weight_current_block_start_tile_id = weight_start_tile_id; - for(uint32_t block_weight_h = 0; block_weight_h < num_blocks_weight_h; block_weight_h++) { - cb_reserve_back(cb_id_weight, weight_block_num_tiles); - uint32_t weight_write_l1_addr = get_write_ptr(cb_id_weight); - uint32_t weight_row_start_tile_id = weight_current_block_start_tile_id; - - // mcast args - uint32_t weights_start_address = weight_write_l1_addr; - uint32_t weights_block_size_bytes = 0; - - // loop over weight block tiles along h - for(uint32_t weight_tile_h_i = 0; weight_tile_h_i < weight_block_height_ntiles; ++weight_tile_h_i) { - uint32_t weight_tile_id = weight_row_start_tile_id; - // loop over weight block tiles along w - for(uint32_t weight_tile_w_i = 0; weight_tile_w_i < weight_block_width_ntiles; ++weight_tile_w_i) { - s_weight.noc_async_read_tile(weight_tile_id, weight_write_l1_addr); - weight_write_l1_addr += weight_tile_nbytes; - weights_block_size_bytes += weight_tile_nbytes; - weight_tile_id += 1; - } // for weight_block_w - weight_row_start_tile_id += weight_stride_h; - } // for weight_block_h - noc_async_read_barrier(); - - #ifndef SKIP_MCAST - // wait until all weights mcast destinations have atomically incremented the weights semaphore_addr (i.e. its value should be weights_mcast_num_dests), then reset - // the semaphore_addr value back to zero for the next block - noc_semaphore_wait(weights_mcast_sender_semaphore_addr_ptr, weights_mcast_num_dests); - noc_semaphore_set(weights_mcast_sender_semaphore_addr_ptr, 0); - - // Now we have the block in the CB address, we can mcast to dests! - uint64_t weights_multicast_data_addr = get_noc_multicast_addr( - weights_mcast_dest_noc_start_x, - weights_mcast_dest_noc_start_y, - weights_mcast_dest_noc_end_x, - weights_mcast_dest_noc_end_y, - weights_start_address); - // num_dests must not include source, since we are NOT really doing a local copy! - noc_async_write_multicast(weights_start_address, weights_multicast_data_addr, weights_block_size_bytes, weights_mcast_num_cores, true, true); - - // Note: no need for write barrier, since these two multicasts are done on the same noc id, same vc, same cmd_buf - // Also, this only works because we are setting VCs statically (using NOC_CMD_STATIC_VC). - - // We should also multicast the flag to destinations - // num_dests must not include source, since we are NOT really doing a local copy! - noc_semaphore_set_multicast(weights_mcast_receiver_semaphore_addr, weights_mcast_receiver_semaphore_noc_addr, weights_mcast_num_cores); - #endif - - weight_current_block_start_tile_id += weight_next_block_stride_h; - cb_push_back(cb_id_weight, weight_block_num_tiles); - } // for num_blocks_weight_h - - - // first read in bias if enabled (done only once for all blocks) - #ifdef FUSE_BIAS - constexpr uint32_t bias_cb_id = get_compile_time_arg_val(3); - constexpr uint32_t bias_in_dram = get_compile_time_arg_val(4) == 1; - - const uint32_t bias_pagesize = get_tile_size(bias_cb_id); - const DataFormat bias_df = get_dataformat(bias_cb_id); - const InterleavedAddrGenFast s_bias = { - .bank_base_address = bias_addr, - .page_size = bias_pagesize, - .data_format = bias_df - }; - - cb_reserve_back(bias_cb_id, bias_ntiles); - uint32_t bias_l1_addr = get_write_ptr(bias_cb_id); - - // mcast args - uint32_t bias_start_address = bias_l1_addr; - uint32_t bias_block_size_bytes = 0; - for (uint32_t bias_tile = 0; bias_tile < bias_ntiles; ++ bias_tile) { - s_bias.noc_async_read_tile(bias_tile, bias_l1_addr); - bias_l1_addr += bias_pagesize; - bias_block_size_bytes += bias_pagesize; - } - noc_async_read_barrier(); - - // MCAST BIAS (shares some mcast args with weights) - #ifndef SKIP_MCAST - // wait until all weights mcast destinations have atomically incremented the weights semaphore_addr (i.e. its value should be weights_mcast_num_dests), then reset - // the semaphore_addr value back to zero for the next block - noc_semaphore_wait(weights_mcast_sender_semaphore_addr_ptr, weights_mcast_num_dests); - noc_semaphore_set(weights_mcast_sender_semaphore_addr_ptr, 0); - - // Now we have the block in the CB address, we can mcast to dests! - uint64_t bias_multicast_data_addr = get_noc_multicast_addr( - weights_mcast_dest_noc_start_x, - weights_mcast_dest_noc_start_y, - weights_mcast_dest_noc_end_x, - weights_mcast_dest_noc_end_y, - bias_start_address); - // num_dests must not include source, since we are NOT really doing a local copy! - noc_async_write_multicast(bias_start_address, bias_multicast_data_addr, bias_block_size_bytes, weights_mcast_num_cores, true, true); - - // Note: no need for write barrier, since these two multicasts are done on the same noc id, same vc, same cmd_buf - // Also, this only works because we are setting VCs statically (using NOC_CMD_STATIC_VC). - - // We should also multicast the flag to destinations - // num_dests must not include source, since we are NOT really doing a local copy! - noc_semaphore_set_multicast(weights_mcast_receiver_semaphore_addr, weights_mcast_receiver_semaphore_noc_addr, weights_mcast_num_cores); - #endif - - cb_push_back(bias_cb_id, bias_ntiles); - #endif - - #ifndef SHARDED_OUT - uint32_t out_block_h_start_tile_id = out_start_tile_id; - uint32_t out_block_h_start_tile_id_h = out_start_tile_id_h; - for(uint32_t bh = 0; bh < out_num_blocks_h; bh++) { - uint32_t out_block_w_start_tile_id = out_block_h_start_tile_id; - uint32_t out_block_w_start_tile_id_w = 0; - for (uint32_t bw = 0; bw < out_num_blocks_w; bw++) { - - uint32_t out_sbh_start_tile_id = out_block_w_start_tile_id; - uint32_t out_sbh_start_tile_id_h = out_block_h_start_tile_id_h; - for(uint32_t sbh = 0; sbh < out_num_subblocks_h; sbh++) { - uint32_t out_sbw_start_tile_id = out_sbh_start_tile_id; - uint32_t out_sbw_start_tile_id_w = out_block_w_start_tile_id_w; - for(uint32_t sbw = 0; sbw < out_num_subblocks_w; sbw++) { - uint32_t out_sb_row_start_tile_id = out_sbw_start_tile_id; - // wait for one subblock worth tiles - cb_wait_front(cb_id_out0, out_subblock_tile_count); - uint32_t l1_read_addr = get_read_ptr(cb_id_out0); - for(uint32_t h = 0; h < out_subblock_h; h++) { - uint32_t out_tile_id = out_sb_row_start_tile_id; - uint32_t out_tile_id_h = out_sbh_start_tile_id_h + h; - if (out_tile_id_h >= out_height_num_tiles) { // block shape height padding - break; - } - for(uint32_t w = 0; w < out_subblock_w; w++) { - uint32_t out_tile_id_w = out_sbw_start_tile_id_w + w; - if (out_tile_id_w >= out_width_num_tiles) { // block shape width padding - l1_read_addr += tile_nbytes; - } else { - //DPRINT << "out_tile_id - " << out_tile_id << ENDL(); - s.noc_async_write_tile(out_tile_id, l1_read_addr); - l1_read_addr += tile_nbytes; - out_tile_id += out_next_tile_stride_w; - } - } // out_subblock_w (ntiles) - out_sb_row_start_tile_id += out_next_tile_stride_h; - } // out_subblock_h (ntiles) - noc_async_write_barrier(); - //DPRINT << "Done writing subblock." << ENDL(); - cb_pop_front(cb_id_out0, out_subblock_tile_count); - out_sbw_start_tile_id += out_next_subblock_stride_w; - out_sbw_start_tile_id_w += out_subblock_w; - } // out_num_subblocks_w - out_sbh_start_tile_id += out_next_subblock_stride_h; - out_sbh_start_tile_id_h += out_subblock_h; - } // out_num_subblocks_h - out_block_w_start_tile_id += out_next_block_stride_w; - out_block_w_start_tile_id_w += weight_block_width_ntiles; - } // out_num_blocks_w - out_block_h_start_tile_id += out_next_block_stride_h; - out_block_h_start_tile_id_h += out_block_height_num_tiles; - } // out_num_blocks_h - - #else - cb_wait_front(cb_id_out0, out_subblock_tile_count * out_num_subblocks_h * out_num_subblocks_w * out_num_blocks_w * out_num_blocks_h); - #endif -} diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_and_reader_weights_resnet50_first_conv_tiled_out.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_and_reader_weights_resnet50_first_conv_tiled_out.cpp deleted file mode 100644 index b5fb1f3dfc5..00000000000 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_and_reader_weights_resnet50_first_conv_tiled_out.cpp +++ /dev/null @@ -1,222 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include -#include "dataflow_api.h" -// #include "debug/dprint.h" - -#ifdef FUSE_BIAS - #include "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/dataflow/reader_bmm_single_core_bias.hpp" -#endif - -FORCE_INLINE void read_weight_blocks_inner_h_dim(uint32_t cb_id_weight, - uint32_t num_blocks_weight_h, - uint32_t weight_block_num_tiles, - uint32_t weight_start_tile_id, - uint32_t weight_block_height_ntiles, - uint32_t weight_block_width_ntiles, - const InterleavedPow2AddrGen& s_weight, - uint32_t weight_tile_nbytes, - uint32_t weight_stride_h, - uint32_t weight_next_block_stride_h) { - // weight DRAM -> L1 (weights in tiled form) - uint32_t weight_current_block_start_tile_id = weight_start_tile_id; - for(uint32_t block_weight_h = 0; block_weight_h < num_blocks_weight_h; block_weight_h++) { - cb_reserve_back(cb_id_weight, weight_block_num_tiles); - uint32_t weight_write_l1_addr = get_write_ptr(cb_id_weight); - uint32_t weight_row_start_tile_id = weight_current_block_start_tile_id; - // loop over weight block tiles along h - for(uint32_t weight_tile_h_i = 0; weight_tile_h_i < weight_block_height_ntiles; ++weight_tile_h_i) { - uint32_t weight_tile_id = weight_row_start_tile_id; - // loop over weight block tiles along w - for(uint32_t weight_tile_w_i = 0; weight_tile_w_i < weight_block_width_ntiles; ++weight_tile_w_i) { - uint64_t weight_tile_noc_addr = get_noc_addr(weight_tile_id, s_weight); - noc_async_read(weight_tile_noc_addr, weight_write_l1_addr, weight_tile_nbytes); - weight_write_l1_addr += weight_tile_nbytes; - weight_tile_id += 1; - } // for weight_block_w - weight_row_start_tile_id += weight_stride_h; - } // for weight_block_h - noc_async_read_barrier(); - - weight_current_block_start_tile_id += weight_next_block_stride_h; - cb_push_back(cb_id_weight, weight_block_num_tiles); - } // for num_blocks_weight_h -} - -template -FORCE_INLINE void write_tiles_in_output_block(uint32_t cb_id_out0, - uint32_t block_height_ntiles, - uint32_t block_width_ntiles, - uint32_t block_start_row_id, - uint32_t block_row_offset, - uint32_t block_row_size, - uint32_t block_row_size_unpadded, // to remove padding from the last block in the row - uint32_t num_rows_unpadded, - const InterleavedPow2AddrGenFast& s) { - constexpr uint32_t TILE_HEIGHT = 32; // TODO: use common source of truth - uint32_t block_row_id = block_start_row_id; - for (uint32_t tile_row_id = 0; tile_row_id < block_height_ntiles; tile_row_id++) { - // We reserve back an entire row of tiles in a block and issue a bunch of reads - cb_wait_front(cb_id_out0, block_width_ntiles); - uint32_t l1_read_addr = get_read_ptr(cb_id_out0); - for (uint32_t j = 0; j < TILE_HEIGHT; j++) { - if (block_row_id >= num_rows_unpadded) { - break; - } - s.noc_async_write_page(block_row_id, l1_read_addr, block_row_size_unpadded, block_row_offset); - l1_read_addr += block_row_size; - block_row_id++; - } // for tile_nrows - noc_async_write_barrier(); - cb_pop_front(cb_id_out0, block_width_ntiles); - } // for block_height_ntiles -} - -void kernel_main() { - uint32_t i = 0; - uint32_t out_addr = get_arg_val(i); i+=1; - uint32_t weight_addr_dram_base = get_arg_val(i); i+=1; - // Bias args. Unused if bias fusion is not enabled. - const uint32_t bias_addr = get_arg_val(i); i += 1; - - uint32_t out_next_tile_stride_h = get_arg_val(i); i+=1; - uint32_t out_next_tile_stride_w = get_arg_val(i); i+=1; - uint32_t out_next_subblock_stride_h = get_arg_val(i); i+=1; - uint32_t out_next_subblock_stride_w = get_arg_val(i); i+=1; - uint32_t out_next_block_stride_h = get_arg_val(i); i+=1; - uint32_t out_next_block_stride_w = get_arg_val(i); i+=1; - uint32_t out_subblock_h = get_arg_val(i); i+=1; - uint32_t out_subblock_w = get_arg_val(i); i+=1; - uint32_t out_subblock_tile_count = get_arg_val(i); i+=1; - uint32_t out_num_subblocks_h = get_arg_val(i); i+=1; - uint32_t out_num_subblocks_w = get_arg_val(i); i+=1; - uint32_t out_num_blocks_h = get_arg_val(i); i+=1; - uint32_t out_num_blocks_w = get_arg_val(i); i+=1; - uint32_t out_block_height_num_tiles = get_arg_val(i); i+=1; - uint32_t out_height_num_tiles = get_arg_val(i); i+=1; - uint32_t out_width_num_tiles = get_arg_val(i); i+=1; - uint32_t out_start_tile_id = get_arg_val(i); i+=1; - uint32_t out_start_tile_id_h = get_arg_val(i); i+=1; - uint32_t out_start_tile_id_w = get_arg_val(i); i+=1; - - uint32_t num_blocks_weight_h = get_arg_val(i); i+=1; - uint32_t weight_block_num_tiles = get_arg_val(i); i+=1; - uint32_t weight_block_height_ntiles = get_arg_val(i); i+=1; - uint32_t weight_block_width_ntiles = get_arg_val(i); i+=1; - uint32_t weight_stride_h = get_arg_val(i); i+=1; - uint32_t weight_next_block_stride_h = get_arg_val(i); i+=1; - uint32_t weight_next_block_stride_w = get_arg_val(i); i+=1; - - // Bias args. Unused if bias fusion is not enabled. - const uint32_t bias_ntiles = get_arg_val(i); i += 1; - const uint32_t bias_tile_offset = get_arg_val(i); i += 1; - - uint32_t noop = get_arg_val(i); i+=1; - if(noop) { - return; - } - - - constexpr bool out_in_dram = get_compile_time_arg_val(0) == 1; - constexpr uint32_t cb_id_out0 = get_compile_time_arg_val(1); - constexpr uint32_t cb_id_weight = get_compile_time_arg_val(2); - - const DataFormat out_df = get_dataformat(cb_id_out0); - - const uint32_t tile_nbytes = get_tile_size(cb_id_out0); - constexpr uint32_t tile_size_pow2_exponent = 11; - const InterleavedPow2AddrGen s = { - .bank_base_address = out_addr, - .log_base_2_of_page_size = tile_size_pow2_exponent - }; - const uint32_t weight_tile_nbytes = get_tile_size(cb_id_weight); - const InterleavedPow2AddrGen s_weight = { - .bank_base_address = weight_addr_dram_base, - .log_base_2_of_page_size = tile_size_pow2_exponent - }; - //DPRINT << "Going to read all weights " << ENDL(); - uint32_t weight_start_tile_id = 0; - // read weight blocks inner dim - read_weight_blocks_inner_h_dim(cb_id_weight, - num_blocks_weight_h, - weight_block_num_tiles, - weight_start_tile_id, - weight_block_height_ntiles, - weight_block_width_ntiles, - s_weight, - weight_tile_nbytes, - weight_stride_h, - weight_next_block_stride_h); - //DPRINT << "Read all weights " << ENDL(); - - // first read in bias if enabled (done only once for all blocks) - #ifdef FUSE_BIAS - - constexpr uint32_t bias_cb_id = get_compile_time_arg_val(3); - constexpr uint32_t bias_log2_of_pagesize = get_compile_time_arg_val(4); - constexpr uint32_t bias_pagesize = get_compile_time_arg_val(5); - constexpr uint32_t bias_in_dram = get_compile_time_arg_val(6) == 1; - - read_bias(bias_addr, bias_ntiles, bias_cb_id, bias_log2_of_pagesize, bias_pagesize); - #endif - - #ifndef SHARDED_OUT - uint32_t out_block_h_start_tile_id = out_start_tile_id; - uint32_t out_block_h_start_tile_id_h = out_start_tile_id_h; - for(uint32_t bh = 0; bh < out_num_blocks_h; bh++) { - uint32_t out_block_w_start_tile_id = out_block_h_start_tile_id; - uint32_t out_block_w_start_tile_id_w = 0; - for (uint32_t bw = 0; bw < out_num_blocks_w; bw++) { - - uint32_t out_sbh_start_tile_id = out_block_w_start_tile_id; - uint32_t out_sbh_start_tile_id_h = out_block_h_start_tile_id_h; - for(uint32_t sbh = 0; sbh < out_num_subblocks_h; sbh++) { - uint32_t out_sbw_start_tile_id = out_sbh_start_tile_id; - uint32_t out_sbw_start_tile_id_w = out_block_w_start_tile_id_w; - for(uint32_t sbw = 0; sbw < out_num_subblocks_w; sbw++) { - uint32_t out_sb_row_start_tile_id = out_sbw_start_tile_id; - // wait for one subblock worth tiles - cb_wait_front(cb_id_out0, out_subblock_tile_count); - uint32_t l1_read_addr = get_read_ptr(cb_id_out0); - for(uint32_t h = 0; h < out_subblock_h; h++) { - uint32_t out_tile_id = out_sb_row_start_tile_id; - uint32_t out_tile_id_h = out_sbh_start_tile_id_h + h; - if (out_tile_id_h >= out_height_num_tiles) { // block shape height padding - break; - } - for(uint32_t w = 0; w < out_subblock_w; w++) { - uint32_t out_tile_id_w = out_sbw_start_tile_id_w + w; - if (out_tile_id_w >= out_width_num_tiles) { // block shape width padding - l1_read_addr += tile_nbytes; - } else { - //DPRINT << "out_tile_id - " << out_tile_id << ENDL(); - uint64_t out_tile_noc_addr = get_noc_addr(out_tile_id, s); - noc_async_write(l1_read_addr, out_tile_noc_addr, tile_nbytes); - l1_read_addr += tile_nbytes; - out_tile_id += out_next_tile_stride_w; - } - } // out_subblock_w (ntiles) - out_sb_row_start_tile_id += out_next_tile_stride_h; - } // out_subblock_h (ntiles) - noc_async_write_barrier(); - //DPRINT << "Done writing subblock." << ENDL(); - cb_pop_front(cb_id_out0, out_subblock_tile_count); - out_sbw_start_tile_id += out_next_subblock_stride_w; - out_sbw_start_tile_id_w += out_subblock_w; - } // out_num_subblocks_w - out_sbh_start_tile_id += out_next_subblock_stride_h; - out_sbh_start_tile_id_h += out_subblock_h; - } // out_num_subblocks_h - out_block_w_start_tile_id += out_next_block_stride_w; - out_block_w_start_tile_id_w += weight_block_width_ntiles; - } // out_num_blocks_w - out_block_h_start_tile_id += out_next_block_stride_h; - out_block_h_start_tile_id_h += out_block_height_num_tiles; - } // out_num_blocks_h - - #else - cb_wait_front(cb_id_out0, out_subblock_tile_count * out_num_subblocks_h * out_num_subblocks_w * out_num_blocks_w * out_num_blocks_h); - #endif -} diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_and_reader_weights_resnet50_first_conv_untilize_out.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_and_reader_weights_resnet50_first_conv_untilize_out.cpp deleted file mode 100644 index 9004ba251d1..00000000000 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_and_reader_weights_resnet50_first_conv_untilize_out.cpp +++ /dev/null @@ -1,151 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include -#include "dataflow_api.h" -// #include "debug/dprint.h" - -FORCE_INLINE void read_weight_blocks_inner_h_dim(uint32_t cb_id_weight, - uint32_t num_blocks_weight_h, - uint32_t weight_block_num_tiles, - uint32_t weight_start_tile_id, - uint32_t weight_block_height_ntiles, - uint32_t weight_block_width_ntiles, - const InterleavedPow2AddrGen& s_weight, - uint32_t weight_tile_nbytes, - uint32_t weight_stride_h, - uint32_t weight_next_block_stride_h) { - // weight DRAM -> L1 (weights in tiled form) - uint32_t weight_current_block_start_tile_id = weight_start_tile_id; - for(uint32_t block_weight_h = 0; block_weight_h < num_blocks_weight_h; block_weight_h++) { - cb_reserve_back(cb_id_weight, weight_block_num_tiles); - uint32_t weight_write_l1_addr = get_write_ptr(cb_id_weight); - uint32_t weight_row_start_tile_id = weight_current_block_start_tile_id; - // loop over weight block tiles along h - for(uint32_t weight_tile_h_i = 0; weight_tile_h_i < weight_block_height_ntiles; ++weight_tile_h_i) { - uint32_t weight_tile_id = weight_row_start_tile_id; - // loop over weight block tiles along w - for(uint32_t weight_tile_w_i = 0; weight_tile_w_i < weight_block_width_ntiles; ++weight_tile_w_i) { - uint64_t weight_tile_noc_addr = get_noc_addr(weight_tile_id, s_weight); - noc_async_read(weight_tile_noc_addr, weight_write_l1_addr, weight_tile_nbytes); - weight_write_l1_addr += weight_tile_nbytes; - weight_tile_id += 1; - } // for weight_block_w - weight_row_start_tile_id += weight_stride_h; - } // for weight_block_h - noc_async_read_barrier(); - - weight_current_block_start_tile_id += weight_next_block_stride_h; - cb_push_back(cb_id_weight, weight_block_num_tiles); - } // for num_blocks_weight_h -} - -template -FORCE_INLINE void write_tiles_in_output_block(uint32_t cb_id_out0, - uint32_t block_height_ntiles, - uint32_t block_width_ntiles, - uint32_t block_start_row_id, - uint32_t block_row_offset, - uint32_t block_row_size, - uint32_t block_row_size_unpadded, // to remove padding from the last block in the row - uint32_t num_rows_unpadded, - const InterleavedPow2AddrGenFast& s) { - constexpr uint32_t TILE_HEIGHT = 32; // TODO: use common source of truth - uint32_t block_row_id = block_start_row_id; - for (uint32_t tile_row_id = 0; tile_row_id < block_height_ntiles; tile_row_id++) { - // We reserve back an entire row of tiles in a block and issue a bunch of reads - cb_wait_front(cb_id_out0, block_width_ntiles); - uint32_t l1_read_addr = get_read_ptr(cb_id_out0); - for (uint32_t j = 0; j < TILE_HEIGHT; j++) { - if (block_row_id >= num_rows_unpadded) { - break; - } - s.noc_async_write_page(block_row_id, l1_read_addr, block_row_size_unpadded, block_row_offset); - l1_read_addr += block_row_size; - block_row_id++; - } // for tile_nrows - noc_async_write_barrier(); - cb_pop_front(cb_id_out0, block_width_ntiles); - } // for block_height_ntiles -} - -void kernel_main() { - uint32_t i = 0; - uint32_t dst_addr = get_arg_val(i); i+=1; // out_dram_addr - uint32_t weight_addr_dram_base = get_arg_val(i); i+=1; - - uint32_t num_rows_block = get_arg_val(i); i+=1; - uint32_t block_row_size = get_arg_val(i); i+=1; // in0_block_w * TILE_WIDTH * dtype_nbytes - uint32_t batch = get_arg_val(i); i+=1; - uint32_t num_blocks_h = get_arg_val(i); i+=1; - uint32_t num_blocks_w = get_arg_val(i); i+=1; - uint32_t output_row_size = get_arg_val(i); i+=1; // output row size bytes - uint32_t last_block_row_size_unpadded = get_arg_val(i); i+=1; // unpadded last block width - uint32_t num_output_rows_unpadded = get_arg_val(i); i+=1; - - uint32_t num_blocks_weight_h = get_arg_val(i); i+=1; - uint32_t weight_block_num_tiles = get_arg_val(i); i+=1; - uint32_t weight_block_height_ntiles = get_arg_val(i); i+=1; - uint32_t weight_block_width_ntiles = get_arg_val(i); i+=1; - uint32_t weight_stride_h = get_arg_val(i); i+=1; - uint32_t weight_next_block_stride_h = get_arg_val(i); i+=1; - uint32_t weight_next_block_stride_w = get_arg_val(i); i+=1; - - - constexpr bool out_in_dram = get_compile_time_arg_val(0) == 1; - constexpr uint32_t cb_id_out0 = get_compile_time_arg_val(1); - constexpr uint32_t cb_id_weight = get_compile_time_arg_val(2); - constexpr uint32_t log_2_of_output_row_size = get_compile_time_arg_val(3); - // NOTE: Row major layout only supports bfp16 - // TT_ASSERT(out_df != DataFormat::Bfp8_b); - const DataFormat out_df = get_dataformat(cb_id_out0); - - constexpr uint32_t TILE_HEIGHT = 32; // TODO: use common source of truth - - const uint32_t block_width_ntiles = block_row_size >> 6; // Assuming 2 bytes per datum, there are 64 bytes per tile row - const uint32_t block_height_ntiles = num_rows_block / TILE_HEIGHT; - uint32_t block_start_row_id = 0; - - const InterleavedPow2AddrGenFast s = { - .bank_base_address = dst_addr, - .log_base_2_of_page_size = log_2_of_output_row_size - }; - const uint32_t weight_tile_nbytes = get_tile_size(cb_id_weight); - constexpr uint32_t tile_size_pow2_exponent = 11; - const InterleavedPow2AddrGen s_weight = { - .bank_base_address = weight_addr_dram_base, - .log_base_2_of_page_size = tile_size_pow2_exponent - }; - //DPRINT << "Going to read all weights " << ENDL(); - uint32_t weight_start_tile_id = 0; - // read weight blocks inner dim - read_weight_blocks_inner_h_dim(cb_id_weight, - num_blocks_weight_h, - weight_block_num_tiles, - weight_start_tile_id, - weight_block_height_ntiles, - weight_block_width_ntiles, - s_weight, - weight_tile_nbytes, - weight_stride_h, - weight_next_block_stride_h); - //DPRINT << "Read all weights " << ENDL(); - - - for(uint32_t block_h = 0; block_h < num_blocks_h; block_h++) { // num_blocks_w == 1 - uint32_t block_row_offset = 0; - uint32_t current_block_row_size_unpadded = block_row_size; - write_tiles_in_output_block(cb_id_out0, - block_height_ntiles, - block_width_ntiles, - block_start_row_id, - block_row_offset, - block_row_size, - current_block_row_size_unpadded, // padding is only in the last block - num_output_rows_unpadded, - s); - block_row_offset += block_row_size; - block_start_row_id += num_rows_block; - } // for num_blocks_h -} diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_matmul_tile_layout.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_matmul_tile_layout.cpp deleted file mode 100644 index 84fac8e1431..00000000000 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_matmul_tile_layout.cpp +++ /dev/null @@ -1,74 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include "dataflow_api.h" - -void kernel_main() { - - - // out tensor args - uint32_t out_tensor_addr = get_arg_val(0); - uint32_t out_tensor_start_tile_id = get_arg_val(1); - uint32_t out_tensor_stride_w = get_arg_val(2); - uint32_t out_tensor_stride_h = get_arg_val(3); - uint32_t out_tensor_next_subblock_stride_w = get_arg_val(4); - uint32_t out_tensor_next_subblock_stride_h = get_arg_val(5); - - // out subblock args - uint32_t out_subblock_w = get_arg_val(6); - uint32_t out_subblock_h = get_arg_val(7); - uint32_t out_subblock_tile_count = get_arg_val(8); - uint32_t out_num_subblocks_w = get_arg_val(9); - uint32_t out_num_subblocks_h = get_arg_val(10); - - // const args for tile-based bank-swizzled layout - // could be added to the arg list in the future to test different - // bank-swizzling configurations - constexpr uint32_t num_used_dram_ch = 8; - constexpr uint32_t num_used_dram_ch_pow2_exponent = 3; - constexpr uint32_t tile_size_pow2_exponent = 11; - - constexpr uint32_t cb_id_out0 = 16; - - // single-tile - uint32_t single_tile_size_bytes = get_tile_size(cb_id_out0); - - const InterleavedPow2AddrGen s = { - .bank_base_address = out_tensor_addr, - - - .log_base_2_of_page_size = tile_size_pow2_exponent - }; - - - bool one_time_profile = true; - uint32_t out_tensor_sbh_start_tile_id = out_tensor_start_tile_id; - for(uint32_t sbh = 0; sbh < out_num_subblocks_h; sbh++) { - uint32_t out_tensor_sbw_start_tile_id = out_tensor_sbh_start_tile_id; - for(uint32_t sbw = 0; sbw < out_num_subblocks_w; sbw++) { - uint32_t out_tensor_sb_row_start_tile_id = out_tensor_sbw_start_tile_id; - - cb_wait_front(cb_id_out0, out_subblock_tile_count); - uint32_t l1_read_addr = get_read_ptr(cb_id_out0); - - for(uint32_t h = 0; h < out_subblock_h; h++) { - uint32_t out_tensor_tile_id = out_tensor_sb_row_start_tile_id; - for(uint32_t w = 0; w < out_subblock_w; w++) { - uint64_t out_tensor_tile_noc_addr = get_noc_addr(out_tensor_tile_id, s); - - noc_async_write(l1_read_addr, out_tensor_tile_noc_addr, single_tile_size_bytes); - l1_read_addr+=single_tile_size_bytes; - - out_tensor_tile_id += out_tensor_stride_w; - } - out_tensor_sb_row_start_tile_id += out_tensor_stride_h; - } - - noc_async_write_barrier(); - cb_pop_front(cb_id_out0, out_subblock_tile_count); - out_tensor_sbw_start_tile_id += out_tensor_next_subblock_stride_w; - } - out_tensor_sbh_start_tile_id += out_tensor_next_subblock_stride_h; - } -} diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_mcast_receiver_conv_weights_tiled_col_to_rm_blocks.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_mcast_receiver_conv_weights_tiled_col_to_rm_blocks.cpp deleted file mode 100644 index 11fa9cee5c9..00000000000 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_mcast_receiver_conv_weights_tiled_col_to_rm_blocks.cpp +++ /dev/null @@ -1,223 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include "dataflow_api.h" - -// #include "debug/dprint.h" - - -void kernel_main() { - // This writer is for output tensor in tile format - uint32_t i = 0; - uint32_t out_addr = get_arg_val(i); i+=1; - uint32_t weight_addr_dram_base = get_arg_val(i); i+=1; - // Bias arg. Unused if bias fusion is not enabled. - const uint32_t bias_addr = get_arg_val(i); i += 1; - - uint32_t out_next_tile_stride_h = get_arg_val(i); i+=1; - uint32_t out_next_tile_stride_w = get_arg_val(i); i+=1; - uint32_t out_next_subblock_stride_h = get_arg_val(i); i+=1; - uint32_t out_next_subblock_stride_w = get_arg_val(i); i+=1; - uint32_t out_next_block_stride_h = get_arg_val(i); i+=1; - uint32_t out_next_block_stride_w = get_arg_val(i); i+=1; - uint32_t out_subblock_h = get_arg_val(i); i+=1; - uint32_t out_subblock_w = get_arg_val(i); i+=1; - uint32_t out_subblock_tile_count = get_arg_val(i); i+=1; - uint32_t out_num_subblocks_h = get_arg_val(i); i+=1; - uint32_t out_num_subblocks_w = get_arg_val(i); i+=1; - uint32_t out_num_blocks_h = get_arg_val(i); i+=1; - uint32_t out_num_blocks_w = get_arg_val(i); i+=1; - uint32_t out_block_height_num_tiles = get_arg_val(i); i+=1; - uint32_t out_height_num_tiles = get_arg_val(i); i+=1; - uint32_t out_width_num_tiles = get_arg_val(i); i+=1; - uint32_t out_start_tile_id = get_arg_val(i); i+=1; - uint32_t out_start_tile_id_h = get_arg_val(i); i+=1; - uint32_t out_start_tile_id_w = get_arg_val(i); i+=1; - - uint32_t num_blocks_weight_h = get_arg_val(i); i+=1; - uint32_t weight_block_num_tiles = get_arg_val(i); i+=1; - uint32_t weight_block_height_num_outer = get_arg_val(i); i+=1; - uint32_t weight_block_height_ntiles = get_arg_val(i); i+=1; - uint32_t weight_block_width_ntiles = get_arg_val(i); i+=1; - uint32_t weight_stride_h = get_arg_val(i); i+=1; - uint32_t weight_next_block_stride_h = get_arg_val(i); i+=1; - uint32_t weight_next_block_stride_w = get_arg_val(i); i+=1; - - // Bias arg. Unused if bias fusion is not enabled. - const uint32_t bias_ntiles = get_arg_val(i); i += 1; - const uint32_t bias_tile_offset = get_arg_val(i); i += 1; - - uint32_t noop = get_arg_val(i); i+=1; - if(noop) { - return; - } - - // mcast args - uint32_t weights_mcast_sender_noc_x = get_arg_val(i); i+=1; - uint32_t weights_mcast_sender_noc_y = get_arg_val(i); i+=1; - uint32_t weights_mcast_sender_semaphore_addr = get_semaphore(get_arg_val(i)); i+=1; - uint32_t weights_mcast_receiver_semaphore_addr = get_semaphore(get_arg_val(i)); i+=1; - - constexpr bool out_in_dram = get_compile_time_arg_val(0) == 1; - constexpr uint32_t cb_id_out0 = get_compile_time_arg_val(1); - constexpr uint32_t cb_id_weight = get_compile_time_arg_val(2); - - volatile tt_l1_ptr uint32_t* weights_mcast_receiver_semaphore_addr_ptr = reinterpret_cast(weights_mcast_receiver_semaphore_addr); - - const uint32_t tile_nbytes = get_tile_size(cb_id_out0); - const DataFormat out_df = get_dataformat(cb_id_out0); - - const InterleavedAddrGenFast s = { - .bank_base_address = out_addr, - .page_size = tile_nbytes, - .data_format = out_df - }; - - // read in bias if enabled (done only once for all batches) - #ifdef FUSE_BIAS - constexpr uint32_t bias_cb_id = get_compile_time_arg_val(3); - bool load_bias = true; - #endif - - // DPRINT << "tile_nbytes - " << tile_nbytes << ENDL(); - // DPRINT << "out_num_blocks_h - " << out_num_blocks_h << ENDL(); - // DPRINT << "out_num_blocks_w - " << out_num_blocks_w << ENDL(); - - // DPRINT << "out_num_subblocks_h - " << out_num_subblocks_h << ENDL(); - // DPRINT << "out_num_subblocks_w - " << out_num_subblocks_w << ENDL(); - - // DPRINT << "out_subblock_h - " << out_subblock_h << ENDL(); - // DPRINT << "out_subblock_w - " << out_subblock_w << ENDL(); - - // DPRINT << "out_subblock_tile_count - " << out_subblock_tile_count << ENDL(); - - // DPRINT << "num_blocks_weight_h - " << num_blocks_weight_h << ENDL(); - // DPRINT << "weight_block_height_ntiles - " << weight_block_height_ntiles << ENDL(); - // DPRINT << "weight_block_width_ntiles - " << weight_block_width_ntiles << ENDL(); - - // DPRINT << "out_subblock_h - " << out_subblock_h << ENDL(); - // DPRINT << "out_subblock_w - " << out_subblock_w << ENDL(); - // DPRINT << "out_block_height_num_tiles - " << out_block_height_num_tiles << ENDL(); - // DPRINT << "out_height_num_tiles - " << out_height_num_tiles << ENDL(); - // DPRINT << "out_width_num_tiles - " << out_width_num_tiles << ENDL(); - - // const uint32_t weight_tile_nbytes = get_tile_size(cb_id_weight); - // const InterleavedPow2AddrGen s_weight = { - // .bank_base_address = weight_addr_dram_base, - // .log_base_2_of_page_size = tile_size_pow2_exponent - // }; - - // const InterleavedAddrGenFast s = { - // .bank_base_address = out_addr, - // .page_size = tile_nbytes, - // .data_format = out_df - // }; - - // OUTER most loop is looping over out blocks in width dim because blocks from compute are in col major order. - // Write out col major blocks in row major layout to output - uint32_t out_block_w_start_tile_id = out_start_tile_id; - //DPRINT << "out_start_tile_id=" << out_start_tile_id << ENDL(); - uint32_t out_block_w_start_tile_id_w = out_start_tile_id_w; - uint32_t weight_start_tile_id = out_start_tile_id_w; - //DPRINT << "weight_start_tile_id=" << weight_start_tile_id << ENDL(); - for (uint32_t bw = 0; bw < out_num_blocks_w; bw++) { - uint32_t out_block_h_start_tile_id = out_block_w_start_tile_id; - uint32_t out_block_h_start_tile_id_h = out_start_tile_id_h; - for(uint32_t bh = 0; bh < out_num_blocks_h; bh++) { - // MCAST RECEIVE WEIGHTS - // read weight blocks inner dim - // read weight slice - 1 block of weights in width dim and full weight matrix height - // read slice only once for all activation blocks - for(uint32_t weight_tile_h_outer_i = 0; weight_tile_h_outer_i < weight_block_height_num_outer; weight_tile_h_outer_i++) { - for(uint32_t block_weight_h = 0; block_weight_h < num_blocks_weight_h; block_weight_h++) { - cb_reserve_back(cb_id_weight, weight_block_num_tiles); - // Set weights semaphore value to INVALID - noc_semaphore_set(weights_mcast_receiver_semaphore_addr_ptr, INVALID); - - // Atomic increment source core counter - uint64_t weights_mcast_sender_semaphore_noc_addr = get_noc_addr(weights_mcast_sender_noc_x, weights_mcast_sender_noc_y, weights_mcast_sender_semaphore_addr); - noc_semaphore_inc(weights_mcast_sender_semaphore_noc_addr, 1); - - // wait on weights semaphore value to become VALID (set by mcast sender after it multicasts data) - noc_semaphore_wait(weights_mcast_receiver_semaphore_addr_ptr, VALID); - - cb_push_back(cb_id_weight, weight_block_num_tiles); - } // for num_blocks_weight_h - } // for weight_block_height_num_outer - - #ifdef FUSE_BIAS - if (load_bias) { - cb_reserve_back(bias_cb_id, bias_ntiles); - - // Set weights semaphore value to INVALID - noc_semaphore_set(weights_mcast_receiver_semaphore_addr_ptr, INVALID); - - // Atomic increment source core counter - uint64_t weights_mcast_sender_semaphore_noc_addr = get_noc_addr(weights_mcast_sender_noc_x, weights_mcast_sender_noc_y, weights_mcast_sender_semaphore_addr); - noc_semaphore_inc(weights_mcast_sender_semaphore_noc_addr, 1); - - // wait on weights semaphore value to become VALID (set by mcast sender after it multicasts data) - noc_semaphore_wait(weights_mcast_receiver_semaphore_addr_ptr, VALID); - - cb_push_back(bias_cb_id, bias_ntiles); - load_bias = false; - } - #endif - - #ifndef SHARDED_OUT - uint32_t out_sbh_start_tile_id = out_block_h_start_tile_id; - uint32_t out_sbh_start_tile_id_h = out_block_h_start_tile_id_h; // - for(uint32_t sbh = 0; sbh < out_num_subblocks_h; sbh++) { - uint32_t out_sbw_start_tile_id = out_sbh_start_tile_id; - uint32_t out_sbw_start_tile_id_w = out_block_w_start_tile_id_w; - for(uint32_t sbw = 0; sbw < out_num_subblocks_w; sbw++) { - uint32_t out_sb_row_start_tile_id = out_sbw_start_tile_id; - // wait for one subblock worth tiles - cb_wait_front(cb_id_out0, out_subblock_tile_count); - uint32_t l1_read_addr = get_read_ptr(cb_id_out0); - for(uint32_t h = 0; h < out_subblock_h; h++) { - uint32_t out_tile_id = out_sb_row_start_tile_id; - uint32_t out_tile_id_h = out_sbh_start_tile_id_h + h; - if (out_tile_id_h >= out_height_num_tiles) { // block shape height padding - break; - } - for(uint32_t w = 0; w < out_subblock_w; w++) { - uint32_t out_tile_id_w = out_sbw_start_tile_id_w + w; - if (out_tile_id_w >= out_width_num_tiles) { // block shape width padding - l1_read_addr += tile_nbytes; - } else { - //DPRINT << "out_tile_id - " << out_tile_id << ENDL(); - uint64_t out_tile_noc_addr = get_noc_addr(out_tile_id, s); - //DPRINT << "out_tile_id=" << out_tile_id << ENDL(); - noc_async_write(l1_read_addr, out_tile_noc_addr, tile_nbytes); - l1_read_addr += tile_nbytes; - out_tile_id += out_next_tile_stride_w; - } - } // out_subblock_w (ntiles) - out_sb_row_start_tile_id += out_next_tile_stride_h; - } // out_subblock_h (ntiles) - noc_async_write_barrier(); - //DPRINT << "Done writing subblock." << ENDL(); - cb_pop_front(cb_id_out0, out_subblock_tile_count); - out_sbw_start_tile_id += out_next_subblock_stride_w; - out_sbw_start_tile_id_w += out_subblock_w; - } // out_num_subblocks_w - out_sbh_start_tile_id += out_next_subblock_stride_h; - out_sbh_start_tile_id_h += out_subblock_h; - } // out_num_subblocks_h - out_block_h_start_tile_id += out_next_block_stride_h; - out_block_h_start_tile_id_h += out_block_height_num_tiles; - #endif - } // out_num_blocks_h - out_block_w_start_tile_id += out_next_block_stride_w; - out_block_w_start_tile_id_w += weight_block_width_ntiles; - - // Increment weight start tile id for next block in width dim - weight_start_tile_id += weight_next_block_stride_w; - } // out_num_blocks_w - - #ifdef SHARDED_OUT - cb_wait_front(cb_id_out0, out_subblock_tile_count * out_num_subblocks_h * out_num_subblocks_w * out_num_blocks_w * out_num_blocks_h); - #endif -} diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_mcast_receiver_conv_weights_tiled_col_to_rm_blocks_num_blocks_weight_h_eq_1.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_mcast_receiver_conv_weights_tiled_col_to_rm_blocks_num_blocks_weight_h_eq_1.cpp deleted file mode 100644 index b9892d485aa..00000000000 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_mcast_receiver_conv_weights_tiled_col_to_rm_blocks_num_blocks_weight_h_eq_1.cpp +++ /dev/null @@ -1,222 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include "dataflow_api.h" - -// #include "debug/dprint.h" - - -void kernel_main() { - // This writer is for output tensor in tile format - uint32_t i = 0; - uint32_t out_addr = get_arg_val(i); i+=1; - uint32_t weight_addr_dram_base = get_arg_val(i); i+=1; - // Bias arg. Unused if bias fusion is not enabled. - const uint32_t bias_addr = get_arg_val(i); i += 1; - - uint32_t out_next_tile_stride_h = get_arg_val(i); i+=1; - uint32_t out_next_tile_stride_w = get_arg_val(i); i+=1; - uint32_t out_next_subblock_stride_h = get_arg_val(i); i+=1; - uint32_t out_next_subblock_stride_w = get_arg_val(i); i+=1; - uint32_t out_next_block_stride_h = get_arg_val(i); i+=1; - uint32_t out_next_block_stride_w = get_arg_val(i); i+=1; - uint32_t out_subblock_h = get_arg_val(i); i+=1; - uint32_t out_subblock_w = get_arg_val(i); i+=1; - uint32_t out_subblock_tile_count = get_arg_val(i); i+=1; - uint32_t out_num_subblocks_h = get_arg_val(i); i+=1; - uint32_t out_num_subblocks_w = get_arg_val(i); i+=1; - uint32_t out_num_blocks_h = get_arg_val(i); i+=1; - uint32_t out_num_blocks_w = get_arg_val(i); i+=1; - uint32_t out_block_height_num_tiles = get_arg_val(i); i+=1; - uint32_t out_height_num_tiles = get_arg_val(i); i+=1; - uint32_t out_width_num_tiles = get_arg_val(i); i+=1; - uint32_t out_start_tile_id = get_arg_val(i); i+=1; - uint32_t out_start_tile_id_h = get_arg_val(i); i+=1; - uint32_t out_start_tile_id_w = get_arg_val(i); i+=1; - - uint32_t num_blocks_weight_h = get_arg_val(i); i+=1; - uint32_t weight_block_num_tiles = get_arg_val(i); i+=1; - uint32_t weight_block_height_ntiles = get_arg_val(i); i+=1; - uint32_t weight_block_width_ntiles = get_arg_val(i); i+=1; - uint32_t weight_stride_h = get_arg_val(i); i+=1; - uint32_t weight_next_block_stride_h = get_arg_val(i); i+=1; - uint32_t weight_next_block_stride_w = get_arg_val(i); i+=1; - - // Bias arg. Unused if bias fusion is not enabled. - const uint32_t bias_ntiles = get_arg_val(i); i += 1; - const uint32_t bias_tile_offset = get_arg_val(i); i += 1; - - uint32_t noop = get_arg_val(i); i+=1; - if(noop) { - return; - } - - // mcast args - uint32_t weights_mcast_sender_noc_x = get_arg_val(i); i+=1; - uint32_t weights_mcast_sender_noc_y = get_arg_val(i); i+=1; - uint32_t weights_mcast_sender_semaphore_addr = get_arg_val(i); i+=1; - uint32_t weights_mcast_receiver_semaphore_addr = get_arg_val(i); i+=1; - - constexpr bool out_in_dram = get_compile_time_arg_val(0) == 1; - constexpr uint32_t cb_id_out0 = get_compile_time_arg_val(1); - constexpr uint32_t cb_id_weight = get_compile_time_arg_val(2); - - volatile tt_l1_ptr uint32_t* weights_mcast_receiver_semaphore_addr_ptr = reinterpret_cast(weights_mcast_receiver_semaphore_addr); - - const uint32_t tile_nbytes = get_tile_size(cb_id_out0); - const DataFormat out_df = get_dataformat(cb_id_out0); - - constexpr uint32_t tile_size_pow2_exponent = 11; // == 2^11 = 2048 = 2 * 32 * 32 (assuming dtype = 2 bytes) - const InterleavedPow2AddrGen s = { - .bank_base_address = out_addr, - .log_base_2_of_page_size = tile_size_pow2_exponent - }; - - // read in bias if enabled (done only once for all batches) - #ifdef FUSE_BIAS - constexpr uint32_t bias_cb_id = get_compile_time_arg_val(3); - bool load_bias = true; - #endif - - // DPRINT << "tile_nbytes - " << tile_nbytes << ENDL(); - // DPRINT << "out_num_blocks_h - " << out_num_blocks_h << ENDL(); - // DPRINT << "out_num_blocks_w - " << out_num_blocks_w << ENDL(); - - // DPRINT << "out_num_subblocks_h - " << out_num_subblocks_h << ENDL(); - // DPRINT << "out_num_subblocks_w - " << out_num_subblocks_w << ENDL(); - - // DPRINT << "out_subblock_h - " << out_subblock_h << ENDL(); - // DPRINT << "out_subblock_w - " << out_subblock_w << ENDL(); - - // DPRINT << "out_subblock_tile_count - " << out_subblock_tile_count << ENDL(); - - // DPRINT << "num_blocks_weight_h - " << num_blocks_weight_h << ENDL(); - // DPRINT << "weight_block_height_ntiles - " << weight_block_height_ntiles << ENDL(); - // DPRINT << "weight_block_width_ntiles - " << weight_block_width_ntiles << ENDL(); - - // DPRINT << "out_subblock_h - " << out_subblock_h << ENDL(); - // DPRINT << "out_subblock_w - " << out_subblock_w << ENDL(); - // DPRINT << "out_block_height_num_tiles - " << out_block_height_num_tiles << ENDL(); - // DPRINT << "out_height_num_tiles - " << out_height_num_tiles << ENDL(); - // DPRINT << "out_width_num_tiles - " << out_width_num_tiles << ENDL(); - - const uint32_t weight_tile_nbytes = get_tile_size(cb_id_weight); - const InterleavedPow2AddrGen s_weight = { - .bank_base_address = weight_addr_dram_base, - .log_base_2_of_page_size = tile_size_pow2_exponent - }; - - // const InterleavedAddrGenFast s = { - // .bank_base_address = out_addr, - // .page_size = tile_nbytes, - // .data_format = out_df - // }; - - // OUTER most loop is looping over out blocks in width dim because blocks from compute are in col major order. - // Write out col major blocks in row major layout to output - uint32_t out_block_w_start_tile_id = out_start_tile_id; - //DPRINT << "out_start_tile_id=" << out_start_tile_id << ENDL(); - uint32_t out_block_w_start_tile_id_w = out_start_tile_id_w; - uint32_t weight_start_tile_id = out_start_tile_id_w; - //DPRINT << "weight_start_tile_id=" << weight_start_tile_id << ENDL(); - for (uint32_t bw = 0; bw < out_num_blocks_w; bw++) { - - // MCAST RECEIVE WEIGHTS - // read weight blocks inner dim - // read weight slice - 1 block of weights in width dim and full weight matrix height - // read slice only once for all activation blocks - cb_reserve_back(cb_id_weight, weight_block_num_tiles); - - // Set weights semaphore value to INVALID - noc_semaphore_set(weights_mcast_receiver_semaphore_addr_ptr, INVALID); - - // Atomic increment source core counter - uint64_t weights_mcast_sender_semaphore_noc_addr = get_noc_addr(weights_mcast_sender_noc_x, weights_mcast_sender_noc_y, weights_mcast_sender_semaphore_addr); - noc_semaphore_inc(weights_mcast_sender_semaphore_noc_addr, 1); - - // wait on weights semaphore value to become VALID (set by mcast sender after it multicasts data) - noc_semaphore_wait(weights_mcast_receiver_semaphore_addr_ptr, VALID); - - cb_push_back(cb_id_weight, weight_block_num_tiles); - - #ifdef FUSE_BIAS - if (load_bias) { - cb_reserve_back(bias_cb_id, bias_ntiles); - - // Set weights semaphore value to INVALID - noc_semaphore_set(weights_mcast_receiver_semaphore_addr_ptr, INVALID); - - // Atomic increment source core counter - uint64_t weights_mcast_sender_semaphore_noc_addr = get_noc_addr(weights_mcast_sender_noc_x, weights_mcast_sender_noc_y, weights_mcast_sender_semaphore_addr); - noc_semaphore_inc(weights_mcast_sender_semaphore_noc_addr, 1); - - // wait on weights semaphore value to become VALID (set by mcast sender after it multicasts data) - noc_semaphore_wait(weights_mcast_receiver_semaphore_addr_ptr, VALID); - - cb_push_back(bias_cb_id, bias_ntiles); - load_bias = false; - } - #endif - - #ifndef SHARDED_OUT - uint32_t out_block_h_start_tile_id = out_block_w_start_tile_id; - //DPRINT << "out_block_h_start_tile_id=" << out_block_h_start_tile_id << ENDL(); - uint32_t out_block_h_start_tile_id_h = out_start_tile_id_h; - for(uint32_t bh = 0; bh < out_num_blocks_h; bh++) { - - uint32_t out_sbh_start_tile_id = out_block_h_start_tile_id; - uint32_t out_sbh_start_tile_id_h = out_block_h_start_tile_id_h; // - for(uint32_t sbh = 0; sbh < out_num_subblocks_h; sbh++) { - uint32_t out_sbw_start_tile_id = out_sbh_start_tile_id; - uint32_t out_sbw_start_tile_id_w = out_block_w_start_tile_id_w; - for(uint32_t sbw = 0; sbw < out_num_subblocks_w; sbw++) { - uint32_t out_sb_row_start_tile_id = out_sbw_start_tile_id; - // wait for one subblock worth tiles - cb_wait_front(cb_id_out0, out_subblock_tile_count); - uint32_t l1_read_addr = get_read_ptr(cb_id_out0); - for(uint32_t h = 0; h < out_subblock_h; h++) { - uint32_t out_tile_id = out_sb_row_start_tile_id; - uint32_t out_tile_id_h = out_sbh_start_tile_id_h + h; - if (out_tile_id_h >= out_height_num_tiles) { // block shape height padding - break; - } - for(uint32_t w = 0; w < out_subblock_w; w++) { - uint32_t out_tile_id_w = out_sbw_start_tile_id_w + w; - if (out_tile_id_w >= out_width_num_tiles) { // block shape width padding - l1_read_addr += tile_nbytes; - } else { - //DPRINT << "out_tile_id - " << out_tile_id << ENDL(); - uint64_t out_tile_noc_addr = get_noc_addr(out_tile_id, s); - //DPRINT << "out_tile_id=" << out_tile_id << ENDL(); - noc_async_write(l1_read_addr, out_tile_noc_addr, tile_nbytes); - l1_read_addr += tile_nbytes; - out_tile_id += out_next_tile_stride_w; - } - } // out_subblock_w (ntiles) - out_sb_row_start_tile_id += out_next_tile_stride_h; - } // out_subblock_h (ntiles) - noc_async_write_barrier(); - //DPRINT << "Done writing subblock." << ENDL(); - cb_pop_front(cb_id_out0, out_subblock_tile_count); - out_sbw_start_tile_id += out_next_subblock_stride_w; - out_sbw_start_tile_id_w += out_subblock_w; - } // out_num_subblocks_w - out_sbh_start_tile_id += out_next_subblock_stride_h; - out_sbh_start_tile_id_h += out_subblock_h; - } // out_num_subblocks_h - out_block_h_start_tile_id += out_next_block_stride_h; - out_block_h_start_tile_id_h += out_block_height_num_tiles; - } // out_num_blocks_h - out_block_w_start_tile_id += out_next_block_stride_w; - out_block_w_start_tile_id_w += weight_block_width_ntiles; - #endif - - // Increment weight start tile id for next block in width dim - weight_start_tile_id += weight_next_block_stride_w; - } // out_num_blocks_w - - #ifdef SHARDED_OUT - cb_wait_front(cb_id_out0, out_subblock_tile_count * out_num_subblocks_h * out_num_subblocks_w * out_num_blocks_w * out_num_blocks_h); - #endif -} diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_mcast_sender_conv_weights_tiled_col_to_rm_blocks.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_mcast_sender_conv_weights_tiled_col_to_rm_blocks.cpp deleted file mode 100644 index c3b9ac0bea3..00000000000 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_mcast_sender_conv_weights_tiled_col_to_rm_blocks.cpp +++ /dev/null @@ -1,322 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include "dataflow_api.h" - -// #include "debug/dprint.h" - - -void kernel_main() { - // This writer is for output tensor in tile format - uint32_t i = 0; - uint32_t out_addr = get_arg_val(i); i+=1; - uint32_t weight_addr_dram_base = get_arg_val(i); i+=1; - // Bias arg. Unused if bias fusion is not enabled. - const uint32_t bias_addr = get_arg_val(i); i += 1; - - uint32_t out_next_tile_stride_h = get_arg_val(i); i+=1; - uint32_t out_next_tile_stride_w = get_arg_val(i); i+=1; - uint32_t out_next_subblock_stride_h = get_arg_val(i); i+=1; - uint32_t out_next_subblock_stride_w = get_arg_val(i); i+=1; - uint32_t out_next_block_stride_h = get_arg_val(i); i+=1; - uint32_t out_next_block_stride_w = get_arg_val(i); i+=1; - uint32_t out_subblock_h = get_arg_val(i); i+=1; - uint32_t out_subblock_w = get_arg_val(i); i+=1; - uint32_t out_subblock_tile_count = get_arg_val(i); i+=1; - uint32_t out_num_subblocks_h = get_arg_val(i); i+=1; - uint32_t out_num_subblocks_w = get_arg_val(i); i+=1; - uint32_t out_num_blocks_h = get_arg_val(i); i+=1; - uint32_t out_num_blocks_w = get_arg_val(i); i+=1; - uint32_t out_block_height_num_tiles = get_arg_val(i); i+=1; - uint32_t out_height_num_tiles = get_arg_val(i); i+=1; - uint32_t out_width_num_tiles = get_arg_val(i); i+=1; - uint32_t out_start_tile_id = get_arg_val(i); i+=1; - uint32_t out_start_tile_id_h = get_arg_val(i); i+=1; - uint32_t out_start_tile_id_w = get_arg_val(i); i+=1; - - uint32_t num_blocks_weight_h = get_arg_val(i); i+=1; - uint32_t weight_block_num_tiles = get_arg_val(i); i+=1; - uint32_t weight_block_height_num_outer = get_arg_val(i); i+=1; - uint32_t weight_block_height_ntiles = get_arg_val(i); i+=1; - uint32_t weight_block_width_ntiles = get_arg_val(i); i+=1; - uint32_t weight_stride_h = get_arg_val(i); i+=1; - uint32_t weight_next_block_stride_h = get_arg_val(i); i+=1; - uint32_t weight_next_block_stride_w = get_arg_val(i); i+=1; - - // Bias arg. Unused if bias fusion is not enabled. - const uint32_t bias_ntiles = get_arg_val(i); i += 1; - const uint32_t bias_tile_offset = get_arg_val(i); i += 1; - - uint32_t noop = get_arg_val(i); i+=1; - if(noop) { - return; - } - - // mcast args - uint32_t weights_mcast_dest_noc_start_x = get_arg_val(i); i+=1; - uint32_t weights_mcast_dest_noc_start_y = get_arg_val(i); i+=1; - uint32_t weights_mcast_dest_noc_end_x = get_arg_val(i); i+=1; - uint32_t weights_mcast_dest_noc_end_y = get_arg_val(i); i+=1; - uint32_t weights_mcast_num_dests = get_arg_val(i); i+=1; - uint32_t weights_mcast_num_cores = get_arg_val(i); i+=1; - uint32_t weights_mcast_sender_semaphore_addr = get_semaphore(get_arg_val(i)); i+=1; - uint32_t weights_mcast_receiver_semaphore_addr = get_semaphore(get_arg_val(i)); i+=1; - - - constexpr bool out_in_dram = get_compile_time_arg_val(0) == 1; - constexpr uint32_t cb_id_out0 = get_compile_time_arg_val(1); - constexpr uint32_t cb_id_weight = get_compile_time_arg_val(2); - - - #ifndef SKIP_MCAST - // Set ur local VALID value, to be mcasted to destinations flag address after the data has been mcasted - volatile tt_l1_ptr uint32_t* weights_mcast_receiver_semaphore_addr_ptr = reinterpret_cast(weights_mcast_receiver_semaphore_addr); - *(weights_mcast_receiver_semaphore_addr_ptr) = VALID; - // local address that will be atomically incremented by mcast receivers, to know when all receivers are ready - // to receive the mcast - volatile tt_l1_ptr uint32_t* weights_mcast_sender_semaphore_addr_ptr = reinterpret_cast(weights_mcast_sender_semaphore_addr); - - uint64_t weights_mcast_receiver_semaphore_noc_addr = get_noc_multicast_addr( - weights_mcast_dest_noc_start_x, - weights_mcast_dest_noc_start_y, - weights_mcast_dest_noc_end_x, - weights_mcast_dest_noc_end_y, - weights_mcast_receiver_semaphore_addr); - #endif - - const uint32_t tile_nbytes = get_tile_size(cb_id_out0); - const DataFormat out_df = get_dataformat(cb_id_out0); - - const InterleavedAddrGenFast s = { - .bank_base_address = out_addr, - .page_size = tile_nbytes, - .data_format = out_df - }; - - // read in bias if enabled (done only once for all batches) - #ifdef FUSE_BIAS - constexpr uint32_t bias_cb_id = get_compile_time_arg_val(3); - constexpr uint32_t bias_in_dram = get_compile_time_arg_val(4) == 1; - - const uint32_t bias_pagesize = get_tile_size(bias_cb_id); - const DataFormat bias_df = get_dataformat(bias_cb_id); - const InterleavedAddrGenFast s_bias = { - .bank_base_address = bias_addr, - .page_size = bias_pagesize, - .data_format = bias_df - }; - - bool load_bias = true; - #endif - - // DPRINT << "tile_nbytes - " << tile_nbytes << ENDL(); - // DPRINT << "out_num_blocks_h - " << out_num_blocks_h << ENDL(); - // DPRINT << "out_num_blocks_w - " << out_num_blocks_w << ENDL(); - - // DPRINT << "out_num_subblocks_h - " << out_num_subblocks_h << ENDL(); - // DPRINT << "out_num_subblocks_w - " << out_num_subblocks_w << ENDL(); - - // DPRINT << "out_subblock_h - " << out_subblock_h << ENDL(); - // DPRINT << "out_subblock_w - " << out_subblock_w << ENDL(); - - // DPRINT << "out_subblock_tile_count - " << out_subblock_tile_count << ENDL(); - - // DPRINT << "num_blocks_weight_h - " << num_blocks_weight_h << ENDL(); - // DPRINT << "weight_block_height_ntiles - " << weight_block_height_ntiles << ENDL(); - // DPRINT << "weight_block_width_ntiles - " << weight_block_width_ntiles << ENDL(); - - // DPRINT << "out_subblock_h - " << out_subblock_h << ENDL(); - // DPRINT << "out_subblock_w - " << out_subblock_w << ENDL(); - // DPRINT << "out_block_height_num_tiles - " << out_block_height_num_tiles << ENDL(); - // DPRINT << "out_height_num_tiles - " << out_height_num_tiles << ENDL(); - // DPRINT << "out_width_num_tiles - " << out_width_num_tiles << ENDL(); - - const uint32_t weight_tile_nbytes = get_tile_size(cb_id_weight); - const DataFormat weight_df = get_dataformat(cb_id_weight); - const InterleavedAddrGenFast s_weight = { - .bank_base_address = weight_addr_dram_base, - .page_size = weight_tile_nbytes, - .data_format = weight_df - }; - - // const InterleavedAddrGenFast s = { - // .bank_base_address = out_addr, - // .page_size = tile_nbytes, - // .data_format = out_df - // }; - - - // OUTER most loop is looping over out blocks in width dim because blocks from compute are in col major order. - // Write out col major blocks in row major layout to output - uint32_t out_block_w_start_tile_id = out_start_tile_id; - //DPRINT << "out_start_tile_id=" << out_start_tile_id << ENDL(); - uint32_t out_block_w_start_tile_id_w = out_start_tile_id_w; - uint32_t weight_start_tile_id = out_start_tile_id_w; - uint32_t weight_inner_block_stride_h = weight_next_block_stride_h / weight_block_height_num_outer; // TODO: Pass as args - //DPRINT << "weight_start_tile_id=" << weight_start_tile_id << ENDL(); - for (uint32_t bw = 0; bw < out_num_blocks_w; bw++) { - uint32_t out_block_h_start_tile_id = out_block_w_start_tile_id; - uint32_t out_block_h_start_tile_id_h = out_start_tile_id_h; - for(uint32_t bh = 0; bh < out_num_blocks_h; bh++) { - // READ WEIGHTS + MCAST SEND WEIGHTS - // read weight blocks inner dim - // read weight slice - 1 block of weights in width dim and full weight matrix height - // read slice only once for all activation blocks - uint32_t weight_h_offset = 0; - for(uint32_t weight_tile_h_outer_i = 0; weight_tile_h_outer_i < weight_block_height_num_outer; weight_tile_h_outer_i++) { - uint32_t weight_current_block_start_tile_id = weight_start_tile_id; - for(uint32_t block_weight_h = 0; block_weight_h < num_blocks_weight_h; block_weight_h++) { - cb_reserve_back(cb_id_weight, weight_block_num_tiles); - uint32_t weight_write_l1_addr = get_write_ptr(cb_id_weight); - uint32_t weight_row_start_tile_id = weight_current_block_start_tile_id + weight_h_offset; - - // mcast args - uint32_t weights_start_address = weight_write_l1_addr; - uint32_t weights_block_size_bytes = 0; - - // loop over weight block tiles along h - for(uint32_t weight_tile_h_i = 0; weight_tile_h_i < weight_block_height_ntiles; ++weight_tile_h_i) { - uint32_t weight_tile_id = weight_row_start_tile_id; - // loop over weight block tiles along w - for(uint32_t weight_tile_w_i = 0; weight_tile_w_i < weight_block_width_ntiles; ++weight_tile_w_i) { - //DPRINT << "weight_tile_id=" << weight_tile_id << ENDL(); - s_weight.noc_async_read_tile(weight_tile_id, weight_write_l1_addr); - weight_write_l1_addr += weight_tile_nbytes; - weights_block_size_bytes += weight_tile_nbytes; - weight_tile_id += 1; - } // for weight_block_w - weight_row_start_tile_id += weight_stride_h; - } // for weight_block_h - noc_async_read_barrier(); - - #ifndef SKIP_MCAST - // wait until all weights mcast destinations have atomically incremented the weights semaphore_addr (i.e. its value should be weights_mcast_num_dests), then reset - // the semaphore_addr value back to zero for the next block - noc_semaphore_wait(weights_mcast_sender_semaphore_addr_ptr, weights_mcast_num_dests); - noc_semaphore_set(weights_mcast_sender_semaphore_addr_ptr, 0); - - // Now we have the block in the CB address, we can mcast to dests! - uint64_t weights_multicast_data_addr = get_noc_multicast_addr( - weights_mcast_dest_noc_start_x, - weights_mcast_dest_noc_start_y, - weights_mcast_dest_noc_end_x, - weights_mcast_dest_noc_end_y, - weights_start_address); - // num_dests must not include source, since we are NOT really doing a local copy! - noc_async_write_multicast(weights_start_address, weights_multicast_data_addr, weights_block_size_bytes, weights_mcast_num_cores, true, true); - - // Note: no need for write barrier, since these two multicasts are done on the same noc id, same vc, same cmd_buf - // Also, this only works because we are setting VCs statically (using NOC_CMD_STATIC_VC). - - // We should also multicast the flag to destinations - // num_dests must not include source, since we are NOT really doing a local copy! - noc_semaphore_set_multicast(weights_mcast_receiver_semaphore_addr, weights_mcast_receiver_semaphore_noc_addr, weights_mcast_num_cores); - #endif - - weight_current_block_start_tile_id += weight_next_block_stride_h; - - cb_push_back(cb_id_weight, weight_block_num_tiles); - } // for num_blocks_weight_h - weight_h_offset += weight_inner_block_stride_h; - } // for weight_block_height_num_outer - - - #ifdef FUSE_BIAS - if (load_bias) { - cb_reserve_back(bias_cb_id, bias_ntiles); - uint32_t bias_l1_addr = get_write_ptr(bias_cb_id); - - // mcast args - uint32_t bias_start_address = bias_l1_addr; - uint32_t bias_block_size_bytes = 0; - for (uint32_t bias_tile = bias_tile_offset; bias_tile < bias_tile_offset + bias_ntiles; ++ bias_tile) { - s_bias.noc_async_read_tile(bias_tile, bias_l1_addr); - bias_l1_addr += bias_pagesize; - bias_block_size_bytes += bias_pagesize; - } - noc_async_read_barrier(); - - // MCAST BIAS (shares some mcast args with weights) - #ifndef SKIP_MCAST - // wait until all weights mcast destinations have atomically incremented the weights semaphore_addr (i.e. its value should be weights_mcast_num_dests), then reset - // the semaphore_addr value back to zero for the next block - noc_semaphore_wait(weights_mcast_sender_semaphore_addr_ptr, weights_mcast_num_dests); - noc_semaphore_set(weights_mcast_sender_semaphore_addr_ptr, 0); - - // Now we have the block in the CB address, we can mcast to dests! - uint64_t bias_multicast_data_addr = get_noc_multicast_addr( - weights_mcast_dest_noc_start_x, - weights_mcast_dest_noc_start_y, - weights_mcast_dest_noc_end_x, - weights_mcast_dest_noc_end_y, - bias_start_address); - // num_dests must not include source, since we are NOT really doing a local copy! - noc_async_write_multicast(bias_start_address, bias_multicast_data_addr, bias_block_size_bytes, weights_mcast_num_cores, true, true); - - // Note: no need for write barrier, since these two multicasts are done on the same noc id, same vc, same cmd_buf - // Also, this only works because we are setting VCs statically (using NOC_CMD_STATIC_VC). - - // We should also multicast the flag to destinations - // num_dests must not include source, since we are NOT really doing a local copy! - noc_semaphore_set_multicast(weights_mcast_receiver_semaphore_addr, weights_mcast_receiver_semaphore_noc_addr, weights_mcast_num_cores); - #endif - - cb_push_back(bias_cb_id, bias_ntiles); - load_bias = false; - } - #endif - - #ifndef SHARDED_OUT - uint32_t out_sbh_start_tile_id = out_block_h_start_tile_id; - uint32_t out_sbh_start_tile_id_h = out_block_h_start_tile_id_h; // - for(uint32_t sbh = 0; sbh < out_num_subblocks_h; sbh++) { - uint32_t out_sbw_start_tile_id = out_sbh_start_tile_id; - uint32_t out_sbw_start_tile_id_w = out_block_w_start_tile_id_w; - for(uint32_t sbw = 0; sbw < out_num_subblocks_w; sbw++) { - uint32_t out_sb_row_start_tile_id = out_sbw_start_tile_id; - // wait for one subblock worth tiles - cb_wait_front(cb_id_out0, out_subblock_tile_count); - uint32_t l1_read_addr = get_read_ptr(cb_id_out0); - for(uint32_t h = 0; h < out_subblock_h; h++) { - uint32_t out_tile_id = out_sb_row_start_tile_id; - uint32_t out_tile_id_h = out_sbh_start_tile_id_h + h; - if (out_tile_id_h >= out_height_num_tiles) { // block shape height padding - break; - } - for(uint32_t w = 0; w < out_subblock_w; w++) { - uint32_t out_tile_id_w = out_sbw_start_tile_id_w + w; - if (out_tile_id_w >= out_width_num_tiles) { // block shape width padding - l1_read_addr += tile_nbytes; - } else { - //DPRINT << "out_tile_id - " << out_tile_id << ENDL(); - s.noc_async_write_tile(out_tile_id, l1_read_addr); - l1_read_addr += tile_nbytes; - out_tile_id += out_next_tile_stride_w; - } - } // out_subblock_w (ntiles) - out_sb_row_start_tile_id += out_next_tile_stride_h; - } // out_subblock_h (ntiles) - noc_async_write_barrier(); - //DPRINT << "Done writing subblock." << ENDL(); - cb_pop_front(cb_id_out0, out_subblock_tile_count); - out_sbw_start_tile_id += out_next_subblock_stride_w; - out_sbw_start_tile_id_w += out_subblock_w; - } // out_num_subblocks_w - out_sbh_start_tile_id += out_next_subblock_stride_h; - out_sbh_start_tile_id_h += out_subblock_h; - } // out_num_subblocks_h - out_block_h_start_tile_id += out_next_block_stride_h; - out_block_h_start_tile_id_h += out_block_height_num_tiles; - #endif - } // out_num_blocks_h - out_block_w_start_tile_id += out_next_block_stride_w; - out_block_w_start_tile_id_w += weight_block_width_ntiles; - - // Increment weight start tile id for next block in width dim - weight_start_tile_id += weight_next_block_stride_w; - } // out_num_blocks_w - #ifdef SHARDED_OUT - cb_wait_front(cb_id_out0, out_subblock_tile_count * out_num_subblocks_h * out_num_subblocks_w * out_num_blocks_w * out_num_blocks_h); - #endif -} diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_mcast_sender_conv_weights_tiled_col_to_rm_blocks_num_blocks_weight_h_eq_1.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_mcast_sender_conv_weights_tiled_col_to_rm_blocks_num_blocks_weight_h_eq_1.cpp deleted file mode 100644 index e518b29eccd..00000000000 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_mcast_sender_conv_weights_tiled_col_to_rm_blocks_num_blocks_weight_h_eq_1.cpp +++ /dev/null @@ -1,315 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include "dataflow_api.h" - -// #include "debug/dprint.h" - - -void kernel_main() { - // This writer is for output tensor in tile format - uint32_t i = 0; - uint32_t out_addr = get_arg_val(i); i+=1; - uint32_t weight_addr_dram_base = get_arg_val(i); i+=1; - // Bias arg. Unused if bias fusion is not enabled. - const uint32_t bias_addr = get_arg_val(i); i += 1; - - uint32_t out_next_tile_stride_h = get_arg_val(i); i+=1; - uint32_t out_next_tile_stride_w = get_arg_val(i); i+=1; - uint32_t out_next_subblock_stride_h = get_arg_val(i); i+=1; - uint32_t out_next_subblock_stride_w = get_arg_val(i); i+=1; - uint32_t out_next_block_stride_h = get_arg_val(i); i+=1; - uint32_t out_next_block_stride_w = get_arg_val(i); i+=1; - uint32_t out_subblock_h = get_arg_val(i); i+=1; - uint32_t out_subblock_w = get_arg_val(i); i+=1; - uint32_t out_subblock_tile_count = get_arg_val(i); i+=1; - uint32_t out_num_subblocks_h = get_arg_val(i); i+=1; - uint32_t out_num_subblocks_w = get_arg_val(i); i+=1; - uint32_t out_num_blocks_h = get_arg_val(i); i+=1; - uint32_t out_num_blocks_w = get_arg_val(i); i+=1; - uint32_t out_block_height_num_tiles = get_arg_val(i); i+=1; - uint32_t out_height_num_tiles = get_arg_val(i); i+=1; - uint32_t out_width_num_tiles = get_arg_val(i); i+=1; - uint32_t out_start_tile_id = get_arg_val(i); i+=1; - uint32_t out_start_tile_id_h = get_arg_val(i); i+=1; - uint32_t out_start_tile_id_w = get_arg_val(i); i+=1; - - uint32_t num_blocks_weight_h = get_arg_val(i); i+=1; - uint32_t weight_block_num_tiles = get_arg_val(i); i+=1; - uint32_t weight_block_height_ntiles = get_arg_val(i); i+=1; - uint32_t weight_block_width_ntiles = get_arg_val(i); i+=1; - uint32_t weight_stride_h = get_arg_val(i); i+=1; - uint32_t weight_next_block_stride_h = get_arg_val(i); i+=1; - uint32_t weight_next_block_stride_w = get_arg_val(i); i+=1; - - // Bias arg. Unused if bias fusion is not enabled. - const uint32_t bias_ntiles = get_arg_val(i); i += 1; - const uint32_t bias_tile_offset = get_arg_val(i); i += 1; - - uint32_t noop = get_arg_val(i); i+=1; - if(noop) { - return; - } - - // mcast args - uint32_t weights_mcast_dest_noc_start_x = get_arg_val(i); i+=1; - uint32_t weights_mcast_dest_noc_start_y = get_arg_val(i); i+=1; - uint32_t weights_mcast_dest_noc_end_x = get_arg_val(i); i+=1; - uint32_t weights_mcast_dest_noc_end_y = get_arg_val(i); i+=1; - uint32_t weights_mcast_num_dests = get_arg_val(i); i+=1; - uint32_t weights_mcast_num_cores = get_arg_val(i); i+=1; - uint32_t weights_mcast_sender_semaphore_addr = get_arg_val(i); i+=1; - uint32_t weights_mcast_receiver_semaphore_addr = get_arg_val(i); i+=1; - - - constexpr bool out_in_dram = get_compile_time_arg_val(0) == 1; - constexpr uint32_t cb_id_out0 = get_compile_time_arg_val(1); - constexpr uint32_t cb_id_weight = get_compile_time_arg_val(2); - - - #ifndef SKIP_MCAST - // Set ur local VALID value, to be mcasted to destinations flag address after the data has been mcasted - volatile tt_l1_ptr uint32_t* weights_mcast_receiver_semaphore_addr_ptr = reinterpret_cast(weights_mcast_receiver_semaphore_addr); - *(weights_mcast_receiver_semaphore_addr_ptr) = VALID; - // local address that will be atomically incremented by mcast receivers, to know when all receivers are ready - // to receive the mcast - volatile tt_l1_ptr uint32_t* weights_mcast_sender_semaphore_addr_ptr = reinterpret_cast(weights_mcast_sender_semaphore_addr); - - uint64_t weights_mcast_receiver_semaphore_noc_addr = get_noc_multicast_addr( - weights_mcast_dest_noc_start_x, - weights_mcast_dest_noc_start_y, - weights_mcast_dest_noc_end_x, - weights_mcast_dest_noc_end_y, - weights_mcast_receiver_semaphore_addr); - #endif - - const uint32_t tile_nbytes = get_tile_size(cb_id_out0); - const DataFormat out_df = get_dataformat(cb_id_out0); - - constexpr uint32_t tile_size_pow2_exponent = 11; // == 2^11 = 2048 = 2 * 32 * 32 (assuming dtype = 2 bytes) - const InterleavedPow2AddrGen s = { - .bank_base_address = out_addr, - .log_base_2_of_page_size = tile_size_pow2_exponent - }; - - // read in bias if enabled (done only once for all batches) - #ifdef FUSE_BIAS - constexpr uint32_t bias_cb_id = get_compile_time_arg_val(3); - constexpr uint32_t bias_log2_of_pagesize = get_compile_time_arg_val(4); - constexpr uint32_t bias_pagesize = get_compile_time_arg_val(5); - constexpr uint32_t bias_in_dram = get_compile_time_arg_val(6) == 1; - - const InterleavedPow2AddrGenFast s_bias = { - .bank_base_address = bias_addr, - .log_base_2_of_page_size = bias_log2_of_pagesize - }; - - bool load_bias = true; - #endif - - // DPRINT << "tile_nbytes - " << tile_nbytes << ENDL(); - // DPRINT << "out_num_blocks_h - " << out_num_blocks_h << ENDL(); - // DPRINT << "out_num_blocks_w - " << out_num_blocks_w << ENDL(); - - // DPRINT << "out_num_subblocks_h - " << out_num_subblocks_h << ENDL(); - // DPRINT << "out_num_subblocks_w - " << out_num_subblocks_w << ENDL(); - - // DPRINT << "out_subblock_h - " << out_subblock_h << ENDL(); - // DPRINT << "out_subblock_w - " << out_subblock_w << ENDL(); - - // DPRINT << "out_subblock_tile_count - " << out_subblock_tile_count << ENDL(); - - // DPRINT << "num_blocks_weight_h - " << num_blocks_weight_h << ENDL(); - // DPRINT << "weight_block_height_ntiles - " << weight_block_height_ntiles << ENDL(); - // DPRINT << "weight_block_width_ntiles - " << weight_block_width_ntiles << ENDL(); - - // DPRINT << "out_subblock_h - " << out_subblock_h << ENDL(); - // DPRINT << "out_subblock_w - " << out_subblock_w << ENDL(); - // DPRINT << "out_block_height_num_tiles - " << out_block_height_num_tiles << ENDL(); - // DPRINT << "out_height_num_tiles - " << out_height_num_tiles << ENDL(); - // DPRINT << "out_width_num_tiles - " << out_width_num_tiles << ENDL(); - - const uint32_t weight_tile_nbytes = get_tile_size(cb_id_weight); - const InterleavedPow2AddrGen s_weight = { - .bank_base_address = weight_addr_dram_base, - .log_base_2_of_page_size = tile_size_pow2_exponent - }; - - // const InterleavedAddrGenFast s = { - // .bank_base_address = out_addr, - // .page_size = tile_nbytes, - // .data_format = out_df - // }; - - // OUTER most loop is looping over out blocks in width dim because blocks from compute are in col major order. - // Write out col major blocks in row major layout to output - uint32_t out_block_w_start_tile_id = out_start_tile_id; - //DPRINT << "out_start_tile_id=" << out_start_tile_id << ENDL(); - uint32_t out_block_w_start_tile_id_w = out_start_tile_id_w; - uint32_t weight_start_tile_id = out_start_tile_id_w; - //DPRINT << "weight_start_tile_id=" << weight_start_tile_id << ENDL(); - for (uint32_t bw = 0; bw < out_num_blocks_w; bw++) { - - // READ WEIGHTS + MCAST SEND WEIGHTS - // read weight blocks inner dim - // read weight slice - 1 block of weights in width dim and full weight matrix height - // read slice only once for all activation blocks - uint32_t weight_current_block_start_tile_id = weight_start_tile_id; - cb_reserve_back(cb_id_weight, weight_block_num_tiles); - uint32_t weight_write_l1_addr = get_write_ptr(cb_id_weight); - uint32_t weight_row_start_tile_id = weight_current_block_start_tile_id; - - // mcast args - uint32_t weights_start_address = weight_write_l1_addr; - uint32_t weights_block_size_bytes = 0; - - // loop over weight block tiles along h - for(uint32_t weight_tile_h_i = 0; weight_tile_h_i < weight_block_height_ntiles; ++weight_tile_h_i) { - uint32_t weight_tile_id = weight_row_start_tile_id; - // loop over weight block tiles along w - for(uint32_t weight_tile_w_i = 0; weight_tile_w_i < weight_block_width_ntiles; ++weight_tile_w_i) { - uint64_t weight_tile_noc_addr = get_noc_addr(weight_tile_id, s_weight); - //DPRINT << "weight_tile_id=" << weight_tile_id << ENDL(); - noc_async_read(weight_tile_noc_addr, weight_write_l1_addr, weight_tile_nbytes); - weight_write_l1_addr += weight_tile_nbytes; - weights_block_size_bytes += weight_tile_nbytes; - weight_tile_id += 1; - } // for weight_block_w - weight_row_start_tile_id += weight_stride_h; - } // for weight_block_h - noc_async_read_barrier(); - - #ifndef SKIP_MCAST - // wait until all weights mcast destinations have atomically incremented the weights semaphore_addr (i.e. its value should be weights_mcast_num_dests), then reset - // the semaphore_addr value back to zero for the next block - noc_semaphore_wait(weights_mcast_sender_semaphore_addr_ptr, weights_mcast_num_dests); - noc_semaphore_set(weights_mcast_sender_semaphore_addr_ptr, 0); - - // Now we have the block in the CB address, we can mcast to dests! - uint64_t weights_multicast_data_addr = get_noc_multicast_addr( - weights_mcast_dest_noc_start_x, - weights_mcast_dest_noc_start_y, - weights_mcast_dest_noc_end_x, - weights_mcast_dest_noc_end_y, - weights_start_address); - // num_dests must not include source, since we are NOT really doing a local copy! - noc_async_write_multicast(weights_start_address, weights_multicast_data_addr, weights_block_size_bytes, weights_mcast_num_cores, true, true); - - // Note: no need for write barrier, since these two multicasts are done on the same noc id, same vc, same cmd_buf - // Also, this only works because we are setting VCs statically (using NOC_CMD_STATIC_VC). - - // We should also multicast the flag to destinations - // num_dests must not include source, since we are NOT really doing a local copy! - noc_semaphore_set_multicast(weights_mcast_receiver_semaphore_addr, weights_mcast_receiver_semaphore_noc_addr, weights_mcast_num_cores); - #endif - - weight_current_block_start_tile_id += weight_next_block_stride_h; - cb_push_back(cb_id_weight, weight_block_num_tiles); - - - #ifdef FUSE_BIAS - if (load_bias) { - cb_reserve_back(bias_cb_id, bias_ntiles); - uint32_t bias_l1_addr = get_write_ptr(bias_cb_id); - - // mcast args - uint32_t bias_start_address = bias_l1_addr; - uint32_t bias_block_size_bytes = 0; - for (uint32_t bias_tile = bias_tile_offset; bias_tile < bias_tile_offset + bias_ntiles; ++ bias_tile) { - s_bias.noc_async_read_page(bias_tile, bias_l1_addr); - bias_l1_addr += bias_pagesize; - bias_block_size_bytes += bias_pagesize; - } - noc_async_read_barrier(); - - // MCAST BIAS (shares some mcast args with weights) - #ifndef SKIP_MCAST - // wait until all weights mcast destinations have atomically incremented the weights semaphore_addr (i.e. its value should be weights_mcast_num_dests), then reset - // the semaphore_addr value back to zero for the next block - noc_semaphore_wait(weights_mcast_sender_semaphore_addr_ptr, weights_mcast_num_dests); - noc_semaphore_set(weights_mcast_sender_semaphore_addr_ptr, 0); - - // Now we have the block in the CB address, we can mcast to dests! - uint64_t bias_multicast_data_addr = get_noc_multicast_addr( - weights_mcast_dest_noc_start_x, - weights_mcast_dest_noc_start_y, - weights_mcast_dest_noc_end_x, - weights_mcast_dest_noc_end_y, - bias_start_address); - // num_dests must not include source, since we are NOT really doing a local copy! - noc_async_write_multicast(bias_start_address, bias_multicast_data_addr, bias_block_size_bytes, weights_mcast_num_cores, true, true); - - // Note: no need for write barrier, since these two multicasts are done on the same noc id, same vc, same cmd_buf - // Also, this only works because we are setting VCs statically (using NOC_CMD_STATIC_VC). - - // We should also multicast the flag to destinations - // num_dests must not include source, since we are NOT really doing a local copy! - noc_semaphore_set_multicast(weights_mcast_receiver_semaphore_addr, weights_mcast_receiver_semaphore_noc_addr, weights_mcast_num_cores); - #endif - - cb_push_back(bias_cb_id, bias_ntiles); - load_bias = false; - } - #endif - - // Increment weight start tile id for next block in width dim - weight_start_tile_id += weight_next_block_stride_w; - - #ifndef SHARDED_OUT - uint32_t out_block_h_start_tile_id = out_block_w_start_tile_id; - //DPRINT << "out_block_h_start_tile_id=" << out_block_h_start_tile_id << ENDL(); - uint32_t out_block_h_start_tile_id_h = out_start_tile_id_h; - for(uint32_t bh = 0; bh < out_num_blocks_h; bh++) { - - uint32_t out_sbh_start_tile_id = out_block_h_start_tile_id; - uint32_t out_sbh_start_tile_id_h = out_block_h_start_tile_id_h; // - for(uint32_t sbh = 0; sbh < out_num_subblocks_h; sbh++) { - uint32_t out_sbw_start_tile_id = out_sbh_start_tile_id; - uint32_t out_sbw_start_tile_id_w = out_block_w_start_tile_id_w; - for(uint32_t sbw = 0; sbw < out_num_subblocks_w; sbw++) { - uint32_t out_sb_row_start_tile_id = out_sbw_start_tile_id; - // wait for one subblock worth tiles - cb_wait_front(cb_id_out0, out_subblock_tile_count); - uint32_t l1_read_addr = get_read_ptr(cb_id_out0); - for(uint32_t h = 0; h < out_subblock_h; h++) { - uint32_t out_tile_id = out_sb_row_start_tile_id; - uint32_t out_tile_id_h = out_sbh_start_tile_id_h + h; - if (out_tile_id_h >= out_height_num_tiles) { // block shape height padding - break; - } - for(uint32_t w = 0; w < out_subblock_w; w++) { - uint32_t out_tile_id_w = out_sbw_start_tile_id_w + w; - if (out_tile_id_w >= out_width_num_tiles) { // block shape width padding - l1_read_addr += tile_nbytes; - } else { - //DPRINT << "out_tile_id - " << out_tile_id << ENDL(); - uint64_t out_tile_noc_addr = get_noc_addr(out_tile_id, s); - //DPRINT << "out_tile_id=" << out_tile_id << ENDL(); - noc_async_write(l1_read_addr, out_tile_noc_addr, tile_nbytes); - l1_read_addr += tile_nbytes; - out_tile_id += out_next_tile_stride_w; - } - } // out_subblock_w (ntiles) - out_sb_row_start_tile_id += out_next_tile_stride_h; - } // out_subblock_h (ntiles) - noc_async_write_barrier(); - //DPRINT << "Done writing subblock." << ENDL(); - cb_pop_front(cb_id_out0, out_subblock_tile_count); - out_sbw_start_tile_id += out_next_subblock_stride_w; - out_sbw_start_tile_id_w += out_subblock_w; - } // out_num_subblocks_w - out_sbh_start_tile_id += out_next_subblock_stride_h; - out_sbh_start_tile_id_h += out_subblock_h; - } // out_num_subblocks_h - out_block_h_start_tile_id += out_next_block_stride_h; - out_block_h_start_tile_id_h += out_block_height_num_tiles; - } // out_num_blocks_h - out_block_w_start_tile_id += out_next_block_stride_w; - out_block_w_start_tile_id_w += weight_block_width_ntiles; - #endif - } // out_num_blocks_w - #ifdef SHARDED_OUT - cb_wait_front(cb_id_out0, out_subblock_tile_count * out_num_subblocks_h * out_num_subblocks_w * out_num_blocks_w * out_num_blocks_h); - #endif -} diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_reader_conv_weights_tiled.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_reader_conv_weights_tiled.cpp deleted file mode 100644 index 2fa5490ac49..00000000000 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_reader_conv_weights_tiled.cpp +++ /dev/null @@ -1,205 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include "dataflow_api.h" - -// #include "debug/dprint.h" - -#ifdef FUSE_BIAS - #include "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/dataflow/reader_bmm_single_core_bias.hpp" -#endif - -inline void read_weight_blocks_inner_h_dim(uint32_t cb_id_weight, - uint32_t num_blocks_weight_h, - uint32_t weight_block_num_tiles, - uint32_t weight_start_tile_id, - uint32_t weight_block_height_ntiles, - uint32_t weight_block_width_ntiles, - const InterleavedPow2AddrGen& s_weight, - uint32_t weight_tile_nbytes, - uint32_t weight_stride_h, - uint32_t weight_next_block_stride_h) { - // weight DRAM -> L1 (weights in tiled form) - uint32_t weight_current_block_start_tile_id = weight_start_tile_id; - for(uint32_t block_weight_h = 0; block_weight_h < num_blocks_weight_h; block_weight_h++) { - cb_reserve_back(cb_id_weight, weight_block_num_tiles); - uint32_t weight_write_l1_addr = get_write_ptr(cb_id_weight); - uint32_t weight_row_start_tile_id = weight_current_block_start_tile_id; - // loop over weight block tiles along h - for(uint32_t weight_tile_h_i = 0; weight_tile_h_i < weight_block_height_ntiles; ++weight_tile_h_i) { - uint32_t weight_tile_id = weight_row_start_tile_id; - // loop over weight block tiles along w - for(uint32_t weight_tile_w_i = 0; weight_tile_w_i < weight_block_width_ntiles; ++weight_tile_w_i) { - uint64_t weight_tile_noc_addr = get_noc_addr(weight_tile_id, s_weight); - noc_async_read(weight_tile_noc_addr, weight_write_l1_addr, weight_tile_nbytes); - weight_write_l1_addr += weight_tile_nbytes; - weight_tile_id += 1; - } // for weight_block_w - weight_row_start_tile_id += weight_stride_h; - } // for weight_block_h - noc_async_read_barrier(); - weight_current_block_start_tile_id += weight_next_block_stride_h; - cb_push_back(cb_id_weight, weight_block_num_tiles); - } // for num_blocks_weight_h -} - -void kernel_main() { - // This writer is for output tensor in tile format - uint32_t i = 0; - uint32_t out_addr = get_arg_val(i); i+=1; - uint32_t weight_addr_dram_base = get_arg_val(i); i+=1; - - uint32_t out_next_tile_stride_h = get_arg_val(i); i+=1; - uint32_t out_next_tile_stride_w = get_arg_val(i); i+=1; - uint32_t out_next_subblock_stride_h = get_arg_val(i); i+=1; - uint32_t out_next_subblock_stride_w = get_arg_val(i); i+=1; - uint32_t out_next_block_stride_h = get_arg_val(i); i+=1; - uint32_t out_next_block_stride_w = get_arg_val(i); i+=1; - uint32_t out_subblock_h = get_arg_val(i); i+=1; - uint32_t out_subblock_w = get_arg_val(i); i+=1; - uint32_t out_subblock_tile_count = get_arg_val(i); i+=1; - uint32_t out_num_subblocks_h = get_arg_val(i); i+=1; - uint32_t out_num_subblocks_w = get_arg_val(i); i+=1; - uint32_t out_num_blocks_h = get_arg_val(i); i+=1; - uint32_t out_num_blocks_w = get_arg_val(i); i+=1; - uint32_t out_block_height_num_tiles = get_arg_val(i); i+=1; - uint32_t out_height_num_tiles = get_arg_val(i); i+=1; - uint32_t out_width_num_tiles = get_arg_val(i); i+=1; - - uint32_t num_blocks_weight_h = get_arg_val(i); i+=1; - uint32_t weight_block_num_tiles = get_arg_val(i); i+=1; - uint32_t weight_block_height_ntiles = get_arg_val(i); i+=1; - uint32_t weight_block_width_ntiles = get_arg_val(i); i+=1; - uint32_t weight_stride_h = get_arg_val(i); i+=1; - uint32_t weight_next_block_stride_h = get_arg_val(i); i+=1; - uint32_t weight_next_block_stride_w = get_arg_val(i); i+=1; - - constexpr bool out_in_dram = get_compile_time_arg_val(0) == 1; - constexpr uint32_t cb_id_out0 = get_compile_time_arg_val(1); - constexpr uint32_t cb_id_weight = get_compile_time_arg_val(2); - - const uint32_t tile_nbytes = get_tile_size(cb_id_out0); - const DataFormat out_df = get_dataformat(cb_id_out0); - - constexpr uint32_t tile_size_pow2_exponent = 11; // == 2^11 = 2048 = 2 * 32 * 32 (assuming dtype = 2 bytes) - const InterleavedPow2AddrGen s = { - .bank_base_address = out_addr, - .log_base_2_of_page_size = tile_size_pow2_exponent - }; - - // first read in bias if enabled (done only once for all batches) - #ifdef FUSE_BIAS - const uint32_t bias_addr = get_arg_val(i); i += 1; - const uint32_t bias_ntiles = get_arg_val(i); i += 1; - - constexpr uint32_t bias_cb_id = get_compile_time_arg_val(3); - constexpr uint32_t bias_log2_of_pagesize = get_compile_time_arg_val(4); - constexpr uint32_t bias_pagesize = get_compile_time_arg_val(5); - constexpr uint32_t bias_in_dram = get_compile_time_arg_val(6) == 1; - - read_bias(bias_addr, bias_ntiles, bias_cb_id, bias_log2_of_pagesize, bias_pagesize); - #endif - - // DPRINT << "tile_nbytes - " << tile_nbytes << ENDL(); - // DPRINT << "out_num_blocks_h - " << out_num_blocks_h << ENDL(); - // DPRINT << "out_num_blocks_w - " << out_num_blocks_w << ENDL(); - - // DPRINT << "out_num_subblocks_h - " << out_num_subblocks_h << ENDL(); - // DPRINT << "out_num_subblocks_w - " << out_num_subblocks_w << ENDL(); - - // DPRINT << "out_subblock_h - " << out_subblock_h << ENDL(); - // DPRINT << "out_subblock_w - " << out_subblock_w << ENDL(); - - // DPRINT << "out_subblock_tile_count - " << out_subblock_tile_count << ENDL(); - - // DPRINT << "num_blocks_weight_h - " << num_blocks_weight_h << ENDL(); - // DPRINT << "weight_block_height_ntiles - " << weight_block_height_ntiles << ENDL(); - // DPRINT << "weight_block_width_ntiles - " << weight_block_width_ntiles << ENDL(); - - // DPRINT << "out_subblock_h - " << out_subblock_h << ENDL(); - // DPRINT << "out_subblock_w - " << out_subblock_w << ENDL(); - // DPRINT << "out_block_height_num_tiles - " << out_block_height_num_tiles << ENDL(); - // DPRINT << "out_height_num_tiles - " << out_height_num_tiles << ENDL(); - // DPRINT << "out_width_num_tiles - " << out_width_num_tiles << ENDL(); - - const uint32_t weight_tile_nbytes = get_tile_size(cb_id_weight); - const InterleavedPow2AddrGen s_weight = { - .bank_base_address = weight_addr_dram_base, - .log_base_2_of_page_size = tile_size_pow2_exponent - }; - - // const InterleavedAddrGenFast s = { - // .bank_base_address = out_addr, - // .page_size = tile_nbytes, - // .data_format = out_df - // }; - - uint32_t out_block_h_start_tile_id = 0; - uint32_t out_block_h_start_tile_id_h = 0; - for(uint32_t bh = 0; bh < out_num_blocks_h; bh++) { - // Reset weight start tile index - uint32_t weight_start_tile_id = 0; - uint32_t out_block_w_start_tile_id = out_block_h_start_tile_id; - uint32_t out_block_w_start_tile_id_w = 0; - for (uint32_t bw = 0; bw < out_num_blocks_w; bw++) { - // read weight blocks inner dim - read_weight_blocks_inner_h_dim(cb_id_weight, - num_blocks_weight_h, - weight_block_num_tiles, - weight_start_tile_id, - weight_block_height_ntiles, - weight_block_width_ntiles, - s_weight, - weight_tile_nbytes, - weight_stride_h, - weight_next_block_stride_h); - // Increment weight start tile id for next block in width dim - weight_start_tile_id += weight_next_block_stride_w; - - uint32_t out_sbh_start_tile_id = out_block_w_start_tile_id; - uint32_t out_sbh_start_tile_id_h = out_block_h_start_tile_id_h; - for(uint32_t sbh = 0; sbh < out_num_subblocks_h; sbh++) { - uint32_t out_sbw_start_tile_id = out_sbh_start_tile_id; - uint32_t out_sbw_start_tile_id_w = out_block_w_start_tile_id_w; - for(uint32_t sbw = 0; sbw < out_num_subblocks_w; sbw++) { - uint32_t out_sb_row_start_tile_id = out_sbw_start_tile_id; - // wait for one subblock worth tiles - cb_wait_front(cb_id_out0, out_subblock_tile_count); - uint32_t l1_read_addr = get_read_ptr(cb_id_out0); - for(uint32_t h = 0; h < out_subblock_h; h++) { - uint32_t out_tile_id = out_sb_row_start_tile_id; - uint32_t out_tile_id_h = out_sbh_start_tile_id_h + h; - if (out_tile_id_h >= out_height_num_tiles) { // block shape height padding - break; - } - for(uint32_t w = 0; w < out_subblock_w; w++) { - uint32_t out_tile_id_w = out_sbw_start_tile_id_w + w; - if (out_tile_id_w >= out_width_num_tiles) { // block shape width padding - l1_read_addr += tile_nbytes; - } else { - //DPRINT << "out_tile_id - " << out_tile_id << ENDL(); - uint64_t out_tile_noc_addr = get_noc_addr(out_tile_id, s); - noc_async_write(l1_read_addr, out_tile_noc_addr, tile_nbytes); - l1_read_addr += tile_nbytes; - out_tile_id += out_next_tile_stride_w; - } - } // out_subblock_w (ntiles) - out_sb_row_start_tile_id += out_next_tile_stride_h; - } // out_subblock_h (ntiles) - noc_async_write_barrier(); - //DPRINT << "Done writing subblock." << ENDL(); - cb_pop_front(cb_id_out0, out_subblock_tile_count); - out_sbw_start_tile_id += out_next_subblock_stride_w; - out_sbw_start_tile_id_w += out_subblock_w; - } // out_num_subblocks_w - out_sbh_start_tile_id += out_next_subblock_stride_h; - out_sbh_start_tile_id_h += out_subblock_h; - } // out_num_subblocks_h - out_block_w_start_tile_id += out_next_block_stride_w; - out_block_w_start_tile_id_w += weight_block_width_ntiles; - } // out_num_blocks_w - out_block_h_start_tile_id += out_next_block_stride_h; - out_block_h_start_tile_id_h += out_block_height_num_tiles; - } // out_num_blocks_h -} diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_reader_conv_weights_tiled_col_to_rm_blocks.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_reader_conv_weights_tiled_col_to_rm_blocks.cpp deleted file mode 100644 index cdd10d90e6a..00000000000 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_reader_conv_weights_tiled_col_to_rm_blocks.cpp +++ /dev/null @@ -1,223 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include "dataflow_api.h" - -// #include "debug/dprint.h" - -#ifdef FUSE_BIAS - #include "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/dataflow/reader_bmm_single_core_bias.hpp" -#endif - -inline void read_weight_blocks_inner_h_dim(uint32_t cb_id_weight, - uint32_t num_blocks_weight_h, - uint32_t weight_block_num_tiles, - uint32_t weight_start_tile_id, - uint32_t weight_block_height_ntiles, - uint32_t weight_block_width_ntiles, - const InterleavedPow2AddrGen& s_weight, - uint32_t weight_tile_nbytes, - uint32_t weight_stride_h, - uint32_t weight_next_block_stride_h) { - // weight DRAM -> L1 (weights in tiled form) - uint32_t weight_current_block_start_tile_id = weight_start_tile_id; - for(uint32_t block_weight_h = 0; block_weight_h < num_blocks_weight_h; block_weight_h++) { - cb_reserve_back(cb_id_weight, weight_block_num_tiles); - uint32_t weight_write_l1_addr = get_write_ptr(cb_id_weight); - uint32_t weight_row_start_tile_id = weight_current_block_start_tile_id; - // loop over weight block tiles along h - for(uint32_t weight_tile_h_i = 0; weight_tile_h_i < weight_block_height_ntiles; ++weight_tile_h_i) { - uint32_t weight_tile_id = weight_row_start_tile_id; - // loop over weight block tiles along w - for(uint32_t weight_tile_w_i = 0; weight_tile_w_i < weight_block_width_ntiles; ++weight_tile_w_i) { - uint64_t weight_tile_noc_addr = get_noc_addr(weight_tile_id, s_weight); - //DPRINT << "weight_tile_id=" << weight_tile_id << ENDL(); - noc_async_read(weight_tile_noc_addr, weight_write_l1_addr, weight_tile_nbytes); - weight_write_l1_addr += weight_tile_nbytes; - weight_tile_id += 1; - } // for weight_block_w - weight_row_start_tile_id += weight_stride_h; - } // for weight_block_h - noc_async_read_barrier(); - weight_current_block_start_tile_id += weight_next_block_stride_h; - cb_push_back(cb_id_weight, weight_block_num_tiles); - } // for num_blocks_weight_h -} - -void kernel_main() { - // This writer is for output tensor in tile format - uint32_t i = 0; - uint32_t out_addr = get_arg_val(i); i+=1; - uint32_t weight_addr_dram_base = get_arg_val(i); i+=1; - // Bias arg. Unused if bias fusion is not enabled. - const uint32_t bias_addr = get_arg_val(i); i += 1; - - uint32_t out_next_tile_stride_h = get_arg_val(i); i+=1; - uint32_t out_next_tile_stride_w = get_arg_val(i); i+=1; - uint32_t out_next_subblock_stride_h = get_arg_val(i); i+=1; - uint32_t out_next_subblock_stride_w = get_arg_val(i); i+=1; - uint32_t out_next_block_stride_h = get_arg_val(i); i+=1; - uint32_t out_next_block_stride_w = get_arg_val(i); i+=1; - uint32_t out_subblock_h = get_arg_val(i); i+=1; - uint32_t out_subblock_w = get_arg_val(i); i+=1; - uint32_t out_subblock_tile_count = get_arg_val(i); i+=1; - uint32_t out_num_subblocks_h = get_arg_val(i); i+=1; - uint32_t out_num_subblocks_w = get_arg_val(i); i+=1; - uint32_t out_num_blocks_h = get_arg_val(i); i+=1; - uint32_t out_num_blocks_w = get_arg_val(i); i+=1; - uint32_t out_block_height_num_tiles = get_arg_val(i); i+=1; - uint32_t out_height_num_tiles = get_arg_val(i); i+=1; - uint32_t out_width_num_tiles = get_arg_val(i); i+=1; - uint32_t out_start_tile_id = get_arg_val(i); i+=1; - uint32_t out_start_tile_id_h = get_arg_val(i); i+=1; - uint32_t out_start_tile_id_w = get_arg_val(i); i+=1; - - uint32_t num_blocks_weight_h = get_arg_val(i); i+=1; - uint32_t weight_block_num_tiles = get_arg_val(i); i+=1; - uint32_t weight_block_height_ntiles = get_arg_val(i); i+=1; - uint32_t weight_block_width_ntiles = get_arg_val(i); i+=1; - uint32_t weight_stride_h = get_arg_val(i); i+=1; - uint32_t weight_next_block_stride_h = get_arg_val(i); i+=1; - uint32_t weight_next_block_stride_w = get_arg_val(i); i+=1; - - // Bias arg. Unused if bias fusion is not enabled. - const uint32_t bias_ntiles = get_arg_val(i); i += 1; - const uint32_t bias_tile_offset = get_arg_val(i); i += 1; - - uint32_t noop = get_arg_val(i); i+=1; - if(noop) { - return; - } - - constexpr bool out_in_dram = get_compile_time_arg_val(0) == 1; - constexpr uint32_t cb_id_out0 = get_compile_time_arg_val(1); - constexpr uint32_t cb_id_weight = get_compile_time_arg_val(2); - - const uint32_t tile_nbytes = get_tile_size(cb_id_out0); - const DataFormat out_df = get_dataformat(cb_id_out0); - - constexpr uint32_t tile_size_pow2_exponent = 11; // == 2^11 = 2048 = 2 * 32 * 32 (assuming dtype = 2 bytes) - const InterleavedPow2AddrGen s = { - .bank_base_address = out_addr, - .log_base_2_of_page_size = tile_size_pow2_exponent - }; - - // first read in bias if enabled (done only once for all batches) - #ifdef FUSE_BIAS - - constexpr uint32_t bias_cb_id = get_compile_time_arg_val(3); - constexpr uint32_t bias_log2_of_pagesize = get_compile_time_arg_val(4); - constexpr uint32_t bias_pagesize = get_compile_time_arg_val(5); - constexpr uint32_t bias_in_dram = get_compile_time_arg_val(6) == 1; - - read_bias_with_offset(bias_addr, bias_tile_offset, bias_ntiles, bias_cb_id, bias_log2_of_pagesize, bias_pagesize); - #endif - - // DPRINT << "tile_nbytes - " << tile_nbytes << ENDL(); - // DPRINT << "out_num_blocks_h - " << out_num_blocks_h << ENDL(); - // DPRINT << "out_num_blocks_w - " << out_num_blocks_w << ENDL(); - - // DPRINT << "out_num_subblocks_h - " << out_num_subblocks_h << ENDL(); - // DPRINT << "out_num_subblocks_w - " << out_num_subblocks_w << ENDL(); - - // DPRINT << "out_subblock_h - " << out_subblock_h << ENDL(); - // DPRINT << "out_subblock_w - " << out_subblock_w << ENDL(); - - // DPRINT << "out_subblock_tile_count - " << out_subblock_tile_count << ENDL(); - - // DPRINT << "num_blocks_weight_h - " << num_blocks_weight_h << ENDL(); - // DPRINT << "weight_block_height_ntiles - " << weight_block_height_ntiles << ENDL(); - // DPRINT << "weight_block_width_ntiles - " << weight_block_width_ntiles << ENDL(); - - // DPRINT << "out_subblock_h - " << out_subblock_h << ENDL(); - // DPRINT << "out_subblock_w - " << out_subblock_w << ENDL(); - // DPRINT << "out_block_height_num_tiles - " << out_block_height_num_tiles << ENDL(); - // DPRINT << "out_height_num_tiles - " << out_height_num_tiles << ENDL(); - // DPRINT << "out_width_num_tiles - " << out_width_num_tiles << ENDL(); - - const uint32_t weight_tile_nbytes = get_tile_size(cb_id_weight); - const InterleavedPow2AddrGen s_weight = { - .bank_base_address = weight_addr_dram_base, - .log_base_2_of_page_size = tile_size_pow2_exponent - }; - - // const InterleavedAddrGenFast s = { - // .bank_base_address = out_addr, - // .page_size = tile_nbytes, - // .data_format = out_df - // }; - - // OUTER most loop is looping over out blocks in width dim because blocks from compute are in col major order. - // Write out col major blocks in row major layout to output - uint32_t out_block_w_start_tile_id = out_start_tile_id; - //DPRINT << "out_start_tile_id=" << out_start_tile_id << ENDL(); - uint32_t out_block_w_start_tile_id_w = out_start_tile_id_w; - uint32_t weight_start_tile_id = out_start_tile_id_w; - //DPRINT << "weight_start_tile_id=" << weight_start_tile_id << ENDL(); - for (uint32_t bw = 0; bw < out_num_blocks_w; bw++) { - uint32_t out_block_h_start_tile_id = out_block_w_start_tile_id; - uint32_t out_block_h_start_tile_id_h = out_start_tile_id_h; - for(uint32_t bh = 0; bh < out_num_blocks_h; bh++) { - // read weight blocks inner dim - read_weight_blocks_inner_h_dim(cb_id_weight, - num_blocks_weight_h, - weight_block_num_tiles, - weight_start_tile_id, - weight_block_height_ntiles, - weight_block_width_ntiles, - s_weight, - weight_tile_nbytes, - weight_stride_h, - weight_next_block_stride_h); - - uint32_t out_sbh_start_tile_id = out_block_h_start_tile_id; - uint32_t out_sbh_start_tile_id_h = out_block_h_start_tile_id_h; // - for(uint32_t sbh = 0; sbh < out_num_subblocks_h; sbh++) { - uint32_t out_sbw_start_tile_id = out_sbh_start_tile_id; - uint32_t out_sbw_start_tile_id_w = out_block_w_start_tile_id_w; - for(uint32_t sbw = 0; sbw < out_num_subblocks_w; sbw++) { - uint32_t out_sb_row_start_tile_id = out_sbw_start_tile_id; - // wait for one subblock worth tiles - cb_wait_front(cb_id_out0, out_subblock_tile_count); - uint32_t l1_read_addr = get_read_ptr(cb_id_out0); - for(uint32_t h = 0; h < out_subblock_h; h++) { - uint32_t out_tile_id = out_sb_row_start_tile_id; - uint32_t out_tile_id_h = out_sbh_start_tile_id_h + h; - if (out_tile_id_h >= out_height_num_tiles) { // block shape height padding - break; - } - for(uint32_t w = 0; w < out_subblock_w; w++) { - uint32_t out_tile_id_w = out_sbw_start_tile_id_w + w; - if (out_tile_id_w >= out_width_num_tiles) { // block shape width padding - l1_read_addr += tile_nbytes; - } else { - //DPRINT << "out_tile_id - " << out_tile_id << ENDL(); - uint64_t out_tile_noc_addr = get_noc_addr(out_tile_id, s); - //DPRINT << "out_tile_id=" << out_tile_id << ENDL(); - noc_async_write(l1_read_addr, out_tile_noc_addr, tile_nbytes); - l1_read_addr += tile_nbytes; - out_tile_id += out_next_tile_stride_w; - } - } // out_subblock_w (ntiles) - out_sb_row_start_tile_id += out_next_tile_stride_h; - } // out_subblock_h (ntiles) - noc_async_write_barrier(); - //DPRINT << "Done writing subblock." << ENDL(); - cb_pop_front(cb_id_out0, out_subblock_tile_count); - out_sbw_start_tile_id += out_next_subblock_stride_w; - out_sbw_start_tile_id_w += out_subblock_w; - } // out_num_subblocks_w - out_sbh_start_tile_id += out_next_subblock_stride_h; - out_sbh_start_tile_id_h += out_subblock_h; - } // out_num_subblocks_h - out_block_h_start_tile_id += out_next_block_stride_h; - out_block_h_start_tile_id_h += out_block_height_num_tiles; - } // out_num_blocks_h - out_block_w_start_tile_id += out_next_block_stride_w; - out_block_w_start_tile_id_w += weight_block_width_ntiles; - - // Increment weight start tile id for next block in width dim - weight_start_tile_id += weight_next_block_stride_w; - } // out_num_blocks_w -} diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_reader_conv_weights_tiled_col_to_rm_blocks_read_weight_slices_once.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_reader_conv_weights_tiled_col_to_rm_blocks_read_weight_slices_once.cpp deleted file mode 100644 index 385073c31cf..00000000000 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_reader_conv_weights_tiled_col_to_rm_blocks_read_weight_slices_once.cpp +++ /dev/null @@ -1,234 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include "dataflow_api.h" - -// #include "debug/dprint.h" - -#ifdef FUSE_BIAS - #include "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/dataflow/reader_bmm_single_core_bias.hpp" -#endif - -inline void read_weight_blocks_inner_h_dim(uint32_t cb_id_weight, - uint32_t num_blocks_weight_h, - uint32_t weight_block_num_tiles, - uint32_t weight_start_tile_id, - uint32_t weight_block_height_ntiles, - uint32_t weight_block_width_ntiles, - const InterleavedPow2AddrGen& s_weight, - uint32_t weight_tile_nbytes, - uint32_t weight_stride_h, - uint32_t weight_next_block_stride_h) { - // weight DRAM -> L1 (weights in tiled form) - uint32_t weight_current_block_start_tile_id = weight_start_tile_id; - for(uint32_t block_weight_h = 0; block_weight_h < num_blocks_weight_h; block_weight_h++) { - cb_reserve_back(cb_id_weight, weight_block_num_tiles); - uint32_t weight_write_l1_addr = get_write_ptr(cb_id_weight); - uint32_t weight_row_start_tile_id = weight_current_block_start_tile_id; - // loop over weight block tiles along h - for(uint32_t weight_tile_h_i = 0; weight_tile_h_i < weight_block_height_ntiles; ++weight_tile_h_i) { - uint32_t weight_tile_id = weight_row_start_tile_id; - // loop over weight block tiles along w - for(uint32_t weight_tile_w_i = 0; weight_tile_w_i < weight_block_width_ntiles; ++weight_tile_w_i) { - uint64_t weight_tile_noc_addr = get_noc_addr(weight_tile_id, s_weight); - //DPRINT << "weight_tile_id=" << weight_tile_id << ENDL(); - noc_async_read(weight_tile_noc_addr, weight_write_l1_addr, weight_tile_nbytes); - weight_write_l1_addr += weight_tile_nbytes; - weight_tile_id += 1; - } // for weight_block_w - weight_row_start_tile_id += weight_stride_h; - } // for weight_block_h - noc_async_read_barrier(); - weight_current_block_start_tile_id += weight_next_block_stride_h; - cb_push_back(cb_id_weight, weight_block_num_tiles); - } // for num_blocks_weight_h -} - -void kernel_main() { - // This writer is for output tensor in tile format - uint32_t i = 0; - uint32_t out_addr = get_arg_val(i); i+=1; - uint32_t weight_addr_dram_base = get_arg_val(i); i+=1; - // Bias arg. Unused if bias fusion is not enabled. - const uint32_t bias_addr = get_arg_val(i); i += 1; - - uint32_t out_next_tile_stride_h = get_arg_val(i); i+=1; - uint32_t out_next_tile_stride_w = get_arg_val(i); i+=1; - uint32_t out_next_subblock_stride_h = get_arg_val(i); i+=1; - uint32_t out_next_subblock_stride_w = get_arg_val(i); i+=1; - uint32_t out_next_block_stride_h = get_arg_val(i); i+=1; - uint32_t out_next_block_stride_w = get_arg_val(i); i+=1; - uint32_t out_subblock_h = get_arg_val(i); i+=1; - uint32_t out_subblock_w = get_arg_val(i); i+=1; - uint32_t out_subblock_tile_count = get_arg_val(i); i+=1; - uint32_t out_num_subblocks_h = get_arg_val(i); i+=1; - uint32_t out_num_subblocks_w = get_arg_val(i); i+=1; - uint32_t out_num_blocks_h = get_arg_val(i); i+=1; - uint32_t out_num_blocks_w = get_arg_val(i); i+=1; - uint32_t out_block_height_num_tiles = get_arg_val(i); i+=1; - uint32_t out_height_num_tiles = get_arg_val(i); i+=1; - uint32_t out_width_num_tiles = get_arg_val(i); i+=1; - uint32_t out_start_tile_id = get_arg_val(i); i+=1; - uint32_t out_start_tile_id_h = get_arg_val(i); i+=1; - uint32_t out_start_tile_id_w = get_arg_val(i); i+=1; - - uint32_t num_blocks_weight_h = get_arg_val(i); i+=1; - uint32_t weight_block_num_tiles = get_arg_val(i); i+=1; - uint32_t weight_block_height_ntiles = get_arg_val(i); i+=1; - uint32_t weight_block_width_ntiles = get_arg_val(i); i+=1; - uint32_t weight_stride_h = get_arg_val(i); i+=1; - uint32_t weight_next_block_stride_h = get_arg_val(i); i+=1; - uint32_t weight_next_block_stride_w = get_arg_val(i); i+=1; - - // Bias arg. Unused if bias fusion is not enabled. - const uint32_t bias_ntiles = get_arg_val(i); i += 1; - const uint32_t bias_tile_offset = get_arg_val(i); i += 1; - - uint32_t noop = get_arg_val(i); i+=1; - if(noop) { - return; - } - - constexpr bool out_in_dram = get_compile_time_arg_val(0) == 1; - constexpr uint32_t cb_id_out0 = get_compile_time_arg_val(1); - constexpr uint32_t cb_id_weight = get_compile_time_arg_val(2); - - const uint32_t tile_nbytes = get_tile_size(cb_id_out0); - const DataFormat out_df = get_dataformat(cb_id_out0); - - constexpr uint32_t tile_size_pow2_exponent = 11; // == 2^11 = 2048 = 2 * 32 * 32 (assuming dtype = 2 bytes) - const InterleavedPow2AddrGen s = { - .bank_base_address = out_addr, - .log_base_2_of_page_size = tile_size_pow2_exponent - }; - - // first read in bias if enabled (done only once for all batches) - #ifdef FUSE_BIAS - - constexpr uint32_t bias_cb_id = get_compile_time_arg_val(3); - constexpr uint32_t bias_log2_of_pagesize = get_compile_time_arg_val(4); - constexpr uint32_t bias_pagesize = get_compile_time_arg_val(5); - constexpr uint32_t bias_in_dram = get_compile_time_arg_val(6) == 1; - - read_bias_with_offset(bias_addr, bias_tile_offset, bias_ntiles, bias_cb_id, bias_log2_of_pagesize, bias_pagesize); - #endif - - // DPRINT << "tile_nbytes - " << tile_nbytes << ENDL(); - // DPRINT << "out_num_blocks_h - " << out_num_blocks_h << ENDL(); - // DPRINT << "out_num_blocks_w - " << out_num_blocks_w << ENDL(); - - // DPRINT << "out_num_subblocks_h - " << out_num_subblocks_h << ENDL(); - // DPRINT << "out_num_subblocks_w - " << out_num_subblocks_w << ENDL(); - - // DPRINT << "out_subblock_h - " << out_subblock_h << ENDL(); - // DPRINT << "out_subblock_w - " << out_subblock_w << ENDL(); - - // DPRINT << "out_subblock_tile_count - " << out_subblock_tile_count << ENDL(); - - // DPRINT << "num_blocks_weight_h - " << num_blocks_weight_h << ENDL(); - // DPRINT << "weight_block_height_ntiles - " << weight_block_height_ntiles << ENDL(); - // DPRINT << "weight_block_width_ntiles - " << weight_block_width_ntiles << ENDL(); - - // DPRINT << "out_subblock_h - " << out_subblock_h << ENDL(); - // DPRINT << "out_subblock_w - " << out_subblock_w << ENDL(); - // DPRINT << "out_block_height_num_tiles - " << out_block_height_num_tiles << ENDL(); - // DPRINT << "out_height_num_tiles - " << out_height_num_tiles << ENDL(); - // DPRINT << "out_width_num_tiles - " << out_width_num_tiles << ENDL(); - - const uint32_t weight_tile_nbytes = get_tile_size(cb_id_weight); - const InterleavedPow2AddrGen s_weight = { - .bank_base_address = weight_addr_dram_base, - .log_base_2_of_page_size = tile_size_pow2_exponent - }; - - // const InterleavedAddrGenFast s = { - // .bank_base_address = out_addr, - // .page_size = tile_nbytes, - // .data_format = out_df - // }; - - // OUTER most loop is looping over out blocks in width dim because blocks from compute are in col major order. - // Write out col major blocks in row major layout to output - uint32_t out_block_w_start_tile_id = out_start_tile_id; - //DPRINT << "out_start_tile_id=" << out_start_tile_id << ENDL(); - uint32_t out_block_w_start_tile_id_w = out_start_tile_id_w; - uint32_t weight_start_tile_id = out_start_tile_id_w; - //DPRINT << "weight_start_tile_id=" << weight_start_tile_id << ENDL(); - for (uint32_t bw = 0; bw < out_num_blocks_w; bw++) { - - // read weight blocks inner dim - // read weight slice - 1 block of weights in width dim and full weight matrix height - // read slice only once for all activation blocks - read_weight_blocks_inner_h_dim(cb_id_weight, - num_blocks_weight_h, - weight_block_num_tiles, - weight_start_tile_id, - weight_block_height_ntiles, - weight_block_width_ntiles, - s_weight, - weight_tile_nbytes, - weight_stride_h, - weight_next_block_stride_h); - - // Increment weight start tile id for next block in width dim - weight_start_tile_id += weight_next_block_stride_w; - - #ifndef SHARDED_OUT - uint32_t out_block_h_start_tile_id = out_block_w_start_tile_id; - //DPRINT << "out_block_h_start_tile_id=" << out_block_h_start_tile_id << ENDL(); - uint32_t out_block_h_start_tile_id_h = out_start_tile_id_h; - for(uint32_t bh = 0; bh < out_num_blocks_h; bh++) { - - uint32_t out_sbh_start_tile_id = out_block_h_start_tile_id; - uint32_t out_sbh_start_tile_id_h = out_block_h_start_tile_id_h; // - for(uint32_t sbh = 0; sbh < out_num_subblocks_h; sbh++) { - uint32_t out_sbw_start_tile_id = out_sbh_start_tile_id; - uint32_t out_sbw_start_tile_id_w = out_block_w_start_tile_id_w; - for(uint32_t sbw = 0; sbw < out_num_subblocks_w; sbw++) { - uint32_t out_sb_row_start_tile_id = out_sbw_start_tile_id; - // wait for one subblock worth tiles - cb_wait_front(cb_id_out0, out_subblock_tile_count); - uint32_t l1_read_addr = get_read_ptr(cb_id_out0); - for(uint32_t h = 0; h < out_subblock_h; h++) { - uint32_t out_tile_id = out_sb_row_start_tile_id; - uint32_t out_tile_id_h = out_sbh_start_tile_id_h + h; - if (out_tile_id_h >= out_height_num_tiles) { // block shape height padding - break; - } - for(uint32_t w = 0; w < out_subblock_w; w++) { - uint32_t out_tile_id_w = out_sbw_start_tile_id_w + w; - if (out_tile_id_w >= out_width_num_tiles) { // block shape width padding - l1_read_addr += tile_nbytes; - } else { - //DPRINT << "out_tile_id - " << out_tile_id << ENDL(); - uint64_t out_tile_noc_addr = get_noc_addr(out_tile_id, s); - //DPRINT << "out_tile_id=" << out_tile_id << ENDL(); - noc_async_write(l1_read_addr, out_tile_noc_addr, tile_nbytes); - l1_read_addr += tile_nbytes; - out_tile_id += out_next_tile_stride_w; - } - } // out_subblock_w (ntiles) - out_sb_row_start_tile_id += out_next_tile_stride_h; - } // out_subblock_h (ntiles) - noc_async_write_barrier(); - //DPRINT << "Done writing subblock." << ENDL(); - cb_pop_front(cb_id_out0, out_subblock_tile_count); - out_sbw_start_tile_id += out_next_subblock_stride_w; - out_sbw_start_tile_id_w += out_subblock_w; - } // out_num_subblocks_w - out_sbh_start_tile_id += out_next_subblock_stride_h; - out_sbh_start_tile_id_h += out_subblock_h; - } // out_num_subblocks_h - out_block_h_start_tile_id += out_next_block_stride_h; - out_block_h_start_tile_id_h += out_block_height_num_tiles; - } // out_num_blocks_h - out_block_w_start_tile_id += out_next_block_stride_w; - out_block_w_start_tile_id_w += weight_block_width_ntiles; - #endif - } // out_num_blocks_w - - #ifdef SHARDED_OUT - cb_wait_front(cb_id_out0, out_subblock_tile_count * out_num_subblocks_h * out_num_subblocks_w * out_num_blocks_w * out_num_blocks_h); - #endif -} diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_unary_stick_8bank_blocks_reader_weight_tile_with_pow2_addr_gen_fast.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_unary_stick_8bank_blocks_reader_weight_tile_with_pow2_addr_gen_fast.cpp deleted file mode 100644 index 82945621314..00000000000 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_unary_stick_8bank_blocks_reader_weight_tile_with_pow2_addr_gen_fast.cpp +++ /dev/null @@ -1,163 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include -#include "dataflow_api.h" -// #include "debug/dprint.h" - -FORCE_INLINE void read_weight_blocks_inner_h_dim(uint32_t cb_id_weight, - uint32_t num_blocks_weight_h, - uint32_t weight_block_num_tiles, - uint32_t weight_start_tile_id, - uint32_t weight_block_height_ntiles, - uint32_t weight_block_width_ntiles, - const InterleavedPow2AddrGen& s_weight, - uint32_t weight_tile_nbytes, - uint32_t weight_stride_h, - uint32_t weight_next_block_stride_h) { - // weight DRAM -> L1 (weights in tiled form) - uint32_t weight_current_block_start_tile_id = weight_start_tile_id; - for(uint32_t block_weight_h = 0; block_weight_h < num_blocks_weight_h; block_weight_h++) { - cb_reserve_back(cb_id_weight, weight_block_num_tiles); - uint32_t weight_write_l1_addr = get_write_ptr(cb_id_weight); - uint32_t weight_row_start_tile_id = weight_current_block_start_tile_id; - // loop over weight block tiles along h - for(uint32_t weight_tile_h_i = 0; weight_tile_h_i < weight_block_height_ntiles; ++weight_tile_h_i) { - uint32_t weight_tile_id = weight_row_start_tile_id; - // loop over weight block tiles along w - for(uint32_t weight_tile_w_i = 0; weight_tile_w_i < weight_block_width_ntiles; ++weight_tile_w_i) { - uint64_t weight_tile_noc_addr = get_noc_addr(weight_tile_id, s_weight); - noc_async_read(weight_tile_noc_addr, weight_write_l1_addr, weight_tile_nbytes); - weight_write_l1_addr += weight_tile_nbytes; - weight_tile_id += 1; - } // for weight_block_w - weight_row_start_tile_id += weight_stride_h; - } // for weight_block_h - noc_async_read_barrier(); - - weight_current_block_start_tile_id += weight_next_block_stride_h; - cb_push_back(cb_id_weight, weight_block_num_tiles); - } // for num_blocks_weight_h -} - -template -FORCE_INLINE void write_tiles_in_output_block(uint32_t cb_id_out0, - uint32_t block_height_ntiles, - uint32_t block_width_ntiles, - uint32_t block_start_row_id, - uint32_t block_row_offset, - uint32_t block_row_size, - uint32_t block_row_size_unpadded, // to remove padding from the last block in the row - uint32_t num_rows_unpadded, - const InterleavedPow2AddrGenFast& s) { - constexpr uint32_t TILE_HEIGHT = 32; // TODO: use common source of truth - uint32_t block_row_id = block_start_row_id; - for (uint32_t tile_row_id = 0; tile_row_id < block_height_ntiles; tile_row_id++) { - // We reserve back an entire row of tiles in a block and issue a bunch of reads - cb_wait_front(cb_id_out0, block_width_ntiles); - uint32_t l1_read_addr = get_read_ptr(cb_id_out0); - for (uint32_t j = 0; j < TILE_HEIGHT; j++) { - if (block_row_id >= num_rows_unpadded) { - break; - } - s.noc_async_write_page(block_row_id, l1_read_addr, block_row_size_unpadded, block_row_offset); - l1_read_addr += block_row_size; - block_row_id++; - } // for tile_nrows - noc_async_write_barrier(); - cb_pop_front(cb_id_out0, block_width_ntiles); - } // for block_height_ntiles -} - -void kernel_main() { - uint32_t i = 0; - uint32_t dst_addr = get_arg_val(i); i+=1; // out_dram_addr - uint32_t weight_addr_dram_base = get_arg_val(i); i+=1; - - uint32_t num_rows_block = get_arg_val(i); i+=1; - uint32_t block_row_size = get_arg_val(i); i+=1; // in0_block_w * TILE_WIDTH * dtype_nbytes - uint32_t batch = get_arg_val(i); i+=1; - uint32_t num_blocks_h = get_arg_val(i); i+=1; - uint32_t num_blocks_w = get_arg_val(i); i+=1; - uint32_t output_row_size = get_arg_val(i); i+=1; // output row size bytes - uint32_t last_block_row_size_unpadded = get_arg_val(i); i+=1; // unpadded last block width - uint32_t num_output_rows_unpadded = get_arg_val(i); i+=1; - - uint32_t num_blocks_weight_h = get_arg_val(i); i+=1; - uint32_t weight_block_num_tiles = get_arg_val(i); i+=1; - uint32_t weight_block_height_ntiles = get_arg_val(i); i+=1; - uint32_t weight_block_width_ntiles = get_arg_val(i); i+=1; - uint32_t weight_stride_h = get_arg_val(i); i+=1; - uint32_t weight_next_block_stride_h = get_arg_val(i); i+=1; - uint32_t weight_next_block_stride_w = get_arg_val(i); i+=1; - - - constexpr bool out_in_dram = get_compile_time_arg_val(0) == 1; - constexpr uint32_t cb_id_out0 = get_compile_time_arg_val(1); - constexpr uint32_t cb_id_weight = get_compile_time_arg_val(2); - constexpr uint32_t log_2_of_output_row_size = get_compile_time_arg_val(3); - - //DPRINT << "cb id weight " << cb_id_weight << ENDL(); - // NOTE: Row major layout only supports bfp16 - // TT_ASSERT(out_df != DataFormat::Bfp8_b); - const DataFormat out_df = get_dataformat(cb_id_out0); - - constexpr uint32_t TILE_HEIGHT = 32; // TODO: use common source of truth - - const uint32_t block_width_ntiles = block_row_size >> 6; // Assuming 2 bytes per datum, there are 64 bytes per tile row - const uint32_t block_height_ntiles = num_rows_block / TILE_HEIGHT; - uint32_t block_start_row_id = 0; - - const InterleavedPow2AddrGenFast s = { - .bank_base_address = dst_addr, - .log_base_2_of_page_size = log_2_of_output_row_size - }; - - const uint32_t weight_tile_nbytes = get_tile_size(cb_id_weight); - constexpr uint32_t tile_size_pow2_exponent = 11; - const InterleavedPow2AddrGen s_weight = { - .bank_base_address = weight_addr_dram_base, - .log_base_2_of_page_size = tile_size_pow2_exponent - }; - - for(uint32_t b = 0; b < batch; ++b) { - for(uint32_t block_h = 0; block_h < num_blocks_h; block_h++) { - uint32_t block_row_offset = 0; - // Reset weight start tile index - uint32_t weight_start_tile_id = 0; - for(uint32_t block_w = 0; block_w < num_blocks_w; block_w++) { - - // read weight blocks inner dim - read_weight_blocks_inner_h_dim(cb_id_weight, - num_blocks_weight_h, - weight_block_num_tiles, - weight_start_tile_id, - weight_block_height_ntiles, - weight_block_width_ntiles, - s_weight, - weight_tile_nbytes, - weight_stride_h, - weight_next_block_stride_h); - // Increment weight start tile id for next block in width dim - weight_start_tile_id += weight_next_block_stride_w; - - uint32_t current_block_row_size_unpadded = block_row_size; - if(block_w == (num_blocks_w - 1)) { - current_block_row_size_unpadded = last_block_row_size_unpadded; - } - write_tiles_in_output_block(cb_id_out0, - block_height_ntiles, - block_width_ntiles, - block_start_row_id, - block_row_offset, - block_row_size, - current_block_row_size_unpadded, // padding is only in the last block - num_output_rows_unpadded, - s); - block_row_offset += block_row_size; - } // for num_blocks_w - block_start_row_id += num_rows_block; - } // for num_blocks_h - } // for batch -} diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_unary_stick_layout_8bank_blocks_reader_weight_tile_layout.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_unary_stick_layout_8bank_blocks_reader_weight_tile_layout.cpp deleted file mode 100644 index dc6f878cb7c..00000000000 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_unary_stick_layout_8bank_blocks_reader_weight_tile_layout.cpp +++ /dev/null @@ -1,167 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include -#include "dataflow_api.h" -// #include "debug/dprint.h" - -// TODO: FORCE INLINE -inline void read_weight_blocks_inner_h_dim(uint32_t cb_id_weight, - uint32_t num_blocks_weight_h, - uint32_t weight_block_num_tiles, - uint32_t weight_start_tile_id, - uint32_t weight_block_height_ntiles, - uint32_t weight_block_width_ntiles, - const InterleavedPow2AddrGen& s_weight, - uint32_t weight_tile_nbytes, - uint32_t weight_stride_h, - uint32_t weight_next_block_stride_h) { - // weight DRAM -> L1 (weights in tiled form) - uint32_t weight_current_block_start_tile_id = weight_start_tile_id; - for(uint32_t block_weight_h = 0; block_weight_h < num_blocks_weight_h; block_weight_h++) { - cb_reserve_back(cb_id_weight, weight_block_num_tiles); - uint32_t weight_write_l1_addr = get_write_ptr(cb_id_weight); - uint32_t weight_row_start_tile_id = weight_current_block_start_tile_id; - // loop over weight block tiles along h - for(uint32_t weight_tile_h_i = 0; weight_tile_h_i < weight_block_height_ntiles; ++weight_tile_h_i) { - uint32_t weight_tile_id = weight_row_start_tile_id; - // loop over weight block tiles along w - for(uint32_t weight_tile_w_i = 0; weight_tile_w_i < weight_block_width_ntiles; ++weight_tile_w_i) { - uint64_t weight_tile_noc_addr = get_noc_addr(weight_tile_id, s_weight); - noc_async_read(weight_tile_noc_addr, weight_write_l1_addr, weight_tile_nbytes); - weight_write_l1_addr += weight_tile_nbytes; - weight_tile_id += 1; - } // for weight_block_w - weight_row_start_tile_id += weight_stride_h; - } // for weight_block_h - noc_async_read_barrier(); - - weight_current_block_start_tile_id += weight_next_block_stride_h; - cb_push_back(cb_id_weight, weight_block_num_tiles); - } // for num_blocks_weight_h -} - -template -inline void write_tiles_in_output_block(uint32_t cb_id_out0, - uint32_t block_height_ntiles, - uint32_t block_width_ntiles, - uint32_t block_start_row_id, - uint32_t block_row_offset, - uint32_t block_row_size, - uint32_t block_row_size_unpadded, // to remove padding from the last block in the row - uint32_t num_rows_unpadded, - const InterleavedAddrGen& s) { - constexpr uint32_t TILE_HEIGHT = 32; // TODO: use common source of truth - uint32_t block_row_id = block_start_row_id; - for (uint32_t tile_row_id = 0; tile_row_id < block_height_ntiles; tile_row_id++) { - // We reserve back an entire row of tiles in a block and issue a bunch of reads - cb_wait_front(cb_id_out0, block_width_ntiles); - uint32_t l1_read_addr = get_read_ptr(cb_id_out0); - for (uint32_t j = 0; j < TILE_HEIGHT; j++) { - if (block_row_id >= num_rows_unpadded) { - break; - } - uint64_t dst_noc_addr = get_noc_addr(block_row_id, s, block_row_offset); - noc_async_write(l1_read_addr, dst_noc_addr, block_row_size_unpadded); - l1_read_addr += block_row_size; - block_row_id++; - } // for tile_nrows - noc_async_write_barrier(); - cb_pop_front(cb_id_out0, block_width_ntiles); - } // for block_height_ntiles -} - -void kernel_main() { - uint32_t i = 0; - uint32_t dst_addr = get_arg_val(i); i+=1; // out_dram_addr - uint32_t weight_addr_dram_base = get_arg_val(i); i+=1; - - uint32_t num_rows_block = get_arg_val(i); i+=1; - uint32_t block_row_size = get_arg_val(i); i+=1; // in0_block_w * TILE_WIDTH * dtype_nbytes - uint32_t batch = get_arg_val(i); i+=1; - uint32_t num_blocks_h = get_arg_val(i); i+=1; - uint32_t num_blocks_w = get_arg_val(i); i+=1; - uint32_t output_row_size = get_arg_val(i); i+=1; // output row size bytes - uint32_t last_block_row_size_unpadded = get_arg_val(i); i+=1; // unpadded last block width - uint32_t num_output_rows_unpadded = get_arg_val(i); i+=1; - - uint32_t num_blocks_weight_h = get_arg_val(i); i+=1; - uint32_t weight_block_num_tiles = get_arg_val(i); i+=1; - uint32_t weight_block_height_ntiles = get_arg_val(i); i+=1; - uint32_t weight_block_width_ntiles = get_arg_val(i); i+=1; - uint32_t weight_stride_h = get_arg_val(i); i+=1; - uint32_t weight_next_block_stride_h = get_arg_val(i); i+=1; - uint32_t weight_next_block_stride_w = get_arg_val(i); i+=1; - - - constexpr bool out_in_dram = get_compile_time_arg_val(0) == 1; - constexpr uint32_t cb_id_out0 = get_compile_time_arg_val(1); - constexpr uint32_t cb_id_weight = get_compile_time_arg_val(2); - //DPRINT << "cb id weight " << cb_id_weight << ENDL(); - // NOTE: Row major layout only supports bfp16 - // TT_ASSERT(out_df != DataFormat::Bfp8_b); - const DataFormat out_df = get_dataformat(cb_id_out0); - - constexpr uint32_t TILE_HEIGHT = 32; // TODO: use common source of truth - - const uint32_t block_width_ntiles = block_row_size >> 6; // Assuming 2 bytes per datum, there are 64 bytes per tile row - const uint32_t block_height_ntiles = num_rows_block / TILE_HEIGHT; - uint32_t block_start_row_id = 0; - - // const InterleavedAddrGenFast s = { - // .bank_base_address = dst_addr, - // .page_size = output_row_size, - // .data_format = out_df - // }; - const InterleavedAddrGen s = { - .bank_base_address = dst_addr, - .page_size = output_row_size - }; - const uint32_t weight_tile_nbytes = get_tile_size(cb_id_weight); - constexpr uint32_t tile_size_pow2_exponent = 11; - const InterleavedPow2AddrGen s_weight = { - .bank_base_address = weight_addr_dram_base, - .log_base_2_of_page_size = tile_size_pow2_exponent - }; - - for(uint32_t b = 0; b < batch; ++b) { - for(uint32_t block_h = 0; block_h < num_blocks_h; block_h++) { - uint32_t block_row_offset = 0; - // Reset weight start tile index - uint32_t weight_start_tile_id = 0; - for(uint32_t block_w = 0; block_w < num_blocks_w; block_w++) { - - // read weight blocks inner dim - read_weight_blocks_inner_h_dim(cb_id_weight, - num_blocks_weight_h, - weight_block_num_tiles, - weight_start_tile_id, - weight_block_height_ntiles, - weight_block_width_ntiles, - s_weight, - weight_tile_nbytes, - weight_stride_h, - weight_next_block_stride_h); - // Increment weight start tile id for next block in width dim - weight_start_tile_id += weight_next_block_stride_w; - - uint32_t current_block_row_size_unpadded = block_row_size; - if(block_w == (num_blocks_w - 1)) { - current_block_row_size_unpadded = last_block_row_size_unpadded; - } - write_tiles_in_output_block(cb_id_out0, - block_height_ntiles, - block_width_ntiles, - block_start_row_id, - block_row_offset, - block_row_size, - current_block_row_size_unpadded, // padding is only in the last block - num_output_rows_unpadded, - s); - block_row_offset += block_row_size; - } // for num_blocks_w - block_start_row_id += num_rows_block; - } // for num_blocks_h - } // for batch -} diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/multi_core_optimized_conv/optimized_conv_op.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/multi_core_optimized_conv/optimized_conv_op.cpp deleted file mode 100644 index 77c67e0595f..00000000000 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/multi_core_optimized_conv/optimized_conv_op.cpp +++ /dev/null @@ -1,1261 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include "ttnn/operations/conv/conv2d/device/optimized_conv_op.hpp" -#include "tt_metal/host_api.hpp" -#include "tt_metal/detail/tt_metal.hpp" -#include "tt_metal/detail/util.hpp" -#include "tt_metal/common/constants.hpp" - -#include "tt_stl/reflection.hpp" - -#include "ttnn/deprecated/tt_dnn/op_library/sharding_utilities.hpp" -#include "ttnn/operations/experimental/auto_format/auto_format.hpp" - -#include "ttnn/tensor/tensor_utils.hpp" - -using namespace tt::constants; -namespace ttnn::operations::conv { -namespace conv2d { - -using namespace tt; - -const uint32_t act_cb = CB::c_in0; -const uint32_t weight_cb = CB::c_in1; -const uint32_t bias_cb = CB::c_in2; -const uint32_t sharded_act_cb = CB::c_in3; -const uint32_t cb_for_reader_indices = CB::c_in4; -const uint32_t cb_for_reader_offsets = CB::c_in5; -const uint32_t sharded_act_mcast_receiver_cb = CB::c_in6; -const uint32_t matmul_partials_cb = CB::c_intermed0; -const uint32_t tilize_mode_tilized_act_cb = CB::c_intermed1; -const uint32_t untilize_mode_reblock_cb = CB::c_intermed2; -const uint32_t out0_cb = CB::c_out0; - - -std::tuple create_CBs(tt_metal::Program &program, - const Tensor& input, - CoreRange core, - uint32_t num_cb0_tiles, - uint32_t num_cb1_tiles, - uint32_t num_cb0_tilized_tiles, - uint32_t num_output_tiles, - uint32_t num_reblock_cb_tiles, - uint32_t num_writer_output_tiles, - bool untilize_out, - tt::DataFormat act_df, - tt::DataFormat weight_df, - tt::DataFormat tilized_act_df, - tt::DataFormat out_df, - tt::DataFormat bias_df, - bool weight_width_sliced, - const Tensor& output, - uint32_t bias_ntiles = 0, - bool with_bias = false -) { - - uint32_t act_tile_size = tt_metal::detail::TileSize(act_df); - uint32_t weight_tile_size = tt_metal::detail::TileSize(weight_df); - uint32_t tilized_act_tile_size = tt_metal::detail::TileSize(tilized_act_df); - uint32_t out_tile_size = tt_metal::detail::TileSize(out_df); - - // Invariants - CircularBufferConfig cb_act_config = CircularBufferConfig(num_cb0_tiles * act_tile_size, {{act_cb, act_df}}) - .set_page_size(act_cb, act_tile_size); - auto cb_act = tt_metal::CreateCircularBuffer(program, core, cb_act_config); - - CBHandle cb_sharded_act = 0; - CBHandle cb_sharded_act_mcast_receiver = 0; - if (input.is_sharded()) { - uint32_t num_bytes_for_df = datum_size(act_df); - auto shard_shape = input.shard_spec().value().shape; - CircularBufferConfig cb_sharded_act_config = CircularBufferConfig(shard_shape[0] * shard_shape[1] * num_bytes_for_df, {{sharded_act_cb, act_df}}) - .set_page_size(sharded_act_cb, shard_shape[1] * num_bytes_for_df); - // incoming data is the input cb instead of raw l1/dram addr - cb_sharded_act_config.set_globally_allocated_address(*input.buffer()); - cb_sharded_act = tt_metal::CreateCircularBuffer(program, core, cb_sharded_act_config); - - // For 2D convs, we need a separate cb to receive mcasted input shards - if (weight_width_sliced) { - CircularBufferConfig cb_sharded_act_mcast_receiver_config = CircularBufferConfig(shard_shape[0] * shard_shape[1] * num_bytes_for_df, {{sharded_act_mcast_receiver_cb, tt::DataFormat::Float16_b}}) - .set_page_size(sharded_act_mcast_receiver_cb, shard_shape[1] * num_bytes_for_df); - cb_sharded_act_mcast_receiver = tt_metal::CreateCircularBuffer(program, core, cb_sharded_act_mcast_receiver_config); - } - } - - CircularBufferConfig cb_weight_config = CircularBufferConfig(num_cb1_tiles * weight_tile_size, {{weight_cb, weight_df}}) - .set_page_size(weight_cb, weight_tile_size); - auto cb_weight = tt_metal::CreateCircularBuffer(program, core, cb_weight_config); - - // Used for placing tilized activations - CircularBufferConfig cb_src0_tilized_config = CircularBufferConfig(num_cb0_tilized_tiles * tilized_act_tile_size, {{tilize_mode_tilized_act_cb, tilized_act_df}}) - .set_page_size(tilize_mode_tilized_act_cb, tilized_act_tile_size); - auto cb_src0_tilized = tt_metal::CreateCircularBuffer(program, core, cb_src0_tilized_config); - - CBHandle cb_output = 0; - if (untilize_out) { - CircularBufferConfig cb_matmul_partials_config = CircularBufferConfig(num_output_tiles * out_tile_size, {{matmul_partials_cb, out_df}}) - .set_page_size(matmul_partials_cb, out_tile_size); - auto cb_matmul_partials = tt_metal::CreateCircularBuffer(program, core, cb_matmul_partials_config); - - // Supposed to be a small CB only responsible for reorganizing - // the output blocks to fill the whole "per core output block width" - CircularBufferConfig cb_reblock_config = CircularBufferConfig(num_reblock_cb_tiles * out_tile_size, {{untilize_mode_reblock_cb, out_df}}) - .set_page_size(untilize_mode_reblock_cb, out_tile_size); - auto cb_reblock = tt_metal::CreateCircularBuffer(program, core, cb_reblock_config); - - CircularBufferConfig cb_output_config = CircularBufferConfig(num_writer_output_tiles * out_tile_size, {{out0_cb, out_df}}) - .set_page_size(out0_cb, out_tile_size); - if (output.is_sharded()) { - cb_output_config = cb_output_config.set_globally_allocated_address(*output.buffer()); - } - cb_output = tt_metal::CreateCircularBuffer(program, core, cb_output_config); - } else { - CoreRangeSet cores(std::set({core})); - std::map cb_output_data_format_spec = { - {out0_cb, out_df}, - {matmul_partials_cb, out_df} - }; - CircularBufferConfig cb_matmul_partials_config = CircularBufferConfig(num_output_tiles * out_tile_size, cb_output_data_format_spec) - .set_page_size(out0_cb, out_tile_size) - .set_page_size(matmul_partials_cb, out_tile_size); - if (output.is_sharded()) { - cb_matmul_partials_config = cb_matmul_partials_config.set_globally_allocated_address(*output.buffer()); - } - cb_output = tt_metal::CreateCircularBuffer(program, cores, cb_matmul_partials_config); - } - - if (with_bias) { - uint32_t bias_tile_size = tt_metal::detail::TileSize(bias_df); - // bias input - uint32_t bias_pagesize = bias_tile_size; - CircularBufferConfig cb_bias_config = CircularBufferConfig(bias_ntiles * bias_pagesize, {{bias_cb, bias_df}}) - .set_page_size(bias_cb, bias_pagesize); - auto cb_bias = tt_metal::CreateCircularBuffer(program, core, cb_bias_config); - - log_debug("BIAS CBs: {} {} {}", bias_cb, bias_ntiles, bias_pagesize); - } - - return {cb_sharded_act, cb_output}; -} - -operation::ProgramWithCallbacks multi_core_optimized_conv_(const Tensor& a, const Tensor &b, const Shape& ashape, std::optional bias, vector conv_params, uint32_t output_channels, bool untilize_out, bool has_bias, bool fuse_relu, const MathFidelity math_fidelity, const OptimizedConvParallelizationConfig& parallelization_config, const OptimizedConvBlockConfig& block_config, uint32_t extra_padding_for_32B_alignment, Tensor &output) { - bool pass = true; - tt_metal::Device *device = a.device(); - TT_ASSERT(a.get_layout() == Layout::ROW_MAJOR, "Conv activation should be in row major layout"); - TT_ASSERT(output_channels <= b.get_legacy_shape()[3], "Invalid weight shape. Incorrect weight tensor."); - uint32_t act_block_h_ntiles = block_config.act_block_h_ntiles; - uint32_t act_block_w_ntiles = block_config.act_block_w_ntiles; - uint32_t weight_block_w_ntiles = parallelization_config.per_core_out_matrix_width_ntiles; - uint32_t out_block_h_ntiles = parallelization_config.per_core_out_matrix_height_ntiles; - uint32_t out_subblock_h_ntiles = block_config.out_subblock_h_ntiles; - uint32_t out_subblock_w_ntiles = block_config.out_subblock_w_ntiles; - //assert(out_block_h_ntiles == act_block_h_ntiles); // TODO: fix output block sizing - TT_ASSERT(out_block_h_ntiles >= act_block_h_ntiles, "Output block height (in # of tiles) should be greater than or equal to activation block height (in # of tiles)"); - - // Partitions conv inner dim into blocks to support sharding along this dim - // TODO: Only 2D convs with sharded input use this, but we can uplift to support generically - // TODO: Only updated variables which is affected, but there may be more that needs to account for this - // TODO: Loop naming in reader, writer, and compute kernels could also be cleaned up - // TODO: Can conv_act_c_blocks be same as num_blocks_act_w? - - uint32_t conv_act_size_h = ashape[1]; - uint32_t conv_act_size_w = ashape[2]; - uint32_t conv_act_size_c = ashape[3]; - uint32_t weight_size_h = (uint32_t) conv_params[0]; - uint32_t weight_size_w = (uint32_t) conv_params[1]; - uint32_t stride_h = (uint32_t) conv_params[2]; - uint32_t stride_w = (uint32_t) conv_params[3]; - uint32_t pad_h = (uint32_t) conv_params[4]; - uint32_t pad_w = (uint32_t) conv_params[5]; - - bool rn50_first_conv = (conv_act_size_h == 230 && conv_act_size_w == (231 + extra_padding_for_32B_alignment) && - weight_size_h == 7 && weight_size_w == 8 && - stride_h == 2 && stride_w == 2); - // Compute the 2d matrix shape - auto [act_matrix_shape, act_matrix_shape_unpadded] = optimized_conv_op_utils::compute_opt_conv_activation_as_mm_shape(ashape.value, conv_params, out_block_h_ntiles, extra_padding_for_32B_alignment); - assert(act_matrix_shape.size() == 3); - assert(act_matrix_shape[0] == 1); - uint32_t act_matrix_height = (uint32_t) act_matrix_shape[1]; - uint32_t act_matrix_width = (uint32_t) act_matrix_shape[2]; - uint32_t act_matrix_height_unpadded = (uint32_t) act_matrix_shape_unpadded[1]; - uint32_t act_matrix_width_unpadded = (uint32_t) act_matrix_shape_unpadded[2]; - - // Tensor b has weights and it should be tiled layout after converting conv weights into weight matrix - TT_ASSERT(b.get_layout() == Layout::TILE, "Conv weights should be in tiled layout"); - TT_ASSERT(b.get_legacy_shape()[0] == 1, "Conv weight matrix shape is invalid"); - TT_ASSERT(b.get_legacy_shape()[1] == 1, "Conv weight matrix shape is invalid"); - uint32_t weight_matrix_height = b.get_legacy_shape()[2]; - uint32_t weight_matrix_width = b.get_legacy_shape()[3]; - - if (has_bias) { - // Tensor bias is of shape {output_channels} - TT_ASSERT(bias.has_value()); - TT_ASSERT(bias.value().buffer() != nullptr); - auto bias_shape_without_padding = bias.value().get_legacy_shape().without_padding(); - TT_ASSERT(bias_shape_without_padding[0] == 1, "Bias should have batch == 1"); - // TT_ASSERT(bias_shape_without_padding[1] == 1 && bias_shape_without_padding[2] == 1, "Bias should have H == W == 1"); - TT_ASSERT(bias_shape_without_padding[3] == output_channels, "Bias should have output_channels"); - } - - // Normal matrix shape check - TT_ASSERT(act_matrix_width == weight_matrix_height, "The width of tensor a needs to match the height of tensor b"); - - // Tile size divisibility checks - TT_ASSERT(act_matrix_height % TILE_HEIGHT == 0, "Height of activation matrix needs to be divisible by 32"); - TT_ASSERT(act_matrix_width % TILE_WIDTH == 0, "Width of activation matrix needs to be divisible by 32"); - TT_ASSERT(weight_matrix_height % TILE_HEIGHT == 0, "Height of weight matrix needs to be divisible by 32"); - TT_ASSERT(weight_matrix_width % TILE_WIDTH == 0, "Width of weight matrix needs to be divisible by 32"); - - // Device compatibility checks - TT_ASSERT(a.storage_type() == StorageType::DEVICE && - b.storage_type() == StorageType::DEVICE && - "Operands to large matmul need to be on device!"); - TT_ASSERT(a.device() == b.device(), "Operands to conv need to be on the same device!"); - TT_ASSERT(a.buffer() != nullptr && b.buffer() != nullptr, "Operands to conv need to be allocated in buffers on device!"); - if (has_bias) { - TT_ASSERT(bias.value().storage_type() == StorageType::DEVICE, "Bias should be on device"); - TT_ASSERT(bias.value().device() == a.device(), "Bias should be on the same device as act tensor"); - } - - // Convert tensor dims to tile dims - uint32_t act_matrix_height_ntiles = act_matrix_height / TILE_HEIGHT; - uint32_t act_matrix_width_ntiles = act_matrix_width / TILE_WIDTH; - uint32_t weight_matrix_height_ntiles = weight_matrix_height / TILE_HEIGHT; - uint32_t weight_matrix_width_ntiles = weight_matrix_width / TILE_WIDTH; - - assert(act_matrix_height_ntiles % act_block_h_ntiles == 0); - assert(act_matrix_width_ntiles % act_block_w_ntiles == 0); - assert(weight_matrix_width_ntiles % weight_block_w_ntiles == 0); - assert(act_matrix_height_ntiles % out_block_h_ntiles == 0); - - uint32_t conv_act_c_blocks = 1; // input is HS - uint32_t num_blocks_act_h = act_matrix_height_ntiles / act_block_h_ntiles; - uint32_t num_blocks_out_h = act_matrix_height_ntiles / out_block_h_ntiles; - uint32_t num_blocks_act_w = act_matrix_width_ntiles / act_block_w_ntiles; - uint32_t num_blocks_weight_w = weight_matrix_width_ntiles / weight_block_w_ntiles; - - if (rn50_first_conv) { - assert(num_blocks_weight_w == 1); - } - - // act block info - uint32_t act_block_w_datums = act_matrix_width / num_blocks_act_w; - uint32_t act_block_h_datums = act_matrix_height / num_blocks_act_h; - TT_ASSERT((act_block_w_datums == conv_act_size_c * weight_size_w) || ((act_block_w_datums <= conv_act_size_c) && (conv_act_size_c % act_block_w_datums == 0))); - - - // weight block info - uint32_t weight_block_w_datums = weight_matrix_width / num_blocks_weight_w; - assert(weight_block_w_ntiles % out_subblock_w_ntiles == 0); - uint32_t weight_num_subblocks = weight_block_w_ntiles / out_subblock_w_ntiles; - uint32_t weight_block_h_ntiles = act_block_w_ntiles; - uint32_t weight_block_num_tiles = weight_block_w_ntiles * weight_block_h_ntiles; - - uint32_t num_groups = num_blocks_act_h * num_blocks_act_w * num_blocks_weight_w; - // writer of conv op partially removes padding on the width - // it removes the padding done for block width but it doesn't remove padding done for tiled width - uint32_t output_channels_padded_to_tile_width = round_up(output_channels, TILE_WIDTH); - assert(output_channels_padded_to_tile_width <= weight_matrix_width); - uint32_t output_width_num_tiles = output_channels_padded_to_tile_width / TILE_WIDTH; - uint32_t num_blocks_output_w = (uint32_t) std::ceil((double) output_channels_padded_to_tile_width / (double) weight_block_w_datums); - uint32_t last_block_width_datums = (output_channels_padded_to_tile_width % weight_block_w_datums == 0) ? weight_block_w_datums : (output_channels_padded_to_tile_width % weight_block_w_datums); - assert(last_block_width_datums % TILE_WIDTH == 0); - - // sanity check - assert(num_blocks_output_w == num_blocks_weight_w); - tt_metal::Program program = tt_metal::CreateProgram(); - //CoreCoord core_coord = {0, 0}; // TODO: avoid another var here. Find a way to use core range instead. - //CoreRange core = {{0, 0}, {0, 0}}; - - tt::DataFormat act_df = tt_metal::datatype_to_dataformat_converter(a.get_dtype()); - tt::DataFormat weight_df = tt_metal::datatype_to_dataformat_converter(b.get_dtype()); - tt::DataFormat out_df = tt_metal::datatype_to_dataformat_converter(output.get_dtype()); - tt::DataFormat bias_df = has_bias ? tt_metal::datatype_to_dataformat_converter(bias.value().get_dtype()) : tt::DataFormat::Float16_b; - tt::DataFormat tilized_act_df = out_df; - - tt_metal::Buffer *src0_dram_buffer = a.buffer(); - tt_metal::Buffer *src1_dram_buffer = b.buffer(); - - tt_metal::Buffer *dst_dram_buffer = output.buffer(); - TT_ASSERT(dst_dram_buffer != nullptr, "Output buffer should be allocated on device!"); - - // out - uint32_t out_dram_addr = dst_dram_buffer->address(); - uint32_t out_subblock_num_tiles = out_subblock_h_ntiles * out_subblock_w_ntiles; - TT_ASSERT(out_subblock_num_tiles <= 8, "Need to ensure that matmul partials fit in dst"); - - // act - uint32_t act_dram_addr = src0_dram_buffer->address(); - auto act_dram_noc_xy = src0_dram_buffer->noc_coordinates(); - uint32_t act_noc_x = act_dram_noc_xy.x; - uint32_t act_noc_y = act_dram_noc_xy.y; - - assert(act_matrix_width_ntiles % act_block_w_ntiles == 0); - assert(act_block_h_ntiles % out_subblock_h_ntiles == 0); - assert(out_block_h_ntiles % out_subblock_h_ntiles == 0); - uint32_t act_num_subblocks = act_block_h_ntiles / out_subblock_h_ntiles; - uint32_t act_block_num_tiles = act_block_h_ntiles * act_block_w_ntiles; - uint32_t act_subblock_h_ntiles = out_subblock_h_ntiles; - uint32_t act_subblock_num_tiles = act_subblock_h_ntiles * act_block_w_ntiles; - - // weight - uint32_t weight_dram_addr = src1_dram_buffer->address(); - auto weight_dram_noc_xy = src1_dram_buffer->noc_coordinates(); - uint32_t weight_noc_x = weight_dram_noc_xy.x; - uint32_t weight_noc_y = weight_dram_noc_xy.y; - - // bias - tt_metal::Buffer *bias_buffer = nullptr; - uint32_t bias_dram_addr = 0; - uint32_t bias_ntiles = 0; - if (has_bias) { - bias_buffer = bias.value().buffer(); - bias_dram_addr = bias_buffer->address(); - bias_ntiles = bias.value().get_legacy_shape()[3] / constants::TILE_WIDTH; // TODO: support non tile multiple sizes - } - - //uint32_t conv_output_size_h = ((conv_act_size_h - weight_size_h + (2 * pad_h)) / stride_h) + 1; - //uint32_t conv_output_size_w = ((conv_act_size_w - weight_size_w + (2 * pad_w)) / stride_w) + 1; - - auto [conv_output_size_h, conv_output_size_w] = optimized_conv_op_utils::compute_opt_conv_output_face_shape(conv_act_size_h, conv_act_size_w, weight_size_h, weight_size_w, stride_h, stride_w, pad_h, pad_w, extra_padding_for_32B_alignment); - - std::map reader_defines; - - if (act_matrix_height_unpadded < act_block_h_datums * num_blocks_act_h) { - reader_defines["ACT_BLOCK_HEIGHT_PADDING"] = "1"; - } - - if (conv_act_c_blocks > 1) { - reader_defines["ACT_W_OUTER_BLOCKS"] = "1"; - } - - uint32_t output_height_padded_to_tile_height = round_up(act_matrix_height_unpadded, TILE_HEIGHT); - uint32_t output_height_num_tiles = output_height_padded_to_tile_height / TILE_HEIGHT; - assert(output_height_num_tiles <= act_matrix_height_ntiles); - - uint32_t src_dram_act_buffer_size_bytes = src0_dram_buffer->size(); - uint32_t src_dram_weight_buffer_size_bytes = src1_dram_buffer->size(); - uint32_t dst_l1_act_buffer_size_bytes = act_block_h_ntiles * act_block_w_ntiles * tt::tt_metal::detail::TileSize(act_df); - uint32_t dst_l1_weight_buffer_size_bytes = weight_block_h_ntiles * weight_block_w_ntiles * tt::tt_metal::detail::TileSize(weight_df); - - - // For debug - { - log_debug(tt::LogOp, "conv_act_size_c: {}", conv_act_size_c); - log_debug(tt::LogOp, "conv_act_size_h: {}", conv_act_size_h); - log_debug(tt::LogOp, "conv_act_size_w: {}", conv_act_size_w); - log_debug(tt::LogOp, "act_matrix_height: {}", act_matrix_height); - log_debug(tt::LogOp, "act_matrix_width: {}", act_matrix_width); - log_debug(tt::LogOp, "act_matrix_height_unpadded: {}", act_matrix_height_unpadded); - log_debug(tt::LogOp, "act_matrix_width_unpadded: {}", act_matrix_width_unpadded); - log_debug(tt::LogOp, "act_matrix_height_ntiles: {}", act_matrix_height_ntiles); - log_debug(tt::LogOp, "act_matrix_width_ntiles: {}", act_matrix_width_ntiles); - log_debug(tt::LogOp, "weight_matrix_width_ntiles: {}", weight_matrix_width_ntiles); - log_debug(tt::LogOp, "num_blocks_act_h: {}", num_blocks_act_h); - log_debug(tt::LogOp, "num_blocks_act_w: {}", num_blocks_act_w); - log_debug(tt::LogOp, "num_blocks_weight_w: {}", num_blocks_weight_w); - log_debug(tt::LogOp, "num_blocks_out_h: {}", num_blocks_out_h); - log_debug(tt::LogOp, "act_dram_addr: {}", act_dram_addr); - log_debug(tt::LogOp, "act_block_h_ntiles: {}", act_block_h_ntiles); - log_debug(tt::LogOp, "act_block_h_datums: {}", act_block_h_datums); - log_debug(tt::LogOp, "act_block_w_ntiles: {}", act_block_w_ntiles); - log_debug(tt::LogOp, "act_block_w_datums: {}", act_block_w_datums); - log_debug(tt::LogOp, "out_block_h_ntiles: {}", out_block_h_ntiles); - log_debug(tt::LogOp, "act_num_subblocks: {}", act_num_subblocks); - log_debug(tt::LogOp, "act_block_num_tiles: {}", act_block_num_tiles); - log_debug(tt::LogOp, "act_subblock_h_ntiles: {}", act_subblock_h_ntiles); - log_debug(tt::LogOp, "act_subblock_num_tiles: {}", act_subblock_num_tiles); - log_debug(tt::LogOp, "out_subblock_num_tiles: {}", out_subblock_num_tiles); - log_debug(tt::LogOp, "weight_dram_addr: {}", weight_dram_addr); - log_debug(tt::LogOp, "weight_num_subblocks: {}", weight_num_subblocks); - log_debug(tt::LogOp, "weight_block_num_tiles: {}", weight_block_num_tiles); - log_debug(tt::LogOp, "weight_block_w_ntiles: {}", weight_block_w_ntiles); - log_debug(tt::LogOp, "weight_block_h_ntiles: {}", weight_block_h_ntiles); - log_debug(tt::LogOp, "has_bias: {}", has_bias); - log_debug(tt::LogOp, "bias_dram_addr: {}", bias_dram_addr); - log_debug(tt::LogOp, "bias_ntiles: {}", bias_ntiles); - log_debug(tt::LogOp, "out_dram_addr: {}", out_dram_addr); - log_debug(tt::LogOp, "out_subblock_h_ntiles: {}", out_subblock_h_ntiles); - log_debug(tt::LogOp, "out_subblock_w_ntiles: {}", out_subblock_w_ntiles); - log_debug(tt::LogOp, "out_subblock_num_tiles: {}", out_subblock_num_tiles); - log_debug(tt::LogOp, "num_groups: {}", num_groups); - } - // parallelization config - const auto& p_config = parallelization_config; - uint32_t num_cores_x = p_config.grid_size.x; - uint32_t num_cores_y = p_config.grid_size.y; - uint32_t total_num_cores = num_cores_x * num_cores_y; - assert(num_cores_x < 13); - assert(num_cores_y < 10); - uint32_t per_core_out_matrix_height_ntiles = p_config.per_core_out_matrix_height_ntiles; - uint32_t per_core_out_matrix_width_ntiles = p_config.per_core_out_matrix_width_ntiles; - //cout << "per_core_weight_matrix_width_ntiles=" << per_core_weight_matrix_width_ntiles << endl; - // cout << "total_num_cores=" << total_num_cores << endl; - // cout << "per_core_out_matrix_height_ntiles=" << per_core_out_matrix_height_ntiles << endl; - // cout << "act_matrix_height_ntiles=" << act_matrix_height_ntiles << endl; - // cout << "act_block_h_datums=" << act_block_h_datums << endl; - // cout << "num_blocks_act_h=" << num_blocks_act_h << endl; - bool weight_width_sliced = per_core_out_matrix_width_ntiles < weight_matrix_width_ntiles; - assert(weight_matrix_width_ntiles % per_core_out_matrix_width_ntiles == 0); - assert(per_core_out_matrix_width_ntiles % weight_block_w_ntiles == 0); - uint32_t num_blocks_weight_w_per_core = per_core_out_matrix_width_ntiles / weight_block_w_ntiles; - if (not weight_width_sliced) { - assert(num_blocks_weight_w_per_core == num_blocks_weight_w); - } - uint32_t num_weight_slices_width = weight_matrix_width_ntiles / per_core_out_matrix_width_ntiles; - assert(num_cores_y % num_weight_slices_width == 0); - uint32_t num_cores_y_per_weight_slice_width = num_cores_y / num_weight_slices_width; - uint32_t total_num_cores_per_weight_slice = num_cores_y_per_weight_slice_width * num_cores_x; - if (weight_width_sliced) { - assert(total_num_cores_per_weight_slice * per_core_out_matrix_height_ntiles == act_matrix_height_ntiles); - } - else { - assert(total_num_cores * per_core_out_matrix_height_ntiles >= act_matrix_height_ntiles); - } - assert(per_core_out_matrix_height_ntiles % act_block_h_ntiles == 0); - uint32_t num_blocks_act_h_per_core = per_core_out_matrix_height_ntiles / act_block_h_ntiles; - assert(per_core_out_matrix_height_ntiles % out_block_h_ntiles == 0); - uint32_t num_blocks_out_h_per_core = per_core_out_matrix_height_ntiles / out_block_h_ntiles; - bool act_height_sliced = per_core_out_matrix_height_ntiles < act_matrix_height_ntiles; - if (not act_height_sliced) { - assert(num_blocks_act_h_per_core == num_blocks_act_h); - assert(num_blocks_out_h_per_core == num_blocks_out_h); - assert(num_cores_x == 1); - } - // cout << "num_blocks_act_h_per_core=" << num_blocks_act_h_per_core << endl; - assert(act_matrix_height_ntiles % per_core_out_matrix_height_ntiles == 0); - uint32_t total_active_num_cores_per_weight_slice = act_matrix_height_ntiles / per_core_out_matrix_height_ntiles; - assert(total_active_num_cores_per_weight_slice <= total_num_cores_per_weight_slice); - uint32_t total_noop_cores = total_num_cores_per_weight_slice - total_active_num_cores_per_weight_slice; - uint32_t total_active_num_cores = total_active_num_cores_per_weight_slice * num_weight_slices_width; - if (weight_width_sliced) { - assert(total_noop_cores == 0); - assert(total_active_num_cores == total_num_cores); - } - // cout << "act_matrix_height_ntiles=" << act_matrix_height_ntiles << endl; - // cout << "per_core_out_matrix_height_ntiles=" << per_core_out_matrix_height_ntiles << endl; - // cout << "total_active_num_cores_per_weight_slice="<< total_active_num_cores_per_weight_slice << endl; - // cout << "num weight slices = " << num_weight_slices_width << endl; - // cout << "total num active cores" << total_active_num_cores << endl; - if (has_bias) { - assert(bias_ntiles % num_weight_slices_width == 0); - assert(bias_ntiles == weight_matrix_width_ntiles); - } - uint32_t bias_ntiles_per_core = bias_ntiles / num_weight_slices_width; - - bool act_block_w_equals_input_channels_x_filter_width = (act_block_w_datums == (conv_act_size_c * weight_size_w)); - if (rn50_first_conv) { - assert(not weight_width_sliced); // weight width slicing not supported for rn50 first conv - assert(act_block_w_equals_input_channels_x_filter_width); - } - - vector debug_cores; - for(uint32_t core_i = 0; core_i < total_num_cores; core_i++) { - uint32_t core_x_i = core_i % num_cores_x; - uint32_t core_y_i = core_i / num_cores_x; - debug_cores.push_back({core_x_i+1, core_y_i+1}); - } - - CoreRange all_cores(CoreCoord(0, 0), CoreCoord(num_cores_x - 1, num_cores_y - 1)); - assert(total_active_num_cores >= num_cores_x); - uint32_t num_active_cores_x = num_cores_x; - uint32_t num_active_cores_y_with_full_x = total_active_num_cores / num_cores_x; - uint32_t num_active_cores_x_last_y = total_active_num_cores % num_cores_x; - assert((num_active_cores_x * num_active_cores_y_with_full_x) + num_active_cores_x_last_y == total_active_num_cores); - - // cout << "All active cores. Core Ranges:" << endl; - // cout << "Core range 1 - (0,0) to (" << num_active_cores_x - 1 << "," << num_active_cores_y_with_full_x - 1 << ")" << endl; - - std::set all_active_cores_set; - all_active_cores_set.insert(CoreRange(CoreCoord(0, 0), CoreCoord(num_active_cores_x - 1, num_active_cores_y_with_full_x - 1))); - if (num_active_cores_x_last_y > 0) { - all_active_cores_set.insert(CoreRange(CoreCoord(0, num_active_cores_y_with_full_x), CoreCoord(num_active_cores_x_last_y - 1, num_active_cores_y_with_full_x))); - // cout << "Core range 2 - (0," << num_active_cores_y_with_full_x << ") to (" << num_active_cores_x_last_y - 1 << "," << num_active_cores_y_with_full_x << ")" << endl; - } - CoreRangeSet all_active_cores(all_active_cores_set); - std::set noop_cores_set; - if (total_noop_cores > 0) { - assert(total_noop_cores == (num_cores_x - num_active_cores_x_last_y)); - noop_cores_set.insert(CoreRange(CoreCoord(num_active_cores_x_last_y, num_active_cores_y_with_full_x), CoreCoord(num_cores_x - 1, num_active_cores_y_with_full_x))); - // cout << "Noop core range - (" << num_active_cores_x_last_y << "," << num_active_cores_y_with_full_x << ") to (" << num_cores_x - 1 << "," << num_active_cores_y_with_full_x << ")" << endl; - - } - CoreRangeSet noop_cores(noop_cores_set); - - // Mcast cores - // If total_num_cores, there is no mcasting - CoreCoord top_left_core = {(std::size_t) 0, (std::size_t) 0}; - CoreCoord top_left_core_plus_one = {(std::size_t) 1, (std::size_t) 1}; - CoreCoord bottom_right_core = {(std::size_t) num_cores_x - 1, (std::size_t) num_cores_y - 1}; - auto top_left_core_physical = device->worker_core_from_logical_core(top_left_core); - auto top_left_core_plus_one_physical = device->worker_core_from_logical_core(top_left_core_plus_one); - auto bottom_right_core_physical = device->worker_core_from_logical_core(bottom_right_core); - - CoreRange mcast_sender_cores(top_left_core, top_left_core); // If single core, this kernel doesn't do mcasting - CoreRangeSet mcast_receiver_cores{{}}; - uint32_t weights_mcast_sender_semaphore_id{}; - uint32_t weights_mcast_receiver_semaphore_id{}; - uint32_t act_mcast_sender_semaphore_id = 0; - uint32_t act_mcast_receiver_semaphore_id = 0; - std::vector act_mcast_noc_y; - // 2D mcast - if (weight_width_sliced) { - mcast_sender_cores = CoreRange(top_left_core, CoreCoord(0, num_cores_y - 1)); - mcast_receiver_cores = {{CoreRange(CoreCoord(1, 0), bottom_right_core)}}; - weights_mcast_sender_semaphore_id = tt_metal::CreateSemaphore(program, all_cores, INVALID); - weights_mcast_receiver_semaphore_id = tt_metal::CreateSemaphore(program, all_cores, INVALID); - // 1D mcast - } else { - if (total_num_cores > 1) { - std::set mcast_receiver_set; - if (num_cores_x > 1) { - mcast_receiver_set.insert(CoreRange(CoreCoord(1, 0), CoreCoord(num_cores_x - 1, 0))); - } if (num_cores_y > 1) { - mcast_receiver_set.insert(CoreRange(CoreCoord(0, 1), bottom_right_core)); - } - mcast_receiver_cores = mcast_receiver_set; - weights_mcast_sender_semaphore_id = tt_metal::CreateSemaphore(program, all_cores, INVALID); - weights_mcast_receiver_semaphore_id = tt_metal::CreateSemaphore(program, all_cores, INVALID); - } - } - - bool read_3x3_window_in_inner_loop = false; - uint32_t num_weight_cb_tiles = weight_block_h_ntiles * weight_block_w_ntiles / conv_act_c_blocks; - uint32_t num_act_cb_tiles = act_block_h_ntiles * act_block_w_ntiles / conv_act_c_blocks; - // TODO: This flag should be set in kernel logic but need this for create_CB - if (a.memory_config().is_sharded() && weight_size_h == 3 && weight_size_w == 3 && stride_h == 1 && weight_width_sliced) { - // If conv_act_c_blocks > 1 and we have 2D conv with sharded input, we always read entire 3x3 window before pushing in reader/writer - // TODO: Generalize this to not make this assumption - read_3x3_window_in_inner_loop = true; - num_weight_cb_tiles *= weight_size_h * weight_size_w; - num_act_cb_tiles *= weight_size_h * weight_size_w; - } - uint32_t num_cb0_tilized_tiles = num_act_cb_tiles; - - if (per_core_out_matrix_width_ntiles < 8) { - num_weight_cb_tiles = num_weight_cb_tiles * 2; - } - if (rn50_first_conv) { - num_weight_cb_tiles = weight_block_h_ntiles * weight_block_w_ntiles * num_blocks_weight_w * num_blocks_act_w; - } - if (conv_act_size_c / conv_act_c_blocks < 256) { - num_act_cb_tiles = num_act_cb_tiles * 2; // double buffered - } - uint32_t writer_output_block_num_tiles = out_block_h_ntiles * weight_block_w_ntiles; - - // if (!(conv_output_size_w == 14 || conv_output_size_w == 7)) { - // writer_output_block_num_tiles = writer_output_block_num_tiles * 2; - // } - - // TODO: Moving this function call to after kernel logic causes pcc fails - // There are additional CBs and semaphores created in 2D conv in kernel logic, - // so does order of create_cb calls matter? - auto [cb_sharded_act, cb_output] = create_CBs( - program, - a, - all_cores, - num_act_cb_tiles, // row major act cb - num_weight_cb_tiles, // tiled weight cb - num_cb0_tilized_tiles, // tiled act cb - writer_output_block_num_tiles, // math output cb - weight_block_w_ntiles, // reblock cb - writer_output_block_num_tiles, // writer output cb, double bufferred - untilize_out, - act_df, - weight_df, - tilized_act_df, - out_df, - bias_df, - weight_width_sliced, - output, - bias_ntiles_per_core, - has_bias); - - string reader_kernel; - string compute_kernel; - string writer_mcast_sender_kernel; - string writer_mcast_receiver_kernel; - bool reader_with_indices = false; - if (rn50_first_conv) { - reader_kernel = "ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_conv_activations_fast_resnet50_first_conv.cpp"; - compute_kernel = "ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/bmm_tilize_untilize_all_weights_in_l1_single_output_block_width_dim.cpp"; - writer_mcast_sender_kernel = "ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_and_mcast_sender_weights_resnet50_first_conv_tiled_out.cpp"; - writer_mcast_receiver_kernel = "ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_and_mcast_receiver_weights_resnet50_first_conv_tiled_out.cpp"; - } else { - compute_kernel = "ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/conv_bmm_tilize_col_major_out_blocks.cpp"; - writer_mcast_sender_kernel = "ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_mcast_sender_conv_weights_tiled_col_to_rm_blocks.cpp"; - writer_mcast_receiver_kernel = "ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_mcast_receiver_conv_weights_tiled_col_to_rm_blocks.cpp"; - if (weight_size_h == 1 && weight_size_w == 1) { - // use custom 1x1 conv kernels - reader_kernel = "ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_conv1x1_activations_fast_for_col_major_conv_out_blocks.cpp"; - assert(conv_act_size_c % act_block_w_datums == 0); - assert(num_blocks_act_w == (conv_act_size_c / act_block_w_datums)); - } - else { - // If sharded input, always use reader kernel for input shard with halo and padding - if (a.memory_config().is_sharded() && weight_size_h == 3 && weight_size_w == 3 && stride_h == 1) { - reader_with_indices = true; - if (weight_width_sliced) { - assert(read_3x3_window_in_inner_loop == true); - reader_kernel = "ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_conv_activations_2d_mcast_padded_with_halo_3x3_weights.cpp"; - writer_mcast_sender_kernel = "ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_2d_mcast_sender_conv_weights_tiled_col_to_rm_blocks.cpp"; - writer_mcast_receiver_kernel = "ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_2d_mcast_receiver_conv_weights_tiled_col_to_rm_blocks.cpp"; - act_mcast_sender_semaphore_id = tt_metal::CreateSemaphore(program, all_cores, INVALID); - act_mcast_receiver_semaphore_id = tt_metal::CreateSemaphore(program, all_cores, INVALID); - - act_mcast_noc_y.reserve(num_cores_y); - for(uint32_t core_idx_y = 0; core_idx_y < num_cores_y; ++core_idx_y) { - act_mcast_noc_y.push_back(device->worker_core_from_logical_core({0, core_idx_y}).y); - } - } else { - reader_kernel = "ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_conv_activations_padded_with_halo_3x3_weights.cpp"; - } - - // Local L1 to store array for reader indices - CircularBufferConfig cb_for_reader_indices_config = CircularBufferConfig(act_block_h_datums * 4, {{cb_for_reader_indices, tt::DataFormat::Float16_b}}) - .set_page_size(cb_for_reader_indices, 4); - auto cb_for_reader_indices_id = tt_metal::CreateCircularBuffer(program, all_cores, cb_for_reader_indices_config); - - // Local L1 to store array for reader offsets - CircularBufferConfig cb_for_reader_offsets_config = CircularBufferConfig(weight_size_h * weight_size_w * 4, {{cb_for_reader_offsets, tt::DataFormat::Float16_b}}) - .set_page_size(cb_for_reader_offsets, 4); - auto cb_for_reader_offsets_id = tt_metal::CreateCircularBuffer(program, all_cores, cb_for_reader_offsets_config); - } else { - // non 1x1 conv - if (act_block_w_equals_input_channels_x_filter_width) { - reader_kernel = "ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_conv_activations_act_block_w_equals_channels_X_filter_width.cpp"; - } else { - assert(act_block_w_datums == conv_act_size_c); - assert(num_blocks_act_w == weight_size_w * weight_size_h); - reader_kernel = "ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_conv_activations_fast_for_col_major_conv_out_blocks.cpp"; - } - } - } - } - TT_ASSERT(!(conv_act_size_c & (conv_act_size_c - 1))); // channel depth power of 2 is supported only - - std::vector reader_rt_args; - std::vector reader_compile_time_args; - std::vector writer_rt_args; - std::vector writer_compile_time_args; - - uint32_t conv_act_c_read_bytes = conv_act_size_c * a.element_size() / conv_act_c_blocks; - // For new reader_with_indices, this is used to calculate offset so use actual read_bytes along c - // For old readers, this is used for bank page size for interleaved; offset is from conv_act_c_read_bytes - uint32_t log_base_2_of_conv_act_size_c_bytes = reader_with_indices ? std::log2(conv_act_c_read_bytes) : std::log2(conv_act_size_c * a.element_size()); - reader_compile_time_args = {(uint32_t) - (src0_dram_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0), - (uint32_t) stride_h, - (uint32_t) stride_w, - (uint32_t) conv_act_size_w, - (uint32_t) conv_output_size_w, - (uint32_t) conv_act_c_read_bytes, - (uint32_t) log_base_2_of_conv_act_size_c_bytes, extra_padding_for_32B_alignment, - (uint32_t) (conv_act_size_c/act_block_w_datums), act_block_w_datums * a.element_size()}; - - // define for bias - std::map writer_defines; - std::map writer_mcast_sender_defines; - std::map compute_defines; - if (output.memory_config().is_sharded()) { - writer_defines["SHARDED_OUT"] = "1"; - writer_mcast_sender_defines["SHARDED_OUT"] = "1"; - } - if (total_num_cores == 1) { - writer_mcast_sender_defines["SKIP_MCAST"] = "1"; - } - if (has_bias) { - writer_defines["FUSE_BIAS"] = "1"; - writer_mcast_sender_defines["FUSE_BIAS"] = "1"; - compute_defines["FUSE_BIAS"] = "1"; - } - - if (fuse_relu) { - compute_defines["PACK_RELU"] = "1"; - } - - writer_compile_time_args = { - (uint32_t) (dst_dram_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0), - out0_cb, - weight_cb, - bias_cb, - (uint32_t) (bias_buffer == nullptr ? 0 : (bias_buffer->buffer_type() == BufferType::DRAM ? 1 : 0))}; - - uint32_t in0_block_w = act_block_w_ntiles / conv_act_c_blocks; - uint32_t in0_block_num_tiles = act_block_num_tiles / conv_act_c_blocks; - uint32_t in0_subblock_num_tiles = act_subblock_num_tiles / conv_act_c_blocks; - uint32_t in1_block_num_tiles = weight_block_num_tiles / conv_act_c_blocks; - uint32_t in0_num_blocks_w = num_blocks_act_w * conv_act_c_blocks; // Fold outer c_block loop together with weight_block_num_tiles = 9 - if (read_3x3_window_in_inner_loop) { - const uint32_t window_size = weight_size_h * weight_size_w; - in0_block_w *= window_size; - in0_block_num_tiles *= window_size; - in0_subblock_num_tiles *= window_size; - in1_block_num_tiles *= window_size; - in0_num_blocks_w /= window_size; - } - - vector compute_kernel_args = { - in0_block_w, - act_num_subblocks, - in0_block_num_tiles, - in0_subblock_num_tiles, - act_subblock_h_ntiles, - - weight_num_subblocks, - in1_block_num_tiles, - weight_block_w_ntiles, - - num_blocks_act_h_per_core, - in0_num_blocks_w, - num_blocks_weight_w_per_core, - - out_subblock_h_ntiles, - out_subblock_w_ntiles, - out_subblock_num_tiles, - - true, - untilize_out, - - bias_ntiles_per_core - }; - - auto writer_mcast_noc = tt_metal::detail::GetPreferredNOCForDRAMWrite(device->arch()); - auto reader_noc = tt_metal::detail::GetPreferredNOCForDRAMRead(device->arch()); - auto writer_mcast_sender_id = CreateKernel( - program, - writer_mcast_sender_kernel, - mcast_sender_cores, - DataMovementConfig{ - .processor = DataMovementProcessor::RISCV_0, - .noc = writer_mcast_noc, - .compile_args = writer_compile_time_args, - .defines = writer_mcast_sender_defines}); - - KernelHandle writer_mcast_receiver_id{}; - if (total_num_cores > 1) { - writer_mcast_receiver_id = CreateKernel( - program, - writer_mcast_receiver_kernel, - mcast_receiver_cores, - DataMovementConfig{ - .processor = DataMovementProcessor::RISCV_0, - .noc = writer_mcast_noc, - .compile_args = writer_compile_time_args, - .defines = writer_defines}); - } - - auto reader_id = CreateKernel( - program, - reader_kernel, - all_cores, - DataMovementConfig{ - .processor = DataMovementProcessor::RISCV_1, - .noc = reader_noc, - .compile_args = reader_compile_time_args, - .defines = reader_defines}); - - // Compile compute kernel for active cores only - // Compile blank kernel for noop cores - auto compute_id = CreateKernel( - program, - compute_kernel, - all_active_cores, - ComputeConfig{ - .math_fidelity = math_fidelity, - .compile_args = compute_kernel_args, - .defines = compute_defines}); - - if (total_noop_cores > 0) { - auto compute_id = CreateKernel( - program, - "tt_metal/kernels/compute/blank.cpp", - noop_cores, ComputeConfig{}); - } - - vector reader_ids; - vector writer_ids; - //tt_start_debug_print_server(); - for(uint32_t core_i = 0; core_i < total_num_cores; core_i++) { - uint32_t core_x_i = core_i % num_cores_x; - uint32_t core_y_i = core_i / num_cores_x; - // cout << "core_x_i=" << core_x_i << ", core_y_i=" << core_y_i << endl; - CoreRange core(CoreCoord(core_x_i, core_y_i), CoreCoord(core_x_i, core_y_i)); - bool noop_core = false; - for (const auto & noop_core_range : noop_cores.ranges()) { - if (noop_core_range.contains(core)) { - // cout << "No op core" << endl; - // cout << "core_x_i=" << core_x_i << ", core_y_i=" << core_y_i << endl; - noop_core = true; - break; - } - } - // per core specific args - uint32_t act_slice_i = core_i % (num_cores_y_per_weight_slice_width * num_cores_x); - uint32_t weight_slice_i = core_i / (num_cores_y_per_weight_slice_width * num_cores_x); - uint32_t total_h_start = act_slice_i * per_core_out_matrix_height_ntiles * TILE_HEIGHT; - uint32_t n_start = total_h_start / (conv_output_size_h * conv_output_size_w); - uint32_t matrix_h_start = total_h_start % (conv_output_size_h * conv_output_size_w); - uint32_t out_h_start = matrix_h_start / conv_output_size_w; - uint32_t out_w_start = matrix_h_start % conv_output_size_w; - uint32_t in_h_start = (n_start * conv_act_size_h) + out_h_start * stride_h; - uint32_t last_start_in_h_curr_image = 222 + (n_start * conv_act_size_h); - uint32_t out_start_tile_id = (act_slice_i * per_core_out_matrix_height_ntiles * weight_matrix_width_ntiles) + (weight_slice_i * per_core_out_matrix_width_ntiles); - uint32_t out_start_tile_id_h = act_slice_i * per_core_out_matrix_height_ntiles; - uint32_t out_start_tile_id_w = weight_slice_i * per_core_out_matrix_width_ntiles; - uint32_t bias_tile_offset = weight_slice_i * per_core_out_matrix_width_ntiles; - if (has_bias) { - assert(bias_tile_offset < bias_ntiles); - } - // cout << "act_slice_i=" << act_slice_i << endl; - // cout << "weight_slice_i=" << weight_slice_i << endl; - // cout << "core_i=" << core_i << endl; - // cout << "num_blocks_act_h_per_core=" << num_blocks_act_h_per_core << endl; - // cout << "num_blocks_weight_w_per_core=" << num_blocks_weight_w_per_core << endl; - // cout << "bias_tile_offset=" << bias_tile_offset << endl; - // cout << "out_start_tile_id=" << out_start_tile_id << endl; - // cout << "out_start_tile_id_w=" << out_start_tile_id_w << endl; - // cout << "per_core_out_matrix_height_ntiles=" << per_core_out_matrix_height_ntiles << endl; - // cout << "weight_matrix_width_ntiles=" << weight_matrix_width_ntiles << endl; - // cout << "out_start_tile_id_h=" << out_start_tile_id_h << endl; - // cout << endl; - // cout << "total_h_start=" << total_h_start << endl; - // cout << "in_h_start=" << in_h_start << endl; - // cout << "out_h_start=" << out_h_start << endl; - // cout << "out_w_start=" << out_w_start << endl; - // cout << "matrix_h_start=" << matrix_h_start << endl; - // cout << "n_start=" << n_start << endl; - - if (rn50_first_conv) { - assert(pad_h == 0 && pad_w == 0); - reader_rt_args = { - act_dram_addr, - conv_act_size_c, - conv_output_size_w, - weight_size_w, - num_blocks_act_h_per_core, - num_blocks_act_w, - act_block_h_datums, - act_block_num_tiles, - in_h_start, - out_w_start, - last_start_in_h_curr_image, - (uint32_t) noop_core - }; - } else if (reader_with_indices) { - /* Logic to compute: - * NOTE: This logic is wrong if stride !=1 - * first_partial_right_aligned_row_width - * skip_after_partial_right_aligned_row - * first_partial_image_num_rows - * skip_after_first_partial_image_row - * num_full_images - * skip_after_full_image - * last_partial_image_num_rows - * last_partial_left_aligned_row_width - */ - - // If 2D, same image specs across a row - uint32_t start_stick = weight_width_sliced ? core_x_i * act_block_h_datums : core_i * act_block_h_datums; - uint32_t end_stick = start_stick + act_block_h_datums; - - ShardingConfig sharding_config = get_specs_for_sharding_partition(start_stick, end_stick, conv_act_size_h, conv_act_size_w, weight_size_w, pad_h, pad_w); - uint32_t first_partial_right_aligned_row_width = sharding_config.first_partial_right_aligned_row_width; - uint32_t skip_after_partial_right_aligned_row = sharding_config.skip_after_partial_right_aligned_row; - uint32_t first_partial_image_num_rows = sharding_config.first_partial_image_num_rows; - uint32_t skip_after_first_partial_image_row = sharding_config.skip_after_first_partial_image_row; - uint32_t num_full_images = sharding_config.num_full_images; - uint32_t skip_after_full_image = sharding_config.skip_after_full_image; - uint32_t last_partial_image_num_rows = sharding_config.last_partial_image_num_rows; - uint32_t last_partial_left_aligned_row_width = sharding_config.last_partial_left_aligned_row_width; - - if (weight_width_sliced) { - auto shard_shape = a.shard_spec().value().shape; - uint32_t shard_size_bytes = shard_shape[0] * shard_shape[1] * a.element_size(); - CoreCoord bottom_core = {(std::size_t) core_x_i, (std::size_t) num_cores_y - 1}; - auto bottom_core_physical = device->worker_core_from_logical_core(bottom_core); - - bool reader_is_noc_0 = reader_noc == NOC::NOC_0; - uint32_t act_mcast_dest_noc_start_x = bottom_core_physical.x; - uint32_t act_mcast_dest_noc_start_y = reader_is_noc_0 ? top_left_core_physical.y : bottom_core_physical.y; - uint32_t act_mcast_dest_noc_end_x = bottom_core_physical.x; - uint32_t act_mcast_dest_noc_end_y = reader_is_noc_0 ? bottom_core_physical.y : top_left_core_physical.y; - reader_rt_args = { - // arguments for act - act_dram_addr, - act_noc_x, - act_noc_y, - - conv_act_size_w, - conv_act_size_h, - conv_act_size_c, - weight_size_h, - weight_size_w, - stride_h, - stride_w, - pad_h, - pad_w, - conv_output_size_h, - conv_output_size_w, - num_blocks_act_h_per_core, // per core - num_blocks_act_w, - num_blocks_weight_w_per_core, - num_groups, - - act_matrix_height_unpadded, - act_matrix_width_unpadded, - act_matrix_height, - act_matrix_width, - act_matrix_height_ntiles, - act_matrix_width_ntiles, - act_block_h_datums, - act_block_w_datums, - act_block_h_ntiles, - act_block_w_ntiles, - in0_block_num_tiles, - conv_act_c_blocks, - - src_dram_act_buffer_size_bytes, - dst_l1_act_buffer_size_bytes, - - n_start, - out_h_start, - out_w_start, - total_h_start, - - // Specs for reader indices - first_partial_right_aligned_row_width, - skip_after_partial_right_aligned_row, - first_partial_image_num_rows, - skip_after_first_partial_image_row, - num_full_images, - skip_after_full_image, - last_partial_image_num_rows, - last_partial_left_aligned_row_width, - - // Specs for reader offsets - 1, // window_outer - 3, // window_inner = 9 / 3, ie. read 3 width coalesced - - (uint32_t) noop_core, - - // mcast args - act_mcast_dest_noc_start_x, - act_mcast_dest_noc_start_y, - act_mcast_dest_noc_end_x, - act_mcast_dest_noc_end_y, - num_cores_y - 1, - num_cores_y - 1, - act_mcast_sender_semaphore_id, - act_mcast_receiver_semaphore_id, - shard_size_bytes, - core_y_i, // act_mcast_sender_id (goes down the column) - (uint32_t) bottom_core_physical.x, // act_mcast_sender_noc_x - }; - reader_rt_args.insert(reader_rt_args.end(), act_mcast_noc_y.begin(), act_mcast_noc_y.end()); // act_mcast_sender_noc_y - } else { - reader_rt_args = { - // arguments for act - act_dram_addr, - act_noc_x, - act_noc_y, - - conv_act_size_w, - conv_act_size_h, - conv_act_size_c, - weight_size_h, - weight_size_w, - stride_h, - stride_w, - pad_h, - pad_w, - conv_output_size_h, - conv_output_size_w, - num_blocks_act_h_per_core, // per core - num_blocks_act_w, - num_blocks_weight_w_per_core, - num_groups, - - act_matrix_height_unpadded, - act_matrix_width_unpadded, - act_matrix_height, - act_matrix_width, - act_matrix_height_ntiles, - act_matrix_width_ntiles, - act_block_h_datums, - act_block_w_datums, - act_block_h_ntiles, - act_block_w_ntiles, - act_block_num_tiles / conv_act_c_blocks, - conv_act_c_blocks, - - src_dram_act_buffer_size_bytes, - dst_l1_act_buffer_size_bytes, - - n_start, - out_h_start, - out_w_start, - total_h_start, - - // Specs for reader indices - first_partial_right_aligned_row_width, - skip_after_partial_right_aligned_row, - first_partial_image_num_rows, - skip_after_first_partial_image_row, - num_full_images, - skip_after_full_image, - last_partial_image_num_rows, - last_partial_left_aligned_row_width, - - // Specs for reader offsets - num_blocks_act_w, // window_outer - weight_size_h * weight_size_w / num_blocks_act_w, // window_inner - - (uint32_t) noop_core - }; - } - } else { - reader_rt_args = { - // arguments for act - act_dram_addr, - act_noc_x, - act_noc_y, - - conv_act_size_w, - conv_act_size_h, - conv_act_size_c, - weight_size_h, - weight_size_w, - stride_h, - stride_w, - pad_h, - pad_w, - conv_output_size_h, - conv_output_size_w, - num_blocks_act_h_per_core, // per core - num_blocks_act_w, - num_blocks_weight_w_per_core, - num_groups, - - act_matrix_height_unpadded, - act_matrix_width_unpadded, - act_matrix_height, - act_matrix_width, - act_matrix_height_ntiles, - act_matrix_width_ntiles, - act_block_h_datums, - act_block_w_datums, - act_block_h_ntiles, - act_block_w_ntiles, - act_block_num_tiles / conv_act_c_blocks, - conv_act_c_blocks, - - src_dram_act_buffer_size_bytes, - dst_l1_act_buffer_size_bytes, - - n_start, - out_h_start, - out_w_start, - total_h_start, - - (uint32_t) noop_core - }; - } - - SetRuntimeArgs( - program, reader_id, core, - reader_rt_args - ); - reader_ids.push_back(reader_id); - - writer_rt_args = { - out_dram_addr, - weight_dram_addr, - bias_dram_addr, - - output_width_num_tiles, // out_next_tile_stride_h - 1, // out_next_tile_stride_w - out_subblock_h_ntiles * output_width_num_tiles, // out_next_subblock_stride_h - out_subblock_w_ntiles, // out_next_subblock_stride_w - act_block_h_ntiles * output_width_num_tiles, // out_next_block_stride_h - weight_block_w_ntiles, // out_next_block_stride_w - out_subblock_h_ntiles, - out_subblock_w_ntiles, - out_subblock_num_tiles, - act_block_h_ntiles / out_subblock_h_ntiles, // out_num_subblocks_h - weight_block_w_ntiles / out_subblock_w_ntiles, // out_num_subblocks_w - num_blocks_act_h_per_core, // out_num_blocks_h - num_blocks_weight_w_per_core, // out_num_blocks_w - act_block_h_ntiles, // out_block_height_num_tiles - output_height_num_tiles, // out_height_num_tiles without block shape padding - output_width_num_tiles, // out_width_num_tiles withoug block shape padding - out_start_tile_id, - out_start_tile_id_h, - out_start_tile_id_w, - - num_blocks_act_w, // = number of blocks of weight in height dim - in1_block_num_tiles, - conv_act_c_blocks, - weight_block_h_ntiles / conv_act_c_blocks, - weight_block_w_ntiles, - weight_matrix_width_ntiles, // weight_stride_h - weight_matrix_width_ntiles * weight_block_h_ntiles, // weight_next_block_stride_h, - weight_block_w_ntiles, // weight_next_block_stride_w - - // bias - bias_ntiles_per_core, - bias_tile_offset, - - (uint32_t) noop_core - }; - - // Mcast sender - // 2D mcast - if (weight_width_sliced) { - CoreCoord right_core = {(std::size_t) num_cores_x - 1, (std::size_t) core_y_i}; - auto right_core_physical = device->worker_core_from_logical_core(right_core); - // sender - if (core_x_i == 0) { - if (writer_mcast_noc == NOC::NOC_0) { - writer_rt_args.push_back(top_left_core_plus_one_physical.x); // weights_mcast_dest_noc_start_x - writer_rt_args.push_back(right_core_physical.y); // weights_mcast_dest_noc_start_y - writer_rt_args.push_back(bottom_right_core_physical.x); // weights_mcast_dest_noc_end_x - writer_rt_args.push_back(right_core_physical.y); // weights_mcast_dest_noc_end_y - } else { - writer_rt_args.push_back(bottom_right_core_physical.x); // weights_mcast_dest_noc_start_x - writer_rt_args.push_back(right_core_physical.y); // weights_mcast_dest_noc_start_y - writer_rt_args.push_back(top_left_core_plus_one_physical.x); // weights_mcast_dest_noc_end_x - writer_rt_args.push_back(right_core_physical.y); // weights_mcast_dest_noc_end_y - } - - writer_rt_args.push_back(num_cores_x - 1); // weights_mcast_num_dests - writer_rt_args.push_back(num_cores_x - 1); // weights_mcast_num_cores - writer_rt_args.push_back(weights_mcast_sender_semaphore_id); - writer_rt_args.push_back(weights_mcast_receiver_semaphore_id); - - SetRuntimeArgs( - program, writer_mcast_sender_id, core, - writer_rt_args - ); - writer_ids.push_back(writer_mcast_sender_id); - // receiver - } else { - writer_rt_args.push_back(top_left_core_physical.x); // weights_mcast_sender_noc_x - writer_rt_args.push_back(right_core_physical.y); // weights_mcast_sender_noc_y - writer_rt_args.push_back(weights_mcast_sender_semaphore_id); - writer_rt_args.push_back(weights_mcast_receiver_semaphore_id); - - SetRuntimeArgs( - program, writer_mcast_receiver_id, core, - writer_rt_args - ); - writer_ids.push_back(writer_mcast_receiver_id); - } - // 1D mcast - } else { - // sender - if (core_x_i == 0 and core_y_i == 0) { - if (writer_mcast_noc == NOC::NOC_0) { - writer_rt_args.push_back(top_left_core_physical.x); // weights_mcast_dest_noc_start_x - writer_rt_args.push_back(top_left_core_physical.y); // weights_mcast_dest_noc_start_y - writer_rt_args.push_back(bottom_right_core_physical.x); // weights_mcast_dest_noc_end_x - writer_rt_args.push_back(bottom_right_core_physical.y); // weights_mcast_dest_noc_end_y - } else { - writer_rt_args.push_back(bottom_right_core_physical.x); // weights_mcast_dest_noc_start_x - writer_rt_args.push_back(bottom_right_core_physical.y); // weights_mcast_dest_noc_start_y - writer_rt_args.push_back(top_left_core_physical.x); // weights_mcast_dest_noc_end_x - writer_rt_args.push_back(top_left_core_physical.y); // weights_mcast_dest_noc_end_y - } - writer_rt_args.push_back(total_active_num_cores - 1); // weights_mcast_num_dests - writer_rt_args.push_back(total_num_cores - 1); // weights_mcast_num_cores - writer_rt_args.push_back(weights_mcast_sender_semaphore_id); - writer_rt_args.push_back(weights_mcast_receiver_semaphore_id); - - SetRuntimeArgs( - program, writer_mcast_sender_id, core, - writer_rt_args - ); - writer_ids.push_back(writer_mcast_sender_id); - // receiver - } else { - writer_rt_args.push_back(top_left_core_physical.x); // weights_mcast_sender_noc_x - writer_rt_args.push_back(top_left_core_physical.y); // weights_mcast_sender_noc_y - writer_rt_args.push_back(weights_mcast_sender_semaphore_id); - writer_rt_args.push_back(weights_mcast_receiver_semaphore_id); - - SetRuntimeArgs( - program, writer_mcast_receiver_id, core, - writer_rt_args - ); - writer_ids.push_back(writer_mcast_receiver_id); - } - } - - } // for num_cores - - auto override_runtime_arguments_callback = [ - reader_kernel_ids=reader_ids, - writer_kernel_ids=writer_ids, - cb_sharded_act=cb_sharded_act, - cb_output=cb_output, - total_num_cores=total_num_cores, - num_cores_x=num_cores_x, - num_cores_y=num_cores_y, - has_bias=has_bias - ] - ( - const void* operation, - Program& program, - const std::vector& input_tensors, - const std::vector>& optional_input_tensors, - const std::vector& output_tensors - ) { - - TT_ASSERT(input_tensors.size() + optional_input_tensors.size() == 4); - TT_ASSERT(output_tensors.size() == 1); - - auto src_buffer_a = input_tensors.at(0).buffer(); - auto src_buffer_b = input_tensors.at(1).buffer(); - auto src_a_is_sharded = input_tensors.at(0).memory_config().is_sharded(); - - auto dst_buffer = output_tensors.at(0).buffer(); - bool out_sharded = output_tensors.at(0).memory_config().is_sharded(); - - for(uint32_t core_i = 0; core_i < total_num_cores; core_i++) { - uint32_t core_x_i = core_i % num_cores_x; - uint32_t core_y_i = core_i / num_cores_x; - CoreCoord core = {core_x_i, core_y_i}; - - if (!src_a_is_sharded) { - auto &runtime_args = GetRuntimeArgs(program, reader_kernel_ids[core_i], core); - runtime_args[0] = src_buffer_a->address(); - } - - { - auto &runtime_args = GetRuntimeArgs(program, writer_kernel_ids[core_i], core); - runtime_args[0] = dst_buffer->address(); - runtime_args[1] = src_buffer_b->address(); - if (has_bias) { - auto src_buffer_c = optional_input_tensors.at(0).value().buffer(); - TT_ASSERT(src_buffer_c != nullptr); - runtime_args[2] = src_buffer_c->address(); - } - } - } - - if (src_a_is_sharded) { - UpdateDynamicCircularBufferAddress(program, cb_sharded_act, *src_buffer_a); - } - - if (out_sharded) { - UpdateDynamicCircularBufferAddress(program, cb_output, *dst_buffer); - } - }; - return {.program=std::move(program), .override_runtime_arguments_callback=override_runtime_arguments_callback}; -} - -} // namespace tt_metal - -} // namespace tt diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/multi_core_optimized_conv_sharded/optimized_conv_op_sharded.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/multi_core_optimized_conv_sharded/optimized_conv_op_sharded.cpp deleted file mode 100644 index a0bcabe8876..00000000000 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/multi_core_optimized_conv_sharded/optimized_conv_op_sharded.cpp +++ /dev/null @@ -1,1184 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include "ttnn/operations/conv/conv2d/device/optimized_conv_op.hpp" -#include "tt_metal/host_api.hpp" -#include "tt_metal/detail/tt_metal.hpp" -#include "tt_metal/detail/util.hpp" -#include "tt_metal/common/constants.hpp" - -#include "tt_stl/reflection.hpp" - -#include "tt_metal/common/work_split.hpp" -#include "ttnn/deprecated/tt_dnn/op_library/sharding_utilities.hpp" -#include "ttnn/operations/experimental/auto_format/auto_format.hpp" - -#include "ttnn/tensor/tensor_utils.hpp" - -using namespace tt::constants; - -namespace ttnn::operations::conv { -namespace conv2d { - -using namespace tt; - -const uint32_t act_cb = CB::c_in0; -const uint32_t weight_cb = CB::c_in1; -const uint32_t bias_cb = CB::c_in2; -const uint32_t sharded_act_cb = CB::c_in3; -const uint32_t cb_for_reader_indices = CB::c_in4; -const uint32_t cb_for_reader_offsets = CB::c_in5; -const uint32_t act_cb_row_major_bfloat16 = CB::c_in6; -const uint32_t matmul_partials_cb = CB::c_intermed0; -const uint32_t tilize_mode_tilized_act_cb = CB::c_intermed1; -const uint32_t untilize_mode_reblock_cb = CB::c_intermed2; -const uint32_t out0_cb = CB::c_out0; - - -std::tuple create_CBs_for_sharded_input( - tt_metal::Program &program, - const Tensor& input, - CoreRange core, - uint32_t num_cb0_tiles, - uint32_t num_cb1_tiles, - uint32_t num_cb0_tilized_tiles, - uint32_t num_output_tiles, - uint32_t num_reblock_cb_tiles, - uint32_t num_writer_output_tiles, - bool untilize_out, - tt::DataFormat act_df, - tt::DataFormat weight_df, - tt::DataFormat tilized_act_df, - tt::DataFormat out_df, - tt::DataFormat bias_df, - bool weight_width_sliced, - const Tensor& output, - uint32_t bias_ntiles = 0, - bool with_bias = false -) { - - uint32_t act_tile_size = tt_metal::detail::TileSize(act_df); - uint32_t weight_tile_size = tt_metal::detail::TileSize(weight_df); - uint32_t tilized_act_tile_size = tt_metal::detail::TileSize(tilized_act_df); - uint32_t out_tile_size = tt_metal::detail::TileSize(out_df); - - CBHandle cb_sharded_act = 0; - if (input.memory_config().is_sharded()) { - uint32_t num_bytes_for_df = datum_size(act_df); - auto shard_shape = input.shard_spec().value().shape; - // 2D-sys-conv already has uint16_t indicies, TODO: do the same for 1D-sys-conv - TT_ASSERT(shard_shape[0] <= (1<<16), "Shard height must be less than 2^16, read pattern indicies are uint16_t"); - CircularBufferConfig cb_sharded_act_config = CircularBufferConfig(shard_shape[0] * shard_shape[1] * num_bytes_for_df, {{sharded_act_cb, act_df}}) - .set_page_size(sharded_act_cb, shard_shape[1] * num_bytes_for_df); - // incoming data is the input cb instead of raw l1/dram addr - cb_sharded_act_config.set_globally_allocated_address(*input.buffer()); - cb_sharded_act = tt_metal::CreateCircularBuffer(program, core, cb_sharded_act_config); - - if (weight_width_sliced) { - // For 2D convs, each core creates and tilizes full input matrix then mcasts round robin style - // Each core receives input into act_cb, so won't need a separate cb to receive - // However, we need a separate cb to push ROW_MAJOR BFLOAT16 data for tilizing and configure act cb to be output df - - // num_cb0_tiles is double buffered - CircularBufferConfig cb_act_config = CircularBufferConfig(num_cb0_tiles * tilized_act_tile_size, {{act_cb, tilized_act_df}}) - .set_page_size(act_cb, tilized_act_tile_size); - auto cb_act = tt_metal::CreateCircularBuffer(program, core, cb_act_config); - - // num_cb0_tilized_tiles is single buffered - CircularBufferConfig cb_act_row_major_bfloat16_config = CircularBufferConfig(num_cb0_tilized_tiles * act_tile_size, {{act_cb_row_major_bfloat16, act_df}}) - .set_page_size(act_cb_row_major_bfloat16, act_tile_size); - auto cb_act_row_major_bfloat16 = tt_metal::CreateCircularBuffer(program, core, cb_act_row_major_bfloat16_config); - } else { - // For 1D convs, locally create act matrix in act_cb, which is always ROW_MAJOR BFLOAT16 - // Then, tilize input in compute - CircularBufferConfig cb_act_config = CircularBufferConfig(num_cb0_tiles * act_tile_size, {{act_cb, act_df}}) - .set_page_size(act_cb, act_tile_size); - auto cb_act = tt_metal::CreateCircularBuffer(program, core, cb_act_config); - } - } else { - TT_ASSERT(false, "Input must be sharded!"); - } - - - CircularBufferConfig cb_weight_config = CircularBufferConfig(num_cb1_tiles * weight_tile_size, {{weight_cb, weight_df}}) - .set_page_size(weight_cb, weight_tile_size); - auto cb_weight = tt_metal::CreateCircularBuffer(program, core, cb_weight_config); - - // Used for placing tilized activations - CircularBufferConfig cb_src0_tilized_config = CircularBufferConfig(num_cb0_tilized_tiles * tilized_act_tile_size, {{tilize_mode_tilized_act_cb, tilized_act_df}}) - .set_page_size(tilize_mode_tilized_act_cb, tilized_act_tile_size); - auto cb_src0_tilized = tt_metal::CreateCircularBuffer(program, core, cb_src0_tilized_config); - - CBHandle cb_output = 0; - if (untilize_out) { - CircularBufferConfig cb_matmul_partials_config = CircularBufferConfig(num_output_tiles * out_tile_size, {{matmul_partials_cb, out_df}}) - .set_page_size(matmul_partials_cb, out_tile_size); - auto cb_matmul_partials = tt_metal::CreateCircularBuffer(program, core, cb_matmul_partials_config); - - // Supposed to be a small CB only responsible for reorganizing - // the output blocks to fill the whole "per core output block width" - CircularBufferConfig cb_reblock_config = CircularBufferConfig(num_reblock_cb_tiles * out_tile_size, {{untilize_mode_reblock_cb, out_df}}) - .set_page_size(untilize_mode_reblock_cb, out_tile_size); - auto cb_reblock = tt_metal::CreateCircularBuffer(program, core, cb_reblock_config); - - CircularBufferConfig cb_output_config = CircularBufferConfig(num_writer_output_tiles * out_tile_size, {{out0_cb, out_df}}) - .set_page_size(out0_cb, out_tile_size); - if (output.is_sharded()) { - cb_output_config = cb_output_config.set_globally_allocated_address(*output.buffer()); - } - cb_output = tt_metal::CreateCircularBuffer(program, core, cb_output_config); - } else { - CoreRangeSet cores(std::set({core})); - std::map cb_output_data_format_spec = { - {out0_cb, out_df}, - {matmul_partials_cb, out_df} - }; - CircularBufferConfig cb_matmul_partials_config = CircularBufferConfig(num_output_tiles * out_tile_size, cb_output_data_format_spec) - .set_page_size(out0_cb, out_tile_size) - .set_page_size(matmul_partials_cb, out_tile_size); - if (output.is_sharded()) { - cb_matmul_partials_config = cb_matmul_partials_config.set_globally_allocated_address(*output.buffer()); - } - cb_output = tt_metal::CreateCircularBuffer(program, cores, cb_matmul_partials_config); - } - - if (with_bias) { - uint32_t bias_tile_size = tt_metal::detail::TileSize(bias_df); - // bias input - uint32_t bias_pagesize = bias_tile_size; - CircularBufferConfig cb_bias_config = CircularBufferConfig(bias_ntiles * bias_pagesize, {{bias_cb, bias_df}}) - .set_page_size(bias_cb, bias_pagesize); - auto cb_bias = tt_metal::CreateCircularBuffer(program, core, cb_bias_config); - - log_debug("BIAS CBs: {} {} {}", bias_cb, bias_ntiles, bias_pagesize); - } - - return {cb_sharded_act, cb_output}; -} - -operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_(const Tensor& a, const Tensor &b, const Shape& ashape, std::optional bias, vector conv_params, uint32_t output_channels, bool untilize_out, bool has_bias, bool fuse_relu, const MathFidelity math_fidelity, const OptimizedConvParallelizationConfig& parallelization_config, const OptimizedConvBlockConfig& block_config, uint32_t extra_padding_for_32B_alignment, Tensor &output) { - bool pass = true; - tt_metal::Device *device = a.device(); - TT_ASSERT(a.get_layout() == Layout::ROW_MAJOR, "Conv activation should be in row major layout"); - TT_ASSERT(output_channels <= b.get_legacy_shape()[3], "Invalid weight shape. Incorrect weight tensor."); - uint32_t act_block_h_ntiles = block_config.act_block_h_ntiles; - uint32_t act_block_w_ntiles = block_config.act_block_w_ntiles; - uint32_t weight_block_w_ntiles = parallelization_config.per_core_out_matrix_width_ntiles; - uint32_t out_block_h_ntiles = parallelization_config.per_core_out_matrix_height_ntiles; - uint32_t out_subblock_h_ntiles = block_config.out_subblock_h_ntiles; - uint32_t out_subblock_w_ntiles = block_config.out_subblock_w_ntiles; - //assert(out_block_h_ntiles == act_block_h_ntiles); // TODO: fix output block sizing - TT_ASSERT(out_block_h_ntiles >= act_block_h_ntiles, "Output block height (in # of tiles) should be greater than or equal to activation block height (in # of tiles)"); - - // Partitions conv inner dim into blocks to support sharding along this dim - // TODO: Only 2D convs with sharded input use this, but we can uplift to support generically - // TODO: Only updated variables which is affected, but there may be more that needs to account for this - // TODO: Loop naming in reader, writer, and compute kernels could also be cleaned up - // TODO: Can conv_act_c_blocks be same as num_blocks_act_w? - - uint32_t conv_act_size_h = ashape[1]; - uint32_t conv_act_size_w = ashape[2]; - uint32_t conv_act_size_c = ashape[3]; - uint32_t weight_size_h = (uint32_t) conv_params[0]; - uint32_t weight_size_w = (uint32_t) conv_params[1]; - uint32_t stride_h = (uint32_t) conv_params[2]; - uint32_t stride_w = (uint32_t) conv_params[3]; - uint32_t pad_h = (uint32_t) conv_params[4]; - uint32_t pad_w = (uint32_t) conv_params[5]; - - bool rn50_first_conv = (conv_act_size_h == 230 && conv_act_size_w == (231 + extra_padding_for_32B_alignment) && - weight_size_h == 7 && weight_size_w == 8 && - stride_h == 2 && stride_w == 2); - // Compute the 2d matrix shape - auto [act_matrix_shape, act_matrix_shape_unpadded] = optimized_conv_op_utils::compute_opt_conv_activation_as_mm_shape(ashape.value, conv_params, out_block_h_ntiles, extra_padding_for_32B_alignment); - assert(act_matrix_shape.size() == 3); - assert(act_matrix_shape[0] == 1); - uint32_t act_matrix_height = (uint32_t) act_matrix_shape[1]; - uint32_t act_matrix_width = (uint32_t) act_matrix_shape[2]; - uint32_t act_matrix_height_unpadded = (uint32_t) act_matrix_shape_unpadded[1]; - uint32_t act_matrix_width_unpadded = (uint32_t) act_matrix_shape_unpadded[2]; - - // Tensor b has weights and it should be tiled layout after converting conv weights into weight matrix - TT_ASSERT(b.get_layout() == Layout::TILE, "Conv weights should be in tiled layout"); - TT_ASSERT(b.get_legacy_shape()[0] == 1, "Conv weight matrix shape is invalid"); - TT_ASSERT(b.get_legacy_shape()[1] == 1, "Conv weight matrix shape is invalid"); - uint32_t weight_matrix_height = b.get_legacy_shape()[2]; - uint32_t weight_matrix_width = b.get_legacy_shape()[3]; - - if (has_bias) { - // Tensor bias is of shape {output_channels} - TT_ASSERT(bias.has_value()); - TT_ASSERT(bias.value().buffer() != nullptr); - auto bias_shape_without_padding = bias.value().get_legacy_shape().without_padding(); - TT_ASSERT(bias_shape_without_padding[0] == 1, "Bias should have batch == 1"); - // TT_ASSERT(bias_shape_without_padding[1] == 1 && bias_shape_without_padding[2] == 1, "Bias should have H == W == 1"); - TT_ASSERT(bias_shape_without_padding[3] == output_channels, "Bias should have output_channels"); - } - - // Normal matrix shape check - TT_ASSERT(act_matrix_width == weight_matrix_height, "The width of tensor a needs to match the height of tensor b"); - - // Tile size divisibility checks - TT_ASSERT(act_matrix_height % TILE_HEIGHT == 0, "Height of activation matrix needs to be divisible by 32"); - TT_ASSERT(act_matrix_width % TILE_WIDTH == 0, "Width of activation matrix needs to be divisible by 32"); - TT_ASSERT(weight_matrix_height % TILE_HEIGHT == 0, "Height of weight matrix needs to be divisible by 32"); - TT_ASSERT(weight_matrix_width % TILE_WIDTH == 0, "Width of weight matrix needs to be divisible by 32"); - - // Device compatibility checks - TT_ASSERT(a.storage_type() == StorageType::DEVICE && - b.storage_type() == StorageType::DEVICE && - "Operands to large matmul need to be on device!"); - TT_ASSERT(a.device() == b.device(), "Operands to conv need to be on the same device!"); - TT_ASSERT(a.buffer() != nullptr && b.buffer() != nullptr, "Operands to conv need to be allocated in buffers on device!"); - if (has_bias) { - TT_ASSERT(bias.value().storage_type() == StorageType::DEVICE, "Bias should be on device"); - TT_ASSERT(bias.value().device() == a.device(), "Bias should be on the same device as act tensor"); - } - - // Convert tensor dims to tile dims - uint32_t act_matrix_height_ntiles = act_matrix_height / TILE_HEIGHT; - uint32_t act_matrix_width_ntiles = act_matrix_width / TILE_WIDTH; - uint32_t weight_matrix_height_ntiles = weight_matrix_height / TILE_HEIGHT; - uint32_t weight_matrix_width_ntiles = weight_matrix_width / TILE_WIDTH; - - uint32_t conv_act_c_blocks = weight_matrix_width_ntiles / parallelization_config.per_core_out_matrix_width_ntiles; - - assert(act_matrix_height_ntiles % act_block_h_ntiles == 0); - assert(act_matrix_width_ntiles % act_block_w_ntiles == 0); - assert(weight_matrix_width_ntiles % weight_block_w_ntiles == 0); - assert(act_matrix_height_ntiles % out_block_h_ntiles == 0); - - uint32_t num_blocks_act_h = act_matrix_height_ntiles / act_block_h_ntiles; - uint32_t num_blocks_out_h = act_matrix_height_ntiles / out_block_h_ntiles; - uint32_t num_blocks_act_w = act_matrix_width_ntiles / act_block_w_ntiles; - uint32_t num_blocks_weight_w = weight_matrix_width_ntiles / weight_block_w_ntiles; - - if (rn50_first_conv) { - assert(num_blocks_weight_w == 1); - } - - // act block info - uint32_t act_block_w_datums = act_matrix_width / num_blocks_act_w; - uint32_t act_block_h_datums = act_matrix_height / num_blocks_act_h; - TT_ASSERT((act_block_w_datums == conv_act_size_c * weight_size_w) || ((act_block_w_datums <= conv_act_size_c) && (conv_act_size_c % act_block_w_datums == 0))); - - - // weight block info - uint32_t weight_block_w_datums = weight_matrix_width / num_blocks_weight_w; - assert(weight_block_w_ntiles % out_subblock_w_ntiles == 0); - uint32_t weight_num_subblocks = weight_block_w_ntiles / out_subblock_w_ntiles; - uint32_t weight_block_h_ntiles = act_block_w_ntiles; - uint32_t weight_block_num_tiles = weight_block_w_ntiles * weight_block_h_ntiles; - - uint32_t num_groups = num_blocks_act_h * num_blocks_act_w * num_blocks_weight_w; - // writer of conv op partially removes padding on the width - // it removes the padding done for block width but it doesn't remove padding done for tiled width - uint32_t output_channels_padded_to_tile_width = round_up(output_channels, TILE_WIDTH); - assert(output_channels_padded_to_tile_width <= weight_matrix_width); - uint32_t output_width_num_tiles = output_channels_padded_to_tile_width / TILE_WIDTH; - uint32_t num_blocks_output_w = (uint32_t) std::ceil((double) output_channels_padded_to_tile_width / (double) weight_block_w_datums); - uint32_t last_block_width_datums = (output_channels_padded_to_tile_width % weight_block_w_datums == 0) ? weight_block_w_datums : (output_channels_padded_to_tile_width % weight_block_w_datums); - assert(last_block_width_datums % TILE_WIDTH == 0); - - // sanity check - assert(num_blocks_output_w == num_blocks_weight_w); - - uint32_t out_block_h_datums = out_block_h_ntiles * TILE_HEIGHT; - - tt_metal::Program program = tt_metal::CreateProgram(); - //CoreCoord core_coord = {0, 0}; // TODO: avoid another var here. Find a way to use core range instead. - //CoreRange core = {{0, 0}, {0, 0}}; - - tt::DataFormat act_df = tt_metal::datatype_to_dataformat_converter(a.get_dtype()); - tt::DataFormat weight_df = tt_metal::datatype_to_dataformat_converter(b.get_dtype()); - tt::DataFormat out_df = tt_metal::datatype_to_dataformat_converter(output.get_dtype()); - tt::DataFormat bias_df = has_bias ? tt_metal::datatype_to_dataformat_converter(bias.value().get_dtype()) : tt::DataFormat::Float16_b; - tt::DataFormat tilized_act_df = out_df; - - tt_metal::Buffer *src0_dram_buffer = a.buffer(); - tt_metal::Buffer *src1_dram_buffer = b.buffer(); - - tt_metal::Buffer *dst_dram_buffer = output.buffer(); - TT_ASSERT(dst_dram_buffer != nullptr, "Output buffer should be allocated on device!"); - - // out - uint32_t out_dram_addr = dst_dram_buffer->address(); - uint32_t out_subblock_num_tiles = out_subblock_h_ntiles * out_subblock_w_ntiles; - TT_ASSERT(out_subblock_num_tiles <= 8, "Need to ensure that matmul partials fit in dst"); - - // act - uint32_t act_dram_addr = src0_dram_buffer->address(); - auto act_dram_noc_xy = src0_dram_buffer->noc_coordinates(); - uint32_t act_noc_x = act_dram_noc_xy.x; - uint32_t act_noc_y = act_dram_noc_xy.y; - - assert(act_matrix_width_ntiles % act_block_w_ntiles == 0); - assert(act_block_h_ntiles % out_subblock_h_ntiles == 0); - assert(out_block_h_ntiles % out_subblock_h_ntiles == 0); - uint32_t act_num_subblocks = act_block_h_ntiles / out_subblock_h_ntiles; - uint32_t act_block_num_tiles = act_block_h_ntiles * act_block_w_ntiles; - uint32_t act_subblock_h_ntiles = out_subblock_h_ntiles; - uint32_t act_subblock_num_tiles = act_subblock_h_ntiles * act_block_w_ntiles; - - // weight - uint32_t weight_dram_addr = src1_dram_buffer->address(); - auto weight_dram_noc_xy = src1_dram_buffer->noc_coordinates(); - uint32_t weight_noc_x = weight_dram_noc_xy.x; - uint32_t weight_noc_y = weight_dram_noc_xy.y; - - // bias - tt_metal::Buffer *bias_buffer = nullptr; - uint32_t bias_dram_addr = 0; - uint32_t bias_ntiles = 0; - if (has_bias) { - bias_buffer = bias.value().buffer(); - bias_dram_addr = bias_buffer->address(); - bias_ntiles = bias.value().get_legacy_shape()[3] / constants::TILE_WIDTH; // TODO: support non tile multiple sizes - } - - //uint32_t conv_output_size_h = ((conv_act_size_h - weight_size_h + (2 * pad_h)) / stride_h) + 1; - //uint32_t conv_output_size_w = ((conv_act_size_w - weight_size_w + (2 * pad_w)) / stride_w) + 1; - - auto [conv_output_size_h, conv_output_size_w] = optimized_conv_op_utils::compute_opt_conv_output_face_shape(conv_act_size_h, conv_act_size_w, weight_size_h, weight_size_w, stride_h, stride_w, pad_h, pad_w, extra_padding_for_32B_alignment); - - std::map reader_defines; - - if (act_matrix_height_unpadded < act_matrix_height) { - reader_defines["ACT_BLOCK_HEIGHT_PADDING"] = "1"; - } - - if (conv_act_c_blocks > 1) { - reader_defines["ACT_W_OUTER_BLOCKS"] = "1"; - } - - uint32_t output_height_padded_to_tile_height = round_up(act_matrix_height_unpadded, TILE_HEIGHT); - uint32_t output_height_num_tiles = output_height_padded_to_tile_height / TILE_HEIGHT; - assert(output_height_num_tiles <= act_matrix_height_ntiles); - - uint32_t src_dram_act_buffer_size_bytes = src0_dram_buffer->size(); - uint32_t src_dram_weight_buffer_size_bytes = src1_dram_buffer->size(); - uint32_t dst_l1_act_buffer_size_bytes = out_block_h_ntiles * act_block_w_ntiles * tt::tt_metal::detail::TileSize(act_df); - uint32_t dst_l1_weight_buffer_size_bytes = weight_block_h_ntiles * weight_block_w_ntiles * tt::tt_metal::detail::TileSize(weight_df); - - - // For debug - { - log_debug(tt::LogOp, "conv_act_size_c: {}", conv_act_size_c); - log_debug(tt::LogOp, "conv_act_size_h: {}", conv_act_size_h); - log_debug(tt::LogOp, "conv_act_size_w: {}", conv_act_size_w); - log_debug(tt::LogOp, "act_matrix_height: {}", act_matrix_height); - log_debug(tt::LogOp, "act_matrix_width: {}", act_matrix_width); - log_debug(tt::LogOp, "act_matrix_height_unpadded: {}", act_matrix_height_unpadded); - log_debug(tt::LogOp, "act_matrix_width_unpadded: {}", act_matrix_width_unpadded); - log_debug(tt::LogOp, "act_matrix_height_ntiles: {}", act_matrix_height_ntiles); - log_debug(tt::LogOp, "act_matrix_width_ntiles: {}", act_matrix_width_ntiles); - log_debug(tt::LogOp, "weight_matrix_width_ntiles: {}", weight_matrix_width_ntiles); - log_debug(tt::LogOp, "num_blocks_act_h: {}", num_blocks_act_h); - log_debug(tt::LogOp, "num_blocks_act_w: {}", num_blocks_act_w); - log_debug(tt::LogOp, "num_blocks_weight_w: {}", num_blocks_weight_w); - log_debug(tt::LogOp, "num_blocks_out_h: {}", num_blocks_out_h); - log_debug(tt::LogOp, "act_dram_addr: {}", act_dram_addr); - log_debug(tt::LogOp, "act_block_h_ntiles: {}", act_block_h_ntiles); - log_debug(tt::LogOp, "act_block_h_datums: {}", act_block_h_datums); - log_debug(tt::LogOp, "act_block_w_ntiles: {}", act_block_w_ntiles); - log_debug(tt::LogOp, "act_block_w_datums: {}", act_block_w_datums); - log_debug(tt::LogOp, "out_block_h_ntiles: {}", out_block_h_ntiles); - log_debug(tt::LogOp, "act_num_subblocks: {}", act_num_subblocks); - log_debug(tt::LogOp, "act_block_num_tiles: {}", act_block_num_tiles); - log_debug(tt::LogOp, "act_subblock_h_ntiles: {}", act_subblock_h_ntiles); - log_debug(tt::LogOp, "act_subblock_num_tiles: {}", act_subblock_num_tiles); - log_debug(tt::LogOp, "out_subblock_num_tiles: {}", out_subblock_num_tiles); - log_debug(tt::LogOp, "weight_dram_addr: {}", weight_dram_addr); - log_debug(tt::LogOp, "weight_num_subblocks: {}", weight_num_subblocks); - log_debug(tt::LogOp, "weight_block_num_tiles: {}", weight_block_num_tiles); - log_debug(tt::LogOp, "weight_block_w_ntiles: {}", weight_block_w_ntiles); - log_debug(tt::LogOp, "weight_block_h_ntiles: {}", weight_block_h_ntiles); - log_debug(tt::LogOp, "has_bias: {}", has_bias); - log_debug(tt::LogOp, "bias_dram_addr: {}", bias_dram_addr); - log_debug(tt::LogOp, "bias_ntiles: {}", bias_ntiles); - log_debug(tt::LogOp, "out_dram_addr: {}", out_dram_addr); - log_debug(tt::LogOp, "out_subblock_h_ntiles: {}", out_subblock_h_ntiles); - log_debug(tt::LogOp, "out_subblock_w_ntiles: {}", out_subblock_w_ntiles); - log_debug(tt::LogOp, "out_subblock_num_tiles: {}", out_subblock_num_tiles); - log_debug(tt::LogOp, "num_groups: {}", num_groups); - } - // parallelization config - const auto& p_config = parallelization_config; - uint32_t num_cores_x = p_config.grid_size.x; - uint32_t num_cores_y = p_config.grid_size.y; - uint32_t total_num_cores = num_cores_x * num_cores_y; - assert(num_cores_x < 13); - assert(num_cores_y < 10); - uint32_t per_core_out_matrix_height_ntiles = p_config.per_core_out_matrix_height_ntiles; - uint32_t per_core_out_matrix_width_ntiles = p_config.per_core_out_matrix_width_ntiles; - //cout << "per_core_weight_matrix_width_ntiles=" << per_core_weight_matrix_width_ntiles << endl; - // cout << "total_num_cores=" << total_num_cores << endl; - // cout << "per_core_out_matrix_height_ntiles=" << per_core_out_matrix_height_ntiles << endl; - // cout << "act_matrix_height_ntiles=" << act_matrix_height_ntiles << endl; - // cout << "act_block_h_datums=" << act_block_h_datums << endl; - // cout << "num_blocks_act_h=" << num_blocks_act_h << endl; - - // weight_width_sliced determines is 1d-sysarr-conv or 2d-sysarr-conv - bool weight_width_sliced = per_core_out_matrix_width_ntiles < weight_matrix_width_ntiles; - uint32_t window_outer; - uint32_t window_inner; - if (weight_width_sliced) { - window_outer = 1; // window_outer = 1 becasue all of filter window is processed in the inner loop - window_inner = 3; // window_inner = 9 / 3, ie. read 3 width coalesced - } else { - window_outer = num_blocks_act_w; // window_outer - window_inner = weight_size_h * weight_size_w / num_blocks_act_w; // window_inner - } - reader_defines["WINDOW_INNER"] = std::to_string(window_inner); - log_debug("window_outer: {}, window_inner: {}", window_outer, window_inner); - - assert(weight_matrix_width_ntiles % per_core_out_matrix_width_ntiles == 0); - assert(per_core_out_matrix_width_ntiles % weight_block_w_ntiles == 0); - uint32_t num_blocks_weight_w_per_core = per_core_out_matrix_width_ntiles / weight_block_w_ntiles; - if (not weight_width_sliced) { - assert(num_blocks_weight_w_per_core == num_blocks_weight_w); - } - uint32_t num_weight_slices_width = weight_matrix_width_ntiles / per_core_out_matrix_width_ntiles; - assert(num_cores_y % num_weight_slices_width == 0); - uint32_t num_cores_y_per_weight_slice_width = num_cores_y / num_weight_slices_width; - uint32_t total_num_cores_per_weight_slice = num_cores_y_per_weight_slice_width * num_cores_x; - if (weight_width_sliced) { - assert(total_num_cores_per_weight_slice * per_core_out_matrix_height_ntiles == act_matrix_height_ntiles); - } - else { - assert(total_num_cores * per_core_out_matrix_height_ntiles >= act_matrix_height_ntiles); - } - assert(per_core_out_matrix_height_ntiles % act_block_h_ntiles == 0); - uint32_t num_blocks_act_h_per_core = per_core_out_matrix_height_ntiles / act_block_h_ntiles; - assert(per_core_out_matrix_height_ntiles % out_block_h_ntiles == 0); - uint32_t num_blocks_out_h_per_core = per_core_out_matrix_height_ntiles / out_block_h_ntiles; - bool act_height_sliced = per_core_out_matrix_height_ntiles < act_matrix_height_ntiles; - if (not act_height_sliced) { - assert(num_blocks_act_h_per_core == num_blocks_act_h); - assert(num_blocks_out_h_per_core == num_blocks_out_h); - assert(num_cores_x == 1); - } - // cout << "num_blocks_act_h_per_core=" << num_blocks_act_h_per_core << endl; - assert(act_matrix_height_ntiles % per_core_out_matrix_height_ntiles == 0); - uint32_t total_active_num_cores_per_weight_slice = act_matrix_height_ntiles / per_core_out_matrix_height_ntiles; - assert(total_active_num_cores_per_weight_slice <= total_num_cores_per_weight_slice); - uint32_t total_noop_cores = total_num_cores_per_weight_slice - total_active_num_cores_per_weight_slice; - uint32_t total_active_num_cores = total_active_num_cores_per_weight_slice * num_weight_slices_width; - if (weight_width_sliced) { - assert(total_noop_cores == 0); - assert(total_active_num_cores == total_num_cores); - } - // cout << "act_matrix_height_ntiles=" << act_matrix_height_ntiles << endl; - // cout << "per_core_out_matrix_height_ntiles=" << per_core_out_matrix_height_ntiles << endl; - // cout << "total_active_num_cores_per_weight_slice="<< total_active_num_cores_per_weight_slice << endl; - // cout << "num weight slices = " << num_weight_slices_width << endl; - // cout << "total num active cores" << total_active_num_cores << endl; - if (has_bias) { - assert(bias_ntiles % num_weight_slices_width == 0); - assert(bias_ntiles == weight_matrix_width_ntiles); - } - uint32_t bias_ntiles_per_core = bias_ntiles / num_weight_slices_width; - - bool act_block_w_equals_input_channels_x_filter_width = (act_block_w_datums == (conv_act_size_c * weight_size_w)); - if (rn50_first_conv) { - assert(not weight_width_sliced); // weight width slicing not supported for rn50 first conv - assert(act_block_w_equals_input_channels_x_filter_width); - } - - vector debug_cores; - for(uint32_t core_i = 0; core_i < total_num_cores; core_i++) { - uint32_t core_x_i = core_i % num_cores_x; - uint32_t core_y_i = core_i / num_cores_x; - debug_cores.push_back({core_x_i+1, core_y_i+1}); - } - - CoreRange all_cores(CoreCoord(0, 0), CoreCoord(num_cores_x - 1, num_cores_y - 1)); - assert(total_active_num_cores >= num_cores_x); - uint32_t num_active_cores_x = num_cores_x; - uint32_t num_active_cores_y_with_full_x = total_active_num_cores / num_cores_x; - uint32_t num_active_cores_x_last_y = total_active_num_cores % num_cores_x; - assert((num_active_cores_x * num_active_cores_y_with_full_x) + num_active_cores_x_last_y == total_active_num_cores); - - // cout << "All active cores. Core Ranges:" << endl; - // cout << "Core range 1 - (0,0) to (" << num_active_cores_x - 1 << "," << num_active_cores_y_with_full_x - 1 << ")" << endl; - - std::set all_active_cores_set; - all_active_cores_set.insert(CoreRange(CoreCoord(0, 0), CoreCoord(num_active_cores_x - 1, num_active_cores_y_with_full_x - 1))); - if (num_active_cores_x_last_y > 0) { - all_active_cores_set.insert(CoreRange(CoreCoord(0, num_active_cores_y_with_full_x), CoreCoord(num_active_cores_x_last_y - 1, num_active_cores_y_with_full_x))); - // cout << "Core range 2 - (0," << num_active_cores_y_with_full_x << ") to (" << num_active_cores_x_last_y - 1 << "," << num_active_cores_y_with_full_x << ")" << endl; - } - CoreRangeSet all_active_cores(all_active_cores_set); - std::set noop_cores_set; - if (total_noop_cores > 0) { - assert(total_noop_cores == (num_cores_x - num_active_cores_x_last_y)); - noop_cores_set.insert(CoreRange(CoreCoord(num_active_cores_x_last_y, num_active_cores_y_with_full_x), CoreCoord(num_cores_x - 1, num_active_cores_y_with_full_x))); - // cout << "Noop core range - (" << num_active_cores_x_last_y << "," << num_active_cores_y_with_full_x << ") to (" << num_cores_x - 1 << "," << num_active_cores_y_with_full_x << ")" << endl; - - } - CoreRangeSet noop_cores(noop_cores_set); - - // Mcast cores - // If total_num_cores, there is no mcasting - CoreCoord top_left_core = {(std::size_t) 0, (std::size_t) 0}; - CoreCoord top_left_core_plus_one = {(std::size_t) 1, (std::size_t) 1}; - CoreCoord bottom_right_core = {(std::size_t) num_cores_x - 1, (std::size_t) num_cores_y - 1}; - auto top_left_core_physical = device->worker_core_from_logical_core(top_left_core); - auto top_left_core_plus_one_physical = device->worker_core_from_logical_core(top_left_core_plus_one); - auto bottom_right_core_physical = device->worker_core_from_logical_core(bottom_right_core); - - CoreRange mcast_sender_cores(top_left_core, top_left_core); // If single core, this kernel doesn't do mcasting - CoreRangeSet mcast_receiver_cores{{}}; - uint32_t weights_mcast_sender_semaphore_id{}; - uint32_t weights_mcast_receiver_semaphore_id{}; - uint32_t act_mcast_sender_semaphore_id = 0; - uint32_t act_mcast_receiver_semaphore_id = 0; - std::vector act_mcast_noc_y; - // 2D mcast - if (weight_width_sliced) { - mcast_sender_cores = CoreRange(top_left_core, CoreCoord(0, num_cores_y - 1)); - mcast_receiver_cores = {{CoreRange(CoreCoord(1, 0), bottom_right_core)}}; - weights_mcast_sender_semaphore_id = tt_metal::CreateSemaphore(program, all_cores, INVALID); - weights_mcast_receiver_semaphore_id = tt_metal::CreateSemaphore(program, all_cores, INVALID); - // 1D mcast - } else { - if (total_num_cores > 1) { - std::set mcast_receiver_set; - if (num_cores_x > 1) { - mcast_receiver_set.insert(CoreRange(CoreCoord(1, 0), CoreCoord(num_cores_x - 1, 0))); - } if (num_cores_y > 1) { - mcast_receiver_set.insert(CoreRange(CoreCoord(0, 1), bottom_right_core)); - } - mcast_receiver_cores = mcast_receiver_set; - weights_mcast_sender_semaphore_id = tt_metal::CreateSemaphore(program, all_cores, INVALID); - weights_mcast_receiver_semaphore_id = tt_metal::CreateSemaphore(program, all_cores, INVALID); - } - } - - bool read_3x3_window_in_inner_loop = false; - uint32_t num_weight_cb_tiles = weight_block_h_ntiles * weight_block_w_ntiles / conv_act_c_blocks; - bool fully_buffer_weights = false; - uint32_t num_act_cb_tiles = act_block_h_ntiles * act_block_w_ntiles / conv_act_c_blocks; - // TODO: This flag should be set in kernel logic but need this for create_CB - if (a.memory_config().is_sharded() && weight_size_h == 3 && weight_size_w == 3 && stride_h == 1 && weight_width_sliced) { - // If conv_act_c_blocks > 1 and we have 2D conv with sharded input, we always read entire 3x3 window before pushing in reader/writer - // TODO: Generalize this to not make this assumption - read_3x3_window_in_inner_loop = true; - num_weight_cb_tiles *= weight_size_h * weight_size_w; - num_act_cb_tiles *= weight_size_h * weight_size_w; - } else if (num_blocks_act_h_per_core > 1) { - fully_buffer_weights = true; - } - uint32_t num_cb0_tilized_tiles = num_act_cb_tiles; - - if (fully_buffer_weights) { - num_weight_cb_tiles *= window_outer; - } else if (per_core_out_matrix_width_ntiles < 8) { - num_weight_cb_tiles = num_weight_cb_tiles * 2; - } - if (rn50_first_conv) { - num_weight_cb_tiles = weight_block_h_ntiles * weight_block_w_ntiles * num_blocks_weight_w * num_blocks_act_w; - } - // std::cout << "num_act_cb_tiles = " << num_act_cb_tiles << std::endl; - if (conv_act_size_c / conv_act_c_blocks < 256) { - num_act_cb_tiles = num_act_cb_tiles * 2; // double buffered - // std::cout << "num_act_cb_tiles (post DB) = " << num_act_cb_tiles << std::endl; - } - uint32_t writer_output_block_num_tiles = out_block_h_ntiles * weight_block_w_ntiles; - - // if (!(conv_output_size_w == 14 || conv_output_size_w == 7)) { - // writer_output_block_num_tiles = writer_output_block_num_tiles * 2; - // } - - // TODO: Moving this function call to after kernel logic causes pcc fails - // There are additional CBs and semaphores created in 2D conv in kernel logic, - // so does order of create_cb calls matter? - auto [cb_sharded_act, cb_output] = create_CBs_for_sharded_input( - program, - a, - all_cores, - num_act_cb_tiles, // row major act cb - num_weight_cb_tiles, // tiled weight cb - num_cb0_tilized_tiles, // tiled act cb - writer_output_block_num_tiles, // math output cb - weight_block_w_ntiles, // reblock cb - writer_output_block_num_tiles, // writer output cb, double bufferred - untilize_out, - act_df, - weight_df, - tilized_act_df, - out_df, - bias_df, - weight_width_sliced, - output, - bias_ntiles_per_core, - has_bias); - - string reader_kernel; - string compute_kernel; - string writer_mcast_sender_kernel; - string writer_mcast_receiver_kernel; - bool tilize_in0 = true; - bool reader_with_indices = false; - if (rn50_first_conv) { - // TODO: Add support for sharded rn50_first_conv - TT_ASSERT(false, "Sharded input is not supported for resnet-50 first conv yet!"); - } else { - compute_kernel = "ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/conv_bmm_tilize_col_major_out_blocks.cpp"; - // Input should always be sharded in this conv; always use reader kernel for input shard with halo and padding - if (weight_size_h == 3 && weight_size_w == 3 && stride_h == 1) { - reader_with_indices = true; - // 2D conv - if (weight_width_sliced) { - assert(read_3x3_window_in_inner_loop == true); - reader_kernel = "ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_conv_activations_2d_mcast_padded_with_halo_3x3_weights.cpp"; - writer_mcast_sender_kernel = "ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_2d_mcast_sender_conv_weights_tiled_col_to_rm_blocks.cpp"; - writer_mcast_receiver_kernel = "ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_2d_mcast_receiver_conv_weights_tiled_col_to_rm_blocks.cpp"; - - act_mcast_sender_semaphore_id = tt_metal::CreateSemaphore(program, all_cores, INVALID); - act_mcast_receiver_semaphore_id = tt_metal::CreateSemaphore(program, all_cores, INVALID); - - act_mcast_noc_y.reserve(num_cores_y); - for(uint32_t core_idx_y = 0; core_idx_y < num_cores_y; ++core_idx_y) { - act_mcast_noc_y.push_back(device->worker_core_from_logical_core({0, core_idx_y}).y); - } - - // For 2D convs, pre-tilize input and round robin self-mcast tilized act matrix to other cores - tilize_in0 = false; - } - // 1D conv - else { - reader_kernel = "ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_conv_activations_padded_with_halo_3x3_weights.cpp"; - writer_mcast_sender_kernel = "ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_1d_mcast_sender_conv_weights_tiled_col_to_rm_blocks.cpp"; - writer_mcast_receiver_kernel = "ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_1d_mcast_receiver_conv_weights_tiled_col_to_rm_blocks.cpp"; - } - - // Local L1 to store array for reader indices - // TODO: once 1D-sys-conv is uint16_t indicies (2D-sys-conv already is), then each entry can be 2B (not 4) - CircularBufferConfig cb_for_reader_indices_config = CircularBufferConfig(out_block_h_datums * 4, {{cb_for_reader_indices, tt::DataFormat::Float16_b}}) - .set_page_size(cb_for_reader_indices, 4); - auto cb_for_reader_indices_id = tt_metal::CreateCircularBuffer(program, all_cores, cb_for_reader_indices_config); - - // Local L1 to store array for reader offsets - // TODO: this is not used in 2D-sys-conv, remove also from 1D-sys-conv - CircularBufferConfig cb_for_reader_offsets_config = CircularBufferConfig(weight_size_h * weight_size_w * 4, {{cb_for_reader_offsets, tt::DataFormat::Float16_b}}) - .set_page_size(cb_for_reader_offsets, 4); - auto cb_for_reader_offsets_id = tt_metal::CreateCircularBuffer(program, all_cores, cb_for_reader_offsets_config); - } else { - TT_ASSERT(false, "Sharded input not supported for this conv yet!"); - } - } - TT_ASSERT(!(conv_act_size_c & (conv_act_size_c - 1))); // channel depth power of 2 is supported only - - std::vector reader_rt_args; - std::vector reader_compile_time_args; - std::vector writer_rt_args; - std::vector writer_compile_time_args; - - - uint32_t conv_act_c_read_bytes = conv_act_size_c * a.element_size() / conv_act_c_blocks; - // For new reader_with_indices, this is used to calculate offset so use actual read_bytes along c - // For old readers, this is used for bank page size for interleaved; offset is from conv_act_c_read_bytes - uint32_t log_base_2_of_conv_act_size_c_bytes = reader_with_indices ? std::log2(conv_act_c_read_bytes) : std::log2(conv_act_size_c * a.element_size()); - reader_compile_time_args = {(uint32_t) - (src0_dram_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0), - (uint32_t) stride_h, - (uint32_t) stride_w, - (uint32_t) conv_act_size_w, - (uint32_t) conv_output_size_w, - (uint32_t) conv_act_c_read_bytes, - (uint32_t) log_base_2_of_conv_act_size_c_bytes, - - // unused, TODO: delete - (uint32_t) extra_padding_for_32B_alignment, - (uint32_t) (conv_act_size_c/act_block_w_datums), - (uint32_t) act_block_w_datums * a.element_size(), - - (uint32_t) window_outer, - (uint32_t) window_inner, - (uint32_t) act_block_h_datums}; - - // define for bias - std::map writer_defines; - std::map writer_mcast_sender_defines; - std::map compute_defines; - if (output.memory_config().is_sharded()) { - writer_defines["SHARDED_OUT"] = "1"; - writer_mcast_sender_defines["SHARDED_OUT"] = "1"; - } - if (total_num_cores == 1) { - writer_mcast_sender_defines["SKIP_MCAST"] = "1"; - } - if (has_bias) { - writer_defines["FUSE_BIAS"] = "1"; - writer_mcast_sender_defines["FUSE_BIAS"] = "1"; - compute_defines["FUSE_BIAS"] = "1"; - } - - if (fuse_relu) { - compute_defines["PACK_RELU"] = "1"; - } - - if (!tilize_in0) { - compute_defines["PRE_TILIZE"] = "1"; - } - - writer_compile_time_args = { - (uint32_t) (dst_dram_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0), - out0_cb, - weight_cb, - bias_cb, - (uint32_t) (bias_buffer == nullptr ? 0 : (bias_buffer->buffer_type() == BufferType::DRAM ? 1 : 0))}; - - uint32_t in0_block_w = act_block_w_ntiles / conv_act_c_blocks; - uint32_t in0_block_num_tiles = act_block_num_tiles / conv_act_c_blocks; - uint32_t in0_subblock_num_tiles = act_subblock_num_tiles / conv_act_c_blocks; - uint32_t in1_block_num_tiles = weight_block_num_tiles / conv_act_c_blocks; - uint32_t in0_num_blocks_w = num_blocks_act_w * conv_act_c_blocks; // Fold outer c_block loop together with weight_block_num_tiles = 9 - if (read_3x3_window_in_inner_loop) { - const uint32_t window_size = weight_size_h * weight_size_w; - in0_block_w *= window_size; - in0_block_num_tiles *= window_size; - in0_subblock_num_tiles *= window_size; - in1_block_num_tiles *= window_size; - in0_num_blocks_w /= window_size; - } - - vector compute_kernel_args = { - in0_block_w, - act_num_subblocks, - in0_block_num_tiles, - in0_subblock_num_tiles, - act_subblock_h_ntiles, - - weight_num_subblocks, - in1_block_num_tiles, - weight_block_w_ntiles, - - num_blocks_act_h_per_core, - in0_num_blocks_w, - num_blocks_weight_w_per_core, - - out_subblock_h_ntiles, - out_subblock_w_ntiles, - out_subblock_num_tiles, - - tilize_in0, - untilize_out, - - bias_ntiles_per_core - }; - - auto writer_mcast_noc = tt_metal::detail::GetPreferredNOCForDRAMWrite(device->arch()); - auto reader_noc = tt_metal::detail::GetPreferredNOCForDRAMRead(device->arch()); - auto writer_mcast_sender_id = CreateKernel( - program, - writer_mcast_sender_kernel, - mcast_sender_cores, - DataMovementConfig{ - .processor = DataMovementProcessor::RISCV_0, - .noc = writer_mcast_noc, - .compile_args = writer_compile_time_args, - .defines = writer_mcast_sender_defines}); - - KernelHandle writer_mcast_receiver_id{}; - if (total_num_cores > 1) { - writer_mcast_receiver_id = CreateKernel( - program, - writer_mcast_receiver_kernel, - mcast_receiver_cores, - DataMovementConfig{ - .processor = DataMovementProcessor::RISCV_0, - .noc = writer_mcast_noc, - .compile_args = writer_compile_time_args, - .defines = writer_defines}); - } - - auto reader_id = CreateKernel( - program, - reader_kernel, - all_cores, - DataMovementConfig{ - .processor = DataMovementProcessor::RISCV_1, - .noc = reader_noc, - .compile_args = reader_compile_time_args, - .defines = reader_defines}); - - // Compile compute kernel for active cores only - // Compile blank kernel for noop cores - auto compute_id = CreateKernel( - program, - compute_kernel, - all_active_cores, - ComputeConfig{ - .math_fidelity = math_fidelity, - .compile_args = compute_kernel_args, - .defines = compute_defines}); - - if (total_noop_cores > 0) { - auto compute_id = CreateKernel( - program, - "tt_metal/kernels/compute/blank.cpp", - noop_cores, ComputeConfig{}); - } - - vector reader_ids; - vector writer_ids; - //tt_start_debug_print_server(); - for(uint32_t core_i = 0; core_i < total_num_cores; core_i++) { - uint32_t core_x_i = core_i % num_cores_x; - uint32_t core_y_i = core_i / num_cores_x; - // cout << "core_x_i=" << core_x_i << ", core_y_i=" << core_y_i << endl; - CoreRange core(CoreCoord(core_x_i, core_y_i), CoreCoord(core_x_i, core_y_i)); - bool noop_core = false; - for (const auto & noop_core_range : noop_cores.ranges()) { - if (noop_core_range.contains(core)) { - // cout << "No op core" << endl; - // cout << "core_x_i=" << core_x_i << ", core_y_i=" << core_y_i << endl; - noop_core = true; - break; - } - } - // per core specific args - uint32_t act_slice_i = core_i % (num_cores_y_per_weight_slice_width * num_cores_x); - uint32_t weight_slice_i = core_i / (num_cores_y_per_weight_slice_width * num_cores_x); - uint32_t total_h_start = act_slice_i * per_core_out_matrix_height_ntiles * TILE_HEIGHT; - uint32_t n_start = total_h_start / (conv_output_size_h * conv_output_size_w); - uint32_t matrix_h_start = total_h_start % (conv_output_size_h * conv_output_size_w); - uint32_t out_h_start = matrix_h_start / conv_output_size_w; - uint32_t out_w_start = matrix_h_start % conv_output_size_w; - uint32_t in_h_start = (n_start * conv_act_size_h) + out_h_start * stride_h; - uint32_t last_start_in_h_curr_image = 222 + (n_start * conv_act_size_h); - uint32_t out_start_tile_id = (act_slice_i * per_core_out_matrix_height_ntiles * weight_matrix_width_ntiles) + (weight_slice_i * per_core_out_matrix_width_ntiles); - uint32_t out_start_tile_id_h = act_slice_i * per_core_out_matrix_height_ntiles; - uint32_t out_start_tile_id_w = weight_slice_i * per_core_out_matrix_width_ntiles; - uint32_t bias_tile_offset = weight_slice_i * per_core_out_matrix_width_ntiles; - if (has_bias) { - assert(bias_tile_offset < bias_ntiles); - } - // cout << "act_slice_i=" << act_slice_i << endl; - // cout << "weight_slice_i=" << weight_slice_i << endl; - // cout << "core_i=" << core_i << endl; - // cout << "num_blocks_act_h_per_core=" << num_blocks_act_h_per_core << endl; - // cout << "num_blocks_weight_w_per_core=" << num_blocks_weight_w_per_core << endl; - // cout << "bias_tile_offset=" << bias_tile_offset << endl; - // cout << "out_start_tile_id=" << out_start_tile_id << endl; - // cout << "out_start_tile_id_w=" << out_start_tile_id_w << endl; - // cout << "per_core_out_matrix_height_ntiles=" << per_core_out_matrix_height_ntiles << endl; - // cout << "weight_matrix_width_ntiles=" << weight_matrix_width_ntiles << endl; - // cout << "out_start_tile_id_h=" << out_start_tile_id_h << endl; - // cout << endl; - // cout << "total_h_start=" << total_h_start << endl; - // cout << "in_h_start=" << in_h_start << endl; - // cout << "out_h_start=" << out_h_start << endl; - // cout << "out_w_start=" << out_w_start << endl; - // cout << "matrix_h_start=" << matrix_h_start << endl; - // cout << "n_start=" << n_start << endl; - - if (rn50_first_conv) { - // TODO: Add support for sharded rn50_first_conv - TT_ASSERT(false, "Sharded input is not supported for resnet-50 first conv yet!"); - } else { - TT_ASSERT(reader_with_indices, "Input must be sharded for this conv!"); - /* Logic to compute: - * NOTE: This logic is wrong if stride !=1 - * first_partial_right_aligned_row_width - * skip_after_partial_right_aligned_row - * first_partial_image_num_rows - * skip_after_first_partial_image_row - * num_full_images - * skip_after_full_image - * last_partial_image_num_rows - * last_partial_left_aligned_row_width - */ - - // If 2D, same image specs across a row - uint32_t start_stick = weight_width_sliced ? core_x_i * out_block_h_datums : core_i * out_block_h_datums; - uint32_t end_stick = start_stick + out_block_h_datums; - - ShardingConfig sharding_config = get_specs_for_sharding_partition(start_stick, end_stick, conv_act_size_h, conv_act_size_w, weight_size_w, pad_h, pad_w); - uint32_t first_partial_right_aligned_row_width = sharding_config.first_partial_right_aligned_row_width; - uint32_t skip_after_partial_right_aligned_row = sharding_config.skip_after_partial_right_aligned_row; - uint32_t first_partial_image_num_rows = sharding_config.first_partial_image_num_rows; - uint32_t skip_after_first_partial_image_row = sharding_config.skip_after_first_partial_image_row; - uint32_t num_full_images = sharding_config.num_full_images; - uint32_t skip_after_full_image = sharding_config.skip_after_full_image; - uint32_t last_partial_image_num_rows = sharding_config.last_partial_image_num_rows; - uint32_t last_partial_left_aligned_row_width = sharding_config.last_partial_left_aligned_row_width; - - if (weight_width_sliced) { - auto shard_shape = a.shard_spec().value().shape; - uint32_t tilized_act_tile_size = tt_metal::detail::TileSize(tilized_act_df); - CoreCoord bottom_core = {(std::size_t) core_x_i, (std::size_t) num_cores_y - 1}; - auto bottom_core_physical = device->worker_core_from_logical_core(bottom_core); - - bool reader_is_noc_0 = reader_noc == NOC::NOC_0; - uint32_t act_mcast_dest_noc_start_x = bottom_core_physical.x; - uint32_t act_mcast_dest_noc_start_y = reader_is_noc_0 ? top_left_core_physical.y : bottom_core_physical.y; - uint32_t act_mcast_dest_noc_end_x = bottom_core_physical.x; - uint32_t act_mcast_dest_noc_end_y = reader_is_noc_0 ? bottom_core_physical.y : top_left_core_physical.y; - reader_rt_args = { - conv_act_size_w, - conv_act_size_h, - weight_size_h, - weight_size_w, - - act_block_h_datums, - in0_block_num_tiles, - conv_act_c_blocks, - - // Specs for reader indices - first_partial_right_aligned_row_width, - skip_after_partial_right_aligned_row, - first_partial_image_num_rows, - skip_after_first_partial_image_row, - num_full_images, - skip_after_full_image, - last_partial_image_num_rows, - last_partial_left_aligned_row_width, - - // Specs for reader offsets - window_outer, // window_outer - window_inner, // window_inner = 9 / 3, ie. read 3 width coalesced - - (uint32_t) noop_core, - - // mcast args - act_mcast_dest_noc_start_x, - act_mcast_dest_noc_start_y, - act_mcast_dest_noc_end_x, - act_mcast_dest_noc_end_y, - num_cores_y - 1, - num_cores_y - 1, - act_mcast_sender_semaphore_id, - act_mcast_receiver_semaphore_id, - in0_block_num_tiles * tilized_act_tile_size, // act_mcast_sender_size_bytes - core_y_i, // act_mcast_sender_id (goes down the column) - (uint32_t) bottom_core_physical.x, // act_mcast_sender_noc_x - }; - reader_rt_args.insert(reader_rt_args.end(), act_mcast_noc_y.begin(), act_mcast_noc_y.end()); // act_mcast_sender_noc_y - } else { - reader_rt_args = { - conv_act_size_w, - conv_act_size_h, - weight_size_h, - weight_size_w, - num_blocks_act_h_per_core, - // act_block_h_datums, - act_block_num_tiles / conv_act_c_blocks, - - // Specs for reader indices - first_partial_right_aligned_row_width, - skip_after_partial_right_aligned_row, - first_partial_image_num_rows, - skip_after_first_partial_image_row, - num_full_images, - skip_after_full_image, - last_partial_image_num_rows, - last_partial_left_aligned_row_width, - - // Specs for reader offsets - window_outer, // window_outer - window_inner, // window_inner - - (uint32_t) noop_core - }; - } - } - - SetRuntimeArgs( - program, reader_id, core, - reader_rt_args - ); - reader_ids.push_back(reader_id); - - writer_rt_args = { - out_dram_addr, - weight_dram_addr, - bias_dram_addr, - - output_width_num_tiles, // out_next_tile_stride_h - 1, // out_next_tile_stride_w - out_subblock_h_ntiles * output_width_num_tiles, // out_next_subblock_stride_h - out_subblock_w_ntiles, // out_next_subblock_stride_w - act_block_h_ntiles * output_width_num_tiles, // out_next_block_stride_h - weight_block_w_ntiles, // out_next_block_stride_w - out_subblock_h_ntiles, - out_subblock_w_ntiles, - out_subblock_num_tiles, - act_block_h_ntiles / out_subblock_h_ntiles, // out_num_subblocks_h - weight_block_w_ntiles / out_subblock_w_ntiles, // out_num_subblocks_w - num_blocks_act_h_per_core, // out_num_blocks_h - num_blocks_weight_w_per_core, // out_num_blocks_w - act_block_h_ntiles, // out_block_height_num_tiles - output_height_num_tiles, // out_height_num_tiles without block shape padding - output_width_num_tiles, // out_width_num_tiles withoug block shape padding - out_start_tile_id, - out_start_tile_id_h, - out_start_tile_id_w, - - num_blocks_act_w, // = number of blocks of weight in height dim - in1_block_num_tiles, - conv_act_c_blocks, - weight_block_h_ntiles / conv_act_c_blocks, - weight_block_w_ntiles, - weight_matrix_width_ntiles, // weight_stride_h - weight_matrix_width_ntiles * weight_block_h_ntiles, // weight_next_block_stride_h, - weight_block_w_ntiles, // weight_next_block_stride_w - - // bias - bias_ntiles_per_core, - bias_tile_offset, - - (uint32_t) noop_core - }; - - // Mcast sender - // 2D mcast - if (weight_width_sliced) { - CoreCoord right_core = {(std::size_t) num_cores_x - 1, (std::size_t) core_y_i}; - auto right_core_physical = device->worker_core_from_logical_core(right_core); - // sender - if (core_x_i == 0) { - if (writer_mcast_noc == NOC::NOC_0) { - writer_rt_args.push_back(top_left_core_plus_one_physical.x); // weights_mcast_dest_noc_start_x - writer_rt_args.push_back(right_core_physical.y); // weights_mcast_dest_noc_start_y - writer_rt_args.push_back(bottom_right_core_physical.x); // weights_mcast_dest_noc_end_x - writer_rt_args.push_back(right_core_physical.y); // weights_mcast_dest_noc_end_y - } else { - writer_rt_args.push_back(bottom_right_core_physical.x); // weights_mcast_dest_noc_start_x - writer_rt_args.push_back(right_core_physical.y); // weights_mcast_dest_noc_start_y - writer_rt_args.push_back(top_left_core_plus_one_physical.x); // weights_mcast_dest_noc_end_x - writer_rt_args.push_back(right_core_physical.y); // weights_mcast_dest_noc_end_y - } - - writer_rt_args.push_back(num_cores_x - 1); // weights_mcast_num_dests - writer_rt_args.push_back(num_cores_x - 1); // weights_mcast_num_cores - writer_rt_args.push_back(weights_mcast_sender_semaphore_id); - writer_rt_args.push_back(weights_mcast_receiver_semaphore_id); - - SetRuntimeArgs( - program, writer_mcast_sender_id, core, - writer_rt_args - ); - writer_ids.push_back(writer_mcast_sender_id); - // receiver - } else { - writer_rt_args.push_back(top_left_core_physical.x); // weights_mcast_sender_noc_x - writer_rt_args.push_back(right_core_physical.y); // weights_mcast_sender_noc_y - writer_rt_args.push_back(weights_mcast_sender_semaphore_id); - writer_rt_args.push_back(weights_mcast_receiver_semaphore_id); - - SetRuntimeArgs( - program, writer_mcast_receiver_id, core, - writer_rt_args - ); - writer_ids.push_back(writer_mcast_receiver_id); - } - // 1D mcast - } else { - // sender - if (core_x_i == 0 and core_y_i == 0) { - if (writer_mcast_noc == NOC::NOC_0) { - writer_rt_args.push_back(top_left_core_physical.x); // weights_mcast_dest_noc_start_x - writer_rt_args.push_back(top_left_core_physical.y); // weights_mcast_dest_noc_start_y - writer_rt_args.push_back(bottom_right_core_physical.x); // weights_mcast_dest_noc_end_x - writer_rt_args.push_back(bottom_right_core_physical.y); // weights_mcast_dest_noc_end_y - } else { - writer_rt_args.push_back(bottom_right_core_physical.x); // weights_mcast_dest_noc_start_x - writer_rt_args.push_back(bottom_right_core_physical.y); // weights_mcast_dest_noc_start_y - writer_rt_args.push_back(top_left_core_physical.x); // weights_mcast_dest_noc_end_x - writer_rt_args.push_back(top_left_core_physical.y); // weights_mcast_dest_noc_end_y - } - writer_rt_args.push_back(total_active_num_cores - 1); // weights_mcast_num_dests - writer_rt_args.push_back(total_num_cores - 1); // weights_mcast_num_cores - writer_rt_args.push_back(weights_mcast_sender_semaphore_id); - writer_rt_args.push_back(weights_mcast_receiver_semaphore_id); - - SetRuntimeArgs( - program, writer_mcast_sender_id, core, - writer_rt_args - ); - writer_ids.push_back(writer_mcast_sender_id); - // receiver - } else { - writer_rt_args.push_back(top_left_core_physical.x); // weights_mcast_sender_noc_x - writer_rt_args.push_back(top_left_core_physical.y); // weights_mcast_sender_noc_y - writer_rt_args.push_back(weights_mcast_sender_semaphore_id); - writer_rt_args.push_back(weights_mcast_receiver_semaphore_id); - - SetRuntimeArgs( - program, writer_mcast_receiver_id, core, - writer_rt_args - ); - writer_ids.push_back(writer_mcast_receiver_id); - } - } - - } // for num_cores - - auto override_runtime_arguments_callback = [ - reader_kernel_ids=reader_ids, - writer_kernel_ids=writer_ids, - cb_sharded_act=cb_sharded_act, - cb_output=cb_output, - total_num_cores=total_num_cores, - num_cores_x=num_cores_x, - num_cores_y=num_cores_y, - has_bias=has_bias - ] - ( - const void* operation, - Program& program, - const std::vector& input_tensors, - const std::vector>& optional_input_tensors, - const std::vector& output_tensors - ) { - - TT_ASSERT(input_tensors.size() + optional_input_tensors.size() == 3); - TT_ASSERT(output_tensors.size() == 1); - - auto src_buffer_a = input_tensors.at(0).buffer(); - auto src_buffer_b = input_tensors.at(1).buffer(); - auto src_a_is_sharded = input_tensors.at(0).memory_config().is_sharded(); - - auto dst_buffer = output_tensors.at(0).buffer(); - bool out_sharded = output_tensors.at(0).memory_config().is_sharded(); - - for(uint32_t core_i = 0; core_i < total_num_cores; core_i++) { - uint32_t core_x_i = core_i % num_cores_x; - uint32_t core_y_i = core_i / num_cores_x; - CoreCoord core = {core_x_i, core_y_i}; - - if (!src_a_is_sharded) { - auto &runtime_args = GetRuntimeArgs(program, reader_kernel_ids[core_i], core); - runtime_args[0] = src_buffer_a->address(); - } - - { - auto &runtime_args = GetRuntimeArgs(program, writer_kernel_ids[core_i], core); - runtime_args[0] = dst_buffer->address(); - runtime_args[1] = src_buffer_b->address(); - if (has_bias) { - auto src_buffer_c = optional_input_tensors.at(0).value().buffer(); - TT_ASSERT(src_buffer_c != nullptr); - runtime_args[2] = src_buffer_c->address(); - } - } - } - - if (src_a_is_sharded) { - UpdateDynamicCircularBufferAddress(program, cb_sharded_act, *src_buffer_a); - } - - if (out_sharded) { - UpdateDynamicCircularBufferAddress(program, cb_output, *dst_buffer); - } - }; - return {.program=std::move(program), .override_runtime_arguments_callback=override_runtime_arguments_callback}; -} - -} // namespace tt_metal - -} // namespace tt diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/multi_core_optimized_conv_sharded/optimized_conv_op_sharded_v2.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/multi_core_optimized_conv_sharded/optimized_conv_op_sharded_v2.cpp index 65e65663bad..6c8785f2b2f 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/multi_core_optimized_conv_sharded/optimized_conv_op_sharded_v2.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/multi_core_optimized_conv_sharded/optimized_conv_op_sharded_v2.cpp @@ -1674,79 +1674,6 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl( return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_arguments_callback}; } -operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_( - const Tensor& a, - const Tensor& b, - const Shape& ashape, - std::optional bias, - const std::optional conv_reader_indices, - sliding_window::SlidingWindowConfig sliding_window_config, - uint32_t output_channels, - uint32_t groups, - bool untilize_out, - bool has_bias, - bool fuse_relu, - const OptimizedConvParallelizationConfig& parallelization_config, - const OptimizedConvBlockConfig& block_config, - uint32_t extra_padding_for_32B_alignment, - bool use_shallow_conv_variant, - bool transpose_mcast, - Tensor& output, - DeviceComputeKernelConfig compute_kernel_config, - bool enable_act_double_buffer, - bool enable_split_reader, - bool enable_subblock_padding) { - tt_metal::Program program = tt_metal::CreateProgram(); - if(a.memory_config().memory_layout == TensorMemoryLayout::WIDTH_SHARDED) { - return multi_core_optimized_conv_width_sharded_v2_impl( - program, - a, - b, - ashape, - bias, - conv_reader_indices, - sliding_window_config, - output_channels, - groups, - untilize_out, - has_bias, - fuse_relu, - parallelization_config, - block_config, - extra_padding_for_32B_alignment, - use_shallow_conv_variant, - transpose_mcast, - output, - compute_kernel_config, - enable_act_double_buffer, - enable_split_reader, - enable_subblock_padding); - } - return multi_core_optimized_conv_sharded_v2_impl( - program, - a, - b, - ashape, - bias, - conv_reader_indices, - sliding_window_config, - output_channels, - groups, - untilize_out, - has_bias, - fuse_relu, - parallelization_config, - block_config, - extra_padding_for_32B_alignment, - use_shallow_conv_variant, - transpose_mcast, - output, - compute_kernel_config, - enable_act_double_buffer, - enable_split_reader, - enable_subblock_padding); -} - operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_new( const Tensor& a, const Tensor& b, diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/optimized_conv_op.hpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/optimized_conv_op.hpp index 5b1b8608947..93b0e33fdb9 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/optimized_conv_op.hpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/optimized_conv_op.hpp @@ -43,9 +43,6 @@ struct OptimizedConvBlockConfig { uint32_t out_subblock_w_ntiles; }; -operation::ProgramWithCallbacks multi_core_optimized_conv_(const Tensor& a, const Tensor &b, const Shape& ashape, std::optional bias, vector conv_params, uint32_t output_channels, bool untilize_out, bool has_bias, bool fuse_relu, const MathFidelity math_fidelity, const OptimizedConvParallelizationConfig& parallelization_config, const OptimizedConvBlockConfig& block_config, uint32_t extra_padding_for_32B_alignment, Tensor &output); -operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_(const Tensor& a, const Tensor &b, const Shape& ashape, std::optional bias, vector conv_params, uint32_t output_channels, bool untilize_out, bool has_bias, bool fuse_relu, const MathFidelity math_fidelity, const OptimizedConvParallelizationConfig& parallelization_config, const OptimizedConvBlockConfig& block_config, uint32_t extra_padding_for_32B_alignment, Tensor &output); -operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_(const Tensor& a, const Tensor &b, const Shape& ashape, std::optional bias, const std::optional conv_reader_indices, sliding_window::SlidingWindowConfig sliding_window_config, uint32_t output_channels, uint32_t groups, bool untilize_out, bool has_bias, bool fuse_relu, const OptimizedConvParallelizationConfig& parallelization_config, const OptimizedConvBlockConfig& block_config, uint32_t extra_padding_for_32B_alignment, bool use_shallow_conv_variant, bool transpose_mcast, Tensor &output, DeviceComputeKernelConfig compute_kernel_config, bool enable_act_double_buffer, bool enable_split_reader, bool enable_subblock_padding); operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_new(const Tensor& a, const Tensor &b, std::optional bias, sliding_window::SlidingWindowConfig sliding_window_config, uint32_t output_channels, @@ -62,109 +59,6 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_new(const T bool enable_split_reader, bool enable_subblock_padding); -struct OptimizedConv { - OptimizedConvParallelizationConfig parallelization_config; - OptimizedConvBlockConfig block_config; - - const std::vector conv_params; - const uint32_t output_channels; - bool untilize_out, has_bias, fuse_relu; - MathFidelity math_fidelity; - uint32_t extra_padding_for_32B_alignment; - const MemoryConfig memory_config; - const DataType dtype; - Shape input_tensor_shape; // For sharded input, input tensor shape is nonsense - bool use_shallow_conv_variant; - bool transpose_mcast; // default for GS = true, WH = false - const DeviceComputeKernelConfig compute_kernel_config; - bool enable_act_double_buffer; - bool enable_split_reader; - bool enable_subblock_padding; - OptimizedConv(const std::vector&c_params, - uint32_t output_channels, bool untile_out, - bool has_bias, bool fuse_relu, - MathFidelity mfidelity, const OptimizedConvParallelizationConfig& p_config, - const OptimizedConvBlockConfig& b_config, - uint32_t e_padding_for_32B_alignment, - MemoryConfig memory_config, DataType dtype, const Shape& input_tensor_shape, bool use_shallow_conv_variant, bool transpose_mcast, const DeviceComputeKernelConfig compute_kernel_config, bool enable_act_double_buffer, bool enable_split_reader, bool enable_subblock_padding) : - output_channels(output_channels), - conv_params(c_params), - untilize_out(untile_out), - has_bias(has_bias), - fuse_relu(fuse_relu), - math_fidelity(mfidelity), - parallelization_config(p_config), - block_config(b_config), - extra_padding_for_32B_alignment(e_padding_for_32B_alignment), - memory_config(memory_config), dtype(dtype), input_tensor_shape(input_tensor_shape), - use_shallow_conv_variant(use_shallow_conv_variant), - transpose_mcast(transpose_mcast), - compute_kernel_config(compute_kernel_config), - enable_act_double_buffer(enable_act_double_buffer), - enable_split_reader(enable_split_reader), - enable_subblock_padding(enable_subblock_padding) {} - - void validate(const std::vector& input_tensors, const std::vector>& optional_input_tensors) const; - std::vector compute_output_shapes(const std::vector& input_tensors) const; - std::vector create_output_tensors(const std::vector& input_tensors) const; - operation::ProgramWithCallbacks create_program(const std::vector& input_tensors, const std::vector>& optional_input_tensors, std::vector &output_tensors) const; - - operation::OpPerformanceModel create_op_performance_model(const std::vector& input_tensors, const std::vector>& optional_input_tensors, const std::vector &output_tensors) const; - - static constexpr auto attribute_names = std::make_tuple( - "parallelization_config", - "block_config", - "conv_params", - "output_channels", - "untilize_out", - "has_bias", - "fuse_relu", - "math_fidelity", - "extra_padding_for_32B_alignment", - "memory_config", - "dtype", - "input_tensor_shape", - "use_shallow_conv_variant", - "enable_act_double_buffer", - "enable_split_reader", - "enable_subblock_padding"); - const auto attribute_values() const { - return std::make_tuple( - std::cref(this->parallelization_config), - std::cref(this->block_config), - std::cref(this->conv_params), - std::cref(this->output_channels), - std::cref(this->untilize_out), - std::cref(this->has_bias), - std::cref(this->fuse_relu), - std::cref(this->math_fidelity), - std::cref(this->extra_padding_for_32B_alignment), - std::cref(this->memory_config), - std::cref(this->dtype), - std::cref(this->input_tensor_shape), - std::cref(this->use_shallow_conv_variant), - std::cref(this->enable_act_double_buffer), - std::cref(this->enable_split_reader), - std::cref(this->enable_subblock_padding)); - } -}; - -Tensor optimized_conv(const Tensor& a, const Tensor &b, std::optional bias, const std::optional conv_reader_indices, - const vector conv_params, uint32_t output_channels, - bool untilize_out, bool has_bias, bool fuse_relu, MathFidelity math_fidelity, - const OptimizedConvParallelizationConfig& parallelization_config, - const OptimizedConvBlockConfig& block_config, uint32_t extra_padding_for_32B_alignment=0, - std::optional memory_config = std::nullopt, - std::optional dtype=std::nullopt, - std::optional> input_tensor_shape = std::nullopt, - bool use_shallow_conv_variant = false, - bool tranpose_mcast = true, - std::optional compute_kernel_config = std::nullopt, - bool enable_act_double_buffer = false, - bool enable_split_reader = false, - bool enable_subblock_padding = false -); - // new micro op struct OptimizedConvNew { OptimizedConvParallelizationConfig parallelization_config; diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/optimized_conv_op_program_factory.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/optimized_conv_op_program_factory.cpp index 13851fe7be4..b70378026eb 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/optimized_conv_op_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/optimized_conv_op_program_factory.cpp @@ -70,234 +70,6 @@ pair, vector> compute_opt_conv_activation_as_mm_shape namespace ttnn::operations::conv { namespace conv2d { -Tensor optimized_conv(const Tensor& a, - const Tensor &b, - std::optional bias, - const std::optional conv_reader_indices, - const vector conv_params, - uint32_t output_channels, - bool untilize_out, - bool has_bias, - bool fuse_relu, - MathFidelity math_fidelity, - const OptimizedConvParallelizationConfig& parallelization_config, - const OptimizedConvBlockConfig& block_config, - uint32_t extra_padding_for_32B_alignment, - std::optional memory_config, - std::optional dtype, - std::optional> input_tensor_shape, - bool use_shallow_conv_variant, - bool transpose_mcast, - std::optional compute_kernel_config, - bool enable_act_double_buffer, - bool enable_split_reader, - bool enable_subblock_padding -) { - std::vector output_tensors = {Tensor(operation::get_workers_for_op_output({a, b}))}; - operation::launch_op( - [conv_params, output_channels, untilize_out, has_bias, fuse_relu, math_fidelity, parallelization_config, block_config, extra_padding_for_32B_alignment, memory_config, dtype, input_tensor_shape, use_shallow_conv_variant, transpose_mcast, compute_kernel_config, enable_act_double_buffer, enable_split_reader, enable_subblock_padding] - (const std::vector& input_tensors, const std::vector>& optional_input_tensors, const std::vector>& optional_output_tensors) mutable -> std::vector { - using ttnn::operations::experimental::auto_format::FormatParams; - auto& a = input_tensors.at(0); - auto& b = input_tensors.at(1); - auto& bias = optional_input_tensors.at(0); - //TT_ASSERT(!untilize_out, "Optimized conv only supports tiled out"); - TT_ASSERT(b.get_layout() == Layout::TILE); // Weights should already be formatted - const Shape ashape = Shape(input_tensor_shape.has_value() ? input_tensor_shape.value() : a.get_legacy_shape()); - auto padded_a_shape = Shape(std::vector{ashape[0], ashape[1], ashape[2], tt::round_up(ashape[3], 16)}); - FormatParams input_a_format_params = {.pad_shape=padded_a_shape.value, .pad_value=0.0, .target_layout=Layout::ROW_MAJOR}; - FormatParams input_b_format_params = {.pad_shape=b.get_legacy_shape(), .pad_value=0.0, .target_layout=Layout::TILE}; - FormatParams input_bias_format_params = {}; - if (has_bias) { - input_bias_format_params = {.pad_shape=bias.value().get_legacy_shape(), .pad_value=0, .target_layout=Layout::TILE}; - } - auto output_layout = untilize_out ? Layout::ROW_MAJOR : Layout::TILE; - if (memory_config.has_value()) { - TT_ASSERT((memory_config.value().is_sharded() || memory_config.value().memory_layout == TensorMemoryLayout::INTERLEAVED)); - } - auto arch = a.storage_type() == StorageType::DEVICE ? a.device()->arch() : ttnn::operations::experimental::auto_format::AutoFormat::GetDefaultDevice()->arch(); - bool fp32_accum = a.device()->arch() == tt::ARCH::WORMHOLE_B0; // && compute_kernel_config.has_value()) ? compute_kernel_config.value().fp32_dest_acc_en : false; - auto kernel_config_val = init_device_compute_kernel_config(arch, compute_kernel_config, MathFidelity::LoFi, true, fp32_accum, false); - return operation::run_without_autoformat( - OptimizedConv(conv_params, output_channels, untilize_out, has_bias, fuse_relu, math_fidelity, parallelization_config, block_config, extra_padding_for_32B_alignment, memory_config.value_or(a.memory_config()), dtype.value_or(a.get_dtype()), ashape, use_shallow_conv_variant, transpose_mcast, kernel_config_val, enable_act_double_buffer, enable_split_reader, enable_subblock_padding - ), - input_tensors, - optional_input_tensors); - }, {a, b}, output_tensors, {bias, conv_reader_indices}); - return output_tensors.at(0); -} - -void OptimizedConv::validate(const std::vector& input_tensors, const std::vector>& optional_input_tensors) const { - const auto& input_tensor_a = input_tensors.at(0); - const auto& input_tensor_b = input_tensors.at(1); - // TODO: ... - TT_FATAL(!input_tensor_b.memory_config().is_sharded()); - if (this->untilize_out) { - TT_FATAL((this->dtype == DataType::BFLOAT16) || (this->dtype == DataType::FLOAT32)); - } - if (this->memory_config.is_sharded()) { - uint32_t out_block_h_ntiles = parallelization_config.per_core_out_matrix_height_ntiles; - auto [act_matrix_shape, act_matrix_shape_unpadded] = optimized_conv_op_utils::compute_opt_conv_activation_as_mm_shape(input_tensor_a.get_legacy_shape(), conv_params, out_block_h_ntiles, extra_padding_for_32B_alignment); - uint32_t out_width_ntiles = this->compute_output_shapes(input_tensors).at(0)[-1] / TILE_WIDTH; - if(this->memory_config.memory_layout == TensorMemoryLayout::HEIGHT_SHARDED) { - TT_FATAL(this->parallelization_config.per_core_out_matrix_width_ntiles == out_width_ntiles); - TT_FATAL(this->block_config.out_subblock_w_ntiles == out_width_ntiles || this->block_config.out_subblock_h_ntiles == 1); - } else if (this->memory_config.memory_layout == TensorMemoryLayout::BLOCK_SHARDED) { - // For block sharded, out_width per core is shard width, and this is split along row - // TODO: We should clean this up and relax constraints on out_subblock h and w - if (transpose_mcast) { - out_width_ntiles /= this->parallelization_config.grid_size.y; - } else { - out_width_ntiles /= this->parallelization_config.grid_size.x; - } - TT_FATAL(this->block_config.out_subblock_w_ntiles == out_width_ntiles || this->block_config.out_subblock_h_ntiles == 1); - } - } -} - -std::vector OptimizedConv::compute_output_shapes(const std::vector& input_tensors) const { - const auto& input_tensor_a_shape = this->input_tensor_shape; - uint32_t batch_size = input_tensor_a_shape[0]; - uint32_t conv_activation_h = input_tensor_a_shape[1]; - uint32_t conv_activation_w = input_tensor_a_shape[2]; - // TODO: clean up here - uint32_t filter_h = (uint32_t) conv_params[0]; - uint32_t filter_w = (uint32_t) conv_params[1]; - uint32_t stride_h = (uint32_t) conv_params[2]; - uint32_t stride_w = (uint32_t) conv_params[3]; - uint32_t pad_h = (uint32_t) conv_params[4]; - uint32_t pad_w = (uint32_t) conv_params[5]; - auto [conv_output_h, conv_output_w] = optimized_conv_op_utils::compute_opt_conv_output_face_shape(conv_activation_h, conv_activation_w, filter_h, filter_w, stride_h, stride_w, pad_h, pad_w, extra_padding_for_32B_alignment); - - // Tiled output shape is padded shape. Padded to tile shape. - auto shape_w = batch_size * conv_output_h * conv_output_w; - auto shape_c = output_channels; - auto padded_shape_w = - parallelization_config.num_cores_nhw * parallelization_config.per_core_out_matrix_height_ntiles * TILE_HEIGHT; - auto padded_shape_c = tt::round_up(this->output_channels, TILE_WIDTH); - auto output_padding = Padding( - {{0, 0}, {0, 0}, {0, (padded_shape_w - shape_w)}, {0, (padded_shape_c - shape_c)}}, Padding::PadValue::Zero); - auto output_tensor_shape = Shape(tt::tt_metal::Shape({1, 1, padded_shape_w, padded_shape_c}, output_padding)); - return {output_tensor_shape.value}; -} - -std::vector OptimizedConv::create_output_tensors(const std::vector& input_tensors) const { - const auto& input_tensor = input_tensors.at(0); - const auto& weight_tensor = input_tensors.at(1); - auto output_layout = this->untilize_out ? Layout::ROW_MAJOR : Layout::TILE; - if (this->memory_config.is_sharded()) { - auto output_shape = this->compute_output_shapes(input_tensors).at(0); - if (this->memory_config.memory_layout == TensorMemoryLayout::HEIGHT_SHARDED) { - uint32_t total_height_tiles = tt::tt_metal::compute_volume(output_shape) / output_shape[-1] / TILE_HEIGHT; - uint32_t num_cores = total_height_tiles / this->parallelization_config.per_core_out_matrix_height_ntiles; - CoreRangeSet shard_grid = tt::tt_metal::num_cores_to_corerange_set(num_cores, this->parallelization_config.grid_size, true); - - std::array shard_shape = {this->parallelization_config.per_core_out_matrix_height_ntiles * TILE_HEIGHT, output_shape[-1]}; - auto shard_spec = ShardSpec{shard_grid, shard_shape, ShardOrientation::ROW_MAJOR}; - auto mem_config = this->memory_config; - mem_config.shard_spec = shard_spec; - return {create_device_tensor(output_shape, this->dtype, output_layout, input_tensor.device(), mem_config)}; - } else { - auto [act_matrix_shape, act_matrix_shape_unpadded] = optimized_conv_op_utils::compute_opt_conv_activation_as_mm_shape(this->input_tensor_shape.value, conv_params, this->parallelization_config.per_core_out_matrix_height_ntiles, extra_padding_for_32B_alignment); - uint32_t act_matrix_height = (uint32_t) act_matrix_shape[1]; - uint32_t act_matrix_height_ntiles = act_matrix_height / TILE_HEIGHT; - uint32_t total_active_num_cores_per_weight_slice = act_matrix_height_ntiles / this->parallelization_config.per_core_out_matrix_height_ntiles; - uint32_t weight_matrix_width = weight_tensor.get_legacy_shape()[-1]; - uint32_t weight_matrix_width_ntiles = weight_matrix_width / TILE_WIDTH; - uint32_t num_weight_slices_width = weight_matrix_width_ntiles / this->parallelization_config.per_core_out_matrix_width_ntiles ; - uint32_t total_active_num_cores = total_active_num_cores_per_weight_slice * num_weight_slices_width; - CoreRangeSet shard_grid = tt::tt_metal::num_cores_to_corerange_set(total_active_num_cores, this->parallelization_config.grid_size, true); - std::array shard_shape = {this->parallelization_config.per_core_out_matrix_height_ntiles * TILE_HEIGHT, this->parallelization_config.per_core_out_matrix_width_ntiles * TILE_WIDTH}; - auto shard_spec = ShardSpec{shard_grid, shard_shape, transpose_mcast ? ShardOrientation::COL_MAJOR : ShardOrientation::ROW_MAJOR}; - auto mem_config = this->memory_config; - mem_config.shard_spec = shard_spec; - return {create_device_tensor(output_shape, this->dtype, output_layout, input_tensor.device(), mem_config)}; - } - - } - return operation::generic_create_output_tensors(*this, input_tensors, this->dtype, output_layout, this->memory_config); -} - -operation::ProgramWithCallbacks OptimizedConv::create_program(const std::vector& input_tensors, - const std::vector>& optional_input_tensors, - std::vector& output_tensors) const { - const auto& input_tensor_a = input_tensors.at(0); - const auto& input_tensor_b = input_tensors.at(1); - const auto& input_tensor_bias = optional_input_tensors.at(0); - const auto& conv_reader_indices = optional_input_tensors.at(1); - auto& output_tensor = output_tensors.at(0); - // TODO: Clean up split between different conv types - if (input_tensor_a.memory_config().is_sharded()) { - // If conv_reader_indices is passed in, use v2 where we don't generate indices locally - if (conv_reader_indices.has_value()) { - uint32_t groups = 1; - const auto& input_tensor_a_shape = this->input_tensor_shape; - sliding_window::SlidingWindowConfig sliding_window_config{ - .batch_size = input_tensor_a_shape[0], - .input_hw = {input_tensor_a_shape[1], input_tensor_a_shape[2]}, - .window_hw = {conv_params[0], conv_params[1]}, - .stride_hw = {conv_params[2], conv_params[3]}, - .pad_hw = {conv_params[4], conv_params[5]}, - }; - return multi_core_optimized_conv_sharded_v2_(input_tensor_a, input_tensor_b, this->input_tensor_shape, input_tensor_bias, conv_reader_indices, sliding_window_config, output_channels, groups, untilize_out, has_bias, fuse_relu, parallelization_config, block_config, extra_padding_for_32B_alignment, this->use_shallow_conv_variant, transpose_mcast, output_tensor, this->compute_kernel_config, this->enable_act_double_buffer, this->enable_split_reader, this->enable_subblock_padding); - } else { - return multi_core_optimized_conv_sharded_(input_tensor_a, input_tensor_b, this->input_tensor_shape, input_tensor_bias, conv_params, output_channels, untilize_out, has_bias, fuse_relu, math_fidelity, parallelization_config, block_config, extra_padding_for_32B_alignment, output_tensor); - } - } else { - return multi_core_optimized_conv_(input_tensor_a, input_tensor_b, this->input_tensor_shape, input_tensor_bias, conv_params, output_channels, untilize_out, has_bias, fuse_relu, math_fidelity, parallelization_config, block_config, extra_padding_for_32B_alignment, output_tensor); - } -} - -operation::OpPerformanceModel OptimizedConv::create_op_performance_model(const std::vector& input_tensors, const std::vector>& optional_input_tensors, const std::vector &output_tensors) const { - const auto& input_tensor_a_shape = this->input_tensor_shape; - uint32_t batch_size = input_tensor_a_shape[0]; - uint32_t conv_activation_h = input_tensor_a_shape[1]; - uint32_t conv_activation_w = input_tensor_a_shape[2]; - uint32_t conv_activation_c = input_tensor_a_shape[3]; - - uint32_t filter_h = (uint32_t) conv_params[0]; - uint32_t filter_w = (uint32_t) conv_params[1]; - uint32_t stride_h = (uint32_t) conv_params[2]; - uint32_t stride_w = (uint32_t) conv_params[3]; - uint32_t pad_h = (uint32_t) conv_params[4]; - uint32_t pad_w = (uint32_t) conv_params[5]; - - const auto& t = output_tensors.at(0); - if(t.storage_type() != StorageType::DEVICE) { - tt::log_warning(tt::LogOp, "Output tensor not on DEVICE?!"); - } - - auto arch = t.storage_type() == StorageType::DEVICE ? t.device()->arch() : ttnn::operations::experimental::auto_format::AutoFormat::GetDefaultDevice()->arch(); - const int num_cores = (arch == tt::ARCH::WORMHOLE_B0) ? 8 * 8 : 9 * 12; - const int tensix_mul_adds_per_cycle_lofi = (arch == tt::ARCH::WORMHOLE_B0) ? 4096 : 2048; - - // Calculate output dimensions: relevant for window/stride based OPs (conv, maxpool, downsample) - int output_height = std::floor((conv_activation_h - filter_h + 2 * pad_h) / stride_h + 1); - int output_width = std::floor((conv_activation_w - filter_w + 2 * pad_w) / stride_w + 1); - - // Calculate number of mul/add operations - // TODO: add bias modeling - int64_t num_mul_adds_per_elem = conv_activation_c * filter_h * filter_w * 2; // 1 multiply and 1 add per element - int64_t num_mul_adds = num_mul_adds_per_elem * output_height * output_width * this->output_channels * batch_size; - - int ideal_dev_clock_cycles = std::ceil(((float)num_mul_adds / (float)(num_cores * tensix_mul_adds_per_cycle_lofi)) * (float)operation::OpPerformanceModel::fidelity_multiplier(this->math_fidelity)); - - operation::OpPerformanceModel result(input_tensors, output_tensors, ideal_dev_clock_cycles); - -#if 0 - tt::log_info(tt::LogOp, "OptimizedConv PerfModel:"); - tt::log_info(tt::LogOp, "\t Batch: {}", batch_size); - tt::log_info(tt::LogOp, "\t In (H, W, C): ({}, {}, {})", conv_activation_h, conv_activation_w, conv_activation_c); - tt::log_info(tt::LogOp, "\t Filter (H, W): ({}, {})", filter_h, filter_w); - tt::log_info(tt::LogOp, "\t Filter Stride (H, W): ({}, {})", stride_h, stride_w); - tt::log_info(tt::LogOp, "\t Pad (H, W): ({}, {})", pad_h, pad_w); - tt::log_info(tt::LogOp, "\t Out (H, W, C): ({}, {}, {})", output_height, output_width, this->output_channels); - tt::log_info(tt::LogOp, "\t ideal_dev_clock_cycles: {}", ideal_dev_clock_cycles); -#endif - - return result; -} - Tensor optimized_conv_new(const Tensor& a, const Tensor &b, std::optional bias, sliding_window::SlidingWindowConfig sliding_window_config, uint32_t output_channels,