diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp index f289840c0a2c..39baa8da7904 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp @@ -267,8 +267,6 @@ OptimizedConvBlockConfig determine_per_core_conv_block_config( } } auto grid_size = parallel_config.grid.bounding_box().grid_size(); - act_block_h_ntiles = act_block_h_override > 0 ? act_block_h_override / tt::constants::TILE_HEIGHT - : conv_op_parallel_config.per_core_out_matrix_height / tt::constants::TILE_HEIGHT; uint32_t act_c_num_blocks = parallel_config.shard_scheme == TensorMemoryLayout::HEIGHT_SHARDED ? 1 : parallel_config.shard_orientation == ShardOrientation::COL_MAJOR ? grid_size.y : grid_size.x; @@ -276,9 +274,11 @@ OptimizedConvBlockConfig determine_per_core_conv_block_config( uint32_t act_block_w = parallel_config.shard_scheme == TensorMemoryLayout::HEIGHT_SHARDED ? round_up(padded_in_channels * window_w, 32) : round_up((padded_in_channels / act_c_num_blocks) * window_h * window_w, tt::constants::TILE_WIDTH); + if(parallel_config.shard_scheme == TensorMemoryLayout::WIDTH_SHARDED) { + act_block_w = (padded_in_channels * window_h * window_w)/(parallel_config.grid.num_cores() * act_block_w_div); + } TT_ASSERT(act_block_w % 32 == 0); uint32_t act_block_w_ntiles = act_block_w / 32; - TT_ASSERT(conv_op_parallel_config.per_core_out_matrix_height % tt::constants::TILE_HEIGHT == 0); //TT_FATAL(conv_op_parallel_config.per_core_out_matrix_width % TILE_WIDTH == 0); uint32_t out_block_h_ntiles = div_up(conv_op_parallel_config.per_core_out_matrix_height, tt::constants::TILE_HEIGHT); uint32_t weight_block_w_ntiles = div_up(conv_op_parallel_config.per_core_out_matrix_width, tt::constants::TILE_WIDTH); @@ -755,7 +755,7 @@ ttnn::operations::matmul::MatmulProgramConfig determine_matmul_op_config_from_co TT_ASSERT(conv_blocking_config.act_block_w_ntiles % grid_size_along_c == 0); ttnn::operations::matmul::MatmulMultiCoreReuseMultiCastProgramConfig matmul_config = { .compute_with_storage_grid_size = conv_parallelization_config.grid_size, - .in0_block_w = conv_blocking_config.act_block_w_ntiles, + .in0_block_w = conv_blocking_config.act_block_w_ntiles / grid_size_along_c, .out_subblock_h = conv_blocking_config.out_subblock_h_ntiles, .out_subblock_w = conv_blocking_config.out_subblock_w_ntiles, .out_block_h = div_up(conv_parallelization_config.per_core_out_matrix_height, tt::constants::TILE_HEIGHT), diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/conv_bmm_tilize_col_major_out_blocks.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/conv_bmm_tilize_col_major_out_blocks.cpp index c3a553da5fba..b85111845228 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/conv_bmm_tilize_col_major_out_blocks.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/conv_bmm_tilize_col_major_out_blocks.cpp @@ -103,9 +103,9 @@ void MAIN { constexpr uint32_t out_subblock_num_tiles = get_compile_time_arg_val(13); // out_subblock_h * out_subblock_w; constexpr bool tilize_in0 = get_compile_time_arg_val(14); constexpr bool untilize_out = get_compile_time_arg_val(15); - constexpr uint32_t out_cb_id = get_compile_time_arg_val(19); - uint32_t output_rows_h = get_compile_time_arg_val(17); - constexpr bool is_non_tile_height = get_compile_time_arg_val(18); + constexpr uint32_t out_cb_id = get_compile_time_arg_val(17); + uint32_t output_rows_h = get_compile_time_arg_val(18); + constexpr bool is_non_tile_height = get_compile_time_arg_val(19); #ifdef WIDTH_SHARDED constexpr uint32_t in0_nblocks_w_tilize = get_compile_time_arg_val(20);