From 8db5c453dbe0e46e318379dd70a5eb4cd0dc0a77 Mon Sep 17 00:00:00 2001 From: Sankar Manoj Date: Fri, 20 Dec 2024 06:43:57 +0000 Subject: [PATCH] #0: Fix failing test case for width sharded non-32 multiple output width --- .../ttnn/operations/conv/conv2d/conv2d.cpp | 1 + ...onv2d_op_width_sharded_program_factory.cpp | 2 +- .../conv/conv2d/prepare_conv2d_weights.cpp | 26 ++++++++++++------- .../conv/conv2d/prepare_conv2d_weights.hpp | 3 ++- .../conv_transpose2d/conv_transpose2d.cpp | 1 + 5 files changed, 21 insertions(+), 12 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp index 690cfcbde9a3..cdae35cc10f4 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp @@ -155,6 +155,7 @@ Result conv2d( opt_conv_op_block_config.act_block_w_ntiles, opt_conv_op_block_config.out_subblock_w_ntiles, parallel_config, + output_parallel_config, device, groups, opt_conv_op_block_config.act_block_h_ntiles, diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_width_sharded_program_factory.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_width_sharded_program_factory.cpp index 62696da35ed0..382912c59e4b 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_width_sharded_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_width_sharded_program_factory.cpp @@ -368,7 +368,7 @@ tt::tt_metal::operation::ProgramWithCallbacks multi_core_optimized_conv_width_sh uint32_t num_groups = num_blocks_act_h * num_blocks_act_w * num_blocks_weight_w; // writer of conv op partially removes padding on the width // it removes the padding done for block width but it doesn't remove padding done for tiled width - uint32_t output_channels_padded_to_tile_width = round_up(output_channels, input_num_cores * TILE_WIDTH); + uint32_t output_channels_padded_to_tile_width = round_up(output_channels, output_num_cores * TILE_WIDTH); TT_FATAL( output_channels_padded_to_tile_width <= weight_matrix_width, "output_channels_padded_to_tile_width {} should be less than or equal to weight_matrix_width {}", diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp index 0ba0363a9e6e..dde9bb6d60f1 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp @@ -188,7 +188,8 @@ std::pair> prepare_conv_weights_biases DataType weights_bias_dtype, uint32_t weight_block_h_ntiles, uint32_t weight_block_w_ntiles, - const ParallelConfig& parallel_config, + const ParallelConfig& input_parallel_config, + const ParallelConfig& output_parallel_config, T * device, uint32_t groups, uint32_t act_block_h_ntiles, @@ -231,9 +232,11 @@ std::pair> prepare_conv_weights_biases uint32_t window_h = weights_shape[2]; uint32_t window_w = weights_shape[3]; - uint32_t num_cores_channels = get_num_cores_channels_from_parallel_config(parallel_config); - uint32_t out_channels_padded = tt::round_up(out_channels, num_cores_channels * tt::constants::TILE_WIDTH); - uint32_t in_channels_padded = tt::round_up(in_channels, num_cores_channels * input_channels_alignment); + uint32_t input_num_cores_channels = get_num_cores_channels_from_parallel_config(input_parallel_config); + uint32_t output_num_cores_channels = get_num_cores_channels_from_parallel_config(output_parallel_config); + + uint32_t out_channels_padded = tt::round_up(out_channels, output_num_cores_channels * tt::constants::TILE_WIDTH); + uint32_t in_channels_padded = tt::round_up(in_channels, input_num_cores_channels * input_channels_alignment); uint32_t out_channel_padding = out_channels_padded - out_channels; tt::tt_metal::LegacyShape weights_channels_padded_shape = tt::tt_metal::LegacyShape(std::array( @@ -258,12 +261,12 @@ std::pair> prepare_conv_weights_biases weight_tensor_ = ttnn::pad(weight_tensor_, weights_channels_padded_shape.to_array_4D(), tt::tt_metal::Array4D({0, 0, 0, 0}), 0); // for conv op, pad the weights to block shape - if (parallel_config.shard_scheme == TensorMemoryLayout::HEIGHT_SHARDED) { + if (input_parallel_config.shard_scheme == TensorMemoryLayout::HEIGHT_SHARDED) { weight_tensor_ = tt::tt_metal::convert_conv_weight_tensor_to_special_padding_tiled_layout( weight_tensor_, weight_block_h_ntiles, weight_block_w_ntiles, weights_bias_dtype); - } else if(parallel_config.shard_scheme == TensorMemoryLayout::BLOCK_SHARDED) { + } else if(input_parallel_config.shard_scheme == TensorMemoryLayout::BLOCK_SHARDED) { weight_tensor_ = tt::tt_metal::convert_conv_weight_tensor_to_tiled_layout_block_sharded( - weight_tensor_, num_cores_channels, weights_bias_dtype); + weight_tensor_, input_num_cores_channels, weights_bias_dtype); } else { weight_tensor_ = tt::tt_metal::convert_conv_weight_tensor_to_tiled_layout( weight_tensor_, weight_block_h_ntiles, weight_block_w_ntiles, weights_bias_dtype); @@ -289,7 +292,7 @@ std::pair> prepare_conv_weights_biases bias_tensor_ = bias_tensor.value(); bool is_bias_tensor_is_on_device = ttnn::is_tensor_on_device_or_multidevice(bias_tensor_); if(!is_bias_tensor_is_on_device) { - bias_tensor_ = conv_bias_layout_convert(bias_tensor_, weights_bias_dtype, weight_block_h_ntiles, weight_block_w_ntiles, parallel_config, device, out_channels, is_non_tile_mul_width); + bias_tensor_ = conv_bias_layout_convert(bias_tensor_, weights_bias_dtype, weight_block_h_ntiles, weight_block_w_ntiles, input_parallel_config, device, out_channels, is_non_tile_mul_width); bias_tensor_ = ttnn::operations::core::to_device(bias_tensor_, device, std::nullopt); } } @@ -377,6 +380,7 @@ ttnn::Tensor prepare_conv_weights( opt_conv_op_block_config.act_block_w_ntiles, opt_conv_op_block_config.out_subblock_w_ntiles, parallel_config, + parallel_config, device, groups, opt_conv_op_block_config.act_block_h_ntiles, @@ -550,7 +554,8 @@ template std::pair> prepare_conv_weigh DataType weights_bias_dtype, uint32_t weight_block_h_ntiles, uint32_t weight_block_w_ntiles, - const ParallelConfig& parallel_config, + const ParallelConfig& input_parallel_config, + const ParallelConfig& output_parallel_config, Device* device, uint32_t groups, uint32_t act_block_h_ntiles, @@ -565,7 +570,8 @@ template std::pair> prepare_conv_weigh DataType weights_bias_dtype, uint32_t weight_block_h_ntiles, uint32_t weight_block_w_ntiles, - const ParallelConfig& parallel_config, + const ParallelConfig& input_parallel_config, + const ParallelConfig& output_parallel_config, MeshDevice* device, uint32_t groups, uint32_t act_block_h_ntiles, diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.hpp b/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.hpp index 221a9d230f52..d6ee0bde769f 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.hpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.hpp @@ -77,7 +77,8 @@ std::pair> prepare_conv_weights_biases DataType weights_bias_dtype, uint32_t weight_block_h_ntiles, uint32_t weight_block_w_ntiles, - const sliding_window::ParallelConfig& parallel_config, + const sliding_window::ParallelConfig& input_parallel_config, + const sliding_window::ParallelConfig& output_parallel_config, T * device, uint32_t groups, uint32_t act_block_h_ntiles, diff --git a/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.cpp b/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.cpp index 32e30e6bf5aa..d28d864a0b1b 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.cpp @@ -279,6 +279,7 @@ Result conv_transpose2d( opt_conv_op_block_config.act_block_w_ntiles, opt_conv_op_block_config.out_subblock_w_ntiles, parallel_config, + output_parallel_config, device, groups, opt_conv_op_block_config.act_block_h_ntiles,