diff --git a/models/demos/vgg/tt/ttnn_vgg.py b/models/demos/vgg/tt/ttnn_vgg.py index 82f5dd1c03d5..fe044e07665a 100644 --- a/models/demos/vgg/tt/ttnn_vgg.py +++ b/models/demos/vgg/tt/ttnn_vgg.py @@ -114,6 +114,7 @@ def ttnn_vgg16( tt_weight = parameters.features[conv_feature_ids[iter_conv_id]].weight tt_weight = ttnn.to_layout(ttnn.from_device(tt_weight), layout=ttnn.ROW_MAJOR_LAYOUT) tt_bias = parameters.features[conv_feature_ids[iter_conv_id]].bias + tt_bias = ttnn.to_layout(ttnn.from_device(tt_bias), layout=ttnn.ROW_MAJOR_LAYOUT) # Call ttnn.conv conv_op_cache = {} [tt_output_tensor_on_device, [out_height, out_width], [weights_device, bias_device]] = ttnn.conv2d( @@ -242,6 +243,7 @@ def ttnn_vgg11( tt_weight = parameters.features[conv_feature_ids_2[iter_conv_id]].weight tt_weight = ttnn.to_layout(ttnn.from_device(tt_weight), layout=ttnn.ROW_MAJOR_LAYOUT) tt_bias = parameters.features[conv_feature_ids_2[iter_conv_id]].bias + tt_bias = ttnn.to_layout(ttnn.from_device(tt_bias), layout=ttnn.ROW_MAJOR_LAYOUT) # Call ttnn.conv conv_op_cache = {} diff --git a/tests/ttnn/integration_tests/yolov4/test_ttnn_yolov4.py b/tests/ttnn/integration_tests/yolov4/test_ttnn_yolov4.py index 2adea2af2228..c70a2a8d3966 100644 --- a/tests/ttnn/integration_tests/yolov4/test_ttnn_yolov4.py +++ b/tests/ttnn/integration_tests/yolov4/test_ttnn_yolov4.py @@ -68,6 +68,7 @@ def test_yolov4(device, reset_seeds, model_location_generator): result_2 = result_2[:, :255, :, :] result_3 = result_3[:, :255, :, :] - assert_with_pcc(result_1, ref1, 0.99) - assert_with_pcc(result_2, ref2, 0.99) - assert_with_pcc(result_3, ref3, 0.99) + pcc = 0.985 + assert_with_pcc(result_1, ref1, pcc) + assert_with_pcc(result_2, ref2, pcc) + assert_with_pcc(result_3, ref3, pcc) diff --git a/tests/ttnn/unit_tests/operations/test_prepare_conv_weights.py b/tests/ttnn/unit_tests/operations/test_prepare_conv_weights.py index 09cafdd0aca3..d57f81748b5f 100644 --- a/tests/ttnn/unit_tests/operations/test_prepare_conv_weights.py +++ b/tests/ttnn/unit_tests/operations/test_prepare_conv_weights.py @@ -186,12 +186,141 @@ def test_prepare_conv_weights( compute_config=compute_config, ) + tt_output_tensor = ttnn.from_device(tt_output_tensor_on_device) + torch_output_tensor = ttnn.to_torch(tt_output_tensor) + torch_output_tensor = torch_output_tensor[:, :, :, :output_channels] + torch_output_tensor = torch_output_tensor.reshape(torch_out_golden_tensor.shape) + + pcc = 0.99 + passing, pcc_msg = check_with_pcc_without_tensor_printout(torch_output_tensor, torch_out_golden_tensor, pcc=pcc) + logger.info(f"PCC = {pcc_msg}. Threshold = {pcc}") + assert passing + + +@skip_for_grayskull() +@skip_for_blackhole() +# @skip_for_wormhole_b0() +@pytest.mark.parametrize( + "batch_size, output_channels, input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, use_1d_systolic_array, config_override", + ( + # rn50 layer1 + (8, 64, 64, 56, 56, 3, 3, 1, 1, 1, 1, True, None), + (16, 64, 64, 56, 56, 3, 3, 1, 1, 1, 1, True, None), + (20, 64, 64, 56, 56, 3, 3, 1, 1, 1, 1, True, None), + ), +) +@pytest.mark.parametrize("packer_l1_acc", [True, False], ids=["pack_l1", "no_pack_l1"]) +@pytest.mark.parametrize("has_bias", [True, False], ids=["has_bias", "no_bias"]) +@pytest.mark.parametrize("device_params", [{"l1_small_size": 2**15}], indirect=True) +def test_prepare_bias( + batch_size, + output_channels, + input_channels, + input_height, + input_width, + filter_height, + filter_width, + stride_h, + stride_w, + pad_h, + pad_w, + use_1d_systolic_array, + packer_l1_acc, + config_override, + has_bias, + device, +): + if device.core_grid.y == 7: + pytest.skip("Issue #6992: Statically allocated circular buffers in program clash with L1 buffers on core range") + + if batch_size == 20 and ( + output_channels == 64 or (stride_h == 2 and (output_channels == 256 or output_channels == 128)) + ): + pytest.skip("Skipping test because it won't fit in L1!") + + inp_shape = (batch_size, input_channels, input_height, input_width) + conv_weight_shape = (output_channels, input_channels, filter_height, filter_width) + torch_weight_tensor = torch.randn(conv_weight_shape, dtype=torch.bfloat16) + torch_input_tensor = torch.randn(inp_shape, dtype=torch.bfloat16) + torch_bias_tensor = torch.randn((1, 1, 1, output_channels), dtype=torch.bfloat16) if has_bias else None + + torch_out_golden_tensor = torch.nn.functional.conv2d( + torch_input_tensor, + torch_weight_tensor, + bias=torch_bias_tensor.reshape(-1) if has_bias else None, + stride=(stride_h, stride_w), + padding=(pad_h, pad_w), + dilation=(1, 1), + groups=1, + ).permute(0, 2, 3, 1) + + tt_input_tensor = ttnn.from_torch(torch_input_tensor.transpose(-3, -2).transpose(-2, -1), ttnn.bfloat16) + tt_weight_tensor = ttnn.from_torch(torch_weight_tensor, ttnn.bfloat16) + tt_bias_tensor = ttnn.from_torch(torch_bias_tensor, ttnn.bfloat16) if has_bias else None + + conv_config = ttnn.Conv2dConfig( + dtype=ttnn.bfloat16, + weights_dtype=ttnn.bfloat16, + input_channels_alignment=(16 if input_channels == 16 and input_height == 115 else 32), + enable_act_double_buffer=False, + enable_split_reader=False, + enable_subblock_padding=False, + ) + compute_config = ttnn.init_device_compute_kernel_config(device.arch(), packer_l1_acc=packer_l1_acc) + if config_override and "act_block_h" in config_override: + conv_config.act_block_h_override = config_override["act_block_h"] + + if config_override and "act_block_w_div" in config_override: + conv_config.act_block_w_div = config_override["act_block_w_div"] + + if config_override and "num_cores_nhw" in config_override: + if config_override["num_cores_nhw"] == 98: + conv_config.core_grid = ttnn.CoreRangeSet({ttnn.CoreRange((0, 0), (11, 7)), ttnn.CoreRange((0, 8), (1, 8))}) + conv_config.override_sharding_config = True + print("Setting num_cores_nhw to 98") + + conv_kwargs = { + "input_layout": ttnn.ROW_MAJOR_LAYOUT, + "in_channels": input_channels, + "out_channels": output_channels, + "batch_size": batch_size, + "input_height": input_height, + "input_width": input_width, + "kernel_size": (filter_height, filter_width), + "stride": (stride_h, stride_w), + "padding": (pad_h, pad_w), + "dilation": (1, 1), + "groups": 1, + "device": device, + "conv_config": conv_config, + } + + tt_input_tensor = ttnn.to_device(tt_input_tensor, device) + + tt_bias_tensor_formatted = ( + ttnn.prepare_conv_bias( + bias_tensor=tt_bias_tensor, input_memory_config=tt_input_tensor.memory_config(), **conv_kwargs + ) + if has_bias + else None + ) + + tt_bias_tensor_formatted = ttnn.to_device(tt_bias_tensor_formatted, device) if has_bias else None + (k := next(iter(conv_kwargs)), conv_kwargs.pop(k)) ##removing 1st element from dict + tt_output_tensor_on_device = ttnn.conv2d( + input_tensor=tt_input_tensor, + weight_tensor=tt_weight_tensor, + bias_tensor=tt_bias_tensor_formatted, + **conv_kwargs, + compute_config=compute_config, + ) + tt_output_tensor = ttnn.from_device(tt_output_tensor_on_device) torch_output_tensor = ttnn.to_torch(tt_output_tensor) torch_output_tensor = torch_output_tensor[:, :, :, :output_channels] torch_output_tensor = torch_output_tensor.reshape(torch_out_golden_tensor.shape) - # + pcc = 0.99 passing, pcc_msg = check_with_pcc_without_tensor_printout(torch_output_tensor, torch_out_golden_tensor, pcc=pcc) logger.info(f"PCC = {pcc_msg}. Threshold = {pcc}") diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp index 888bec60ee1f..2b7e7ffec432 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp @@ -79,12 +79,7 @@ Result conv2d( ShardOrientation shard_orientation = conv_config.transpose_shards ? ShardOrientation::COL_MAJOR : ShardOrientation::ROW_MAJOR; - auto num_cores_c = shard_orientation == ShardOrientation::COL_MAJOR ? device->compute_with_storage_grid_size().y : device->compute_with_storage_grid_size().x; - auto elem_size = conv_config.weights_dtype == DataType::BFLOAT8_B ? 1 : 2; - bool is_non_tile_mul_width = - (conv_config.shard_layout == TensorMemoryLayout::BLOCK_SHARDED) && conv_config.act_block_h_override == 0 && - (conv_config.weights_dtype == DataType::BFLOAT8_B || conv_config.weights_dtype == DataType::BFLOAT16) && - conv_config.output_layout == Layout::ROW_MAJOR && ((elem_size * in_channels) % (16 * num_cores_c)) == 0; + bool is_non_tile_mul_width = check_non_tile_mul_width(device, conv_config, in_channels); DeviceComputeKernelConfig compute_config = compute_config_.value_or(init_device_compute_kernel_config( device->arch(), diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp index 0154a972d4a7..7641ad3b4a06 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp @@ -16,33 +16,16 @@ using sliding_window::ParallelConfig; namespace conv2d { -void validate_weight_and_bias_tensors( - const ttnn::Tensor& weight_tensor, std::optional& bias_tensor) { - TT_ASSERT(!ttnn::has_storage_type_of(weight_tensor, ttnn::DEVICE_STORAGE_TYPE)); - TT_ASSERT(weight_tensor.get_layout() == Layout::ROW_MAJOR); - TT_ASSERT(weight_tensor.get_shape().rank() == 4); - // TODO: enable this assert - // TT_ASSERT(weight_tensor.get_shape() == weight_tensor.get_legacy_shape()); - if (bias_tensor.has_value()) { - TT_ASSERT(!ttnn::has_storage_type_of(bias_tensor.value(), ttnn::DEVICE_STORAGE_TYPE)); - TT_ASSERT(bias_tensor.value().get_shape().rank() == 4); - TT_ASSERT(bias_tensor.value().get_layout() == Layout::ROW_MAJOR); - // TODO: enable this assert - // TT_ASSERT(bias_tensor.value().get_shape() == bias_tensor.value().get_legacy_shape()); - } -} - - void validate_weight_tensor(const ttnn::Tensor& weight_tensor) { - TT_ASSERT(!ttnn::has_storage_type_of(weight_tensor, ttnn::DEVICE_STORAGE_TYPE)); - TT_ASSERT(weight_tensor.get_layout() == Layout::ROW_MAJOR); - TT_ASSERT(weight_tensor.get_shape().rank() == 4); + TT_FATAL(!ttnn::has_storage_type_of(weight_tensor, ttnn::DEVICE_STORAGE_TYPE), "conv weight should be placed on host"); + TT_FATAL(weight_tensor.get_layout() == Layout::ROW_MAJOR, "conv weight layout should be in row_major layout"); + TT_FATAL(weight_tensor.get_shape().rank() == 4, "conv weight should be 4D tensor"); } void validate_bias_tensor(const ttnn::Tensor& bias_tensor) { - TT_ASSERT(!ttnn::has_storage_type_of(bias_tensor, ttnn::DEVICE_STORAGE_TYPE)); - TT_ASSERT(bias_tensor.get_shape().rank() == 4); - TT_ASSERT(bias_tensor.get_layout() == Layout::ROW_MAJOR); + TT_FATAL(!ttnn::has_storage_type_of(bias_tensor, ttnn::DEVICE_STORAGE_TYPE), "conv bias should be placed on host"); + TT_FATAL(bias_tensor.get_shape().rank() == 4, "bias tensor should be 4D tensor"); + TT_FATAL(bias_tensor.get_layout() == Layout::ROW_MAJOR, "bias tensor layout should be in row_major layout"); } void validate_weights_format(const std::string& weights_format) { @@ -54,6 +37,54 @@ void validate_weights_format(const std::string& weights_format) { TT_ASSERT(weights_format == "OIHW", "Conv2d weights format must be \"OIHW\""); } +template +bool check_non_tile_mul_width( + T *device, + const Conv2dConfig& conv_config, + const uint32_t in_channels +){ + ShardOrientation shard_orientation = + conv_config.transpose_shards ? ShardOrientation::COL_MAJOR : ShardOrientation::ROW_MAJOR; + auto num_cores_c = shard_orientation == ShardOrientation::COL_MAJOR ? device->compute_with_storage_grid_size().y : device->compute_with_storage_grid_size().x; + auto elem_size = conv_config.weights_dtype == DataType::BFLOAT8_B ? 1 : 2; + bool is_non_tile_mul_width = + (conv_config.shard_layout == TensorMemoryLayout::BLOCK_SHARDED) && conv_config.act_block_h_override == 0 && + (conv_config.weights_dtype == DataType::BFLOAT8_B || conv_config.weights_dtype == DataType::BFLOAT16) && + conv_config.output_layout == Layout::ROW_MAJOR && ((elem_size * in_channels) % (16 * num_cores_c)) == 0; + return is_non_tile_mul_width; +} + +template +ttnn::Tensor conv_bias_layout_convert( + const ttnn::Tensor& bias_tensor, + DataType bias_dtype, + uint32_t weight_block_h_ntiles, + uint32_t weight_block_w_ntiles, + const ParallelConfig& parallel_config, + T * device, + uint32_t out_channels, + bool is_non_tile_mul_width) { + ttnn::Tensor bias_tensor_ = bias_tensor; + validate_bias_tensor(bias_tensor_); + if (!is_non_tile_mul_width) { + auto bias_shape = bias_tensor_.get_shape(); + TT_FATAL(bias_shape[3] == out_channels && bias_shape[0] == 1 && bias_shape[1] == 1 && bias_shape[2] == 1, "bias shape is not correct"); + tt::tt_metal::LegacyShape bias_channels_padded_shape = tt::tt_metal::LegacyShape( + std::array({1, 1, 32, round_up(out_channels, weight_block_w_ntiles * 32)})); + bias_tensor_ = ttnn::pad(bias_tensor_, bias_channels_padded_shape.to_array_4D(), tt::tt_metal::Array4D{0, 0, 0, 0}, 0); + bias_tensor_ = ttnn::to_layout( + bias_tensor_, Layout::TILE, std::nullopt, std::nullopt, (T*)nullptr); + if (bias_tensor_.get_dtype() != bias_dtype) { + bias_tensor_ = ttnn::to_dtype(bias_tensor_, bias_dtype); + } + } else { + uint32_t num_cores_channels = get_num_cores_channels_from_parallel_config(parallel_config); + bias_tensor_ = convert_conv_bias_tensor_to_tiled_layout_block_sharded( + bias_tensor_, num_cores_channels, bias_dtype); + } + return bias_tensor_; +} + template OptimizedConvBlockConfig get_opt_block_config( bool mm_conv, @@ -162,7 +193,7 @@ std::pair> prepare_conv_weights_biases const bool parameters_on_device, bool is_non_tile_mul_width) { - validate_weight_and_bias_tensors(weight_tensor, bias_tensor); + validate_weight_tensor(weight_tensor); ttnn::Tensor weight_tensor_; // tensor to return ttnn::Tensor bias_tensor_; @@ -252,23 +283,12 @@ std::pair> prepare_conv_weights_biases weight_tensor_ = ttnn::operations::core::to_device(weight_tensor_, device, std::nullopt); if (bias_tensor.has_value()) { - if (!is_non_tile_mul_width) { - bias_tensor_ = bias_tensor.value(); - auto bias_shape = bias_tensor_.get_shape(); - TT_ASSERT(bias_shape[3] == out_channels && bias_shape[0] == 1 && bias_shape[1] == 1 && bias_shape[2] == 1); - tt::tt_metal::LegacyShape bias_channels_padded_shape = tt::tt_metal::LegacyShape( - std::array({1, 1, 32, round_up(out_channels, weight_block_w_ntiles * 32)})); - bias_tensor_ = ttnn::pad(bias_tensor_, bias_channels_padded_shape.to_array_4D(), tt::tt_metal::Array4D{0, 0, 0, 0}, 0); - bias_tensor_ = ttnn::to_layout( - bias_tensor_, Layout::TILE, std::nullopt, std::nullopt, (T*)nullptr); - if (bias_tensor_.get_dtype() != weights_bias_dtype) { - bias_tensor_ = ttnn::to_dtype(bias_tensor_, weights_bias_dtype); - } - } else { - bias_tensor_ = convert_conv_bias_tensor_to_tiled_layout_block_sharded( - bias_tensor.value(), num_cores_channels, weights_bias_dtype); + bias_tensor_ = bias_tensor.value(); + bool is_bias_tensor_is_on_device = ttnn::is_tensor_on_device_or_multidevice(bias_tensor_); + if(!is_bias_tensor_is_on_device) { + bias_tensor_ = conv_bias_layout_convert(bias_tensor_, weights_bias_dtype, weight_block_h_ntiles, weight_block_w_ntiles, parallel_config, device, out_channels, is_non_tile_mul_width); + bias_tensor_ = ttnn::operations::core::to_device(bias_tensor_, device, std::nullopt); } - bias_tensor_ = ttnn::operations::core::to_device(bias_tensor_, device, std::nullopt); } return {weight_tensor_, bias_tensor.has_value() ? bias_tensor_ : std::optional()}; @@ -342,6 +362,7 @@ ttnn::Tensor prepare_conv_weights( shard_orientation, !use_non_tile_height); + bool is_non_tile_mul_width = check_non_tile_mul_width(device, conv_config, in_channels); std::optional bias_tensor = std::nullopt; ttnn::Tensor weight_tensor_on_device = weight_tensor; std::optional bias_tensor_on_device = bias_tensor; @@ -357,7 +378,8 @@ ttnn::Tensor prepare_conv_weights( groups, opt_conv_op_block_config.act_block_h_ntiles, input_width, - false); + false, + is_non_tile_mul_width); return weight_tensor_on_device; } @@ -415,20 +437,36 @@ ttnn::Tensor prepare_conv_bias( ); uint32_t weight_block_w_ntiles = opt_conv_op_block_config.out_subblock_w_ntiles; - validate_bias_tensor(bias_tensor); + ShardOrientation shard_orientation = + conv_config.transpose_shards ? ShardOrientation::COL_MAJOR : ShardOrientation::ROW_MAJOR; - ttnn::Tensor bias_tensor_; - bias_tensor_ = bias_tensor; - auto bias_shape = bias_tensor_.get_shape(); - TT_ASSERT(bias_shape[3] == out_channels && bias_shape[0] == 1 && bias_shape[1] == 1 && bias_shape[2] == 1); - tt::tt_metal::LegacyShape bias_channels_padded_shape = tt::tt_metal::LegacyShape( - std::array({1, 1, 32, tt::round_up(out_channels, weight_block_w_ntiles * 32)})); - bias_tensor_ = ttnn::pad(bias_tensor_, bias_channels_padded_shape.to_array_4D(), tt::tt_metal::Array4D({0, 0, 0, 0}), 0); - bias_tensor_ = ttnn::to_layout( - bias_tensor_, Layout::TILE, std::nullopt, std::nullopt, (T*)nullptr); - if (bias_tensor_.get_dtype() != conv_config.weights_dtype) { - bias_tensor_ = ttnn::to_dtype(bias_tensor_, conv_config.weights_dtype); - } + bool use_non_tile_height = conv_config.shard_layout.value() == TensorMemoryLayout::HEIGHT_SHARDED && out_channels <= 256 && conv_config.act_block_h_override == 0 && + (conv_config.dtype == DataType::BFLOAT16 || conv_config.dtype == DataType::FLOAT32) && conv_config.output_layout == Layout::ROW_MAJOR; + use_non_tile_height = use_non_tile_height && conv_config.input_channels_alignment != 16; + + ParallelConfig parallel_config = determine_parallel_config( + conv_config.shard_layout.value(), + batch_size, + in_channels, + output_height, + output_width, + out_channels, + device->compute_with_storage_grid_size(), + shard_orientation, + !use_non_tile_height); + + bool is_non_tile_mul_width = check_non_tile_mul_width(device, conv_config, in_channels); + ttnn::Tensor bias_tensor_ = bias_tensor; + bias_tensor_ = conv_bias_layout_convert( + bias_tensor_, + conv_config.weights_dtype, + opt_conv_op_block_config.act_block_h_ntiles, + weight_block_w_ntiles, + parallel_config, + device, + out_channels, + is_non_tile_mul_width + ); return bias_tensor_; } @@ -568,6 +606,38 @@ template ttnn::Tensor prepare_conv_bias( const std::optional& conv_config_, const std::optional& compute_config_); +template ttnn::Tensor conv_bias_layout_convert( + const ttnn::Tensor& bias_tensor, + DataType bias_dtype, + uint32_t weight_block_h_ntiles, + uint32_t weight_block_w_ntiles, + const sliding_window::ParallelConfig& parallel_config, + Device * device, + uint32_t out_channels, + bool is_non_tile_mul_width); + +template ttnn::Tensor conv_bias_layout_convert( + const ttnn::Tensor& bias_tensor, + DataType bias_dtype, + uint32_t weight_block_h_ntiles, + uint32_t weight_block_w_ntiles, + const sliding_window::ParallelConfig& parallel_config, + MeshDevice* device, + uint32_t out_channels, + bool is_non_tile_mul_width); + +template bool check_non_tile_mul_width( + Device *device, + const Conv2dConfig& conv_config, + const uint32_t in_channels +); + +template bool check_non_tile_mul_width( + MeshDevice *device, + const Conv2dConfig& conv_config, + const uint32_t in_channels +); + } // namespace conv2d } // namespace operations } // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.hpp b/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.hpp index 35b80dac8241..946376226a00 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.hpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.hpp @@ -26,6 +26,18 @@ namespace ttnn { namespace operations::conv { namespace conv2d { + +template +ttnn::Tensor conv_bias_layout_convert( + const ttnn::Tensor& bias_tensor, + DataType bias_dtype, + uint32_t weight_block_h_ntiles, + uint32_t weight_block_w_ntiles, + const sliding_window::ParallelConfig& parallel_config, + T * device, + uint32_t out_channels, + bool is_non_tile_mul_width); + template ttnn::Tensor prepare_conv_weights( const ttnn::Tensor& weight_tensor, @@ -81,6 +93,12 @@ std::pair> prepare_conv_weights_biases const bool parameters_on_device=true, bool is_non_tile_mul_width=false); +template +bool check_non_tile_mul_width( + T* device, + const Conv2dConfig& conv_config, + const uint32_t in_channels); + } // namespace conv2d } // namespace operations::conv } // namespace ttnn