diff --git a/tests/sweep_framework/sweeps/conv2d/short/conv2d_short_sweep.py b/tests/sweep_framework/sweeps/conv2d/short/conv2d_short_sweep.py index dfb7d5e04b9..56d9790840f 100644 --- a/tests/sweep_framework/sweeps/conv2d/short/conv2d_short_sweep.py +++ b/tests/sweep_framework/sweeps/conv2d/short/conv2d_short_sweep.py @@ -1623,7 +1623,6 @@ def test_conv2d_localrun(device, input_spec): [1, 1056, 1056, 96, 96, 3, 3, 2, 2, 1, 1, 4, False, 1], # 127 [1, 528, 528, 192, 192, 3, 3, 2, 2, 1, 1, 2, False, 1], # 220 [1, 2904, 2904, 48, 48, 3, 3, 2, 2, 1, 1, 11, False, 1], # 294 - [1, 3024, 1232, 14, 14, 1, 1, 2, 2, 0, 0, 1, False, 1], # 1421 [1, 819, 256, 100, 136, 3, 3, 1, 1, 1, 1, 1, True, 1], # 1443 [1, 819, 256, 50, 68, 3, 3, 1, 1, 1, 1, 1, True, 1], # 1447 [1, 1024, 3, 224, 224, 32, 32, 32, 32, 0, 0, 1, True, 1], # 1458 diff --git a/tests/ttnn/unit_tests/operations/test_new_conv2d.py b/tests/ttnn/unit_tests/operations/test_new_conv2d.py index 4d790730c16..0beea0f771e 100644 --- a/tests/ttnn/unit_tests/operations/test_new_conv2d.py +++ b/tests/ttnn/unit_tests/operations/test_new_conv2d.py @@ -530,7 +530,6 @@ def test_conv_features_multi_device( @pytest.mark.parametrize( "batch_size, output_channels, input_channels, input_height, input_width, filter_height, filter_width, pad_h, pad_w, act_block_w_div", ( - (2, 128, 128, 9, 9, 3, 3, 0, 0, 1), (2, 128, 256, 9, 9, 3, 3, 1, 1, 1), (2, 576, 576, 9, 9, 3, 3, 0, 0, 1), (2, 960, 960, 5, 5, 3, 3, 0, 0, 1), @@ -538,12 +537,11 @@ def test_conv_features_multi_device( (2, 512, 2048, 17, 17, 3, 3, 1, 1, 1), (2, 768, 768, 17, 17, 3, 3, 0, 0, 1), (2, 1280, 2560, 15, 15, 3, 3, 1, 1, 2), - (2, 1280, 2560, 15, 15, 3, 3, 0, 0, 2), (2, 1280, 1280, 17, 17, 3, 3, 1, 1, 1), + [1, 3024, 1232, 14, 14, 1, 1, 0, 0, 1], (2, 768, 32, 9, 9, 3, 3, 1, 1, 1), (2, 64, 128, 9, 9, 3, 3, 1, 1, 1), (2, 32, 128, 9, 9, 3, 3, 1, 1, 1), - (1, 256, 256, 7, 7, 3, 3, 1, 1, 1), ), ) @pytest.mark.parametrize( diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp index 7f0e355b594..6e5839ac82a 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp @@ -54,13 +54,13 @@ Result conv2d( const std::optional& conv_config_, const std::optional& compute_config_, const std::optional& memory_config) { - const bool mm_conv = use_matmul_for_1x1_conv(kernel_size, stride, padding, dilation, groups); + Conv2dConfig conv_config = conv_config_.value_or(Conv2dConfig()); + const bool mm_conv = use_matmul_for_1x1_conv(kernel_size, stride, padding, dilation, groups, conv_config); const uint32_t output_height = ((input_height - kernel_size[0] - ((kernel_size[0] - 1) * (dilation[0] - 1)) + 2 * padding[0]) / stride[0]) + 1; const uint32_t output_width = ((input_width - kernel_size[1] - ((kernel_size[0] - 1) * (dilation[0] - 1)) + 2 * padding[1]) / stride[1]) + 1; - Conv2dConfig conv_config = conv_config_.value_or(Conv2dConfig()); const auto compute_grid_size = device->compute_with_storage_grid_size(); bool auto_shard = false; @@ -158,6 +158,7 @@ Result conv2d( opt_conv_op_block_config.act_block_w_ntiles, opt_conv_op_block_config.out_subblock_w_ntiles, parallel_config, + output_parallel_config, device, groups, opt_conv_op_block_config.act_block_h_ntiles, @@ -194,6 +195,9 @@ Result conv2d( if (bypass_halo) { if (input_tensor_post_tm.layout() == Layout::TILE) { + // Reshape is used as a workaround to an issue in to_layout mentioned here : + // https://github.com/tenstorrent/tt-metal/issues/16330 + input_tensor_post_tm = ttnn::reshape(input_tensor_post_tm, input_tensor_post_tm.get_padded_shape()); input_tensor_post_tm = ttnn::to_layout(input_tensor_post_tm, Layout::ROW_MAJOR, std::nullopt, std::nullopt, device); } diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.cpp index ac94433d535..5913f8f8cdd 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.cpp @@ -160,7 +160,7 @@ ParallelConfig determine_parallel_config( return pconfig; } -static ParallelConfig determine_output_parallel_config( +ParallelConfig determine_output_parallel_config( const ParallelConfig& input_parallel_config, const CoreCoord& compute_grid_size, uint32_t out_channels, @@ -371,9 +371,12 @@ bool use_matmul_for_1x1_conv( const std::array& stride, const std::array& padding, const std::array& dilation, - uint32_t groups) { + uint32_t groups, + const Conv2dConfig& conv_config) { + bool is_width_sharded = + (conv_config.shard_layout.has_value() && conv_config.shard_layout.value() == TensorMemoryLayout::WIDTH_SHARDED); return kernel_size[0] == 1 && kernel_size[1] == 1 && stride[0] == stride[1] && stride[0] == 1 && padding[0] == 0 && - padding[1] == 0 && dilation[0] == 1 && dilation[1] == 1 && groups == 1; + padding[1] == 0 && dilation[0] == 1 && dilation[1] == 1 && groups == 1 && (not is_width_sharded); } // Implements a heuristic for selecting shard layout based on how many tenix cores are available diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.hpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.hpp index 9a5758872c2..480bcb29a91 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.hpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.hpp @@ -95,7 +95,8 @@ bool use_matmul_for_1x1_conv( const std::array& stride, const std::array& padding, const std::array& dilation, - uint32_t groups); + uint32_t groups, + const Conv2dConfig& conv_config); sliding_window::ParallelConfig determine_parallel_config( const TensorMemoryLayout shard_layout, @@ -109,6 +110,12 @@ sliding_window::ParallelConfig determine_parallel_config( bool enable_channels_padding, bool is_out_tiled = true); +sliding_window::ParallelConfig determine_output_parallel_config( + const sliding_window::ParallelConfig& input_parallel_config, + const CoreCoord& compute_grid_size, + uint32_t out_channels, + bool is_mm_conv); + uint32_t get_num_cores_nhw_from_parallel_config(const sliding_window::ParallelConfig& pconfig); uint32_t get_num_cores_channels_from_parallel_config(const sliding_window::ParallelConfig& pconfig); diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_width_sharded_program_factory.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_width_sharded_program_factory.cpp index bf8e12c6aa8..25070d665a3 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_width_sharded_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_width_sharded_program_factory.cpp @@ -368,7 +368,7 @@ tt::tt_metal::operation::ProgramWithCallbacks multi_core_optimized_conv_width_sh uint32_t num_groups = num_blocks_act_h * num_blocks_act_w * num_blocks_weight_w; // writer of conv op partially removes padding on the width // it removes the padding done for block width but it doesn't remove padding done for tiled width - uint32_t output_channels_padded_to_tile_width = round_up(output_channels, input_num_cores * TILE_WIDTH); + uint32_t output_channels_padded_to_tile_width = round_up(output_channels, output_num_cores * TILE_WIDTH); TT_FATAL( output_channels_padded_to_tile_width <= weight_matrix_width, "output_channels_padded_to_tile_width {} should be less than or equal to weight_matrix_width {}", diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp index a3f39ce5c77..b659f7ea475 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp @@ -538,9 +538,7 @@ ttnn::Tensor conv_bias_layout_convert( validate_bias_tensor(bias_tensor_); if (!is_non_tile_mul_width) { auto bias_shape = bias_tensor_.get_shape(); - TT_FATAL( - bias_shape[3] == out_channels && bias_shape[0] == 1 && bias_shape[1] == 1 && bias_shape[2] == 1, - "bias shape is not correct"); + TT_FATAL(bias_shape[0] == 1 && bias_shape[1] == 1 && bias_shape[2] == 1, "bias shape is not correct"); tt::tt_metal::LegacyShape bias_channels_padded_shape = tt::tt_metal::LegacyShape( std::array({1, 1, 32, round_up(out_channels, weight_block_w_ntiles * 32)})); bias_tensor_ = @@ -664,7 +662,8 @@ std::pair> prepare_conv_weights_biases DataType weights_bias_dtype, uint32_t weight_block_h_ntiles, uint32_t weight_block_w_ntiles, - const ParallelConfig& parallel_config, + const ParallelConfig& input_parallel_config, + const ParallelConfig& output_parallel_config, T* device, uint32_t groups, uint32_t act_block_h_ntiles, @@ -705,9 +704,11 @@ std::pair> prepare_conv_weights_biases uint32_t window_h = weights_shape[2]; uint32_t window_w = weights_shape[3]; - uint32_t num_cores_channels = get_num_cores_channels_from_parallel_config(parallel_config); - uint32_t out_channels_padded = tt::round_up(out_channels, num_cores_channels * tt::constants::TILE_WIDTH); - uint32_t in_channels_padded = tt::round_up(in_channels, num_cores_channels * input_channels_alignment); + uint32_t input_num_cores_channels = get_num_cores_channels_from_parallel_config(input_parallel_config); + uint32_t output_num_cores_channels = get_num_cores_channels_from_parallel_config(output_parallel_config); + + uint32_t out_channels_padded = tt::round_up(out_channels, output_num_cores_channels * tt::constants::TILE_WIDTH); + uint32_t in_channels_padded = tt::round_up(in_channels, input_num_cores_channels * input_channels_alignment); uint32_t out_channel_padding = out_channels_padded - out_channels; tt::tt_metal::LegacyShape weights_channels_padded_shape = tt::tt_metal::LegacyShape( @@ -733,12 +734,12 @@ std::pair> prepare_conv_weights_biases ttnn::pad(weight_tensor_, weights_channels_padded_shape.to_array_4D(), tt::tt_metal::Array4D({0, 0, 0, 0}), 0); // for conv op, pad the weights to block shape - if (parallel_config.shard_scheme == TensorMemoryLayout::HEIGHT_SHARDED) { + if (input_parallel_config.shard_scheme == TensorMemoryLayout::HEIGHT_SHARDED) { weight_tensor_ = convert_conv_weight_tensor_to_special_padding_tiled_layout( weight_tensor_, weight_block_h_ntiles, weight_block_w_ntiles, weights_bias_dtype); - } else if (parallel_config.shard_scheme == TensorMemoryLayout::BLOCK_SHARDED) { + } else if (input_parallel_config.shard_scheme == TensorMemoryLayout::BLOCK_SHARDED) { weight_tensor_ = convert_conv_weight_tensor_to_tiled_layout_block_sharded( - weight_tensor_, num_cores_channels, weights_bias_dtype); + weight_tensor_, input_num_cores_channels, weights_bias_dtype); } else { weight_tensor_ = convert_conv_weight_tensor_to_tiled_layout( weight_tensor_, weight_block_h_ntiles, weight_block_w_ntiles, weights_bias_dtype); @@ -765,14 +766,15 @@ std::pair> prepare_conv_weights_biases bias_tensor_ = bias_tensor.value(); bool is_bias_tensor_is_on_device = ttnn::is_tensor_on_device_or_multidevice(bias_tensor_); if (!is_bias_tensor_is_on_device) { + TT_FATAL(bias_tensor_.shape()[3] == out_channels, "Bias must have the same length as output channels"); bias_tensor_ = conv_bias_layout_convert( bias_tensor_, weights_bias_dtype, weight_block_h_ntiles, weight_block_w_ntiles, - parallel_config, + output_parallel_config, device, - out_channels, + out_channels_padded, is_non_tile_mul_width); bias_tensor_ = ttnn::operations::core::to_device(bias_tensor_, device, std::nullopt); } @@ -806,7 +808,7 @@ ttnn::Tensor prepare_conv_weights( Conv2dConfig conv_config = conv_config_.value_or(Conv2dConfig()); DeviceComputeKernelConfig compute_config = compute_config_.value_or( init_device_compute_kernel_config(device->arch(), std::nullopt, MathFidelity::HiFi4, true, false, false)); - const bool mm_conv = use_matmul_for_1x1_conv(kernel_size, stride, padding, dilation, groups); + const bool mm_conv = use_matmul_for_1x1_conv(kernel_size, stride, padding, dilation, groups, conv_config); const uint32_t output_height = ((input_height - kernel_size[0] - ((kernel_size[0] - 1) * (dilation[0] - 1)) + 2 * padding[0]) / stride[0]) + 1; const uint32_t output_width = @@ -847,6 +849,9 @@ ttnn::Tensor prepare_conv_weights( shard_orientation, !use_non_tile_height); + ParallelConfig output_parallel_config = determine_output_parallel_config( + parallel_config, device->compute_with_storage_grid_size(), out_channels, mm_conv); + bool is_non_tile_mul_width = check_non_tile_mul_width(device, conv_config, in_channels); std::optional bias_tensor = std::nullopt; ttnn::Tensor weight_tensor_on_device = weight_tensor; @@ -859,6 +864,7 @@ ttnn::Tensor prepare_conv_weights( opt_conv_op_block_config.act_block_w_ntiles, opt_conv_op_block_config.out_subblock_w_ntiles, parallel_config, + output_parallel_config, device, groups, opt_conv_op_block_config.act_block_h_ntiles, @@ -890,15 +896,17 @@ ttnn::Tensor prepare_conv_bias( TT_FATAL( !ttnn::is_tensor_on_device_or_multidevice(bias_tensor), "Error: bias tensor must be on host for preparation."); - const bool mm_conv = use_matmul_for_1x1_conv(kernel_size, stride, padding, dilation, groups); + Conv2dConfig conv_config = conv_config_.value_or(Conv2dConfig()); + + const bool mm_conv = use_matmul_for_1x1_conv(kernel_size, stride, padding, dilation, groups, conv_config); const uint32_t output_height = ((input_height - kernel_size[0] - ((kernel_size[0] - 1) * (dilation[0] - 1)) + 2 * padding[0]) / stride[0]) + 1; const uint32_t output_width = ((input_width - kernel_size[1] - ((kernel_size[0] - 1) * (dilation[0] - 1)) + 2 * padding[1]) / stride[1]) + 1; - Conv2dConfig conv_config = conv_config_.value_or(Conv2dConfig()); DeviceComputeKernelConfig compute_config = compute_config_.value_or( init_device_compute_kernel_config(device->arch(), std::nullopt, MathFidelity::HiFi4, true, false, false)); + auto opt_conv_op_block_config = get_opt_block_config( mm_conv, in_channels, @@ -936,14 +944,19 @@ ttnn::Tensor prepare_conv_bias( shard_orientation, !use_non_tile_height); + ParallelConfig output_parallel_config = determine_output_parallel_config( + parallel_config, device->compute_with_storage_grid_size(), out_channels, mm_conv); + bool is_non_tile_mul_width = check_non_tile_mul_width(device, conv_config, in_channels); ttnn::Tensor bias_tensor_ = bias_tensor; + TT_FATAL(bias_tensor_.shape()[3] == out_channels, "Bias must have the same length as output channels"); + bias_tensor_ = conv_bias_layout_convert( bias_tensor_, conv_config.weights_dtype, opt_conv_op_block_config.act_block_h_ntiles, weight_block_w_ntiles, - parallel_config, + output_parallel_config, device, out_channels, is_non_tile_mul_width); @@ -1027,7 +1040,8 @@ template std::pair> prepare_conv_weigh DataType weights_bias_dtype, uint32_t weight_block_h_ntiles, uint32_t weight_block_w_ntiles, - const ParallelConfig& parallel_config, + const ParallelConfig& input_parallel_config, + const ParallelConfig& output_parallel_config, Device* device, uint32_t groups, uint32_t act_block_h_ntiles, @@ -1043,7 +1057,8 @@ prepare_conv_weights_biases_and_move_to_device( DataType weights_bias_dtype, uint32_t weight_block_h_ntiles, uint32_t weight_block_w_ntiles, - const ParallelConfig& parallel_config, + const ParallelConfig& input_parallel_config, + const ParallelConfig& output_parallel_config, MeshDevice* device, uint32_t groups, uint32_t act_block_h_ntiles, diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.hpp b/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.hpp index 2c4b7f8eab1..e42655bb0e1 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.hpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.hpp @@ -111,7 +111,8 @@ std::pair> prepare_conv_weights_biases DataType weights_bias_dtype, uint32_t weight_block_h_ntiles, uint32_t weight_block_w_ntiles, - const sliding_window::ParallelConfig& parallel_config, + const sliding_window::ParallelConfig& input_parallel_config, + const sliding_window::ParallelConfig& output_parallel_config, T* device, uint32_t groups, uint32_t act_block_h_ntiles, diff --git a/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.cpp b/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.cpp index a2700b26e55..c5ed25af6e4 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.cpp @@ -168,7 +168,12 @@ Result conv_transpose2d( log_debug(LogOp, "Padding : ({},{}) ({},{})", input_pad_top, input_pad_bottom, input_pad_left, input_pad_right); const bool mm_conv = use_matmul_for_1x1_conv( - kernel_size, {1, 1}, {input_pad_top + input_pad_bottom, input_pad_left + input_pad_right}, dilation, groups); + kernel_size, + {1, 1}, + {input_pad_top + input_pad_bottom, input_pad_left + input_pad_right}, + dilation, + groups, + conv_config); const auto compute_grid_size = device->compute_with_storage_grid_size(); @@ -268,7 +273,6 @@ Result conv_transpose2d( get_fp32_dest_acc_en(compute_config), conv_config.enable_split_reader); - // TODO: Flip the Weights bool weight_is_on_device = ttnn::is_tensor_on_device_or_multidevice(weight_tensor); ttnn::Tensor weight_tensor_on_device = weight_tensor; std::optional bias_tensor_on_device = bias_tensor; @@ -282,6 +286,7 @@ Result conv_transpose2d( opt_conv_op_block_config.act_block_w_ntiles, opt_conv_op_block_config.out_subblock_w_ntiles, parallel_config, + output_parallel_config, device, groups, opt_conv_op_block_config.act_block_h_ntiles,