From 5280c3a09ffcba8dd2e339ab479935ddd9ded27d Mon Sep 17 00:00:00 2001 From: Borys Bradel Date: Fri, 14 Jun 2024 16:46:14 +0000 Subject: [PATCH] #9059: Get matmul per core factor based on L1 usage --- tt_eager/tt_dnn/op_library/bmm/bmm_op.cpp | 48 ++++++++++++++++++++--- 1 file changed, 43 insertions(+), 5 deletions(-) diff --git a/tt_eager/tt_dnn/op_library/bmm/bmm_op.cpp b/tt_eager/tt_dnn/op_library/bmm/bmm_op.cpp index 005eb014aabd..b61232c87cce 100644 --- a/tt_eager/tt_dnn/op_library/bmm/bmm_op.cpp +++ b/tt_eager/tt_dnn/op_library/bmm/bmm_op.cpp @@ -14,6 +14,7 @@ #include "tt_dnn/op_library/work_split.hpp" #include "tt_metal/common/constants.hpp" #include "tt_metal/host_api.hpp" +#include "tt_metal/hostdevcommon/common_values.hpp" #include "tt_metal/tools/profiler/op_profiler.hpp" #include "ttnn/types.hpp" @@ -733,6 +734,44 @@ Tensor resnet_matmul( namespace operations { namespace primary { + +inline uint32_t get_estimated_size_of_cbs(uint32_t per_core_M, uint32_t per_core_N, uint32_t in0_block_w, uint32_t in0_single_tile_size, uint32_t in1_single_tile_size, uint32_t output_single_tile_size) { + // Circular Buffer sizes: + // src0 CB: per_core_M * in0_block_w * 2 (for double buffer) + // src1 CB: per_core_N * in0_block_w * 2 (for double buffer) + // out CB: per_core_M * per_core_N + // Ignore optional intermediate CB because not needed when need to create a program config. + uint32_t in0_size = per_core_M * in0_block_w * 2 * in0_single_tile_size; + uint32_t in1_size = per_core_M * in0_block_w * 2 * in1_single_tile_size; + uint32_t out_size = per_core_M * per_core_N * output_single_tile_size; + return in0_size + in1_size + out_size; +} + + +inline uint32_t get_per_core_factor( + const Tensor& input_tensor_a, + const Tensor& input_tensor_b, + uint32_t in0_block_w) { + tt::tt_metal::Device* device = input_tensor_a.device(); + const std::vector &bank_ids = + device->bank_ids_from_logical_core(BufferType::L1, *device->compute_cores_.begin()); + std::optional lowest_address = allocator::lowest_occupied_l1_address(*device->allocator_, bank_ids[0]); + uint32_t max_l1_space = lowest_address.has_value() ? lowest_address.value() : device->l1_size_per_core(); + max_l1_space = max_l1_space - L1_UNRESERVED_BASE; + tt::DataFormat in0_data_format = tt_metal::datatype_to_dataformat_converter(input_tensor_a.get_dtype()); + tt::DataFormat in1_data_format = tt_metal::datatype_to_dataformat_converter(input_tensor_b.get_dtype()); + uint32_t in0_single_tile_size = tt_metal::detail::TileSize(in0_data_format); // use as estimate for output as well + uint32_t in1_single_tile_size = tt_metal::detail::TileSize(in1_data_format); + for (uint32_t per_core_factor = 16; per_core_factor > 1; per_core_factor /= 2) { + uint32_t size = get_estimated_size_of_cbs( + per_core_factor, per_core_factor, in0_block_w, in0_single_tile_size, in1_single_tile_size, in0_single_tile_size); + if (size < max_l1_space) { + return per_core_factor; + } + } + return 1; +} + inline MatmulProgramConfig create_simple_matmul_program_config( const Tensor& input_tensor_a, const Tensor& input_tensor_b, @@ -757,8 +796,8 @@ inline MatmulProgramConfig create_simple_matmul_program_config( uint32_t num_blocks_x, num_blocks_y; // out_subblock h/w doesn't matter - per_core_M = 16; - per_core_N = 16; + per_core_M = get_per_core_factor(input_tensor_a, input_tensor_b, in0_block_w); + per_core_N = per_core_M; // Calculate number of blocks along x and y; tensor dims are padded up to 512 num_blocks_y = (Mt - 1) / per_core_M + 1; @@ -787,14 +826,13 @@ inline MatmulProgramConfig create_simple_matmul_program_config( std::nullopt /* compute_with_storage_grid_size */, compute_kernel_config); } else if (core_range.y > 0) { - uint32_t in0_block_w = Kt % 2 == 0 ? 2 : 1; return MatmulMultiCoreReuseMultiCastProgramConfig{ .compute_with_storage_grid_size = {num_cores_x, num_cores_y}, .in0_block_w = in0_block_w, .out_subblock_h = 4, .out_subblock_w = 2, - .per_core_M = 16, - .per_core_N = 16, + .per_core_M = per_core_M, + .per_core_N = per_core_N, .transpose_mcast = false, .fused_activation = std::nullopt, .fuse_batch = false,