diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp index dbdd7268c968..4e2d8a1f8f6a 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp @@ -188,7 +188,9 @@ MemoryConfig create_sharded_memory_config_from_parallel_config( uint32_t nhw_shape = tensor_shape[0] * tensor_shape[1] * tensor_shape[2]; uint32_t nhw_padded = nhw_shape; - nhw_padded = round_up(nhw_shape, num_cores_nhw * tile_size); + if(shard_scheme != TensorMemoryLayout::WIDTH_SHARDED) { + nhw_padded = round_up(nhw_shape, num_cores_nhw * tile_size); + } uint32_t nhw_shard = nhw_padded / num_cores_nhw; TT_ASSERT(channels % num_cores_channels == 0, "Channels: {}, num core channels: {}", channels, num_cores_channels); uint32_t channel_shard = channels / num_cores_channels; diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_width_sharded_program_factory.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_width_sharded_program_factory.cpp index 4c629dd879f1..5b6afc2d719b 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_width_sharded_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_width_sharded_program_factory.cpp @@ -575,7 +575,7 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_width_sharded_v2_impl( if (false) { compute_defines["PACKER_L1_ACC"] = "1"; } - uint32_t output_rows_h = output.shard_spec().value().shape[0]; + uint32_t num_output_tiles = per_core_out_matrix_height_ntiles*per_core_out_matrix_width_ntiles; uint32_t use_non_tile_height = false; compute_kernel_args = { act_block_w_ntiles, //in0_block_w @@ -603,7 +603,7 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_width_sharded_v2_impl( bias_ntiles_per_core, - output_rows_h, + num_output_tiles, use_non_tile_height, total_num_cores, //in0_nblocks_w_tilize. Repeat tilize after all cores have done one round of MCAST. @@ -680,7 +680,6 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_width_sharded_v2_impl( uint32_t out_tile_size = tt_metal::detail::TileSize(out_df); uint32_t interm0_single_tile_size = tt_metal::detail::TileSize(interm0_df); - uint32_t num_output_tiles = per_core_out_matrix_height_ntiles*per_core_out_matrix_width_ntiles; // Share buffer if same data format diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/conv_bmm_tilize_col_major_out_blocks.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/conv_bmm_tilize_col_major_out_blocks.cpp index 5f22f248601e..2a07817f5d8a 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/conv_bmm_tilize_col_major_out_blocks.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/conv_bmm_tilize_col_major_out_blocks.cpp @@ -55,7 +55,7 @@ inline void reblock_and_untilize( uint32_t interm_cb_id, uint32_t out_cb_id) { constexpr bool is_non_tile_height_= is_non_tile_height; - uint32_t TILE_SIZE = is_non_tile_height_ ? 32 : out_subblock_w; + uint32_t TILE_SIZE = is_non_tile_height_ ? 32 : out_block_w; uint32_t num_tiles_in_row_of_subblocks = mulsi3(out_subblock_num_tiles, num_out_subblocks_in_col); cb_wait_front(interm_cb_id, num_tiles_in_row_of_subblocks); uint32_t within_block_index = 0; @@ -411,7 +411,7 @@ void MAIN { pack_untilize_dst_init_short(out_cb_id); copy_tile_to_dst_init_short(); uint32_t curr_tile_output_rows_h = 0; - uint32_t TILE_SIZE = is_non_tile_height ? 32 : out_subblock_w; + uint32_t TILE_SIZE = is_non_tile_height ? 32 : out_block_w; TILE_SIZE = TILE_SIZE*out_subblock_h; for (uint32_t in0_subblock_i = 0; in0_subblock_i < in0_num_subblocks; ++in0_subblock_i) { curr_tile_output_rows_h = output_rows_h < TILE_SIZE ? output_rows_h : TILE_SIZE;