From cf69213196bc3cfba8d2063e9cdd66578bc39651 Mon Sep 17 00:00:00 2001 From: Artem Yerofieiev <169092593+ayerofieiev-tt@users.noreply.github.com> Date: Thu, 12 Dec 2024 20:51:43 -0800 Subject: [PATCH] #0: Fix conv_transpose2d initting wrong compute_kernel_config variant (#15987) ### Ticket None ### Problem description Found that conv_transpose2d is initializing a compute_kernel_config for GS even on WH arch. ### What's changed Follow what conv2d does ### Checklist - [ ] [Post commit CI](https://github.com/tenstorrent/tt-metal/actions/runs/12309784960) --- .../cpp/ttnn/operations/conv/conv2d/conv2d.cpp | 2 +- .../conv/conv2d/prepare_conv2d_weights.cpp | 18 ++++++++++++++++-- .../conv/conv_transpose2d/conv_transpose2d.cpp | 10 ++++++++-- 3 files changed, 25 insertions(+), 5 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp index d6d06ec490f..888bec60ee1 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp @@ -86,7 +86,7 @@ Result conv2d( (conv_config.weights_dtype == DataType::BFLOAT8_B || conv_config.weights_dtype == DataType::BFLOAT16) && conv_config.output_layout == Layout::ROW_MAJOR && ((elem_size * in_channels) % (16 * num_cores_c)) == 0; - DeviceComputeKernelConfig compute_config = compute_config_.value_or( init_device_compute_kernel_config( + DeviceComputeKernelConfig compute_config = compute_config_.value_or(init_device_compute_kernel_config( device->arch(), std::nullopt, MathFidelity::HiFi4, diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp index 1009ed7a87b..0154a972d4a 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp @@ -295,7 +295,14 @@ ttnn::Tensor prepare_conv_weights( const std::optional& compute_config_) { TT_FATAL(!ttnn::is_tensor_on_device_or_multidevice(weight_tensor), "Error: weight tensor must be on host for preparation."); Conv2dConfig conv_config = conv_config_.value_or(Conv2dConfig()); - DeviceComputeKernelConfig compute_config = compute_config_.value_or(DeviceComputeKernelConfig()); + DeviceComputeKernelConfig compute_config = compute_config_.value_or(init_device_compute_kernel_config( + device->arch(), + std::nullopt, + MathFidelity::HiFi4, + true, + false, + false + )); const bool mm_conv = use_matmul_for_1x1_conv(kernel_size, stride, padding, dilation, groups); const uint32_t output_height = ((input_height - kernel_size[0] - ((kernel_size[0] - 1 ) * (dilation[0] - 1)) + 2 * padding[0]) / stride[0]) + 1; const uint32_t output_width = @@ -382,7 +389,14 @@ ttnn::Tensor prepare_conv_bias( ((input_width - kernel_size[1] - ((kernel_size[0] - 1) * (dilation[0] - 1)) + 2 * padding[1]) / stride[1]) + 1; Conv2dConfig conv_config = conv_config_.value_or(Conv2dConfig()); - DeviceComputeKernelConfig compute_config = compute_config_.value_or(DeviceComputeKernelConfig()); + DeviceComputeKernelConfig compute_config = compute_config_.value_or(init_device_compute_kernel_config( + device->arch(), + std::nullopt, + MathFidelity::HiFi4, + true, + false, + false + )); auto opt_conv_op_block_config = get_opt_block_config( mm_conv, in_channels, diff --git a/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.cpp b/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.cpp index 21af1f921fb..40c4db5045f 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.cpp @@ -110,8 +110,14 @@ Result conv_transpose2d( const std::optional& compute_config_, const std::optional& memory_config ) { Conv2dConfig conv_config = conv_config_.value_or(Conv2dConfig()); - DeviceComputeKernelConfig compute_config = compute_config_.value_or(DeviceComputeKernelConfig()); - + DeviceComputeKernelConfig compute_config = compute_config_.value_or(init_device_compute_kernel_config( + device->arch(), + std::nullopt, + MathFidelity::HiFi4, + true, + false, + false + )); //Inverse of sliding_window.get_output_shape() SlidingWindowConfig sliding_window_config = SlidingWindowConfig{