diff --git a/.github/workflows/ttnn-run-sweeps.yaml b/.github/workflows/ttnn-run-sweeps.yaml index 6f570b74363..c62f450e2af 100644 --- a/.github/workflows/ttnn-run-sweeps.yaml +++ b/.github/workflows/ttnn-run-sweeps.yaml @@ -329,6 +329,7 @@ on: - data_movement.repeat_interleave.repeat_interleave - data_movement.nonzero.nonzero - data_movement.backward.concat_bw.concat_bw + - conv_transpose2d.short.conv_transpose2d_short_sweep - conv2d.full.conv2d_misc - conv2d.full.conv2d_sharding - conv2d.full.conv2d_sliding_window diff --git a/tests/sweep_framework/sweep_utils/conv_transpose2d_common.py b/tests/sweep_framework/sweep_utils/conv_transpose2d_common.py new file mode 100644 index 00000000000..8e3c0e58af2 --- /dev/null +++ b/tests/sweep_framework/sweep_utils/conv_transpose2d_common.py @@ -0,0 +1,260 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple, List +import itertools +import random +import torch + +import ttnn + +from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time +from models.utility_functions import torch_random + +# Override the default timeout in seconds for hang detection. +TIMEOUT = 30 + + +def get_input_specs( + batch_list: List[int], + acts_list: List[int], + kernel_list: List[int], + stride_list: List[int], + padding_list: List[int], + dilation_list: List[int], +) -> Tuple[int, int, int, int, int, int, int, int, int, int]: + for batch_size, activation, kernel, stride, padding, dilation in itertools.product( + batch_list, acts_list, kernel_list, stride_list, padding_list, dilation_list + ): + yield (batch_size, activation, activation, kernel, kernel, stride, stride, padding, padding, dilation) + + +def mesh_device_fixture(): + num_devices = ttnn.GetNumPCIeDevices() + # As of now take device id as 0. + device_id = 0 + assert device_id < num_devices, "CreateDevice not supported for non-mmio device" + device = ttnn.CreateDevice(device_id=device_id, l1_small_size=32768) + ttnn.SetDefaultDevice(device) + + device_name = "Unknown" + if ttnn.device.is_grayskull(device): + device_name = "grayskull" + elif ttnn.device.is_wormhole_b0(device): + device_name = "wormhole_b0" + yield device, device_name + + ttnn.close_device(device) + + +def run_full( + input_specs, + input_channels, + output_channels, + transpose_mcast, + output_layout, + has_bias, + enable_act_double_buffer, + enable_split_reader, + enable_subblock_padding, + activations_dtype, + weights_dtype, + math_fidelity, + fp32_accum, + packer_l1_acc, + groups, + override_sharding_config, + core_grid, + use_shallow_conv_variant, + deallocate_activation, + enable_auto_formatting, + device, + padded_input_channels=None, +) -> list: + [ + batch_size, + input_height, + input_width, + kernel_height, + kernel_width, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, + out_pad_h, + out_pad_w, + ] = input_specs + conv_input_shape = [batch_size, input_channels, input_height, input_width] + conv_weight_shape = [output_channels, input_channels // groups, kernel_height, kernel_width] + conv_bias_shape = [1, 1, 1, output_channels] + torch_input_tensor_nchw = torch.randn(conv_input_shape, dtype=torch.bfloat16).float() + + torch_input_tensor = torch.permute(torch_input_tensor_nchw, (0, 2, 3, 1)) + torch_weight_tensor = torch.randn(conv_weight_shape, dtype=torch.bfloat16).float() + + torch_bias_tensor = torch.randn(conv_bias_shape, dtype=torch.bfloat16).float() if has_bias else None + torch_out_golden_tensor = torch.nn.functional.conv2d( + torch_input_tensor_nchw, + torch_weight_tensor, + bias=torch_bias_tensor.reshape(-1) if has_bias else None, + stride=(stride_h, stride_w), + padding=(pad_h, pad_w), + dilation=(dilation, dilation), + groups=groups, + ) + + tt_weight_tensor = ttnn.from_torch( + torch_weight_tensor, weights_dtype if weights_dtype != ttnn.bfloat8_b else ttnn.float32 + ) + tt_bias_tensor = None + if has_bias: + tt_bias_tensor = ttnn.from_torch( + torch_bias_tensor, weights_dtype if weights_dtype != ttnn.bfloat8_b else ttnn.float32 + ) + + tt_input_tensor = ttnn.from_torch(torch_input_tensor, ttnn.bfloat16) + + conv_config = ttnn.Conv2dConfig( + dtype=activations_dtype, + weights_dtype=weights_dtype, + math_fidelity=math_fidelity, + shard_layout=None, + deallocate_activation=deallocate_activation, + fp32_dest_acc_enabled=fp32_accum, + packer_l1_accum_enabled=packer_l1_acc, + override_sharding_config=override_sharding_config, + output_layout=output_layout, + enable_act_double_buffer=enable_act_double_buffer, + enable_split_reader=enable_split_reader, + enable_subblock_padding=enable_subblock_padding, + ) + + if override_sharding_config: + if len(core_grid) == 2: + conv_config.core_grid = ttnn.CoreRangeSet({ttnn.CoreRange(core_grid[0], core_grid[1])}) + elif len(core_grid) == 4: + conv_config.core_grid = ttnn.CoreRangeSet( + {ttnn.CoreRange(core_grid[0], core_grid[1]), ttnn.CoreRange(core_grid[2], core_grid[3])} + ) + start_time = start_measuring_time() + [tt_output_tensor_on_device, out_height, out_width, weights_device, bias_device] = ttnn.conv_transpose2d( + input_tensor=tt_input_tensor, + weight_tensor=tt_weight_tensor, + in_channels=input_channels, + out_channels=output_channels, + device=device, + bias_tensor=tt_bias_tensor, + kernel_size=(kernel_height, kernel_width), + stride=(stride_h, stride_w), + padding=(pad_h, pad_w), + output_padding=(out_pad_h, out_pad_w), + dilation=(dilation_h, dilation_w), + batch_size=batch_size, + input_height=input_height, + input_width=input_width, + conv_config=conv_config, + groups=groups, + ) + + tt_output_tensor = ttnn.from_device(tt_output_tensor_on_device) + torch_output_tensor = ttnn.to_torch(tt_output_tensor) + e2e_perf = stop_measuring_time(start_time) + + # torch_output_tensor is in row major layout and NHWC shape + # NHWC to NCHW + torch_output_tensor = torch_output_tensor.reshape(batch_size, out_height, out_width, torch_output_tensor.shape[-1]) + torch_output_tensor = torch_output_tensor[:, :, :, :output_channels] + + torch_output_tensor = torch.permute(torch_output_tensor, (0, 3, 1, 2)) + + return [check_with_pcc(torch_output_tensor, torch_out_golden_tensor, pcc=0.998), e2e_perf] + + +def run_short( + input_specs, + device, +) -> list: + [ + batch_size, + input_channels, + input_height, + input_width, + output_channels, + kernel_height, + kernel_width, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, + out_pad_h, + out_pad_w, + ] = input_specs + print(input_specs) + groups = 1 + has_bias = True + + conv_input_shape = [batch_size, input_channels, input_height, input_width] + conv_weight_shape = [input_channels, output_channels // groups, kernel_height, kernel_width] + conv_bias_shape = [1, 1, 1, output_channels] + torch_input_tensor_nchw = torch.randn(conv_input_shape, dtype=torch.bfloat16).float() + + torch_input_tensor = torch.permute(torch_input_tensor_nchw, (0, 2, 3, 1)) + torch_weight_tensor = torch.randn(conv_weight_shape, dtype=torch.bfloat16).float() + + torch_bias_tensor = None + if has_bias: + torch_bias_tensor = torch.randn(conv_bias_shape, dtype=torch.bfloat16).float() if has_bias else None + torch_out_golden_tensor = torch.nn.functional.conv_transpose2d( + torch_input_tensor_nchw, + torch_weight_tensor, + bias=torch_bias_tensor.reshape(-1) if has_bias else None, + stride=(stride_h, stride_w), + padding=(pad_h, pad_w), + dilation=(dilation_h, dilation_w), + output_padding=(out_pad_h, out_pad_w), + groups=groups, + ) + + tt_weight_tensor = ttnn.from_torch(torch_weight_tensor, ttnn.bfloat16) + tt_bias_tensor = None + if has_bias: + tt_bias_tensor = ttnn.from_torch(torch_bias_tensor, ttnn.bfloat16) + + tt_input_tensor = ttnn.from_torch(torch_input_tensor, ttnn.bfloat16) + + start_time = start_measuring_time() + [tt_output_tensor_on_device, out_height, out_width, weights_device, bias_device] = ttnn.conv_transpose2d( + input_tensor=tt_input_tensor, + weight_tensor=tt_weight_tensor, + in_channels=input_channels, + out_channels=output_channels, + device=device, + bias_tensor=tt_bias_tensor, + kernel_size=(kernel_height, kernel_width), + stride=(stride_h, stride_w), + padding=(pad_h, pad_w), + output_padding=(out_pad_h, out_pad_w), + dilation=(dilation_h, dilation_w), + batch_size=batch_size, + input_height=input_height, + input_width=input_width, + groups=groups, + ) + + tt_output_tensor = ttnn.from_device(tt_output_tensor_on_device) + torch_output_tensor = ttnn.to_torch(tt_output_tensor) + e2e_perf = stop_measuring_time(start_time) + + # torch_output_tensor is in row major layout and NHWC shape + # NHWC to NCHW + torch_output_tensor = torch_output_tensor.reshape(batch_size, out_height, out_width, torch_output_tensor.shape[-1]) + torch_output_tensor = torch_output_tensor[:, :, :, :output_channels] + + torch_output_tensor = torch.permute(torch_output_tensor, (0, 3, 1, 2)) + + return [check_with_pcc(torch_output_tensor, torch_out_golden_tensor, pcc=0.998), e2e_perf] diff --git a/tests/sweep_framework/sweeps/conv_transpose2d/short/conv_transpose2d_short_sweep.py b/tests/sweep_framework/sweeps/conv_transpose2d/short/conv_transpose2d_short_sweep.py new file mode 100644 index 00000000000..53b7c41ac2d --- /dev/null +++ b/tests/sweep_framework/sweeps/conv_transpose2d/short/conv_transpose2d_short_sweep.py @@ -0,0 +1,63 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple, List +import os +import itertools +import random +import torch + +import ttnn + +from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time +from models.utility_functions import torch_random +from tests.sweep_framework.sweep_utils.conv_transpose2d_common import run_short, mesh_device_fixture + +parameters = { + "short_sweep_suite": { + "input_specs": [ + # Contains following params + # [batch_size, input_channels, input_height, input_width, output_channels, kernel_height, kernel_width, stride_h, stride_w, pad_h, pad_w, dilation_h, dilation_w, out_pad_h, out_pad_w] + # [20, 16, 50, 100, 33, 3, 3, 2, 2, 0, 0, 1, 1, 0, 0], Batch size too big + [1, 16, 50, 100, 33, 3, 3, 2, 2, 0, 0, 1, 1, 0, 0], + [1, 1024, 14, 14, 512, 2, 2, 2, 2, 0, 0, 1, 1, 0, 0], + [1, 128, 112, 112, 64, 2, 2, 2, 2, 0, 0, 1, 1, 0, 0], + [1, 128, 64, 64, 64, 2, 2, 2, 2, 0, 0, 1, 1, 0, 0], + [1, 16, 14, 14, 1, 2, 2, 2, 2, 0, 0, 1, 1, 0, 0], + [1, 256, 32, 32, 128, 2, 2, 2, 2, 0, 0, 1, 1, 0, 0], + [1, 256, 56, 56, 128, 2, 2, 2, 2, 0, 0, 1, 1, 0, 0], + [1, 4, 7, 7, 16, 2, 2, 2, 2, 0, 0, 1, 1, 0, 0], + [1, 512, 16, 16, 256, 2, 2, 2, 2, 0, 0, 1, 1, 0, 0], + # [1, 512, 28, 28, 256, 2, 2, 2, 2, 0, 0, 1, 1, 0, 0], + [1, 64, 128, 128, 32, 2, 2, 2, 2, 0, 0, 1, 1, 0, 0], + ] + }, +} + + +def invalidate_vector(test_vector) -> Tuple[bool, Optional[str]]: + return False, None + + +def run( + input_specs, + *, + device, +) -> list: + return run_short( + input_specs, + device, + ) + + +import pytest + + +@pytest.mark.parametrize("input_spec", parameters["short_sweep_suite"]["input_specs"]) +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) +def test_conv_transpose2d_localrun(device, input_spec): + run_short( + input_spec, + device, + ) diff --git a/tests/ttnn/unit_tests/operations/test_conv_transpose2d.py b/tests/ttnn/unit_tests/operations/test_conv_transpose2d.py index ee03cfb770e..699caa49e54 100644 --- a/tests/ttnn/unit_tests/operations/test_conv_transpose2d.py +++ b/tests/ttnn/unit_tests/operations/test_conv_transpose2d.py @@ -165,6 +165,8 @@ def run_conv_transpose2d( assert passing +@skip_for_blackhole() +@skip_for_grayskull() @pytest.mark.parametrize("device_params", [{"l1_small_size": 64 * 1024}], indirect=True) @pytest.mark.parametrize( "batch_size, input_height, input_width, input_channels, output_channels, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, out_pad_h, out_pad_w, config, shard_layout", @@ -173,26 +175,11 @@ def run_conv_transpose2d( (1, 8, 8, 256, 256, 3, 3, 1, 1, 1, 1, 0, 0, None, ttnn.TensorMemoryLayout.BLOCK_SHARDED), (1, 16, 16, 256, 256, 3, 3, 1, 1, 1, 1, 0, 0, None, ttnn.TensorMemoryLayout.BLOCK_SHARDED), (1, 256, 256, 32, 32, 3, 3, 1, 1, 1, 1, 0, 0, {"act_block_h": 64}, ttnn.TensorMemoryLayout.HEIGHT_SHARDED), + (1, 256, 256, 32, 32, 1, 1, 1, 1, 0, 0, 0, 0, {"act_block_h": 64}, ttnn.TensorMemoryLayout.HEIGHT_SHARDED), # Stride = 2 (1, 8, 8, 32, 64, 3, 3, 2, 2, 1, 1, 1, 1, None, ttnn.TensorMemoryLayout.WIDTH_SHARDED), (1, 128, 128, 32, 64, 3, 3, 2, 2, 1, 1, 1, 1, {"act_block_h": 64}, ttnn.TensorMemoryLayout.HEIGHT_SHARDED), - ( - 1, - 16, - 16, - 256, - 256, - 3, - 3, - 2, - 2, - 1, - 1, - 1, - 1, - None, - ttnn.TensorMemoryLayout.BLOCK_SHARDED, - ), # Fails with error : act_block_w_datums == round_up(conv_act_size_c * filter_w, TILE_WIDTH) + (1, 16, 16, 256, 256, 3, 3, 2, 2, 1, 1, 1, 1, None, ttnn.TensorMemoryLayout.BLOCK_SHARDED), # # (1, 16, 16, 32, 32, 3, 3, 2, 2, 1, 1, 0, 0, None, ttnn.TensorMemoryLayout.HEIGHT_SHARDED), # Issue with reading block sharded tensor # Vanilla Unet # Filter Size = 2 not supported in Block sharded @@ -232,6 +219,8 @@ def test_simple_conv_t2d( config, shard_layout, ): + if device.core_grid.y != 8: + pytest.skip("Needs 8x8 Grid") run_conv_transpose2d( device, math_fidelity=ttnn.MathFidelity.HiFi4, diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp index ec94b4ce6e2..cfcb40d5e20 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp @@ -838,7 +838,7 @@ ttnn::operations::matmul::MatmulProgramConfig determine_matmul_op_config_from_co } } -static void adjust_conv_op_config_for_auto_shard( +void adjust_conv_op_config_for_auto_shard( bool is_mm_conv, uint32_t batch_size, uint32_t in_channels, diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.hpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.hpp index 2a5f22d8e15..58495de6fc1 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.hpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.hpp @@ -126,6 +126,14 @@ uint32_t get_num_cores_nhw_from_parallel_config(const sliding_window::ParallelCo uint32_t get_num_cores_channels_from_parallel_config(const sliding_window::ParallelConfig& pconfig); +ttnn::operations::matmul::MatmulProgramConfig determine_matmul_op_config_from_conv_op_config( + OptimizedConvParallelizationConfig conv_parallelization_config, + OptimizedConvBlockConfig conv_blocking_config, + bool height_sharded, + const string& activation, + bool transpose_mcast, + uint32_t grid_size_along_c); + MemoryConfig create_sharded_memory_config_from_parallel_config(const ttnn::Shape& tensor_shape, sliding_window::ParallelConfig& parallel_config, uint32_t tile_size); OptimizedConvParallelizationConfig determine_conv_op_parallel_config_from_conv_output_mem_config( @@ -157,6 +165,19 @@ std::tuple& bias_tensor); // Converts convolution weights to tilized 2d matrix layout. diff --git a/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.cpp b/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.cpp index d904d1d7cb1..adcc7e39963 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.cpp @@ -103,7 +103,8 @@ Result conv_transpose2d( std::array dilation, uint32_t groups, std::optional bias_tensor, - const std::optional& conv_config_) { + const std::optional& conv_config_, + const std::optional& memory_config ) { conv2d::Conv2dConfig conv_config = conv_config_.value_or(conv2d::Conv2dConfig()); //Inverse of sliding_window.get_output_shape() @@ -118,6 +119,12 @@ Result conv_transpose2d( .is_transpose = true }; + + // ConvTranspose2d is implemented via the Conv2d u_op with flipped weights. + //The input tensor is first passed to the halo op that paddeds the input. + //In the scenario, where stride > 1, the halo op will add interleaved 0s to the input tensor. + //The Conv2d u_op is then called with stride = 1, padding = 0. + //SlidingWindowConfig has a is_transpose flag that is set to true to indicate that the Conv2d u_op & Halo u_op is being called for ConvTranspose2d. uint32_t output_height = (input_height - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) + output_padding[0] + 1; uint32_t output_width = (input_width - 1) * stride[1] - 2 * padding[1] + dilation[1] * (kernel_size[1] - 1) + output_padding[1] + 1; @@ -143,6 +150,25 @@ Result conv_transpose2d( log_debug(LogOp, "Padding : ({},{}) ({},{})", input_pad_top, input_pad_bottom, input_pad_left, input_pad_right); + const bool mm_conv = conv2d::use_matmul_for_1x1_conv(kernel_size, {1, 1}, {input_pad_top + input_pad_bottom, input_pad_left + input_pad_right}, dilation, groups); + + const auto compute_grid_size = device->compute_with_storage_grid_size(); + + if (!input_tensor.is_sharded() && !conv_config.shard_layout.has_value()) { + // In this case we deduce the shard layout. + conv2d::adjust_conv_op_config_for_auto_shard( + mm_conv, + batch_size, + in_channels, + out_channels, + output_height, + output_width, + weight_tensor.get_shape()[3], + full_input_width, + compute_grid_size, + conv_config, + input_tensor.layout()); + } DeviceComputeKernelConfig compute_kernel_config; switch (device->arch()) { @@ -171,8 +197,6 @@ Result conv_transpose2d( TT_THROW("Invalid Device Arch, Got {}",device->arch()); } - const bool mm_conv = conv2d::use_matmul_for_1x1_conv(kernel_size, stride, padding, dilation, groups); - //Call Halo Transpose auto [input_tensor_post_tm, parallel_config, output_parallel_config, tensor_manipulated, use_non_tile_height] = conv2d::shard_or_reshard_tensor_if_required( device, @@ -262,6 +286,39 @@ Result conv_transpose2d( opt_conv_op_block_config.act_block_h_ntiles, input_width); } + if(mm_conv) { + // run conv as matmul + uint32_t num_cores_c = conv2d::get_num_cores_channels_from_parallel_config(parallel_config); + auto matmul_program_config = conv2d::determine_matmul_op_config_from_conv_op_config( + opt_conv_op_parallel_config, + opt_conv_op_block_config, + parallel_config.shard_scheme == TensorMemoryLayout::HEIGHT_SHARDED, + conv_config.activation, + parallel_config.shard_orientation == ShardOrientation::COL_MAJOR, + num_cores_c); + Tensor matmul_input = ttnn::to_layout( + input_tensor_post_tm, Layout::TILE, conv_config.dtype, input_tensor_post_tm.memory_config(), device + ); + auto matmul_output = ttnn::operations::matmul::matmul( + matmul_input, + weight_tensor_on_device, + bias_tensor_on_device, + ttnn::operations::matmul::Matmul{ + matmul_program_config, + /*bcast_batch=*/std::nullopt, + conv_out_memory_config, + conv_config.dtype, + compute_kernel_config}); + if (conv_config.deallocate_activation) { + ttnn::operations::core::deallocate(matmul_input); + } + + if (memory_config.has_value() && memory_config.value() != matmul_output.memory_config()) { + matmul_output = ttnn::to_memory_config(matmul_output, memory_config.value(), std::nullopt); + } + + return {matmul_output, output_height, output_width, weight_tensor_on_device, bias_tensor_on_device}; + } // call conv micro op auto conv_output = optimized_conv_new( halo_output, @@ -283,6 +340,9 @@ Result conv_transpose2d( conv_config.enable_act_double_buffer, conv_config.enable_split_reader, conv_config.enable_subblock_padding); + if (memory_config.has_value() && memory_config.value() != conv_output.memory_config()) { + conv_output = ttnn::to_memory_config(conv_output, memory_config.value(), std::nullopt); + } return {conv_output, output_height, output_width, weight_tensor_on_device, bias_tensor_on_device}; } @@ -303,8 +363,9 @@ Result ConvTranpose2dOperation::invoke( std::array dilation, uint32_t groups, std::optional bias_tensor, - const std::optional& conv_config_){ - return conv_transpose2d(input_tensor, weight_tensor, device, in_channels, out_channels, batch_size, input_height, input_width, kernel_size, stride, padding, output_padding, dilation, groups, std::move(bias_tensor), std::move(conv_config_)); + const std::optional& conv_config_, + const std::optional& memory_config ) { + return conv_transpose2d(input_tensor, weight_tensor, device, in_channels, out_channels, batch_size, input_height, input_width, kernel_size, stride, padding, output_padding, dilation, groups, std::move(bias_tensor), std::move(conv_config_), std::move(memory_config)); } Result ConvTranpose2dOperation::invoke( @@ -324,8 +385,9 @@ Result ConvTranpose2dOperation::invoke( std::array dilation, uint32_t groups, std::optional bias_tensor, - const std::optional& conv_config_){ - return conv_transpose2d(input_tensor, weight_tensor, device, in_channels, out_channels, batch_size, input_height, input_width, kernel_size, stride, padding, output_padding, dilation, groups, std::move(bias_tensor), std::move(conv_config_)); + const std::optional& conv_config_, + const std::optional& memory_config ) { + return conv_transpose2d(input_tensor, weight_tensor, device, in_channels, out_channels, batch_size, input_height, input_width, kernel_size, stride, padding, output_padding, dilation, groups, std::move(bias_tensor), std::move(conv_config_), std::move(memory_config)); } } diff --git a/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.hpp b/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.hpp index 587fcf9711d..437ea85f230 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.hpp +++ b/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.hpp @@ -32,7 +32,8 @@ struct ConvTranpose2dOperation{ std::array dilation, uint32_t groups, std::optional bias_tensor = std::nullopt, - const std::optional& conv_config_ = std::nullopt); + const std::optional& conv_config_ = std::nullopt, + const std::optional& memory_config = std::nullopt); static Result invoke( uint8_t queue_id, @@ -51,7 +52,8 @@ struct ConvTranpose2dOperation{ std::array dilation, uint32_t groups, std::optional bias_tensor = std::nullopt, - const std::optional& conv_config_ = std::nullopt); + const std::optional& conv_config_ = std::nullopt, + const std::optional& memory_config = std::nullopt); }; } // namespace conv_transpose2d diff --git a/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d_pybind.cpp b/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d_pybind.cpp index 8128ea7423b..6cbb2f8555b 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d_pybind.cpp @@ -104,8 +104,9 @@ void py_bind_conv_transpose2d(py::module& module) { uint32_t groups, std::optional bias_tensor, std::optional conv_config, + const std::optional memory_config, const uint8_t& queue_id) -> Result { - return self(queue_id, input_tensor, weight_tensor, device, in_channels, out_channels, batch_size, input_height, input_width, kernel_size, stride, padding, output_padding, dilation, groups, bias_tensor, conv_config); + return self(queue_id, input_tensor, weight_tensor, device, in_channels, out_channels, batch_size, input_height, input_width, kernel_size, stride, padding, output_padding, dilation, groups, bias_tensor, conv_config, memory_config); }, py::kw_only(), py::arg("input_tensor"), @@ -124,6 +125,7 @@ void py_bind_conv_transpose2d(py::module& module) { py::arg("groups"), py::arg("bias_tensor") = std::nullopt, py::arg("conv_config") = std::nullopt, + py::arg("memory_config") = std::nullopt, py::arg("queue_id") = 0}, ttnn::pybind_overload_t{ @@ -143,8 +145,9 @@ void py_bind_conv_transpose2d(py::module& module) { uint32_t groups, std::optional bias_tensor, std::optional conv_config, + const std::optional memory_config, const uint8_t& queue_id) -> Result { - return self(queue_id, input_tensor, weight_tensor, device, in_channels, out_channels, batch_size, input_height, input_width, kernel_size, stride, padding, output_padding, dilation, groups, bias_tensor, conv_config); + return self(queue_id, input_tensor, weight_tensor, device, in_channels, out_channels, batch_size, input_height, input_width, kernel_size, stride, padding, output_padding, dilation, groups, bias_tensor, conv_config, memory_config); }, py::kw_only(), py::arg("input_tensor"), @@ -163,6 +166,7 @@ void py_bind_conv_transpose2d(py::module& module) { py::arg("groups"), py::arg("bias_tensor") = std::nullopt, py::arg("conv_config") = std::nullopt, + py::arg("memory_config") = std::nullopt, py::arg("queue_id") = 0} ); }