diff --git a/tests/sweep_framework/sweeps/conv2d/short/conv2d_short_sweep.py b/tests/sweep_framework/sweeps/conv2d/short/conv2d_short_sweep.py index dfb7d5e04b90..56d9790840fe 100644 --- a/tests/sweep_framework/sweeps/conv2d/short/conv2d_short_sweep.py +++ b/tests/sweep_framework/sweeps/conv2d/short/conv2d_short_sweep.py @@ -1623,7 +1623,6 @@ def test_conv2d_localrun(device, input_spec): [1, 1056, 1056, 96, 96, 3, 3, 2, 2, 1, 1, 4, False, 1], # 127 [1, 528, 528, 192, 192, 3, 3, 2, 2, 1, 1, 2, False, 1], # 220 [1, 2904, 2904, 48, 48, 3, 3, 2, 2, 1, 1, 11, False, 1], # 294 - [1, 3024, 1232, 14, 14, 1, 1, 2, 2, 0, 0, 1, False, 1], # 1421 [1, 819, 256, 100, 136, 3, 3, 1, 1, 1, 1, 1, True, 1], # 1443 [1, 819, 256, 50, 68, 3, 3, 1, 1, 1, 1, 1, True, 1], # 1447 [1, 1024, 3, 224, 224, 32, 32, 32, 32, 0, 0, 1, True, 1], # 1458 diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.cpp index acd3453ecf53..5ecdfdf7c920 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.cpp @@ -184,7 +184,7 @@ ParallelConfig determine_parallel_config( return pconfig; } -static ParallelConfig determine_output_parallel_config( +ParallelConfig determine_output_parallel_config( const ParallelConfig& input_parallel_config, const CoreCoord& compute_grid_size, uint32_t out_channels, diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.hpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.hpp index 69ce604a6713..10ad4cd5c8f4 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.hpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.hpp @@ -188,6 +188,12 @@ Tensor convert_conv_weight_tensor_to_special_padding_tiled_layout( // Converts convolution weights to grouped layout with padded zeros Tensor convert_conv_weight_tensor_to_grouped_layout(const Tensor& conv_weight_tensor, uint32_t num_groups, DataType output_dtype); +sliding_window::ParallelConfig determine_output_parallel_config( + const sliding_window::ParallelConfig& input_parallel_config, + const CoreCoord& compute_grid_size, + uint32_t out_channels, + bool is_mm_conv); + std::ostream& operator<<(std::ostream& os, const Conv2dConfig& config); } // namespace operations::conv diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp index dde9bb6d60f1..f9d06bd24366 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp @@ -368,6 +368,9 @@ ttnn::Tensor prepare_conv_weights( shard_orientation, !use_non_tile_height); + ParallelConfig output_parallel_config = + determine_output_parallel_config(parallel_config, device->compute_with_storage_grid_size(), out_channels, mm_conv); + bool is_non_tile_mul_width = check_non_tile_mul_width(device, conv_config, in_channels); std::optional bias_tensor = std::nullopt; ttnn::Tensor weight_tensor_on_device = weight_tensor; @@ -380,7 +383,7 @@ ttnn::Tensor prepare_conv_weights( opt_conv_op_block_config.act_block_w_ntiles, opt_conv_op_block_config.out_subblock_w_ntiles, parallel_config, - parallel_config, + output_parallel_config, device, groups, opt_conv_op_block_config.act_block_h_ntiles,