From 2f4a4e23eed616bb5736d66a815ba2d77b9d66a1 Mon Sep 17 00:00:00 2001 From: Austin Ho Date: Wed, 22 May 2024 22:25:11 +0000 Subject: [PATCH] #8837: Remove no-op cores from convs and refactor setting RTAs on cache hit --- .../optimized_conv_op_sharded_v2.cpp | 1007 ++++++++++------- 1 file changed, 576 insertions(+), 431 deletions(-) diff --git a/tt_eager/tt_dnn/op_library/conv/multi_core_optimized_conv_sharded/optimized_conv_op_sharded_v2.cpp b/tt_eager/tt_dnn/op_library/conv/multi_core_optimized_conv_sharded/optimized_conv_op_sharded_v2.cpp index f3bba445c94..e4e2e855f50 100644 --- a/tt_eager/tt_dnn/op_library/conv/multi_core_optimized_conv_sharded/optimized_conv_op_sharded_v2.cpp +++ b/tt_eager/tt_dnn/op_library/conv/multi_core_optimized_conv_sharded/optimized_conv_op_sharded_v2.cpp @@ -2,46 +2,41 @@ // // SPDX-License-Identifier: Apache-2.0 +#include "tensor/tensor_utils.hpp" +#include "tt_dnn/op_library/auto_format.hpp" #include "tt_dnn/op_library/conv/optimized_conv_op.hpp" #include "tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp" - -#include "tt_metal/host_api.hpp" +#include "tt_dnn/op_library/sharding_utilities.hpp" +#include "tt_dnn/op_library/sliding_window_op_infra/sliding_window.hpp" +#include "tt_dnn/op_library/work_split.hpp" +#include "tt_metal/common/constants.hpp" #include "tt_metal/detail/tt_metal.hpp" #include "tt_metal/detail/util.hpp" -#include "tt_metal/common/constants.hpp" - +#include "tt_metal/host_api.hpp" #include "tt_stl/reflection.hpp" -#include "tt_dnn/op_library/work_split.hpp" -#include "tt_dnn/op_library/sharding_utilities.hpp" -#include "tt_dnn/op_library/auto_format.hpp" - -#include "tt_dnn/op_library/sliding_window_op_infra/sliding_window.hpp" -#include "tensor/tensor_utils.hpp" - using namespace tt::constants; namespace tt { namespace tt_metal { -const uint32_t act_cb = CB::c_in0; -const uint32_t weight_cb = CB::c_in1; -const uint32_t bias_cb = CB::c_in2; -const uint32_t sharded_act_cb = CB::c_in3; -const uint32_t cb_for_reader_indices = CB::c_in4; -const uint32_t cb_for_l1_array = CB::c_in5; -const uint32_t act_cb_row_major_bfloat16 = CB::c_in6; -const uint32_t act_cb_second_reader = CB::c_in7; -const uint32_t matmul_partials_cb = CB::c_intermed0; -const uint32_t tilize_mode_tilized_act_cb = CB::c_intermed1; -const uint32_t untilize_mode_reblock_cb = CB::c_intermed2; -const uint32_t out0_cb = CB::c_out0; - +const uint32_t act_cb = CB::c_in0; +const uint32_t weight_cb = CB::c_in1; +const uint32_t bias_cb = CB::c_in2; +const uint32_t sharded_act_cb = CB::c_in3; +const uint32_t cb_for_reader_indices = CB::c_in4; +const uint32_t cb_for_l1_array = CB::c_in5; +const uint32_t act_cb_row_major_bfloat16 = CB::c_in6; +const uint32_t act_cb_second_reader = CB::c_in7; +const uint32_t matmul_partials_cb = CB::c_intermed0; +const uint32_t tilize_mode_tilized_act_cb = CB::c_intermed1; +const uint32_t untilize_mode_reblock_cb = CB::c_intermed2; +const uint32_t out0_cb = CB::c_out0; // TODO: Add namespace for utilities? tuple create_CBs_for_sharded_input_v2( - tt_metal::Program &program, + tt_metal::Program& program, const Tensor& input, CoreRange core, uint32_t num_cb0_tiles, @@ -62,10 +57,9 @@ tuple create_CBs_for_sharded_input_v2( bool with_bias, bool split_reader, bool fp32_dest_acc_en, - bool packer_l1_acc_en -) { - - tt::DataFormat interm0_df = packer_l1_acc_en ? (fp32_dest_acc_en ? tt::DataFormat::Float32 : tt::DataFormat::Float16_b) : out_df; + bool packer_l1_acc_en) { + tt::DataFormat interm0_df = + packer_l1_acc_en ? (fp32_dest_acc_en ? tt::DataFormat::Float32 : tt::DataFormat::Float16_b) : out_df; uint32_t act_tile_size = tt_metal::detail::TileSize(act_df); uint32_t weight_tile_size = tt_metal::detail::TileSize(weight_df); @@ -78,9 +72,11 @@ tuple create_CBs_for_sharded_input_v2( uint32_t num_bytes_for_df = datum_size(act_df); auto shard_shape = input.shard_spec().value().shape; // 2D-sys-conv already has uint16_t indicies, TODO: do the same for 1D-sys-conv - TT_ASSERT(shard_shape[0] <= (1<<16), "Shard height must be less than 2^16, read pattern indicies are uint16_t"); - CircularBufferConfig cb_sharded_act_config = CircularBufferConfig(shard_shape[0] * shard_shape[1] * num_bytes_for_df, {{sharded_act_cb, act_df}}) - .set_page_size(sharded_act_cb, shard_shape[1] * num_bytes_for_df); + TT_ASSERT( + shard_shape[0] <= (1 << 16), "Shard height must be less than 2^16, read pattern indicies are uint16_t"); + CircularBufferConfig cb_sharded_act_config = + CircularBufferConfig(shard_shape[0] * shard_shape[1] * num_bytes_for_df, {{sharded_act_cb, act_df}}) + .set_page_size(sharded_act_cb, shard_shape[1] * num_bytes_for_df); // incoming data is the input cb instead of raw l1/dram addr cb_sharded_act_config.set_globally_allocated_address(*input.buffer()); cb_sharded_act = tt_metal::CreateCircularBuffer(program, core, cb_sharded_act_config); @@ -88,17 +84,21 @@ tuple create_CBs_for_sharded_input_v2( if (weight_width_sliced) { // For 2D convs, each core creates and tilizes full input matrix then mcasts round robin style // Each core receives input into act_cb, so won't need a separate cb to receive - // However, we need a separate cb to push ROW_MAJOR BFLOAT16 data for tilizing and configure act cb to be output df + // However, we need a separate cb to push ROW_MAJOR BFLOAT16 data for tilizing and configure act cb to be + // output df // num_cb0_tiles is double buffered - CircularBufferConfig cb_act_config = CircularBufferConfig(num_cb0_tiles * tilized_act_tile_size, {{act_cb, tilized_act_df}}) - .set_page_size(act_cb, tilized_act_tile_size); + CircularBufferConfig cb_act_config = + CircularBufferConfig(num_cb0_tiles * tilized_act_tile_size, {{act_cb, tilized_act_df}}) + .set_page_size(act_cb, tilized_act_tile_size); auto cb_act = tt_metal::CreateCircularBuffer(program, core, cb_act_config); // num_cb0_tilized_tiles is single buffered - CircularBufferConfig cb_act_row_major_bfloat16_config = CircularBufferConfig(num_cb0_tilized_tiles * act_tile_size, {{act_cb_row_major_bfloat16, act_df}}) - .set_page_size(act_cb_row_major_bfloat16, act_tile_size); - auto cb_act_row_major_bfloat16 = tt_metal::CreateCircularBuffer(program, core, cb_act_row_major_bfloat16_config); + CircularBufferConfig cb_act_row_major_bfloat16_config = + CircularBufferConfig(num_cb0_tilized_tiles * act_tile_size, {{act_cb_row_major_bfloat16, act_df}}) + .set_page_size(act_cb_row_major_bfloat16, act_tile_size); + auto cb_act_row_major_bfloat16 = + tt_metal::CreateCircularBuffer(program, core, cb_act_row_major_bfloat16_config); } else { // For 1D convs, locally create act matrix in act_cb, which is always ROW_MAJOR BFLOAT16 // Then, tilize input in compute @@ -108,69 +108,77 @@ tuple create_CBs_for_sharded_input_v2( if (split_reader) { num_cb0_tiles /= 2; - CircularBufferConfig cb_act_config = CircularBufferConfig(num_cb0_tiles * act_tile_size, {{act_cb_second_reader, act_df}}) - .set_page_size(act_cb_second_reader, act_tile_size); + CircularBufferConfig cb_act_config = + CircularBufferConfig(num_cb0_tiles * act_tile_size, {{act_cb_second_reader, act_df}}) + .set_page_size(act_cb_second_reader, act_tile_size); auto cb_act = tt_metal::CreateCircularBuffer(program, core, cb_act_config); } CircularBufferConfig cb_act_config = CircularBufferConfig(num_cb0_tiles * act_tile_size, {{act_cb, act_df}}) - .set_page_size(act_cb, act_tile_size); + .set_page_size(act_cb, act_tile_size); auto cb_act = tt_metal::CreateCircularBuffer(program, core, cb_act_config); } } else { TT_ASSERT(false, "Input must be sharded!"); } - CircularBufferConfig cb_weight_config = CircularBufferConfig(num_cb1_tiles * weight_tile_size, {{weight_cb, weight_df}}) - .set_page_size(weight_cb, weight_tile_size); + CircularBufferConfig cb_weight_config = + CircularBufferConfig(num_cb1_tiles * weight_tile_size, {{weight_cb, weight_df}}) + .set_page_size(weight_cb, weight_tile_size); auto cb_weight = tt_metal::CreateCircularBuffer(program, core, cb_weight_config); // Used for placing tilized activations - CircularBufferConfig cb_src0_tilized_config = CircularBufferConfig(num_cb0_tilized_tiles * tilized_act_tile_size, {{tilize_mode_tilized_act_cb, tilized_act_df}}) - .set_page_size(tilize_mode_tilized_act_cb, tilized_act_tile_size); + CircularBufferConfig cb_src0_tilized_config = + CircularBufferConfig( + num_cb0_tilized_tiles * tilized_act_tile_size, {{tilize_mode_tilized_act_cb, tilized_act_df}}) + .set_page_size(tilize_mode_tilized_act_cb, tilized_act_tile_size); auto cb_src0_tilized = tt_metal::CreateCircularBuffer(program, core, cb_src0_tilized_config); CBHandle cb_output = 0; if (untilize_out) { - CircularBufferConfig cb_matmul_partials_config = CircularBufferConfig(num_output_tiles * interm0_single_tile_size, {{matmul_partials_cb, interm0_df}}) - .set_page_size(matmul_partials_cb, interm0_single_tile_size); + CircularBufferConfig cb_matmul_partials_config = + CircularBufferConfig(num_output_tiles * interm0_single_tile_size, {{matmul_partials_cb, interm0_df}}) + .set_page_size(matmul_partials_cb, interm0_single_tile_size); auto cb_matmul_partials = tt_metal::CreateCircularBuffer(program, core, cb_matmul_partials_config); // Supposed to be a small CB only responsible for reorganizing // the output blocks to fill the whole "per core output block width" - CircularBufferConfig cb_reblock_config = CircularBufferConfig(num_reblock_cb_tiles * out_tile_size, {{untilize_mode_reblock_cb, out_df}}) - .set_page_size(untilize_mode_reblock_cb, out_tile_size); + CircularBufferConfig cb_reblock_config = + CircularBufferConfig(num_reblock_cb_tiles * out_tile_size, {{untilize_mode_reblock_cb, out_df}}) + .set_page_size(untilize_mode_reblock_cb, out_tile_size); auto cb_reblock = tt_metal::CreateCircularBuffer(program, core, cb_reblock_config); - CircularBufferConfig cb_output_config = CircularBufferConfig(num_writer_output_tiles * out_tile_size, {{out0_cb, out_df}}) - .set_page_size(out0_cb, out_tile_size); + CircularBufferConfig cb_output_config = + CircularBufferConfig(num_writer_output_tiles * out_tile_size, {{out0_cb, out_df}}) + .set_page_size(out0_cb, out_tile_size); if (output.is_sharded()) { cb_output_config = cb_output_config.set_globally_allocated_address(*output.buffer()); } cb_output = tt_metal::CreateCircularBuffer(program, core, cb_output_config); } else { - //Share buffer if same data format - if(interm0_df == out_df) { + // Share buffer if same data format + if (interm0_df == out_df) { CoreRangeSet cores(std::set({core})); std::map cb_output_data_format_spec = { - {out0_cb, out_df}, - {matmul_partials_cb, out_df} - }; - CircularBufferConfig cb_matmul_partials_config = CircularBufferConfig(num_output_tiles * out_tile_size, cb_output_data_format_spec) - .set_page_size(out0_cb, out_tile_size) - .set_page_size(matmul_partials_cb, out_tile_size); + {out0_cb, out_df}, {matmul_partials_cb, out_df}}; + CircularBufferConfig cb_matmul_partials_config = + CircularBufferConfig(num_output_tiles * out_tile_size, cb_output_data_format_spec) + .set_page_size(out0_cb, out_tile_size) + .set_page_size(matmul_partials_cb, out_tile_size); if (output.is_sharded()) { cb_matmul_partials_config = cb_matmul_partials_config.set_globally_allocated_address(*output.buffer()); } cb_output = tt_metal::CreateCircularBuffer(program, cores, cb_matmul_partials_config); } else { - //Separate buffer if not same data format - CircularBufferConfig cb_matmul_partials_config = CircularBufferConfig(num_output_tiles * interm0_single_tile_size, {{matmul_partials_cb, interm0_df}}) - .set_page_size(matmul_partials_cb, interm0_single_tile_size); + // Separate buffer if not same data format + CircularBufferConfig cb_matmul_partials_config = + CircularBufferConfig(num_output_tiles * interm0_single_tile_size, {{matmul_partials_cb, interm0_df}}) + .set_page_size(matmul_partials_cb, interm0_single_tile_size); auto cb_matmul_partials = tt_metal::CreateCircularBuffer(program, core, cb_matmul_partials_config); - CircularBufferConfig cb_output_config = CircularBufferConfig(num_output_tiles * out_tile_size, {{out0_cb, out_df}}) - .set_page_size(out0_cb, out_tile_size); + CircularBufferConfig cb_output_config = + CircularBufferConfig(num_output_tiles * out_tile_size, {{out0_cb, out_df}}) + .set_page_size(out0_cb, out_tile_size); if (output.is_sharded()) { cb_output_config = cb_output_config.set_globally_allocated_address(*output.buffer()); } @@ -183,7 +191,7 @@ tuple create_CBs_for_sharded_input_v2( // bias input uint32_t bias_pagesize = bias_tile_size; CircularBufferConfig cb_bias_config = CircularBufferConfig(bias_ntiles * bias_pagesize, {{bias_cb, bias_df}}) - .set_page_size(bias_cb, bias_pagesize); + .set_page_size(bias_cb, bias_pagesize); auto cb_bias = tt_metal::CreateCircularBuffer(program, core, cb_bias_config); log_debug(LogOp, "Bias CB: {}, npages: {}, pagesize: {}", bias_cb, bias_ntiles, bias_pagesize); @@ -192,9 +200,27 @@ tuple create_CBs_for_sharded_input_v2( return {cb_sharded_act, cb_output}; } -operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl(tt_metal::Program& program, const Tensor& a, const Tensor &b, const Shape& ashape, std::optional bias, const std::optional conv_reader_indices, vector conv_params, uint32_t output_channels, bool untilize_out, bool has_bias, bool fuse_relu, const OptimizedConvParallelizationConfig& parallelization_config, const OptimizedConvBlockConfig& block_config, uint32_t extra_padding_for_32B_alignment, bool use_shallow_conv_variant, bool transpose_mcast, Tensor &output, DeviceComputeKernelConfig compute_kernel_config) { +operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl( + tt_metal::Program& program, + const Tensor& a, + const Tensor& b, + const Shape& ashape, + std::optional bias, + const std::optional conv_reader_indices, + vector conv_params, + uint32_t output_channels, + bool untilize_out, + bool has_bias, + bool fuse_relu, + const OptimizedConvParallelizationConfig& parallelization_config, + const OptimizedConvBlockConfig& block_config, + uint32_t extra_padding_for_32B_alignment, + bool use_shallow_conv_variant, + bool transpose_mcast, + Tensor& output, + DeviceComputeKernelConfig compute_kernel_config) { bool pass = true; - tt_metal::Device *device = a.device(); + tt_metal::Device* device = a.device(); TT_ASSERT(a.get_layout() == Layout::ROW_MAJOR, "Conv activation should be in row major layout"); TT_ASSERT(a.memory_config().is_sharded(), "Conv activation must be sharded."); TT_ASSERT(output_channels <= b.get_legacy_shape()[3], "Invalid weight shape. Incorrect weight tensor."); @@ -208,7 +234,8 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl(tt_met DataFormat act_df = tt_metal::datatype_to_dataformat_converter(a.get_dtype()); DataFormat weight_df = tt_metal::datatype_to_dataformat_converter(b.get_dtype()); DataFormat out_df = tt_metal::datatype_to_dataformat_converter(output.get_dtype()); - DataFormat bias_df = has_bias ? tt_metal::datatype_to_dataformat_converter(bias.value().get_dtype()) : DataFormat::Float16_b; + DataFormat bias_df = + has_bias ? tt_metal::datatype_to_dataformat_converter(bias.value().get_dtype()) : DataFormat::Float16_b; DataFormat tilized_act_df = out_df; log_debug(LogOp, "act_df: {}", act_df); @@ -222,25 +249,26 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl(tt_met bool fp32_dest_acc_en; bool packer_l1_acc; - std::visit([&](auto&& compute_kernel_config) { - using T = std::decay_t; - if constexpr (std::is_same_v) { - TT_ASSERT(device->arch() == ARCH::GRAYSKULL, "kernel config is not for graykull"); - math_fidelity = compute_kernel_config.math_fidelity; - math_approx_mode = compute_kernel_config.math_approx_mode; - fp32_dest_acc_en = false; - packer_l1_acc = false; - } else if constexpr (std::is_same_v) { - TT_ASSERT(device->arch() == ARCH::WORMHOLE_B0, "kernel config is not for wormhole_b0"); - math_fidelity = compute_kernel_config.math_fidelity; - math_approx_mode = compute_kernel_config.math_approx_mode; - fp32_dest_acc_en = compute_kernel_config.fp32_dest_acc_en; - packer_l1_acc = compute_kernel_config.packer_l1_acc; - } else { - TT_FATAL("arch not supported"); - } - - }, compute_kernel_config); + std::visit( + [&](auto&& compute_kernel_config) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + TT_ASSERT(device->arch() == ARCH::GRAYSKULL, "kernel config is not for graykull"); + math_fidelity = compute_kernel_config.math_fidelity; + math_approx_mode = compute_kernel_config.math_approx_mode; + fp32_dest_acc_en = false; + packer_l1_acc = false; + } else if constexpr (std::is_same_v) { + TT_ASSERT(device->arch() == ARCH::WORMHOLE_B0, "kernel config is not for wormhole_b0"); + math_fidelity = compute_kernel_config.math_fidelity; + math_approx_mode = compute_kernel_config.math_approx_mode; + fp32_dest_acc_en = compute_kernel_config.fp32_dest_acc_en; + packer_l1_acc = compute_kernel_config.packer_l1_acc; + } else { + TT_FATAL("arch not supported"); + } + }, + compute_kernel_config); if (fp32_dest_acc_en and (out_subblock_h_ntiles * out_subblock_w_ntiles > 4)) { if (out_subblock_w_ntiles >= 4) { @@ -248,13 +276,16 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl(tt_met out_subblock_w_ntiles = find_max_block_size(out_subblock_w_ntiles, 4); } else { while (out_subblock_h_ntiles * out_subblock_w_ntiles > 4) { - uint32_t div = find_max_divisor(out_subblock_h_ntiles, out_subblock_h_ntiles-1); + uint32_t div = find_max_divisor(out_subblock_h_ntiles, out_subblock_h_ntiles - 1); out_subblock_h_ntiles = find_max_block_size(out_subblock_h_ntiles, div); } } } - //assert(out_block_h_ntiles == act_block_h_ntiles); // TODO: fix output block sizing - TT_ASSERT(out_block_h_ntiles >= act_block_h_ntiles, "Output block height (in # of tiles) should be greater than or equal to activation block height (in # of tiles)"); + // assert(out_block_h_ntiles == act_block_h_ntiles); // TODO: fix output block sizing + TT_ASSERT( + out_block_h_ntiles >= act_block_h_ntiles, + "Output block height (in # of tiles) should be greater than or equal to activation block height (in # of " + "tiles)"); // Tensor b has weights and it should be tiled layout after converting conv weights into weight matrix TT_ASSERT(b.get_layout() == Layout::TILE, "Conv weights should be in tiled layout"); @@ -299,33 +330,42 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl(tt_met } TT_FATAL(input_channels_padded >= ashape[3], "Incorrect padding of input channels!"); // check is for 16-byte alignment - TT_FATAL(input_channels_padded % 16 == 0, "Expected input channels to be padded for 16 byte alignment in L1"); // TODO: For bfp16, check if its divisible by 8 not 16. + TT_FATAL( + input_channels_padded % 16 == 0, + "Expected input channels to be padded for 16 byte alignment in L1"); // TODO: For bfp16, check if its divisible + // by 8 not 16. // Always use split reader for first conv in resnet which has input channels = 16 // TODO: Expose option to split readers for 1D convs to python? bool split_reader = use_shallow_conv_variant; if (split_reader) { - TT_FATAL(block_config.act_block_h_ntiles % block_config.out_subblock_h_ntiles == 0, "Out_block_h must be divisible by out_subblock_h!"); - TT_FATAL((block_config.act_block_h_ntiles / block_config.out_subblock_h_ntiles) % 2 == 0, "Number of out_subblock_h must be divisible by 2 for split reader!"); + TT_FATAL( + block_config.act_block_h_ntiles % block_config.out_subblock_h_ntiles == 0, + "Out_block_h must be divisible by out_subblock_h!"); + TT_FATAL( + (block_config.act_block_h_ntiles / block_config.out_subblock_h_ntiles) % 2 == 0, + "Number of out_subblock_h must be divisible by 2 for split reader!"); } Shape ashape_with_channels_padded = {ashape[0], ashape[1], ashape[2], input_channels_padded}; uint32_t conv_act_size_h = ashape_with_channels_padded[1]; uint32_t conv_act_size_w = ashape_with_channels_padded[2]; uint32_t conv_act_size_c = ashape_with_channels_padded[3]; - uint32_t weight_size_h = (uint32_t) conv_params[0]; // filter_h - uint32_t weight_size_w = (uint32_t) conv_params[1]; // filter_W - uint32_t stride_h = (uint32_t) conv_params[2]; - uint32_t stride_w = (uint32_t) conv_params[3]; - uint32_t pad_h = (uint32_t) conv_params[4]; - uint32_t pad_w = (uint32_t) conv_params[5]; + uint32_t weight_size_h = (uint32_t)conv_params[0]; // filter_h + uint32_t weight_size_w = (uint32_t)conv_params[1]; // filter_W + uint32_t stride_h = (uint32_t)conv_params[2]; + uint32_t stride_w = (uint32_t)conv_params[3]; + uint32_t pad_h = (uint32_t)conv_params[4]; + uint32_t pad_w = (uint32_t)conv_params[5]; // Compute the 2d matrix shape - auto [act_matrix_shape, act_matrix_shape_unpadded] = optimized_conv_op_utils::compute_opt_conv_activation_as_mm_shape(ashape_with_channels_padded, conv_params, out_block_h_ntiles, extra_padding_for_32B_alignment); + auto [act_matrix_shape, act_matrix_shape_unpadded] = + optimized_conv_op_utils::compute_opt_conv_activation_as_mm_shape( + ashape_with_channels_padded, conv_params, out_block_h_ntiles, extra_padding_for_32B_alignment); assert(act_matrix_shape.size() == 3); assert(act_matrix_shape[0] == 1); - uint32_t act_matrix_height = (uint32_t) act_matrix_shape[1]; - uint32_t act_matrix_width = (uint32_t) act_matrix_shape[2]; - uint32_t act_matrix_height_unpadded = (uint32_t) act_matrix_shape_unpadded[1]; - uint32_t act_matrix_width_unpadded = (uint32_t) act_matrix_shape_unpadded[2]; + uint32_t act_matrix_height = (uint32_t)act_matrix_shape[1]; + uint32_t act_matrix_width = (uint32_t)act_matrix_shape[2]; + uint32_t act_matrix_height_unpadded = (uint32_t)act_matrix_shape_unpadded[1]; + uint32_t act_matrix_width_unpadded = (uint32_t)act_matrix_shape_unpadded[2]; // TODO: Move all these asserts/checks to validate? @@ -347,11 +387,12 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl(tt_met TT_ASSERT(weight_matrix_width % TILE_WIDTH == 0, "Width of weight matrix needs to be divisible by 32"); // Device compatibility checks - TT_ASSERT(a.storage_type() == StorageType::DEVICE && - b.storage_type() == StorageType::DEVICE && - "Operands to large matmul need to be on device!"); + TT_ASSERT( + a.storage_type() == StorageType::DEVICE && b.storage_type() == StorageType::DEVICE && + "Operands to large matmul need to be on device!"); TT_ASSERT(a.device() == b.device(), "Operands to conv need to be on the same device!"); - TT_ASSERT(a.buffer() != nullptr && b.buffer() != nullptr, "Operands to conv need to be allocated in buffers on device!"); + TT_ASSERT( + a.buffer() != nullptr && b.buffer() != nullptr, "Operands to conv need to be allocated in buffers on device!"); if (has_bias) { TT_ASSERT(bias.value().storage_type() == StorageType::DEVICE, "Bias should be on device"); TT_ASSERT(bias.value().device() == a.device(), "Bias should be on the same device as act tensor"); @@ -374,7 +415,9 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl(tt_met // act block info uint32_t act_block_w_datums = act_matrix_width / num_blocks_act_w; uint32_t act_block_h_datums = act_matrix_height / num_blocks_act_h; - TT_ASSERT((act_block_w_datums == round_up(conv_act_size_c * weight_size_w, TILE_WIDTH)) || ((act_block_w_datums <= conv_act_size_c) && (conv_act_size_c % act_block_w_datums == 0))); + TT_ASSERT( + (act_block_w_datums == round_up(conv_act_size_c * weight_size_w, TILE_WIDTH)) || + ((act_block_w_datums <= conv_act_size_c) && (conv_act_size_c % act_block_w_datums == 0))); // weight block info uint32_t weight_block_w_datums = weight_matrix_width / num_blocks_weight_w; @@ -389,8 +432,11 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl(tt_met uint32_t output_channels_padded_to_tile_width = round_up(output_channels, TILE_WIDTH); assert(output_channels_padded_to_tile_width <= weight_matrix_width); uint32_t output_width_num_tiles = output_channels_padded_to_tile_width / TILE_WIDTH; - uint32_t num_blocks_output_w = (uint32_t) ceil((double) output_channels_padded_to_tile_width / (double) weight_block_w_datums); - uint32_t last_block_width_datums = (output_channels_padded_to_tile_width % weight_block_w_datums == 0) ? weight_block_w_datums : (output_channels_padded_to_tile_width % weight_block_w_datums); + uint32_t num_blocks_output_w = + (uint32_t)ceil((double)output_channels_padded_to_tile_width / (double)weight_block_w_datums); + uint32_t last_block_width_datums = (output_channels_padded_to_tile_width % weight_block_w_datums == 0) + ? weight_block_w_datums + : (output_channels_padded_to_tile_width % weight_block_w_datums); assert(last_block_width_datums % TILE_WIDTH == 0); // sanity check @@ -398,10 +444,10 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl(tt_met uint32_t out_block_h_datums = out_block_h_ntiles * TILE_HEIGHT; - tt_metal::Buffer *src0_dram_buffer = a.buffer(); - tt_metal::Buffer *src1_dram_buffer = b.buffer(); + tt_metal::Buffer* src0_dram_buffer = a.buffer(); + tt_metal::Buffer* src1_dram_buffer = b.buffer(); - tt_metal::Buffer *dst_dram_buffer = output.buffer(); + tt_metal::Buffer* dst_dram_buffer = output.buffer(); TT_ASSERT(dst_dram_buffer != nullptr, "Output buffer should be allocated on device!"); // out @@ -427,16 +473,26 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl(tt_met uint32_t weight_dram_addr = src1_dram_buffer->address(); // bias - Buffer *bias_buffer = nullptr; + Buffer* bias_buffer = nullptr; uint32_t bias_dram_addr = 0; uint32_t bias_ntiles = 0; if (has_bias) { bias_buffer = bias.value().buffer(); bias_dram_addr = bias_buffer->address(); - bias_ntiles = bias.value().get_legacy_shape()[3] / constants::TILE_WIDTH; // TODO: support non tile multiple sizes + bias_ntiles = + bias.value().get_legacy_shape()[3] / constants::TILE_WIDTH; // TODO: support non tile multiple sizes } - auto [conv_output_size_h, conv_output_size_w] = optimized_conv_op_utils::compute_opt_conv_output_face_shape(conv_act_size_h, conv_act_size_w, weight_size_h, weight_size_w, stride_h, stride_w, pad_h, pad_w, extra_padding_for_32B_alignment); + auto [conv_output_size_h, conv_output_size_w] = optimized_conv_op_utils::compute_opt_conv_output_face_shape( + conv_act_size_h, + conv_act_size_w, + weight_size_h, + weight_size_w, + stride_h, + stride_w, + pad_h, + pad_w, + extra_padding_for_32B_alignment); std::map reader_defines; @@ -454,8 +510,10 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl(tt_met uint32_t src_dram_act_buffer_size_bytes = src0_dram_buffer->size(); uint32_t src_dram_weight_buffer_size_bytes = src1_dram_buffer->size(); - uint32_t dst_l1_act_buffer_size_bytes = out_block_h_ntiles * act_block_w_ntiles * tt::tt_metal::detail::TileSize(act_df); - uint32_t dst_l1_weight_buffer_size_bytes = weight_block_h_ntiles * weight_block_w_ntiles * tt::tt_metal::detail::TileSize(weight_df); + uint32_t dst_l1_act_buffer_size_bytes = + out_block_h_ntiles * act_block_w_ntiles * tt::tt_metal::detail::TileSize(act_df); + uint32_t dst_l1_weight_buffer_size_bytes = + weight_block_h_ntiles * weight_block_w_ntiles * tt::tt_metal::detail::TileSize(weight_df); // For debug { @@ -508,11 +566,11 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl(tt_met uint32_t window_outer; uint32_t window_inner; if (weight_width_sliced) { - window_outer = 1; // window_outer = 1 becasue all of filter window is processed in the inner loop - window_inner = 3; // window_inner = 9 / 3, ie. read 3 width coalesced + window_outer = 1; // window_outer = 1 becasue all of filter window is processed in the inner loop + window_inner = 3; // window_inner = 9 / 3, ie. read 3 width coalesced } else { - window_outer = num_blocks_act_w; // window_outer - window_inner = weight_size_h * weight_size_w / num_blocks_act_w; // window_inner + window_outer = num_blocks_act_w; // window_outer + window_inner = weight_size_h * weight_size_w / num_blocks_act_w; // window_inner } reader_defines["WINDOW_INNER"] = std::to_string(window_inner); log_debug(LogOp, "window_outer: {}, window_inner: {}", window_outer, window_inner); @@ -525,7 +583,7 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl(tt_met } uint32_t num_weight_slices_width = weight_matrix_width_ntiles / per_core_out_matrix_width_ntiles; uint32_t total_num_cores_per_weight_slice = 0; - uint32_t total_num_cores_per_act_slice = 0; // only used when (BLOCK_SHARDING && !transpose_mcast) + uint32_t total_num_cores_per_act_slice = 0; // only used when (BLOCK_SHARDING && !transpose_mcast) if (weight_width_sliced) { if (transpose_mcast) { assert(num_cores_y % num_weight_slices_width == 0); @@ -581,33 +639,37 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl(tt_met assert(total_active_num_cores >= num_cores_x); uint32_t num_active_cores_x = num_cores_x; uint32_t num_active_cores_y_with_full_x = total_active_num_cores / num_cores_x; - uint32_t num_active_cores_x_last_y = total_active_num_cores % num_cores_x; + uint32_t num_active_cores_x_last_y = total_active_num_cores % num_cores_x; assert((num_active_cores_x * num_active_cores_y_with_full_x) + num_active_cores_x_last_y == total_active_num_cores); std::set all_active_cores_set; - all_active_cores_set.insert(CoreRange(CoreCoord(0, 0), CoreCoord(num_active_cores_x - 1, num_active_cores_y_with_full_x - 1))); + all_active_cores_set.insert( + CoreRange(CoreCoord(0, 0), CoreCoord(num_active_cores_x - 1, num_active_cores_y_with_full_x - 1))); if (num_active_cores_x_last_y > 0) { - all_active_cores_set.insert(CoreRange(CoreCoord(0, num_active_cores_y_with_full_x), CoreCoord(num_active_cores_x_last_y - 1, num_active_cores_y_with_full_x))); + all_active_cores_set.insert(CoreRange( + CoreCoord(0, num_active_cores_y_with_full_x), + CoreCoord(num_active_cores_x_last_y - 1, num_active_cores_y_with_full_x))); } CoreRangeSet all_active_cores(all_active_cores_set); std::set noop_cores_set; if (total_noop_cores > 0) { assert(total_noop_cores == (num_cores_x - num_active_cores_x_last_y)); - noop_cores_set.insert(CoreRange(CoreCoord(num_active_cores_x_last_y, num_active_cores_y_with_full_x), CoreCoord(num_cores_x - 1, num_active_cores_y_with_full_x))); - + noop_cores_set.insert(CoreRange( + CoreCoord(num_active_cores_x_last_y, num_active_cores_y_with_full_x), + CoreCoord(num_cores_x - 1, num_active_cores_y_with_full_x))); } CoreRangeSet noop_cores(noop_cores_set); // Mcast cores // If total_num_cores, there is no mcasting - CoreCoord top_left_core = {(std::size_t) 0, (std::size_t) 0}; - CoreCoord top_left_core_plus_one = {(std::size_t) 1, (std::size_t) 1}; - CoreCoord bottom_right_core = {(std::size_t) num_cores_x - 1, (std::size_t) num_cores_y - 1}; + CoreCoord top_left_core = {(std::size_t)0, (std::size_t)0}; + CoreCoord top_left_core_plus_one = {(std::size_t)1, (std::size_t)1}; + CoreCoord bottom_right_core = {(std::size_t)num_cores_x - 1, (std::size_t)num_cores_y - 1}; auto top_left_core_physical = device->worker_core_from_logical_core(top_left_core); auto top_left_core_plus_one_physical = device->worker_core_from_logical_core(top_left_core_plus_one); auto bottom_right_core_physical = device->worker_core_from_logical_core(bottom_right_core); - CoreRange mcast_sender_cores(top_left_core, top_left_core); // If single core, this kernel doesn't do mcasting + CoreRange mcast_sender_cores(top_left_core, top_left_core); // If single core, this kernel doesn't do mcasting CoreRangeSet mcast_receiver_cores{{}}; uint32_t weights_mcast_sender_semaphore; uint32_t weights_mcast_receiver_semaphore; @@ -630,9 +692,16 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl(tt_met if (total_num_cores > 1) { std::set mcast_receiver_set; if (num_cores_x > 1) { - mcast_receiver_set.insert(CoreRange(CoreCoord(1, 0), CoreCoord(num_cores_x - 1, 0))); - } if (num_cores_y > 1) { - mcast_receiver_set.insert(CoreRange(CoreCoord(0, 1), bottom_right_core)); + mcast_receiver_set.insert(CoreRange(CoreCoord(1, 0), CoreCoord(num_active_cores_x - 1, 0))); + } + if (num_cores_y > 1) { + mcast_receiver_set.insert( + CoreRange(CoreCoord(0, 1), CoreCoord(num_active_cores_x - 1, num_active_cores_y_with_full_x - 1))); + if (num_active_cores_x_last_y > 0) { + mcast_receiver_set.insert(CoreRange( + CoreCoord(0, num_active_cores_y_with_full_x), + CoreCoord(num_active_cores_x_last_y - 1, num_active_cores_y_with_full_x))); + } } mcast_receiver_cores = mcast_receiver_set; weights_mcast_sender_semaphore = tt_metal::CreateSemaphore(program, all_cores, INVALID); @@ -645,8 +714,10 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl(tt_met bool fully_buffer_weights = false; uint32_t num_act_cb_tiles = act_block_h_ntiles * act_block_w_ntiles / conv_act_c_blocks; // TODO: This flag should be set in kernel logic but need this for create_CB - if (a.memory_config().is_sharded() and weight_size_h == 3 and weight_size_w == 3 and (stride_h == 1 or stride_h == 2) and weight_width_sliced) { - // If conv_act_c_blocks > 1 and we have 2D conv with sharded input, we always read entire 3x3 window before pushing in reader/writer + if (a.memory_config().is_sharded() and weight_size_h == 3 and weight_size_w == 3 and + (stride_h == 1 or stride_h == 2) and weight_width_sliced) { + // If conv_act_c_blocks > 1 and we have 2D conv with sharded input, we always read entire 3x3 window before + // pushing in reader/writer // TODO: Generalize this to not make this assumption read_3x3_window_in_inner_loop = true; num_weight_cb_tiles *= weight_size_h * weight_size_w; @@ -658,12 +729,14 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl(tt_met if (fully_buffer_weights) { num_weight_cb_tiles *= window_outer; - } else if (per_core_out_matrix_width_ntiles < 5 && per_core_out_matrix_height_ntiles < 22) { // Q: where are these numbers from? + } else if (per_core_out_matrix_width_ntiles < 5 && per_core_out_matrix_height_ntiles < 22) { // Q: where are these + // numbers from? num_weight_cb_tiles = num_weight_cb_tiles * 2; } - if (conv_act_size_c / conv_act_c_blocks < 160 && per_core_out_matrix_height_ntiles < 22) { // Q: where are these numbers from? - num_act_cb_tiles = num_act_cb_tiles * 2; // double buffered + if (conv_act_size_c / conv_act_c_blocks < 160 && + per_core_out_matrix_height_ntiles < 22) { // Q: where are these numbers from? + num_act_cb_tiles = num_act_cb_tiles * 2; // double buffered } uint32_t writer_output_block_num_tiles = out_block_h_ntiles * weight_block_w_ntiles; @@ -674,49 +747,50 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl(tt_met std::vector writer_compile_time_args; uint32_t conv_act_c_read_bytes = conv_act_size_c * a.element_size() / conv_act_c_blocks; - uint32_t act_block_w_extra_align_bytes = (round_up(conv_act_size_c * weight_size_w, TILE_WIDTH) - (conv_act_size_c * weight_size_w)) * a.element_size(); + uint32_t act_block_w_extra_align_bytes = + (round_up(conv_act_size_c * weight_size_w, TILE_WIDTH) - (conv_act_size_c * weight_size_w)) * a.element_size(); uint32_t in0_block_w = act_block_w_ntiles / conv_act_c_blocks; uint32_t in0_block_num_tiles = act_block_num_tiles / conv_act_c_blocks; uint32_t in0_subblock_num_tiles = act_subblock_num_tiles / conv_act_c_blocks; uint32_t in1_block_num_tiles = weight_block_num_tiles / conv_act_c_blocks; - uint32_t in0_num_blocks_w = num_blocks_act_w * conv_act_c_blocks; // Fold outer c_block loop together with weight_block_num_tiles = 9 + uint32_t in0_num_blocks_w = + num_blocks_act_w * conv_act_c_blocks; // Fold outer c_block loop together with weight_block_num_tiles = 9 uint32_t tilized_act_tile_size = tt_metal::detail::TileSize(tilized_act_df); - //Only enable packer l1 accumulation when there are in0_num_blocks_w > 2, otherwise - //unnecessary overhead for reconfigs are added. Last iteration of l1 accumulation - //does a spill and reload, so need more than 2 blocks to use l1 acc for packer - //For bias, last iteration of l1 acc remains in intermediate buffer, does not spill and reload + // Only enable packer l1 accumulation when there are in0_num_blocks_w > 2, otherwise + // unnecessary overhead for reconfigs are added. Last iteration of l1 accumulation + // does a spill and reload, so need more than 2 blocks to use l1 acc for packer + // For bias, last iteration of l1 acc remains in intermediate buffer, does not spill and reload bool packer_l1_acc_en = packer_l1_acc && ((has_bias && in0_num_blocks_w > 1) || (in0_num_blocks_w > 2)); // TODO: Moving this function call to after kernel logic causes pcc fails // There are additional CBs and semaphores created in 2D conv in kernel logic, // so does order of create_cb calls matter? auto [cb_sharded_act, cb_output] = create_CBs_for_sharded_input_v2( - program, - a, - all_cores, - num_act_cb_tiles, // row major act cb - num_weight_cb_tiles, // tiled weight cb - num_cb0_tilized_tiles, // tiled act cb - writer_output_block_num_tiles, // math output cb - weight_block_w_ntiles, // reblock cb - writer_output_block_num_tiles, // writer output cb, double bufferred - untilize_out, - act_df, - weight_df, - tilized_act_df, - out_df, - bias_df, - weight_width_sliced, - output, - bias_ntiles_per_core, - has_bias, - split_reader, - fp32_dest_acc_en, - packer_l1_acc_en - ); + program, + a, + all_cores, + num_act_cb_tiles, // row major act cb + num_weight_cb_tiles, // tiled weight cb + num_cb0_tilized_tiles, // tiled act cb + writer_output_block_num_tiles, // math output cb + weight_block_w_ntiles, // reblock cb + writer_output_block_num_tiles, // writer output cb, double bufferred + untilize_out, + act_df, + weight_df, + tilized_act_df, + out_df, + bias_df, + weight_width_sliced, + output, + bias_ntiles_per_core, + has_bias, + split_reader, + fp32_dest_acc_en, + packer_l1_acc_en); string reader_kernel; string compute_kernel; @@ -730,22 +804,28 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl(tt_met if (weight_width_sliced) { // 2D conv assert(read_3x3_window_in_inner_loop == true); - reader_kernel = "tt_eager/tt_dnn/op_library/conv/kernels/reader_conv_activations_2d_mcast_padded_with_halo_3x3_weights_v2.cpp"; - writer_mcast_sender_kernel = "tt_eager/tt_dnn/op_library/conv/kernels/writer_tiled_out_2d_mcast_sender_conv_weights_tiled_col_to_rm_blocks.cpp"; - writer_mcast_receiver_kernel = "tt_eager/tt_dnn/op_library/conv/kernels/writer_tiled_out_2d_mcast_receiver_conv_weights_tiled_col_to_rm_blocks.cpp"; + reader_kernel = + "tt_eager/tt_dnn/op_library/conv/kernels/" + "reader_conv_activations_2d_mcast_padded_with_halo_3x3_weights_v2.cpp"; + writer_mcast_sender_kernel = + "tt_eager/tt_dnn/op_library/conv/kernels/" + "writer_tiled_out_2d_mcast_sender_conv_weights_tiled_col_to_rm_blocks.cpp"; + writer_mcast_receiver_kernel = + "tt_eager/tt_dnn/op_library/conv/kernels/" + "writer_tiled_out_2d_mcast_receiver_conv_weights_tiled_col_to_rm_blocks.cpp"; act_mcast_sender_semaphore = tt_metal::CreateSemaphore(program, all_cores, INVALID); act_mcast_receiver_semaphore = tt_metal::CreateSemaphore(program, all_cores, INVALID); if (transpose_mcast) { act_mcast_noc_y.reserve(num_cores_y); - for(uint32_t core_idx_y = 0; core_idx_y < num_cores_y; ++core_idx_y) { + for (uint32_t core_idx_y = 0; core_idx_y < num_cores_y; ++core_idx_y) { act_mcast_noc_y.push_back(device->worker_core_from_logical_core({0, core_idx_y}).y); } } else { // NOTE: using same var for x as well, this is intentional act_mcast_noc_y.reserve(num_cores_x); - for(int32_t core_idx_x = 0; core_idx_x < num_cores_x; ++core_idx_x) { - act_mcast_noc_y.push_back(device->worker_core_from_logical_core({(uint32_t) core_idx_x, 0}).x); + for (int32_t core_idx_x = 0; core_idx_x < num_cores_x; ++core_idx_x) { + act_mcast_noc_y.push_back(device->worker_core_from_logical_core({(uint32_t)core_idx_x, 0}).x); } } @@ -755,26 +835,38 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl(tt_met // 1D conv TT_ASSERT(act_block_w_datums == round_up(conv_act_size_c * weight_size_w, TILE_WIDTH)); - reader_kernel = "tt_eager/tt_dnn/op_library/conv/kernels/reader_conv_activations_padded_with_halo_3x3_weights_v2.cpp"; + reader_kernel = + "tt_eager/tt_dnn/op_library/conv/kernels/reader_conv_activations_padded_with_halo_3x3_weights_v2.cpp"; if (split_reader) { - writer_mcast_sender_kernel = "tt_eager/tt_dnn/op_library/conv/kernels/reader_writer_tiled_out_1d_mcast_sender_conv_weights_tiled_col_to_rm_blocks.cpp"; - writer_mcast_receiver_kernel = "tt_eager/tt_dnn/op_library/conv/kernels/reader_writer_tiled_out_1d_mcast_receiver_conv_weights_tiled_col_to_rm_blocks.cpp"; + writer_mcast_sender_kernel = + "tt_eager/tt_dnn/op_library/conv/kernels/" + "reader_writer_tiled_out_1d_mcast_sender_conv_weights_tiled_col_to_rm_blocks.cpp"; + writer_mcast_receiver_kernel = + "tt_eager/tt_dnn/op_library/conv/kernels/" + "reader_writer_tiled_out_1d_mcast_receiver_conv_weights_tiled_col_to_rm_blocks.cpp"; } else { - writer_mcast_sender_kernel = "tt_eager/tt_dnn/op_library/conv/kernels/writer_tiled_out_1d_mcast_sender_conv_weights_tiled_col_to_rm_blocks.cpp"; - writer_mcast_receiver_kernel = "tt_eager/tt_dnn/op_library/conv/kernels/writer_tiled_out_1d_mcast_receiver_conv_weights_tiled_col_to_rm_blocks.cpp"; + writer_mcast_sender_kernel = + "tt_eager/tt_dnn/op_library/conv/kernels/" + "writer_tiled_out_1d_mcast_sender_conv_weights_tiled_col_to_rm_blocks.cpp"; + writer_mcast_receiver_kernel = + "tt_eager/tt_dnn/op_library/conv/kernels/" + "writer_tiled_out_1d_mcast_receiver_conv_weights_tiled_col_to_rm_blocks.cpp"; } } // Local L1 to store array for reader indices // All convs use packed uint16 indices, so each entry can be 2B (not 4) - CircularBufferConfig cb_for_reader_indices_config = CircularBufferConfig(out_block_h_datums * 2, {{cb_for_reader_indices, tt::DataFormat::Float16_b}}) - .set_page_size(cb_for_reader_indices, out_block_h_datums * 2); + CircularBufferConfig cb_for_reader_indices_config = + CircularBufferConfig(out_block_h_datums * 2, {{cb_for_reader_indices, tt::DataFormat::Float16_b}}) + .set_page_size(cb_for_reader_indices, out_block_h_datums * 2); cb_for_reader_indices_config.set_globally_allocated_address(*conv_reader_indices.value().buffer()); - auto cb_for_reader_indices_id = tt_metal::CreateCircularBuffer(program, all_cores, cb_for_reader_indices_config); + auto cb_for_reader_indices_id = + tt_metal::CreateCircularBuffer(program, all_cores, cb_for_reader_indices_config); // Local L1 to store temp vars - CircularBufferConfig cb_for_l1_array_config = CircularBufferConfig(32 * 2, {{cb_for_l1_array, tt::DataFormat::Float16_b}}) - .set_page_size(cb_for_l1_array, 32 * 2); + CircularBufferConfig cb_for_l1_array_config = + CircularBufferConfig(32 * 2, {{cb_for_l1_array, tt::DataFormat::Float16_b}}) + .set_page_size(cb_for_l1_array, 32 * 2); auto cb_for_l1_array_id = tt_metal::CreateCircularBuffer(program, all_cores, cb_for_l1_array_config); } else { TT_ASSERT(false, "Sharded input not supported for this conv yet!"); @@ -790,29 +882,29 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl(tt_met } reader_compile_time_args = { - (uint32_t) (src0_dram_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0), - (uint32_t) stride_h, - (uint32_t) stride_w, - (uint32_t) conv_act_size_w, - (uint32_t) conv_output_size_w, // conv_output_w_last_index - (uint32_t) conv_act_c_read_bytes, - (uint32_t) window_outer, - (uint32_t) window_inner, - (uint32_t) act_block_h_datums, - (uint32_t) act_block_num_tiles / conv_act_c_blocks, - (uint32_t) weight_size_w, - (uint32_t) conv_act_size_w + (2 * pad_w), - (uint32_t) act_block_w_extra_align_bytes, // only used for 1d systolic variant - (uint32_t) weight_size_h, - (uint32_t) num_blocks_act_h_per_core, // act_num_blocks_h - (uint32_t) in0_block_num_tiles, // act_block_num_tiles - (uint32_t) conv_act_c_blocks, // act_w_num_outer - (uint32_t) (transpose_mcast ? num_cores_y - 1 : num_cores_x - 1), // act_mcast_num_dests - (uint32_t) (transpose_mcast ? num_cores_y - 1 : num_cores_x - 1), // act_mcast_num_cores - (uint32_t) act_mcast_sender_semaphore, - (uint32_t) act_mcast_receiver_semaphore, - (uint32_t) in0_block_num_tiles * tilized_act_tile_size, // act_mcast_sender_size_bytes - (uint32_t) (transpose_mcast ? 1 : 0), + (uint32_t)(src0_dram_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0), + (uint32_t)stride_h, + (uint32_t)stride_w, + (uint32_t)conv_act_size_w, + (uint32_t)conv_output_size_w, // conv_output_w_last_index + (uint32_t)conv_act_c_read_bytes, + (uint32_t)window_outer, + (uint32_t)window_inner, + (uint32_t)act_block_h_datums, + (uint32_t)act_block_num_tiles / conv_act_c_blocks, + (uint32_t)weight_size_w, + (uint32_t)conv_act_size_w + (2 * pad_w), + (uint32_t)act_block_w_extra_align_bytes, // only used for 1d systolic variant + (uint32_t)weight_size_h, + (uint32_t)num_blocks_act_h_per_core, // act_num_blocks_h + (uint32_t)in0_block_num_tiles, // act_block_num_tiles + (uint32_t)conv_act_c_blocks, // act_w_num_outer + (uint32_t)(transpose_mcast ? num_cores_y - 1 : num_cores_x - 1), // act_mcast_num_dests + (uint32_t)(transpose_mcast ? num_cores_y - 1 : num_cores_x - 1), // act_mcast_num_cores + (uint32_t)act_mcast_sender_semaphore, + (uint32_t)act_mcast_receiver_semaphore, + (uint32_t)in0_block_num_tiles * tilized_act_tile_size, // act_mcast_sender_size_bytes + (uint32_t)(transpose_mcast ? 1 : 0), }; // define for bias @@ -850,39 +942,39 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl(tt_met } writer_compile_time_args = { - (uint32_t) (dst_dram_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0), + (uint32_t)(dst_dram_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0), out0_cb, weight_cb, bias_cb, - (uint32_t) (bias_buffer == nullptr ? 0 : (bias_buffer->buffer_type() == BufferType::DRAM ? 1 : 0)), - num_blocks_act_w, // = number of blocks of weight in height dim + (uint32_t)(bias_buffer == nullptr ? 0 : (bias_buffer->buffer_type() == BufferType::DRAM ? 1 : 0)), + num_blocks_act_w, // = number of blocks of weight in height dim in1_block_num_tiles, conv_act_c_blocks, weight_block_h_ntiles / conv_act_c_blocks, weight_block_w_ntiles, - weight_matrix_width_ntiles, // weight_stride_h - weight_matrix_width_ntiles * weight_block_h_ntiles, // weight_next_block_stride_h, - weight_block_w_ntiles, // weight_next_block_stride_w + weight_matrix_width_ntiles, // weight_stride_h + weight_matrix_width_ntiles * weight_block_h_ntiles, // weight_next_block_stride_h, + weight_block_w_ntiles, // weight_next_block_stride_w // bias bias_ntiles_per_core, - output_width_num_tiles, // out_next_tile_stride_h - 1, // out_next_tile_stride_w - out_subblock_h_ntiles * output_width_num_tiles, // out_next_subblock_stride_h - out_subblock_w_ntiles, // out_next_subblock_stride_w - act_block_h_ntiles * output_width_num_tiles, // out_next_block_stride_h - //weight_block_w_ntiles, // out_next_block_stride_w + output_width_num_tiles, // out_next_tile_stride_h + 1, // out_next_tile_stride_w + out_subblock_h_ntiles * output_width_num_tiles, // out_next_subblock_stride_h + out_subblock_w_ntiles, // out_next_subblock_stride_w + act_block_h_ntiles * output_width_num_tiles, // out_next_block_stride_h + // weight_block_w_ntiles, // out_next_block_stride_w out_subblock_h_ntiles, out_subblock_w_ntiles, out_subblock_num_tiles, - act_block_h_ntiles / out_subblock_h_ntiles, // out_num_subblocks_h - weight_block_w_ntiles / out_subblock_w_ntiles, // out_num_subblocks_w - num_blocks_act_h_per_core, // out_num_blocks_h - num_blocks_weight_w_per_core, // out_num_blocks_w - act_block_h_ntiles, // out_block_height_num_tiles - output_height_num_tiles, // out_height_num_tiles without block shape padding - output_width_num_tiles, // out_width_num_tiles withoug block shape padding + act_block_h_ntiles / out_subblock_h_ntiles, // out_num_subblocks_h + weight_block_w_ntiles / out_subblock_w_ntiles, // out_num_subblocks_w + num_blocks_act_h_per_core, // out_num_blocks_h + num_blocks_weight_w_per_core, // out_num_blocks_w + act_block_h_ntiles, // out_block_height_num_tiles + output_height_num_tiles, // out_height_num_tiles without block shape padding + output_width_num_tiles, // out_width_num_tiles withoug block shape padding out_dram_addr, weight_dram_addr, @@ -890,14 +982,15 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl(tt_met }; if (split_reader) { std::vector split_reader_args = { - (uint32_t) act_block_h_datums, - (uint32_t) act_block_num_tiles / conv_act_c_blocks, - (uint32_t) conv_act_c_read_bytes, - (uint32_t) weight_size_w * conv_act_c_read_bytes, // coalesced_read_bytes - (uint32_t) (conv_act_size_w + 2 * pad_w) * conv_act_c_read_bytes, // window_outer_offset - (uint32_t) act_block_w_extra_align_bytes, // only used for 1d systolic variant + (uint32_t)act_block_h_datums, + (uint32_t)act_block_num_tiles / conv_act_c_blocks, + (uint32_t)conv_act_c_read_bytes, + (uint32_t)weight_size_w * conv_act_c_read_bytes, // coalesced_read_bytes + (uint32_t)(conv_act_size_w + 2 * pad_w) * conv_act_c_read_bytes, // window_outer_offset + (uint32_t)act_block_w_extra_align_bytes, // only used for 1d systolic variant }; - writer_compile_time_args.insert(writer_compile_time_args.end(), split_reader_args.begin(), split_reader_args.end()); + writer_compile_time_args.insert( + writer_compile_time_args.end(), split_reader_args.begin(), split_reader_args.end()); } vector compute_kernel_args = { @@ -922,8 +1015,7 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl(tt_met tilize_in0, untilize_out, - bias_ntiles_per_core - }; + bias_ntiles_per_core}; auto writer_mcast_noc = NOC::NOC_0; auto reader_noc = writer_mcast_noc == NOC::NOC_0 ? NOC::NOC_1 : NOC::NOC_0; @@ -937,7 +1029,7 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl(tt_met .compile_args = writer_compile_time_args, .defines = writer_mcast_sender_defines}); - KernelHandle writer_mcast_receiver_id; + KernelHandle writer_mcast_receiver_id = -1; if (total_num_cores > 1) { writer_mcast_receiver_id = CreateKernel( program, @@ -953,7 +1045,7 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl(tt_met auto reader_id = CreateKernel( program, reader_kernel, - all_cores, + all_active_cores, DataMovementConfig{ .processor = DataMovementProcessor::RISCV_1, .noc = reader_noc, @@ -972,26 +1064,12 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl(tt_met .compile_args = compute_kernel_args, .defines = compute_defines}); - if (total_noop_cores > 0) { - auto compute_id = CreateKernel( - program, - "tt_metal/kernels/compute/blank.cpp", - noop_cores, ComputeConfig{}); - } - - vector reader_ids; - vector writer_ids; - for(uint32_t core_i = 0; core_i < total_num_cores; core_i++) { + for (uint32_t core_i = 0; core_i < total_active_num_cores; core_i++) { uint32_t core_x_i = core_i % num_cores_x; uint32_t core_y_i = core_i / num_cores_x; CoreRange core(CoreCoord(core_x_i, core_y_i), CoreCoord(core_x_i, core_y_i)); bool noop_core = false; - for (const auto & noop_core_range : noop_cores.ranges()) { - if (noop_core_range.contains(core)) { - noop_core = true; - break; - } - } + // per core specific args uint32_t act_slice_i; uint32_t weight_slice_i; @@ -1002,7 +1080,8 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl(tt_met act_slice_i = core_i / total_num_cores_per_act_slice; weight_slice_i = core_i % total_num_cores_per_act_slice; } - uint32_t out_start_tile_id = (act_slice_i * per_core_out_matrix_height_ntiles * weight_matrix_width_ntiles) + (weight_slice_i * per_core_out_matrix_width_ntiles); + uint32_t out_start_tile_id = (act_slice_i * per_core_out_matrix_height_ntiles * weight_matrix_width_ntiles) + + (weight_slice_i * per_core_out_matrix_width_ntiles); uint32_t out_start_tile_id_h = act_slice_i * per_core_out_matrix_height_ntiles; uint32_t out_start_tile_id_w = weight_slice_i * per_core_out_matrix_width_ntiles; uint32_t bias_tile_offset = weight_slice_i * per_core_out_matrix_width_ntiles; @@ -1019,158 +1098,148 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl(tt_met TT_ASSERT(!reader_is_noc_0); if (transpose_mcast) { - CoreCoord bottom_core = {(std::size_t) core_x_i, (std::size_t) num_cores_y - 1}; + CoreCoord bottom_core = {(std::size_t)core_x_i, (std::size_t)num_cores_y - 1}; auto bottom_core_physical = device->worker_core_from_logical_core(bottom_core); uint32_t act_mcast_dest_noc_start_x = bottom_core_physical.x; - uint32_t act_mcast_dest_noc_start_y = reader_is_noc_0 ? top_left_core_physical.y : bottom_core_physical.y; + uint32_t act_mcast_dest_noc_start_y = + reader_is_noc_0 ? top_left_core_physical.y : bottom_core_physical.y; uint32_t act_mcast_dest_noc_end_x = bottom_core_physical.x; uint32_t act_mcast_dest_noc_end_y = reader_is_noc_0 ? bottom_core_physical.y : top_left_core_physical.y; reader_rt_args = { - (uint32_t) noop_core, + (uint32_t)noop_core, // mcast args act_mcast_dest_noc_start_x, act_mcast_dest_noc_start_y, act_mcast_dest_noc_end_x, act_mcast_dest_noc_end_y, - core_y_i, // act_mcast_sender_id (goes down the column) - (uint32_t) bottom_core_physical.x, // act_mcast_sender_noc_x + core_y_i, // act_mcast_sender_id (goes down the column) + (uint32_t)bottom_core_physical.x, // act_mcast_sender_noc_x }; - reader_rt_args.insert(reader_rt_args.end(), act_mcast_noc_y.begin(), act_mcast_noc_y.end()); // act_mcast_sender_noc_y + reader_rt_args.insert( + reader_rt_args.end(), act_mcast_noc_y.begin(), act_mcast_noc_y.end()); // act_mcast_sender_noc_y } else { - CoreCoord core = { core_x_i, core_y_i }; + CoreCoord core = {core_x_i, core_y_i}; auto core_physical = device->worker_core_from_logical_core(core); - CoreCoord bottom_right_core = {(std::size_t) num_cores_x - 1, (std::size_t) num_cores_y - 1}; + CoreCoord bottom_right_core = {(std::size_t)num_cores_x - 1, (std::size_t)num_cores_y - 1}; auto bottom_right_core_physical = device->worker_core_from_logical_core(bottom_right_core); - uint32_t act_mcast_dest_noc_start_x = reader_is_noc_0 ? top_left_core_physical.x : bottom_right_core_physical.x; + uint32_t act_mcast_dest_noc_start_x = + reader_is_noc_0 ? top_left_core_physical.x : bottom_right_core_physical.x; uint32_t act_mcast_dest_noc_start_y = core_physical.y; - uint32_t act_mcast_dest_noc_end_x = reader_is_noc_0 ? bottom_right_core_physical.x : top_left_core_physical.x; + uint32_t act_mcast_dest_noc_end_x = + reader_is_noc_0 ? bottom_right_core_physical.x : top_left_core_physical.x; uint32_t act_mcast_dest_noc_end_y = core_physical.y; reader_rt_args = { - (uint32_t) noop_core, + (uint32_t)noop_core, // mcast args act_mcast_dest_noc_start_x, act_mcast_dest_noc_start_y, act_mcast_dest_noc_end_x, act_mcast_dest_noc_end_y, - core_x_i, // act_mcast_sender_id (goes along the row) - (uint32_t) core_physical.y, // act_mcast_sender_noc_x + core_x_i, // act_mcast_sender_id (goes along the row) + (uint32_t)core_physical.y, // act_mcast_sender_noc_x }; - reader_rt_args.insert(reader_rt_args.end(), act_mcast_noc_y.begin(), act_mcast_noc_y.end()); // act_mcast_sender_noc_y + reader_rt_args.insert( + reader_rt_args.end(), act_mcast_noc_y.begin(), act_mcast_noc_y.end()); // act_mcast_sender_noc_y } } else { - reader_rt_args = { - (uint32_t) noop_core - }; + reader_rt_args = {(uint32_t)noop_core}; } // log_debug("Core: {}, READER RT ARGS: {}", core, reader_rt_args.size()); - SetRuntimeArgs( - program, reader_id, core, - reader_rt_args - ); - reader_ids.push_back(reader_id); + SetRuntimeArgs(program, reader_id, core, reader_rt_args); writer_rt_args = { out_dram_addr, weight_dram_addr, bias_dram_addr, - output_width_num_tiles, // out_next_tile_stride_h - 1, // out_next_tile_stride_w - out_subblock_h_ntiles * output_width_num_tiles, // out_next_subblock_stride_h - out_subblock_w_ntiles, // out_next_subblock_stride_w - act_block_h_ntiles * output_width_num_tiles, // out_next_block_stride_h - weight_block_w_ntiles, // out_next_block_stride_w + output_width_num_tiles, // out_next_tile_stride_h + 1, // out_next_tile_stride_w + out_subblock_h_ntiles * output_width_num_tiles, // out_next_subblock_stride_h + out_subblock_w_ntiles, // out_next_subblock_stride_w + act_block_h_ntiles * output_width_num_tiles, // out_next_block_stride_h + weight_block_w_ntiles, // out_next_block_stride_w out_subblock_h_ntiles, out_subblock_w_ntiles, out_subblock_num_tiles, - act_block_h_ntiles / out_subblock_h_ntiles, // out_num_subblocks_h - weight_block_w_ntiles / out_subblock_w_ntiles, // out_num_subblocks_w - num_blocks_act_h_per_core, // out_num_blocks_h - num_blocks_weight_w_per_core, // out_num_blocks_w - act_block_h_ntiles, // out_block_height_num_tiles - output_height_num_tiles, // out_height_num_tiles without block shape padding - output_width_num_tiles, // out_width_num_tiles withoug block shape padding + act_block_h_ntiles / out_subblock_h_ntiles, // out_num_subblocks_h + weight_block_w_ntiles / out_subblock_w_ntiles, // out_num_subblocks_w + num_blocks_act_h_per_core, // out_num_blocks_h + num_blocks_weight_w_per_core, // out_num_blocks_w + act_block_h_ntiles, // out_block_height_num_tiles + output_height_num_tiles, // out_height_num_tiles without block shape padding + output_width_num_tiles, // out_width_num_tiles withoug block shape padding out_start_tile_id, out_start_tile_id_h, out_start_tile_id_w, - num_blocks_act_w, // = number of blocks of weight in height dim + num_blocks_act_w, // = number of blocks of weight in height dim in1_block_num_tiles, conv_act_c_blocks, weight_block_h_ntiles / conv_act_c_blocks, weight_block_w_ntiles, - weight_matrix_width_ntiles, // weight_stride_h - weight_matrix_width_ntiles * weight_block_h_ntiles, // weight_next_block_stride_h, - weight_block_w_ntiles, // weight_next_block_stride_w + weight_matrix_width_ntiles, // weight_stride_h + weight_matrix_width_ntiles * weight_block_h_ntiles, // weight_next_block_stride_h, + weight_block_w_ntiles, // weight_next_block_stride_w // bias bias_ntiles_per_core, bias_tile_offset, - (uint32_t) noop_core - }; + (uint32_t)noop_core}; // Mcast sender if (weight_width_sliced) { // 2D mcast if (transpose_mcast) { - CoreCoord right_core = {(std::size_t) num_cores_x - 1, (std::size_t) core_y_i}; + CoreCoord right_core = {(std::size_t)num_cores_x - 1, (std::size_t)core_y_i}; auto right_core_physical = device->worker_core_from_logical_core(right_core); if (core_x_i == 0) { // sender if (writer_mcast_noc == NOC::NOC_0) { - writer_rt_args.push_back(top_left_core_plus_one_physical.x); // weights_mcast_dest_noc_start_x - writer_rt_args.push_back(right_core_physical.y); // weights_mcast_dest_noc_start_y - writer_rt_args.push_back(bottom_right_core_physical.x); // weights_mcast_dest_noc_end_x - writer_rt_args.push_back(right_core_physical.y); // weights_mcast_dest_noc_end_y + writer_rt_args.push_back(top_left_core_plus_one_physical.x); // weights_mcast_dest_noc_start_x + writer_rt_args.push_back(right_core_physical.y); // weights_mcast_dest_noc_start_y + writer_rt_args.push_back(bottom_right_core_physical.x); // weights_mcast_dest_noc_end_x + writer_rt_args.push_back(right_core_physical.y); // weights_mcast_dest_noc_end_y } else { - writer_rt_args.push_back(bottom_right_core_physical.x); // weights_mcast_dest_noc_start_x - writer_rt_args.push_back(right_core_physical.y); // weights_mcast_dest_noc_start_y - writer_rt_args.push_back(top_left_core_plus_one_physical.x); // weights_mcast_dest_noc_end_x - writer_rt_args.push_back(right_core_physical.y); // weights_mcast_dest_noc_end_y + writer_rt_args.push_back(bottom_right_core_physical.x); // weights_mcast_dest_noc_start_x + writer_rt_args.push_back(right_core_physical.y); // weights_mcast_dest_noc_start_y + writer_rt_args.push_back(top_left_core_plus_one_physical.x); // weights_mcast_dest_noc_end_x + writer_rt_args.push_back(right_core_physical.y); // weights_mcast_dest_noc_end_y } - writer_rt_args.push_back(num_cores_x - 1); // weights_mcast_num_dests - writer_rt_args.push_back(num_cores_x - 1); // weights_mcast_num_cores + writer_rt_args.push_back(num_cores_x - 1); // weights_mcast_num_dests + writer_rt_args.push_back(num_cores_x - 1); // weights_mcast_num_cores writer_rt_args.push_back(weights_mcast_sender_semaphore); writer_rt_args.push_back(weights_mcast_receiver_semaphore); - SetRuntimeArgs( - program, writer_mcast_sender_id, core, - writer_rt_args - ); - writer_ids.push_back(writer_mcast_sender_id); + SetRuntimeArgs(program, writer_mcast_sender_id, core, writer_rt_args); } else { // receiver - writer_rt_args.push_back(top_left_core_physical.x); // weights_mcast_sender_noc_x - writer_rt_args.push_back(right_core_physical.y); // weights_mcast_sender_noc_y + writer_rt_args.push_back(top_left_core_physical.x); // weights_mcast_sender_noc_x + writer_rt_args.push_back(right_core_physical.y); // weights_mcast_sender_noc_y writer_rt_args.push_back(weights_mcast_sender_semaphore); writer_rt_args.push_back(weights_mcast_receiver_semaphore); - SetRuntimeArgs( - program, writer_mcast_receiver_id, core, - writer_rt_args - ); - writer_ids.push_back(writer_mcast_receiver_id); + SetRuntimeArgs(program, writer_mcast_receiver_id, core, writer_rt_args); } } else { - CoreCoord top_core = {(std::size_t) core_x_i, 0}; + CoreCoord top_core = {(std::size_t)core_x_i, 0}; auto top_core_physical = device->worker_core_from_logical_core(top_core); TT_ASSERT(writer_mcast_noc == NOC::NOC_0); if (core_y_i == 0) { // sender if (writer_mcast_noc == NOC::NOC_0) { - writer_rt_args.push_back(top_core_physical.x); // weights_mcast_dest_noc_start_x - writer_rt_args.push_back(top_left_core_plus_one_physical.y); // weights_mcast_dest_noc_start_y - writer_rt_args.push_back(top_core_physical.x); // weights_mcast_dest_noc_end_x - writer_rt_args.push_back(bottom_right_core_physical.y); // weights_mcast_dest_noc_end_y + writer_rt_args.push_back(top_core_physical.x); // weights_mcast_dest_noc_start_x + writer_rt_args.push_back(top_left_core_plus_one_physical.y); // weights_mcast_dest_noc_start_y + writer_rt_args.push_back(top_core_physical.x); // weights_mcast_dest_noc_end_x + writer_rt_args.push_back(bottom_right_core_physical.y); // weights_mcast_dest_noc_end_y } else { // TODO: ... TT_ASSERT(false, "TODO: Writer on NOC 1 not supported yet!"); @@ -1180,28 +1249,20 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl(tt_met // writer_rt_args.push_back(right_core_physical.y); // weights_mcast_dest_noc_end_y } - writer_rt_args.push_back(num_cores_y - 1); // weights_mcast_num_dests - writer_rt_args.push_back(num_cores_y - 1); // weights_mcast_num_cores + writer_rt_args.push_back(num_cores_y - 1); // weights_mcast_num_dests + writer_rt_args.push_back(num_cores_y - 1); // weights_mcast_num_cores writer_rt_args.push_back(weights_mcast_sender_semaphore); writer_rt_args.push_back(weights_mcast_receiver_semaphore); - SetRuntimeArgs( - program, writer_mcast_sender_id, core, - writer_rt_args - ); - writer_ids.push_back(writer_mcast_sender_id); + SetRuntimeArgs(program, writer_mcast_sender_id, core, writer_rt_args); } else { // receiver - writer_rt_args.push_back(top_core_physical.x); // weights_mcast_sender_noc_x - writer_rt_args.push_back(top_left_core_physical.y); // weights_mcast_sender_noc_y + writer_rt_args.push_back(top_core_physical.x); // weights_mcast_sender_noc_x + writer_rt_args.push_back(top_left_core_physical.y); // weights_mcast_sender_noc_y writer_rt_args.push_back(weights_mcast_sender_semaphore); writer_rt_args.push_back(weights_mcast_receiver_semaphore); - SetRuntimeArgs( - program, writer_mcast_receiver_id, core, - writer_rt_args - ); - writer_ids.push_back(writer_mcast_receiver_id); + SetRuntimeArgs(program, writer_mcast_receiver_id, core, writer_rt_args); } } } else { @@ -1209,117 +1270,167 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl(tt_met if (core_x_i == 0 and core_y_i == 0) { // sender if (writer_mcast_noc == NOC::NOC_0) { - writer_rt_args.push_back(top_left_core_physical.x); // weights_mcast_dest_noc_start_x - writer_rt_args.push_back(top_left_core_physical.y); // weights_mcast_dest_noc_start_y - writer_rt_args.push_back(bottom_right_core_physical.x); // weights_mcast_dest_noc_end_x - writer_rt_args.push_back(bottom_right_core_physical.y); // weights_mcast_dest_noc_end_y + writer_rt_args.push_back(top_left_core_physical.x); // weights_mcast_dest_noc_start_x + writer_rt_args.push_back(top_left_core_physical.y); // weights_mcast_dest_noc_start_y + writer_rt_args.push_back(bottom_right_core_physical.x); // weights_mcast_dest_noc_end_x + writer_rt_args.push_back(bottom_right_core_physical.y); // weights_mcast_dest_noc_end_y } else { - writer_rt_args.push_back(bottom_right_core_physical.x); // weights_mcast_dest_noc_start_x - writer_rt_args.push_back(bottom_right_core_physical.y); // weights_mcast_dest_noc_start_y - writer_rt_args.push_back(top_left_core_physical.x); // weights_mcast_dest_noc_end_x - writer_rt_args.push_back(top_left_core_physical.y); // weights_mcast_dest_noc_end_y + writer_rt_args.push_back(bottom_right_core_physical.x); // weights_mcast_dest_noc_start_x + writer_rt_args.push_back(bottom_right_core_physical.y); // weights_mcast_dest_noc_start_y + writer_rt_args.push_back(top_left_core_physical.x); // weights_mcast_dest_noc_end_x + writer_rt_args.push_back(top_left_core_physical.y); // weights_mcast_dest_noc_end_y } - writer_rt_args.push_back(total_active_num_cores - 1); // weights_mcast_num_dests - writer_rt_args.push_back(total_num_cores - 1); // weights_mcast_num_cores + writer_rt_args.push_back(total_active_num_cores - 1); // weights_mcast_num_dests + writer_rt_args.push_back(total_num_cores - 1); // weights_mcast_num_cores writer_rt_args.push_back(weights_mcast_sender_semaphore); writer_rt_args.push_back(weights_mcast_receiver_semaphore); - SetRuntimeArgs( - program, writer_mcast_sender_id, core, - writer_rt_args - ); - writer_ids.push_back(writer_mcast_sender_id); + SetRuntimeArgs(program, writer_mcast_sender_id, core, writer_rt_args); } else { // receiver - writer_rt_args.push_back(top_left_core_physical.x); // weights_mcast_sender_noc_x - writer_rt_args.push_back(top_left_core_physical.y); // weights_mcast_sender_noc_y + writer_rt_args.push_back(top_left_core_physical.x); // weights_mcast_sender_noc_x + writer_rt_args.push_back(top_left_core_physical.y); // weights_mcast_sender_noc_y writer_rt_args.push_back(weights_mcast_sender_semaphore); writer_rt_args.push_back(weights_mcast_receiver_semaphore); // log_debug("Core: {}, WRITER RT ARGS: {}", core, writer_rt_args.size()); - SetRuntimeArgs( - program, writer_mcast_receiver_id, core, - writer_rt_args - ); - writer_ids.push_back(writer_mcast_receiver_id); + SetRuntimeArgs(program, writer_mcast_receiver_id, core, writer_rt_args); } } - } // for num_cores - - auto override_runtime_arguments_callback = [ - reader_kernel_ids=reader_ids, - writer_kernel_ids=writer_ids, - cb_sharded_act=cb_sharded_act, - cb_output=cb_output, - total_num_cores=total_num_cores, - num_cores_x=num_cores_x, - num_cores_y=num_cores_y, - has_bias=has_bias - ] - ( - const void* operation, - Program& program, - const std::vector& input_tensors, - const std::vector>& optional_input_tensors, - const std::vector& output_tensors - ) { - - // Reader config indices is an optional static sharded tensor, so no need to update address - TT_ASSERT(input_tensors.size() + optional_input_tensors.size() == 4); - TT_ASSERT(output_tensors.size() == 1); - - auto src_buffer_a = input_tensors.at(0).buffer(); - auto src_buffer_b = input_tensors.at(1).buffer(); - auto src_a_is_sharded = input_tensors.at(0).memory_config().is_sharded(); - - auto dst_buffer = output_tensors.at(0).buffer(); - bool out_sharded = output_tensors.at(0).memory_config().is_sharded(); - - for(uint32_t core_i = 0; core_i < total_num_cores; core_i++) { - uint32_t core_x_i = core_i % num_cores_x; - uint32_t core_y_i = core_i / num_cores_x; - CoreCoord core = {core_x_i, core_y_i}; - - if (!src_a_is_sharded) { - auto &runtime_args = GetRuntimeArgs(program, reader_kernel_ids[core_i], core); - runtime_args[0] = src_buffer_a->address(); + } // for num_cores + + auto mcast_sender_cores_vec = grid_to_cores(mcast_sender_cores.start, mcast_sender_cores.end, true); + auto mcast_receiver_cores_vec = corerange_to_cores(mcast_receiver_cores, std::nullopt, true); + auto override_runtime_arguments_callback = + [reader_kernel_id = reader_id, + mcast_sender_cores = mcast_sender_cores_vec, + writer_mcast_sender_id = writer_mcast_sender_id, + mcast_receiver_cores = mcast_receiver_cores_vec, + writer_mcast_receiver_id = writer_mcast_receiver_id, + cb_sharded_act = cb_sharded_act, + cb_output = cb_output, + total_active_num_cores = total_active_num_cores, + num_cores_x = num_cores_x, + num_cores_y = num_cores_y, + has_bias = has_bias]( + const void* operation, + Program& program, + const std::vector& input_tensors, + const std::vector>& optional_input_tensors, + const std::vector& output_tensors) { + // Reader config indices is an optional static sharded tensor, so no need to update address + TT_ASSERT(input_tensors.size() + optional_input_tensors.size() == 4); + TT_ASSERT(output_tensors.size() == 1); + + auto src_buffer_a = input_tensors.at(0).buffer(); + auto src_buffer_b = input_tensors.at(1).buffer(); + bool src_a_is_sharded = input_tensors[0].is_sharded(); + + std::optional src_buffer_c = std::nullopt; + if (has_bias) { + src_buffer_c = optional_input_tensors.at(0).value().buffer(); + TT_ASSERT(src_buffer_c.value() != nullptr); } - { - auto &runtime_args = GetRuntimeArgs(program, writer_kernel_ids[core_i], core); + auto dst_buffer = output_tensors.at(0).buffer(); + bool out_sharded = output_tensors[0].is_sharded(); + + auto& reader_kernel_args_by_core = GetRuntimeArgs(program, reader_kernel_id); + + auto& writer_sender_kernel_args_by_core = GetRuntimeArgs(program, writer_mcast_sender_id); + for (const auto& core : mcast_sender_cores) { + if (!src_a_is_sharded) { + auto& runtime_args = reader_kernel_args_by_core[core.x][core.y]; + runtime_args[0] = src_buffer_a->address(); + } + auto& runtime_args = writer_sender_kernel_args_by_core[core.x][core.y]; runtime_args[0] = dst_buffer->address(); runtime_args[1] = src_buffer_b->address(); if (has_bias) { - auto src_buffer_c = optional_input_tensors.at(0).value().buffer(); - TT_ASSERT(src_buffer_c != nullptr); - runtime_args[2] = src_buffer_c->address(); + runtime_args[2] = (*src_buffer_c)->address(); } } - } - if (src_a_is_sharded) { - UpdateDynamicCircularBufferAddress(program, cb_sharded_act, *src_buffer_a); - } + if (mcast_receiver_cores.size() > 0) { + auto& writer_receiver_kernel_args_by_core = GetRuntimeArgs(program, writer_mcast_receiver_id); + for (const auto& core : mcast_receiver_cores) { + if (!src_a_is_sharded) { + auto& runtime_args = reader_kernel_args_by_core[core.x][core.y]; + runtime_args[0] = src_buffer_a->address(); + } + auto& runtime_args = writer_receiver_kernel_args_by_core[core.x][core.y]; + runtime_args[0] = dst_buffer->address(); + runtime_args[1] = src_buffer_b->address(); + if (has_bias) { + runtime_args[2] = (*src_buffer_c)->address(); + } + } + } - if (out_sharded) { - UpdateDynamicCircularBufferAddress(program, cb_output, *dst_buffer); - } - }; - return {.program=std::move(program), .override_runtime_arguments_callback=override_runtime_arguments_callback}; + if (src_a_is_sharded) { + UpdateDynamicCircularBufferAddress(program, cb_sharded_act, *src_buffer_a); + } + + if (out_sharded) { + UpdateDynamicCircularBufferAddress(program, cb_output, *dst_buffer); + } + }; + return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_arguments_callback}; } -operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_(const Tensor& a, const Tensor &b, const Shape& ashape, std::optional bias, const std::optional conv_reader_indices, vector conv_params, uint32_t output_channels, bool untilize_out, bool has_bias, bool fuse_relu, const OptimizedConvParallelizationConfig& parallelization_config, const OptimizedConvBlockConfig& block_config, uint32_t extra_padding_for_32B_alignment, bool use_shallow_conv_variant, bool transpose_mcast, Tensor &output, DeviceComputeKernelConfig compute_kernel_config) { +operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_( + const Tensor& a, + const Tensor& b, + const Shape& ashape, + std::optional bias, + const std::optional conv_reader_indices, + vector conv_params, + uint32_t output_channels, + bool untilize_out, + bool has_bias, + bool fuse_relu, + const OptimizedConvParallelizationConfig& parallelization_config, + const OptimizedConvBlockConfig& block_config, + uint32_t extra_padding_for_32B_alignment, + bool use_shallow_conv_variant, + bool transpose_mcast, + Tensor& output, + DeviceComputeKernelConfig compute_kernel_config) { tt_metal::Program program = tt_metal::CreateProgram(); - return multi_core_optimized_conv_sharded_v2_impl(program, a, b, ashape, bias, conv_reader_indices, conv_params, output_channels, untilize_out, has_bias, fuse_relu, parallelization_config, block_config, extra_padding_for_32B_alignment, use_shallow_conv_variant, transpose_mcast, output, compute_kernel_config); + return multi_core_optimized_conv_sharded_v2_impl( + program, + a, + b, + ashape, + bias, + conv_reader_indices, + conv_params, + output_channels, + untilize_out, + has_bias, + fuse_relu, + parallelization_config, + block_config, + extra_padding_for_32B_alignment, + use_shallow_conv_variant, + transpose_mcast, + output, + compute_kernel_config); } -operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_new(const Tensor& a, const Tensor &b, std::optional bias, +operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_new( + const Tensor& a, + const Tensor& b, + std::optional bias, vector conv_params, uint32_t output_channels, - bool untilize_out, bool fuse_relu, MathFidelity math_fidelity, + bool untilize_out, + bool fuse_relu, + MathFidelity math_fidelity, const OptimizedConvParallelizationConfig& parallelization_config, - const OptimizedConvBlockConfig& block_config, uint32_t extra_padding_for_32B_alignment, + const OptimizedConvBlockConfig& block_config, + uint32_t extra_padding_for_32B_alignment, DataType output_dtype, std::array input_tensor_shape, bool use_shallow_conv_variant, @@ -1332,20 +1443,34 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_new(const T parallel_config.shard_scheme = a.memory_config().memory_layout; parallel_config.shard_orientation = a.shard_spec().value().orientation; // TODO: pass sliding window config to the function instead of conv params - uint32_t weight_size_h = (uint32_t) conv_params[0]; // filter_h - uint32_t weight_size_w = (uint32_t) conv_params[1]; // filter_W - uint32_t stride_h = (uint32_t) conv_params[2]; - uint32_t stride_w = (uint32_t) conv_params[3]; - uint32_t pad_h = (uint32_t) conv_params[4]; - uint32_t pad_w = (uint32_t) conv_params[5]; - SlidingWindowConfig sliding_window_config = SlidingWindowConfig(input_tensor_shape[0], input_tensor_shape[1], input_tensor_shape[2], weight_size_h, weight_size_w, stride_h, stride_w, - pad_h, pad_w, 1, 1, parallelization_config.num_cores_nhw, parallel_config.grid, true); + uint32_t weight_size_h = (uint32_t)conv_params[0]; // filter_h + uint32_t weight_size_w = (uint32_t)conv_params[1]; // filter_W + uint32_t stride_h = (uint32_t)conv_params[2]; + uint32_t stride_w = (uint32_t)conv_params[3]; + uint32_t pad_h = (uint32_t)conv_params[4]; + uint32_t pad_w = (uint32_t)conv_params[5]; + SlidingWindowConfig sliding_window_config = SlidingWindowConfig( + input_tensor_shape[0], + input_tensor_shape[1], + input_tensor_shape[2], + weight_size_h, + weight_size_w, + stride_h, + stride_w, + pad_h, + pad_w, + 1, + 1, + parallelization_config.num_cores_nhw, + parallel_config.grid, + true); // create conv config tensors auto pad_metadata = sliding_window::generate_pad_metadata(sliding_window_config); auto op_trace_metadata = sliding_window::generate_op_trace_metadata(sliding_window_config); auto shard_boundaries = sliding_window::generate_shard_boundaries(sliding_window_config, op_trace_metadata); - auto conv_sharded_input_top_left_indices = sliding_window::generate_sliding_window_op_config(op_trace_metadata, shard_boundaries, true, true); + auto conv_sharded_input_top_left_indices = + sliding_window::generate_sliding_window_op_config(op_trace_metadata, shard_boundaries, true, true); // create sharded ttnn config tensors DataType indices_tt_dtype = DataType::UINT16; // For 2d convs, each core in a column or row share the same specs @@ -1367,12 +1492,32 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_new(const T // } // } bool is_block_sharded = a.memory_config().memory_layout == TensorMemoryLayout::BLOCK_SHARDED; - auto conv_reader_indices_tensor = sliding_window::construct_on_host_config_tensor(conv_sharded_input_top_left_indices, sliding_window_config, parallel_config); - conv_reader_indices_tensor = sliding_window::move_config_tensor_to_device(conv_reader_indices_tensor, parallel_config, is_block_sharded, a.device()); + auto conv_reader_indices_tensor = sliding_window::construct_on_host_config_tensor( + conv_sharded_input_top_left_indices, sliding_window_config, parallel_config); + conv_reader_indices_tensor = sliding_window::move_config_tensor_to_device( + conv_reader_indices_tensor, parallel_config, is_block_sharded, a.device()); // add config tensor to program tt::tt_metal::detail::AddConfigBuffer(program, conv_reader_indices_tensor.device_buffer()); - return multi_core_optimized_conv_sharded_v2_impl(program, a, b, Shape(input_tensor_shape), bias, conv_reader_indices_tensor, conv_params, output_channels, untilize_out, bias.has_value(), fuse_relu, parallelization_config, block_config, extra_padding_for_32B_alignment, use_shallow_conv_variant, parallel_config.shard_orientation == ShardOrientation::COL_MAJOR, output, compute_kernel_config.value()); + return multi_core_optimized_conv_sharded_v2_impl( + program, + a, + b, + Shape(input_tensor_shape), + bias, + conv_reader_indices_tensor, + conv_params, + output_channels, + untilize_out, + bias.has_value(), + fuse_relu, + parallelization_config, + block_config, + extra_padding_for_32B_alignment, + use_shallow_conv_variant, + parallel_config.shard_orientation == ShardOrientation::COL_MAJOR, + output, + compute_kernel_config.value()); } } // namespace tt_metal