diff --git a/tests/sweep_framework/sweeps/conv2d/short/conv2d_short_sweep.py b/tests/sweep_framework/sweeps/conv2d/short/conv2d_short_sweep.py index c41f6be9092..f012c05782b 100644 --- a/tests/sweep_framework/sweeps/conv2d/short/conv2d_short_sweep.py +++ b/tests/sweep_framework/sweeps/conv2d/short/conv2d_short_sweep.py @@ -1606,48 +1606,17 @@ def test_conv2d_localrun(device, input_spec): [1, 1056, 1056, 96, 96, 3, 3, 2, 2, 1, 1, 4, False, 1], # 127 [1, 528, 528, 192, 192, 3, 3, 2, 2, 1, 1, 2, False, 1], # 220 [1, 2904, 2904, 48, 48, 3, 3, 2, 2, 1, 1, 11, False, 1], # 294 - [1, 1024, 1024, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], # 1407 - [1, 256, 1024, 128, 128, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1408 - [1, 1056, 1056, 48, 48, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1417 - [1, 2904, 1056, 48, 48, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1418 - [1, 3024, 1232, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1420 [1, 3024, 1232, 14, 14, 1, 1, 2, 2, 0, 0, 1, False, 1], # 1421 - [1, 2520, 1344, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1423 - [1, 3712, 1392, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1425 - [1, 1024, 1440, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1427 - [1, 448, 1632, 12, 12, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1431 - [1, 2520, 2520, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1442 [1, 819, 256, 100, 136, 3, 3, 1, 1, 1, 1, 1, True, 1], # 1443 [1, 819, 256, 50, 68, 3, 3, 1, 1, 1, 1, 1, True, 1], # 1447 - [1, 2904, 264, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], # 1451 - [1, 264, 2904, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], # 1452 - [1, 726, 2904, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], # 1453 - [1, 7392, 2904, 24, 24, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1455 [1, 1024, 3, 224, 224, 32, 32, 32, 32, 0, 0, 1, True, 1], # 1458 [1, 768, 3, 224, 224, 32, 32, 32, 32, 0, 0, 1, False, 1], # 1460 [1, 768, 3, 224, 224, 32, 32, 32, 32, 0, 0, 1, True, 1], # 1461 [1, 768, 3, 384, 512, 32, 32, 32, 32, 0, 0, 1, True, 1], # 1464 [1, 64, 3, 720, 1280, 7, 7, 2, 2, 3, 3, 1, False, 1], # 1471 [1, 64, 3, 800, 1088, 7, 7, 2, 2, 3, 3, 1, False, 1], # 1472 - [1, 308, 3024, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], # 1473 - [1, 3024, 3024, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1474 - [1, 3024, 308, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], # 1475 - [1, 3712, 348, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], # 1479 - [1, 348, 3712, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], # 1480 - [1, 3712, 3712, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1481 - [1, 1056, 528, 96, 96, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1491 [1, 1, 64, 480, 640, 3, 3, 1, 1, 1, 1, 1, True, 1], # 1495 [1, 64, 64, 480, 640, 3, 3, 1, 1, 1, 1, 1, True, 1], # 1496 - [1, 1392, 696, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1497 - [1, 1920, 720, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1499 - [1, 2904, 726, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], # 1501 - [1, 7392, 726, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], # 1502 - [1, 1024, 728, 19, 19, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1503 - [1, 726, 7392, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], # 1506 - [1, 7392, 7392, 12, 12, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1507 - [1, 1024, 782, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1508 - [1, 912, 912, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1509 - [1, 1280, 960, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], # 1510 [1, 640, 1920, 32, 32, 3, 3, 1, 1, 1, 1, 1, True, 1], # 1522 [1, 320, 320, 64, 64, 3, 3, 1, 1, 1, 1, 1, True, 1], # 1530 [1, 320, 640, 64, 64, 3, 3, 1, 1, 1, 1, 1, True, 1], # 1540 diff --git a/tests/ttnn/unit_tests/operations/test_new_conv2d.py b/tests/ttnn/unit_tests/operations/test_new_conv2d.py index d41c5deae4f..4d790730c16 100644 --- a/tests/ttnn/unit_tests/operations/test_new_conv2d.py +++ b/tests/ttnn/unit_tests/operations/test_new_conv2d.py @@ -2738,17 +2738,15 @@ def test_shallow_conv_with_tiled_input(device): pad = (1, 1) torch_kernel = torch.randn(kernel_shape, dtype=torch.bfloat16) - torch_input = torch.randn(input_shape, dtype=torch.bfloat16) - tt_kernel = ttnn.from_torch(torch_kernel) - tt_input = ttnn.to_device(ttnn.from_torch(torch_input), device) - tt_input = ttnn.to_layout(tt_input, ttnn.TILE_LAYOUT) + torch_input = torch.randn(input_shape, dtype=torch.bfloat16) + tt_input = ttnn.from_torch(torch_input, device=device) tt_input = ttnn.permute(tt_input, (0, 2, 3, 1)) - tt_input = ttnn.reshape(tt_input, (1, 1, batch_size * img_h * img_w, in_channels)) + tt_input = ttnn.to_layout(tt_input, ttnn.TILE_LAYOUT) - [tt_out, [out_height, out_width], [weights_device, bias_device]] = ttnn.conv2d( + [tt_out, [out_height, out_width], [_, _]] = ttnn.conv2d( input_tensor=tt_input, weight_tensor=tt_kernel, in_channels=in_channels, @@ -2763,9 +2761,6 @@ def test_shallow_conv_with_tiled_input(device): input_height=img_h, input_width=img_w, groups=1, - compute_config=ttnn.init_device_compute_kernel_config( - device.arch(), - ), memory_config=ttnn.DRAM_MEMORY_CONFIG, return_output_dim=True, return_weights_and_bias=True, @@ -2788,3 +2783,72 @@ def test_shallow_conv_with_tiled_input(device): passing, pcc_msg = check_with_pcc_without_tensor_printout(torch_output_tensor, torch_out_golden_tensor, pcc=0.99) logger.info(f"PCC = {pcc_msg}. Threshold = 0.99") assert passing + + +# Tests running conv2d which maps to matmul w/o sharding the input tensor. +# Output tensor is in DRAM. +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) +@pytest.mark.parametrize("tiled_input", [True, False]) +@pytest.mark.parametrize("input_on_device", [True, False]) +def test_dram_input_mm_conv(device, tiled_input, input_on_device): + batch_size = 1 + out_channels, in_channels = 256, 1024 + img_h, img_w = 128, 128 + input_shape = (batch_size, in_channels, img_h, img_w) + + # Params which map conv2d to matmul op. + kernel_h, kernel_w = 1, 1 + stride = (1, 1) + dilation = (1, 1) + pad = (0, 0) + + kernel_shape = (out_channels, in_channels, kernel_h, kernel_w) + torch_kernel = torch.randn(kernel_shape, dtype=torch.bfloat16) + tt_kernel = ttnn.from_torch(torch_kernel) + + torch_input = torch.randn(input_shape, dtype=torch.bfloat16) + if input_on_device: + tt_input = ttnn.from_torch(torch_input, device=device) + tt_input = ttnn.permute(tt_input, (0, 2, 3, 1)) + tt_input = ttnn.reshape(tt_input, (1, 1, batch_size * img_h * img_w, in_channels)) + else: + torch_input_nhwc = torch.permute(torch_input, (0, 2, 3, 1)) + tt_input = ttnn.from_torch(torch_input_nhwc) + + if tiled_input: + tt_input = ttnn.to_layout(tt_input, ttnn.TILE_LAYOUT) + + tt_out = ttnn.conv2d( + input_tensor=tt_input, + weight_tensor=tt_kernel, + in_channels=in_channels, + out_channels=out_channels, + device=device, + kernel_size=(kernel_h, kernel_w), + stride=stride, + padding=pad, + dilation=dilation, + batch_size=batch_size, + input_height=img_h, + input_width=img_w, + ) + + assert tt_out.memory_config().memory_layout == ttnn.TensorMemoryLayout.INTERLEAVED + + tt_output_tensor = ttnn.from_device(tt_out) + torch_output_tensor = ttnn.to_torch(tt_output_tensor) + + # torch_output_tensor is in row major layout and NHWC shape + # NHWC to NCHW + torch_output_tensor = torch_output_tensor.reshape(batch_size, img_h, img_w, torch_output_tensor.shape[-1]) + torch_output_tensor = torch_output_tensor[:, :, :, :out_channels] + + torch_output_tensor = torch.permute(torch_output_tensor, (0, 3, 1, 2)) + + torch_out_golden_tensor = torch.nn.functional.conv2d( + torch_input, torch_kernel, bias=None, stride=stride, padding=pad, dilation=dilation, groups=1 + ) + + passing, pcc_msg = check_with_pcc_without_tensor_printout(torch_output_tensor, torch_out_golden_tensor, pcc=0.99) + logger.info(f"PCC = {pcc_msg}. Threshold = 0.99") + assert passing diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp index d6d06ec490f..bd754354e43 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp @@ -2,24 +2,25 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "conv2d.hpp" -#include "conv2d_utils.hpp" -#include "prepare_conv2d_weights.hpp" -#include -#include #include #include #include "common/constants.hpp" #include "common/math.hpp" -#include "impl/buffers/buffer_constants.hpp" -#include "ttnn/operations/conv/conv2d/device/conv2d_op.hpp" -#include "ttnn/operations/core/compute_kernel/compute_kernel_config.hpp" -#include "ttnn/operations/core/core.hpp" -#include "ttnn/operations/pool/downsample/device/downsample_op.hpp" -#include "ttnn/operations/sliding_window/sliding_window.hpp" + +#include "tt_metal/impl/buffers/buffer_constants.hpp" + #include "ttnn/tensor/tensor.hpp" #include "ttnn/tensor/types.hpp" +#include "ttnn/operations/core/core.hpp" +#include "ttnn/operations/core/compute_kernel/compute_kernel_config.hpp" +#include "ttnn/operations/conv/conv2d/conv2d.hpp" +#include "ttnn/operations/conv/conv2d/conv2d_utils.hpp" +#include "ttnn/operations/conv/conv2d/prepare_conv2d_weights.hpp" +#include "ttnn/operations/conv/conv2d/device/conv2d_op.hpp" +#include "ttnn/operations/matmul/matmul.hpp" +#include "ttnn/operations/sliding_window/halo/halo.hpp" +#include "ttnn/operations/sliding_window/sliding_window.hpp" using namespace tt; namespace ttnn { @@ -60,6 +61,7 @@ Result conv2d( Conv2dConfig conv_config = conv_config_.value_or(Conv2dConfig()); const auto compute_grid_size = device->compute_with_storage_grid_size(); + bool auto_shard = false; if (!input_tensor.is_sharded() && !conv_config.shard_layout.has_value()) { // In this case we deduce the shard layout. adjust_conv_op_config_for_auto_shard_if_necessary( @@ -74,7 +76,9 @@ Result conv2d( compute_grid_size, conv_config, input_tensor.layout(), - ttnn::is_tensor_on_device_or_multidevice(input_tensor) ? std::make_optional(input_tensor.memory_config()) : std::nullopt); + ttnn::is_tensor_on_device_or_multidevice(input_tensor) ? std::make_optional(input_tensor.memory_config()) + : std::nullopt); + auto_shard = true; } ShardOrientation shard_orientation = @@ -86,16 +90,21 @@ Result conv2d( (conv_config.weights_dtype == DataType::BFLOAT8_B || conv_config.weights_dtype == DataType::BFLOAT16) && conv_config.output_layout == Layout::ROW_MAJOR && ((elem_size * in_channels) % (16 * num_cores_c)) == 0; - DeviceComputeKernelConfig compute_config = compute_config_.value_or( init_device_compute_kernel_config( - device->arch(), - std::nullopt, - MathFidelity::HiFi4, - true, - false, - false - )); - auto [input_tensor_post_tm, parallel_config, output_parallel_config, tensor_manipulated, use_non_tile_height] = shard_or_reshard_tensor_if_required( - device, input_tensor, conv_config, batch_size, output_height, output_width, in_channels, out_channels, mm_conv, is_non_tile_mul_width); + DeviceComputeKernelConfig compute_config = compute_config_.value_or( + init_device_compute_kernel_config(device->arch(), std::nullopt, MathFidelity::HiFi4, true, false, false)); + auto [input_tensor_post_tm, parallel_config, output_parallel_config, tensor_manipulated, use_non_tile_height] = + shard_or_reshard_tensor_if_required( + device, + input_tensor, + conv_config, + batch_size, + output_height, + output_width, + in_channels, + out_channels, + mm_conv, + auto_shard, + is_non_tile_mul_width); if (tensor_manipulated) { if (conv_config.deallocate_activation) { ttnn::Tensor input_tensor_ = input_tensor; // TODO: allow in place modification of inputs to the op @@ -255,36 +264,28 @@ Result conv2d( return {conv_output, output_height, output_width, weight_tensor_on_device, bias_tensor_on_device}; } else { // run conv as matmul - uint32_t num_cores_c = get_num_cores_channels_from_parallel_config(parallel_config); - auto matmul_program_config = determine_matmul_op_config_from_conv_op_config( - opt_conv_op_parallel_config, - opt_conv_op_block_config, - parallel_config.shard_scheme == TensorMemoryLayout::HEIGHT_SHARDED, - conv_config.activation, - parallel_config.shard_orientation == ShardOrientation::COL_MAJOR, - num_cores_c); - Tensor matmul_input = input_tensor_post_tm; - if (stride[0] > 1) { - // run downsample - matmul_input = ttnn::operations::downsample::downsample( - input_tensor_post_tm, {batch_size, input_height, input_width, stride[0], stride[1]}); - if (conv_config.deallocate_activation) { - ttnn::operations::core::deallocate(input_tensor_post_tm); - } + std::optional program_config = std::nullopt; + std::optional mm_output_memory_config = std::nullopt; + if (input_tensor_post_tm.is_sharded()) { + uint32_t num_cores_c = get_num_cores_channels_from_parallel_config(parallel_config); + program_config = determine_matmul_op_config_from_conv_op_config( + opt_conv_op_parallel_config, + opt_conv_op_block_config, + parallel_config.shard_scheme == TensorMemoryLayout::HEIGHT_SHARDED, + conv_config.activation, + parallel_config.shard_orientation == ShardOrientation::COL_MAJOR, + num_cores_c); + mm_output_memory_config = conv_out_memory_config; } - auto matmul_output = ttnn::operations::matmul::matmul( - matmul_input, + Tensor matmul_output = ttnn::linear( + input_tensor_post_tm, weight_tensor_on_device, bias_tensor_on_device, - ttnn::operations::matmul::Matmul{ - matmul_program_config, - /*bcast_batch=*/std::nullopt, - conv_out_memory_config, - conv_config.dtype, - compute_config}); - if (conv_config.deallocate_activation) { - ttnn::operations::core::deallocate(matmul_input); - } + false, + false, + mm_output_memory_config, + std::nullopt, + program_config); if (memory_config.has_value() && memory_config.value() != matmul_output.memory_config()) { matmul_output = ttnn::to_memory_config(matmul_output, memory_config.value(), std::nullopt); diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.hpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.hpp index e8310c0dbdc..e39a1f2257b 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.hpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.hpp @@ -4,22 +4,11 @@ #pragma once #include -#include -#include "conv2d_utils.hpp" -#include "ttnn/core.hpp" -#include "ttnn/operations/core/core.hpp" -#include "ttnn/operations/matmul/matmul.hpp" -#include "ttnn/operations/matmul/device/matmul_op.hpp" #include "ttnn/types.hpp" -#include "ttnn/tensor/tensor_utils.hpp" -#include "tt_metal/impl/dispatch/command_queue.hpp" -#include "tt_metal/common/math.hpp" -#include "ttnn/operations/data_movement/pad/pad.hpp" -#include "ttnn/operations/conv/conv2d/device/conv2d_op.hpp" #include "ttnn/tensor/tensor.hpp" -#include "ttnn/operations/sliding_window/sliding_window.hpp" -#include "ttnn/operations/sliding_window/halo/halo.hpp" +#include "ttnn/decorators.hpp" +#include "ttnn/operations/conv/conv2d/conv2d_utils.hpp" namespace ttnn { diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.cpp index bf215230584..122d1900610 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.cpp @@ -8,14 +8,14 @@ #include "conv2d_utils.hpp" #include "common/constants.hpp" +#include "common/logger.hpp" #include "impl/buffers/buffer_constants.hpp" -#include "ttnn/operations/core/compute_kernel/compute_kernel_config.hpp" #include "ttnn/operations/core/core.hpp" -#include "ttnn/operations/pool/downsample/device/downsample_op.hpp" -#include "tt_metal/detail/reports/memory_reporter.hpp" #include "tt_metal/common/work_split.hpp" #include "ttnn/operations/eltwise/unary/common/unary_op_utils.hpp" +#include "ttnn/operations/data_movement/pad/pad.hpp" #include "ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.hpp" +#include "ttnn/tensor/enum_types.hpp" #include "ttnn/tensor/tensor.hpp" #include "tt_metal/common/core_coord.hpp" @@ -293,7 +293,7 @@ OptimizedConvParallelizationConfig determine_conv_op_parallel_config_from_conv_o }; } -std::pair determine_largest_subblock_size( +static std::pair determine_largest_subblock_size( uint32_t block_height, uint32_t block_width, bool fp32_accum, bool split_reader_enabled) { constexpr std::array, 20> subblocks = {{ {2, 4}, {4, 2}, {1, 8}, {8, 1}, {1, 7}, {7, 1}, {2, 3}, {3, 2}, {1, 6}, {6, 1}, @@ -468,7 +468,7 @@ static TensorMemoryLayout select_shard_spec( } template -std::tuple get_conv_padded_input_shape_and_mem_config( +static std::tuple get_conv_padded_input_shape_and_mem_config( T* device, const ttnn::Tensor& input_tensor_, const Conv2dConfig& conv_config, @@ -641,6 +641,7 @@ std::tuple shard_or_re uint32_t in_channels, uint32_t out_channels, bool is_mm_conv, + bool auto_shard, bool is_non_tile_mul_width) { ttnn::Tensor input_tensor = input_tensor_; // tensor to return bool input_tensor_on_device = ttnn::is_tensor_on_device_or_multidevice(input_tensor_); @@ -695,6 +696,9 @@ std::tuple shard_or_re } } + // In case we are in auto sharded codepath and convolution maps to matmul + // Skip sharding of the input tensor and run the matmul out of interleaved tensor. + bool auto_shard_mm = auto_shard && is_mm_conv; if (input_tensor_on_device) { if (is_mm_conv && input_tensor.layout() == Layout::ROW_MAJOR && parallel_config.shard_scheme != TensorMemoryLayout::HEIGHT_SHARDED) { @@ -702,24 +706,17 @@ std::tuple shard_or_re input_tensor = ttnn::to_layout(input_tensor, Layout::TILE, std::nullopt, std::nullopt, input_tensor.device()); } - auto resharded_input_tensor = ttnn::to_memory_config( - input_tensor, input_tensor_sharded_memory_config, std::nullopt); - if (conv_config.deallocate_activation) { - input_tensor.deallocate(); - resharded_input_tensor = ttnn::operations::core::reallocate(resharded_input_tensor, resharded_input_tensor.memory_config()); + if (!auto_shard_mm) { + auto resharded_input_tensor = ttnn::to_memory_config( + input_tensor, input_tensor_sharded_memory_config, std::nullopt); + if (conv_config.deallocate_activation) { + input_tensor.deallocate(); + resharded_input_tensor = ttnn::operations::core::reallocate(resharded_input_tensor, resharded_input_tensor.memory_config()); + } + input_tensor = resharded_input_tensor; } - input_tensor = resharded_input_tensor; } else { - if (is_mm_conv && input_tensor.layout() == Layout::ROW_MAJOR && - parallel_config.shard_scheme != TensorMemoryLayout::HEIGHT_SHARDED) { - // Workaround #13979 ttnn::tilize doesn't support BLOCK_SHARDED layout - input_tensor = ttnn::to_device(input_tensor, device, std::nullopt); - input_tensor = - ttnn::to_layout(input_tensor, Layout::TILE, std::nullopt, std::nullopt, input_tensor.device()); - input_tensor = ttnn::to_memory_config(input_tensor, input_tensor_sharded_memory_config, std::nullopt); - } else { - input_tensor = ttnn::to_device(input_tensor, device, input_tensor_sharded_memory_config); - } + input_tensor = ttnn::to_device(input_tensor, device, (auto_shard_mm ? ttnn::DRAM_MEMORY_CONFIG : input_tensor_sharded_memory_config)); } } return {input_tensor, parallel_config, output_parallel_config, needs_shard_or_reshard, use_non_tile_height}; @@ -848,31 +845,8 @@ void adjust_conv_op_config_for_auto_shard_if_necessary( } } -template std::tuple get_conv_padded_input_shape_and_mem_config( - Device* device, - const ttnn::Tensor& input_tensor_, - const Conv2dConfig& conv_config, - uint32_t batch_size, - uint32_t height, - uint32_t width, - uint32_t in_channels, - uint32_t out_channels, - bool is_mm_conv, - bool is_non_tile_mul_width); - -template std::tuple get_conv_padded_input_shape_and_mem_config( - MeshDevice * device, - const ttnn::Tensor& input_tensor_, - const Conv2dConfig& conv_config, - uint32_t batch_size, - uint32_t height, - uint32_t width, - uint32_t in_channels, - uint32_t out_channels, - bool is_mm_conv, - bool is_non_tile_mul_width); - -template std::tuple shard_or_reshard_tensor_if_required( +template std::tuple +shard_or_reshard_tensor_if_required( Device* device, const ttnn::Tensor& input_tensor_, const Conv2dConfig& conv_config, @@ -882,10 +856,12 @@ template std::tuple sh uint32_t in_channels, uint32_t out_channels, bool is_mm_conv, + bool auto_shard, bool is_non_tile_mul_width); -template std::tuple shard_or_reshard_tensor_if_required( - MeshDevice * device, +template std::tuple +shard_or_reshard_tensor_if_required( + MeshDevice* device, const ttnn::Tensor& input_tensor_, const Conv2dConfig& conv_config, uint32_t batch_size, @@ -894,6 +870,7 @@ template std::tuple sh uint32_t in_channels, uint32_t out_channel, bool is_mm_conv, + bool auto_shard, bool is_non_tile_mul_width); } // namespace operations diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.hpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.hpp index 349e3837329..50c14f26750 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.hpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.hpp @@ -4,22 +4,12 @@ #pragma once #include -#include -#include "ttnn/core.hpp" -#include "ttnn/operations/core/core.hpp" -#include "ttnn/operations/matmul/matmul.hpp" #include "ttnn/operations/matmul/device/matmul_op.hpp" #include "ttnn/types.hpp" -#include "ttnn/tensor/tensor_utils.hpp" -#include "tt_metal/impl/dispatch/command_queue.hpp" -#include "tt_metal/common/math.hpp" -#include "ttnn/operations/data_movement/pad/pad.hpp" #include "ttnn/operations/conv/conv2d/device/conv2d_op.hpp" #include "ttnn/tensor/tensor.hpp" #include "ttnn/operations/sliding_window/sliding_window.hpp" -#include "ttnn/operations/sliding_window/halo/halo.hpp" -#include "tt_metal/common/core_coord.hpp" namespace ttnn { @@ -41,7 +31,7 @@ struct Conv2dConfig { // Ignored when shard_layout == HEIGHT_SHARDED or BLOCK_SHARDED bool reshard_if_not_optimal = false; // if true, override_sharding_config should not be set to true bool override_sharding_config = false; // if true, reshard_if_not_optimal should not be set to true - std::optional shard_layout; + std::optional shard_layout = std::nullopt; std::optional core_grid = std::nullopt; // used only if override_sharding_config is true bool transpose_shards = true; // used only if override_sharding_config is true and if height sharding is false Layout output_layout = Layout::TILE; @@ -127,8 +117,6 @@ MemoryConfig create_sharded_memory_config_from_parallel_config(const ttnn::Shape OptimizedConvParallelizationConfig determine_conv_op_parallel_config_from_conv_output_mem_config( const MemoryConfig& conv_output_mem_config, uint32_t num_cores_nhw, uint32_t num_cores_c); -std::pair determine_largest_subblock_size(uint32_t block_height, uint32_t block_width, bool fp32_accum); - ttnn::operations::matmul::MatmulProgramConfig determine_matmul_op_config_from_conv_op_config( OptimizedConvParallelizationConfig conv_parallelization_config, OptimizedConvBlockConfig conv_blocking_config, @@ -149,19 +137,6 @@ OptimizedConvBlockConfig determine_per_core_conv_block_config( bool fp32_accum, bool split_reader_enabled); -template -std::tuple get_conv_padded_input_shape_and_mem_config( - T * device, - const ttnn::Tensor& input_tensor_, - const Conv2dConfig& conv_config, - uint32_t batch_size, - uint32_t height, - uint32_t width, - uint32_t in_channels, - uint32_t out_channels, - bool is_mm_conv, - bool is_non_tile_mul_width=false); - OptimizedConvParallelizationConfig determine_conv_op_parallel_config_from_conv_output_mem_config( const MemoryConfig& conv_output_mem_config, uint32_t num_cores_nhw, uint32_t num_cores_c); @@ -180,7 +155,8 @@ void adjust_conv_op_config_for_auto_shard_if_necessary( std::optional input_memory_config); template -std::tuple shard_or_reshard_tensor_if_required( +std::tuple +shard_or_reshard_tensor_if_required( T* device, const ttnn::Tensor& input_tensor_, const Conv2dConfig& conv_config, @@ -190,6 +166,7 @@ std::tuple -OptimizedConvBlockConfig get_opt_block_config( - bool mm_conv, - uint32_t in_channels, - uint32_t out_channels, - uint32_t output_height, - uint32_t output_width, - uint32_t batch_size, - uint32_t input_width, - std::array kernel_size, - std::array stride, - T *device, - Layout input_tensor_layout, - Conv2dConfig& conv_config); - } // namespace operations::conv } // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp index 1009ed7a87b..b15e3027c5e 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp @@ -2,11 +2,16 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "prepare_conv2d_weights.hpp" -#include "conv2d_utils.hpp" +#include "ttnn/operations/conv/conv2d/prepare_conv2d_weights.hpp" + +#include "tt_metal/common/work_split.hpp" + +#include "ttnn/operations/conv/conv2d/conv2d_utils.hpp" #include "ttnn/operations/core/compute_kernel/compute_kernel_config.hpp" -#include -#include +#include "ttnn/operations/core/core.hpp" +#include "ttnn/operations/data_movement/pad/pad.hpp" +#include "ttnn/operations/data_movement/reshape_view/reshape.hpp" +#include "ttnn/operations/sliding_window/sliding_window.hpp" using namespace tt; namespace ttnn { @@ -55,7 +60,7 @@ void validate_weights_format(const std::string& weights_format) { } template -OptimizedConvBlockConfig get_opt_block_config( +static OptimizedConvBlockConfig get_opt_block_config( bool mm_conv, uint32_t in_channels, uint32_t out_channels, diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.hpp b/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.hpp index 35b80dac824..ad38e55e0ac 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.hpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.hpp @@ -6,21 +6,13 @@ #pragma once #include -#include "conv2d_utils.hpp" -#include "ttnn/core.hpp" -#include "ttnn/operations/core/core.hpp" -#include "ttnn/operations/matmul/matmul.hpp" -#include "ttnn/operations/matmul/device/matmul_op.hpp" #include "ttnn/types.hpp" -#include "ttnn/tensor/tensor_utils.hpp" -#include "tt_metal/impl/dispatch/command_queue.hpp" -#include "tt_metal/common/math.hpp" -#include "ttnn/operations/data_movement/pad/pad.hpp" -#include "ttnn/operations/conv/conv2d/device/conv2d_op.hpp" #include "ttnn/tensor/tensor.hpp" -#include "ttnn/operations/sliding_window/sliding_window.hpp" -#include "ttnn/operations/sliding_window/halo/halo.hpp" -#include "tt_metal/common/work_split.hpp" +#include "ttnn/operations/conv/conv2d/conv2d_utils.hpp" + +namespace ttnn::operations::sliding_window { + struct ParallelConfig; +} namespace ttnn { diff --git a/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.cpp b/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.cpp index 21af1f921fb..3673d273344 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.cpp @@ -2,14 +2,15 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "conv_transpose2d.hpp" +#include +#include + +#include "ttnn/operations/core/core.hpp" +#include "ttnn/operations/matmul/matmul.hpp" +#include "ttnn/operations/conv/conv_transpose2d/conv_transpose2d.hpp" #include "ttnn/operations/conv/conv2d/conv2d_utils.hpp" #include "ttnn/operations/conv/conv2d/prepare_conv2d_weights.hpp" -#include "../conv2d/conv2d_utils.hpp" -#include -#include -#include -#include "common/bfloat16.hpp" +#include "ttnn/operations/sliding_window/halo/halo.hpp" using namespace tt; namespace ttnn { @@ -160,6 +161,7 @@ Result conv_transpose2d( const auto compute_grid_size = device->compute_with_storage_grid_size(); + bool auto_shard = false; if (!input_tensor.is_sharded() && !conv_config.shard_layout.has_value()) { // In this case we deduce the shard layout. adjust_conv_op_config_for_auto_shard_if_necessary( @@ -174,29 +176,30 @@ Result conv_transpose2d( compute_grid_size, conv_config, input_tensor.layout(), - ttnn::is_tensor_on_device_or_multidevice(input_tensor) ? std::make_optional(input_tensor.memory_config()) : std::nullopt); + ttnn::is_tensor_on_device_or_multidevice(input_tensor) + ? std::make_optional(input_tensor.memory_config()) + : std::nullopt); + auto_shard = true; } //Call Halo Transpose - auto [input_tensor_post_tm, parallel_config, output_parallel_config, tensor_manipulated, use_non_tile_height] = shard_or_reshard_tensor_if_required( - device, - input_tensor, - conv_config, - batch_size, - output_height, - output_width, - in_channels, - out_channels, - mm_conv + auto [input_tensor_post_tm, parallel_config, output_parallel_config, tensor_manipulated, use_non_tile_height] = + shard_or_reshard_tensor_if_required( + device, + input_tensor, + conv_config, + batch_size, + output_height, + output_width, + in_channels, + out_channels, + mm_conv, + auto_shard ); uint32_t round_up_size = !use_non_tile_height ? tt::constants::TILE_HEIGHT : 1; - sliding_window_config.num_cores_nhw = get_num_cores_nhw_from_parallel_config(parallel_config); - sliding_window_config.core_range_set = input_tensor_post_tm.memory_config().shard_spec.value().grid; - sliding_window_config.snap_to_tile = !use_non_tile_height; - if (tensor_manipulated) { if (conv_config.deallocate_activation) { ttnn::Tensor input_tensor_ = input_tensor; // TODO: allow in place modification of inputs to the op @@ -206,24 +209,31 @@ Result conv_transpose2d( conv_config.deallocate_activation = true; } - auto halo_output = ttnn::halo( - DefaultQueueId, - input_tensor_post_tm, - sliding_window_config, - 0, - false, - parallel_config.shard_orientation == ShardOrientation::COL_MAJOR, - 0, - input_tensor_post_tm.memory_config()); - - if(conv_config.deallocate_activation) { - input_tensor_post_tm.deallocate(); - log_debug(tt::LogOp, "Deallocate Input Tensor"); - } - if (conv_config.reallocate_halo_output) { - auto move_output = ttnn::operations::core::reallocate(halo_output, halo_output.memory_config()); - halo_output = move_output; - log_debug(tt::LogOp, "Reallocate Halo Output"); + Tensor halo_output; + if (!mm_conv) { + sliding_window_config.num_cores_nhw = get_num_cores_nhw_from_parallel_config(parallel_config); + sliding_window_config.core_range_set = input_tensor_post_tm.memory_config().shard_spec.value().grid; + sliding_window_config.snap_to_tile = !use_non_tile_height; + + halo_output = ttnn::halo( + DefaultQueueId, + input_tensor_post_tm, + sliding_window_config, + 0, + false, + parallel_config.shard_orientation == ShardOrientation::COL_MAJOR, + 0, + input_tensor_post_tm.memory_config()); + + if(conv_config.deallocate_activation) { + input_tensor_post_tm.deallocate(); + log_debug(tt::LogOp, "Deallocate Input Tensor"); + } + if (conv_config.reallocate_halo_output) { + auto move_output = ttnn::operations::core::reallocate(halo_output, halo_output.memory_config()); + halo_output = move_output; + log_debug(tt::LogOp, "Reallocate Halo Output"); + } } //Call Conv2d u_op with Stride = 1, Padding = 0. @@ -242,13 +252,14 @@ Result conv_transpose2d( uint32_t in_channels_padded = tt::round_up( in_channels, - get_num_cores_channels_from_parallel_config(parallel_config) * - conv_config.input_channels_alignment); + get_num_cores_channels_from_parallel_config(parallel_config) * conv_config.input_channels_alignment); + uint32_t nhw_out_padded_ntile = get_num_cores_nhw_from_parallel_config(output_parallel_config) * + conv_out_memory_config.shard_spec.value().shape[0] / tt::constants::TILE_HEIGHT; auto opt_conv_op_block_config = determine_per_core_conv_block_config( parallel_config, opt_conv_op_parallel_config, in_channels_padded, - (input_tensor_post_tm.shard_spec().value().shape[0] * get_num_cores_nhw_from_parallel_config(parallel_config)) / tt::constants::TILE_HEIGHT, + nhw_out_padded_ntile, conv_config.act_block_h_override, conv_config.act_block_w_div, kernel_size[0], @@ -261,7 +272,6 @@ Result conv_transpose2d( ttnn::Tensor weight_tensor_on_device = weight_tensor; std::optional bias_tensor_on_device = bias_tensor; if (!weight_is_on_device) { - // prepare weights in desired layout and move to device tie(weight_tensor_on_device, bias_tensor_on_device) = prepare_conv_weights_biases_and_move_to_device( transform_weights_for_conv_transpose2d(weight_tensor), @@ -276,29 +286,33 @@ Result conv_transpose2d( opt_conv_op_block_config.act_block_h_ntiles, input_width); } - if(mm_conv) { - // run conv as matmul - uint32_t num_cores_c = get_num_cores_channels_from_parallel_config(parallel_config); - auto matmul_program_config = determine_matmul_op_config_from_conv_op_config( - opt_conv_op_parallel_config, - opt_conv_op_block_config, - parallel_config.shard_scheme == TensorMemoryLayout::HEIGHT_SHARDED, - conv_config.activation, - parallel_config.shard_orientation == ShardOrientation::COL_MAJOR, - num_cores_c); + if (mm_conv) { Tensor matmul_input = ttnn::to_layout( - halo_output, Layout::TILE, conv_config.dtype, input_tensor_post_tm.memory_config(), device - ); - auto matmul_output = ttnn::operations::matmul::matmul( + input_tensor_post_tm, Layout::TILE, conv_config.dtype, input_tensor_post_tm.memory_config(), device); + std::optional program_config = std::nullopt; + std::optional mm_output_memory_config = std::nullopt; + + if (matmul_input.is_sharded()) { + uint32_t num_cores_c = get_num_cores_channels_from_parallel_config(parallel_config); + program_config = determine_matmul_op_config_from_conv_op_config( + opt_conv_op_parallel_config, + opt_conv_op_block_config, + parallel_config.shard_scheme == TensorMemoryLayout::HEIGHT_SHARDED, + conv_config.activation, + parallel_config.shard_orientation == ShardOrientation::COL_MAJOR, + num_cores_c); + mm_output_memory_config = conv_out_memory_config; + } + Tensor matmul_output = ttnn::linear( matmul_input, weight_tensor_on_device, bias_tensor_on_device, - ttnn::operations::matmul::Matmul{ - matmul_program_config, - /*bcast_batch=*/std::nullopt, - conv_out_memory_config, - conv_config.dtype, - compute_config}); + false, + false, + mm_output_memory_config, + std::nullopt, + program_config); + if (conv_config.deallocate_activation) { ttnn::operations::core::deallocate(matmul_input); } diff --git a/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.hpp b/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.hpp index fc23a6f52d6..3fbc4f21813 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.hpp +++ b/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.hpp @@ -5,6 +5,7 @@ #pragma once #include #include "ttnn/operations/conv/conv2d/conv2d_utils.hpp" +#include "ttnn/decorators.hpp" namespace ttnn { diff --git a/ttnn/cpp/ttnn/operations/pool/generic/generic_pools.cpp b/ttnn/cpp/ttnn/operations/pool/generic/generic_pools.cpp index 3efb657c506..7a4f9f9ad7b 100644 --- a/ttnn/cpp/ttnn/operations/pool/generic/generic_pools.cpp +++ b/ttnn/cpp/ttnn/operations/pool/generic/generic_pools.cpp @@ -6,10 +6,11 @@ #include "impl/buffers/buffer_constants.hpp" #include "ttnn/operations/conv/conv2d/conv2d_utils.hpp" +#include "ttnn/operations/core/core.hpp" +#include "ttnn/operations/sliding_window/halo/halo.hpp" #include "ttnn/operations/sliding_window/sliding_window.hpp" #include "tt_metal/common/bfloat16.hpp" #include "tt_metal/common/math.hpp" -#include "ttnn/common/constants.hpp" #include