Skip to content

Commit

Permalink
In case Conv2d maps to matmul device op and
Browse files Browse the repository at this point in the history
conv2d is in auto shard codepath (input is not
sharded, and no sharding was spedified in Conv2dConfig),
skip sharding the input and run matmul with input
in interleaved tensor.
In this case output will be in interleaved tensor as well.

This way L1 memory preassure is reduced and pass
rate on ttnn torch traces is significantly improved.
  • Loading branch information
Pavle Josipovic committed Dec 10, 2024
1 parent 60373c8 commit 0631201
Show file tree
Hide file tree
Showing 11 changed files with 248 additions and 273 deletions.
31 changes: 0 additions & 31 deletions tests/sweep_framework/sweeps/conv2d/short/conv2d_short_sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -1606,48 +1606,17 @@ def test_conv2d_localrun(device, input_spec):
[1, 1056, 1056, 96, 96, 3, 3, 2, 2, 1, 1, 4, False, 1], # 127
[1, 528, 528, 192, 192, 3, 3, 2, 2, 1, 1, 2, False, 1], # 220
[1, 2904, 2904, 48, 48, 3, 3, 2, 2, 1, 1, 11, False, 1], # 294
[1, 1024, 1024, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], # 1407
[1, 256, 1024, 128, 128, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1408
[1, 1056, 1056, 48, 48, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1417
[1, 2904, 1056, 48, 48, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1418
[1, 3024, 1232, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1420
[1, 3024, 1232, 14, 14, 1, 1, 2, 2, 0, 0, 1, False, 1], # 1421
[1, 2520, 1344, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1423
[1, 3712, 1392, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1425
[1, 1024, 1440, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1427
[1, 448, 1632, 12, 12, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1431
[1, 2520, 2520, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1442
[1, 819, 256, 100, 136, 3, 3, 1, 1, 1, 1, 1, True, 1], # 1443
[1, 819, 256, 50, 68, 3, 3, 1, 1, 1, 1, 1, True, 1], # 1447
[1, 2904, 264, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], # 1451
[1, 264, 2904, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], # 1452
[1, 726, 2904, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], # 1453
[1, 7392, 2904, 24, 24, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1455
[1, 1024, 3, 224, 224, 32, 32, 32, 32, 0, 0, 1, True, 1], # 1458
[1, 768, 3, 224, 224, 32, 32, 32, 32, 0, 0, 1, False, 1], # 1460
[1, 768, 3, 224, 224, 32, 32, 32, 32, 0, 0, 1, True, 1], # 1461
[1, 768, 3, 384, 512, 32, 32, 32, 32, 0, 0, 1, True, 1], # 1464
[1, 64, 3, 720, 1280, 7, 7, 2, 2, 3, 3, 1, False, 1], # 1471
[1, 64, 3, 800, 1088, 7, 7, 2, 2, 3, 3, 1, False, 1], # 1472
[1, 308, 3024, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], # 1473
[1, 3024, 3024, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1474
[1, 3024, 308, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], # 1475
[1, 3712, 348, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], # 1479
[1, 348, 3712, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], # 1480
[1, 3712, 3712, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1481
[1, 1056, 528, 96, 96, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1491
[1, 1, 64, 480, 640, 3, 3, 1, 1, 1, 1, 1, True, 1], # 1495
[1, 64, 64, 480, 640, 3, 3, 1, 1, 1, 1, 1, True, 1], # 1496
[1, 1392, 696, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1497
[1, 1920, 720, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1499
[1, 2904, 726, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], # 1501
[1, 7392, 726, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], # 1502
[1, 1024, 728, 19, 19, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1503
[1, 726, 7392, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], # 1506
[1, 7392, 7392, 12, 12, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1507
[1, 1024, 782, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1508
[1, 912, 912, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1509
[1, 1280, 960, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], # 1510
[1, 640, 1920, 32, 32, 3, 3, 1, 1, 1, 1, 1, True, 1], # 1522
[1, 320, 320, 64, 64, 3, 3, 1, 1, 1, 1, 1, True, 1], # 1530
[1, 320, 640, 64, 64, 3, 3, 1, 1, 1, 1, 1, True, 1], # 1540
Expand Down
82 changes: 73 additions & 9 deletions tests/ttnn/unit_tests/operations/test_new_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -2738,17 +2738,15 @@ def test_shallow_conv_with_tiled_input(device):
pad = (1, 1)

torch_kernel = torch.randn(kernel_shape, dtype=torch.bfloat16)
torch_input = torch.randn(input_shape, dtype=torch.bfloat16)

tt_kernel = ttnn.from_torch(torch_kernel)
tt_input = ttnn.to_device(ttnn.from_torch(torch_input), device)

tt_input = ttnn.to_layout(tt_input, ttnn.TILE_LAYOUT)
torch_input = torch.randn(input_shape, dtype=torch.bfloat16)
tt_input = ttnn.from_torch(torch_input, device=device)
tt_input = ttnn.permute(tt_input, (0, 2, 3, 1))

tt_input = ttnn.reshape(tt_input, (1, 1, batch_size * img_h * img_w, in_channels))
tt_input = ttnn.to_layout(tt_input, ttnn.TILE_LAYOUT)

[tt_out, [out_height, out_width], [weights_device, bias_device]] = ttnn.conv2d(
[tt_out, [out_height, out_width], [_, _]] = ttnn.conv2d(
input_tensor=tt_input,
weight_tensor=tt_kernel,
in_channels=in_channels,
Expand All @@ -2763,9 +2761,6 @@ def test_shallow_conv_with_tiled_input(device):
input_height=img_h,
input_width=img_w,
groups=1,
compute_config=ttnn.init_device_compute_kernel_config(
device.arch(),
),
memory_config=ttnn.DRAM_MEMORY_CONFIG,
return_output_dim=True,
return_weights_and_bias=True,
Expand All @@ -2788,3 +2783,72 @@ def test_shallow_conv_with_tiled_input(device):
passing, pcc_msg = check_with_pcc_without_tensor_printout(torch_output_tensor, torch_out_golden_tensor, pcc=0.99)
logger.info(f"PCC = {pcc_msg}. Threshold = 0.99")
assert passing


# Tests running conv2d which maps to matmul w/o sharding the input tensor.
# Output tensor is in DRAM.
@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True)
@pytest.mark.parametrize("tiled_input", [True, False])
@pytest.mark.parametrize("input_on_device", [True, False])
def test_dram_input_mm_conv(device, tiled_input, input_on_device):
batch_size = 1
out_channels, in_channels = 256, 1024
img_h, img_w = 128, 128
input_shape = (batch_size, in_channels, img_h, img_w)

# Params which map conv2d to matmul op.
kernel_h, kernel_w = 1, 1
stride = (1, 1)
dilation = (1, 1)
pad = (0, 0)

kernel_shape = (out_channels, in_channels, kernel_h, kernel_w)
torch_kernel = torch.randn(kernel_shape, dtype=torch.bfloat16)
tt_kernel = ttnn.from_torch(torch_kernel)

torch_input = torch.randn(input_shape, dtype=torch.bfloat16)
if input_on_device:
tt_input = ttnn.from_torch(torch_input, device=device)
tt_input = ttnn.permute(tt_input, (0, 2, 3, 1))
tt_input = ttnn.reshape(tt_input, (1, 1, batch_size * img_h * img_w, in_channels))
else:
torch_input_nhwc = torch.permute(torch_input, (0, 2, 3, 1))
tt_input = ttnn.from_torch(torch_input_nhwc)

if tiled_input:
tt_input = ttnn.to_layout(tt_input, ttnn.TILE_LAYOUT)

tt_out = ttnn.conv2d(
input_tensor=tt_input,
weight_tensor=tt_kernel,
in_channels=in_channels,
out_channels=out_channels,
device=device,
kernel_size=(kernel_h, kernel_w),
stride=stride,
padding=pad,
dilation=dilation,
batch_size=batch_size,
input_height=img_h,
input_width=img_w,
)

assert tt_out.memory_config().memory_layout == ttnn.TensorMemoryLayout.INTERLEAVED

tt_output_tensor = ttnn.from_device(tt_out)
torch_output_tensor = ttnn.to_torch(tt_output_tensor)

# torch_output_tensor is in row major layout and NHWC shape
# NHWC to NCHW
torch_output_tensor = torch_output_tensor.reshape(batch_size, img_h, img_w, torch_output_tensor.shape[-1])
torch_output_tensor = torch_output_tensor[:, :, :, :out_channels]

torch_output_tensor = torch.permute(torch_output_tensor, (0, 3, 1, 2))

torch_out_golden_tensor = torch.nn.functional.conv2d(
torch_input, torch_kernel, bias=None, stride=stride, padding=pad, dilation=dilation, groups=1
)

passing, pcc_msg = check_with_pcc_without_tensor_printout(torch_output_tensor, torch_out_golden_tensor, pcc=0.99)
logger.info(f"PCC = {pcc_msg}. Threshold = 0.99")
assert passing
99 changes: 50 additions & 49 deletions ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,25 @@
//
// SPDX-License-Identifier: Apache-2.0

#include "conv2d.hpp"
#include "conv2d_utils.hpp"
#include "prepare_conv2d_weights.hpp"
#include <sys/types.h>
#include <cstdint>
#include <optional>
#include <utility>

#include "common/constants.hpp"
#include "common/math.hpp"
#include "impl/buffers/buffer_constants.hpp"
#include "ttnn/operations/conv/conv2d/device/conv2d_op.hpp"
#include "ttnn/operations/core/compute_kernel/compute_kernel_config.hpp"
#include "ttnn/operations/core/core.hpp"
#include "ttnn/operations/pool/downsample/device/downsample_op.hpp"
#include "ttnn/operations/sliding_window/sliding_window.hpp"

#include "tt_metal/impl/buffers/buffer_constants.hpp"

#include "ttnn/tensor/tensor.hpp"
#include "ttnn/tensor/types.hpp"
#include "ttnn/operations/core/core.hpp"
#include "ttnn/operations/core/compute_kernel/compute_kernel_config.hpp"
#include "ttnn/operations/conv/conv2d/conv2d.hpp"
#include "ttnn/operations/conv/conv2d/conv2d_utils.hpp"
#include "ttnn/operations/conv/conv2d/prepare_conv2d_weights.hpp"
#include "ttnn/operations/conv/conv2d/device/conv2d_op.hpp"
#include "ttnn/operations/matmul/matmul.hpp"
#include "ttnn/operations/sliding_window/halo/halo.hpp"
#include "ttnn/operations/sliding_window/sliding_window.hpp"

using namespace tt;
namespace ttnn {
Expand Down Expand Up @@ -60,6 +61,7 @@ Result conv2d(
Conv2dConfig conv_config = conv_config_.value_or(Conv2dConfig());
const auto compute_grid_size = device->compute_with_storage_grid_size();

bool auto_shard = false;
if (!input_tensor.is_sharded() && !conv_config.shard_layout.has_value()) {
// In this case we deduce the shard layout.
adjust_conv_op_config_for_auto_shard_if_necessary(
Expand All @@ -74,7 +76,9 @@ Result conv2d(
compute_grid_size,
conv_config,
input_tensor.layout(),
ttnn::is_tensor_on_device_or_multidevice(input_tensor) ? std::make_optional(input_tensor.memory_config()) : std::nullopt);
ttnn::is_tensor_on_device_or_multidevice(input_tensor) ? std::make_optional(input_tensor.memory_config())
: std::nullopt);
auto_shard = true;
}

ShardOrientation shard_orientation =
Expand All @@ -86,16 +90,21 @@ Result conv2d(
(conv_config.weights_dtype == DataType::BFLOAT8_B || conv_config.weights_dtype == DataType::BFLOAT16) &&
conv_config.output_layout == Layout::ROW_MAJOR && ((elem_size * in_channels) % (16 * num_cores_c)) == 0;

DeviceComputeKernelConfig compute_config = compute_config_.value_or( init_device_compute_kernel_config(
device->arch(),
std::nullopt,
MathFidelity::HiFi4,
true,
false,
false
));
auto [input_tensor_post_tm, parallel_config, output_parallel_config, tensor_manipulated, use_non_tile_height] = shard_or_reshard_tensor_if_required(
device, input_tensor, conv_config, batch_size, output_height, output_width, in_channels, out_channels, mm_conv, is_non_tile_mul_width);
DeviceComputeKernelConfig compute_config = compute_config_.value_or(
init_device_compute_kernel_config(device->arch(), std::nullopt, MathFidelity::HiFi4, true, false, false));
auto [input_tensor_post_tm, parallel_config, output_parallel_config, tensor_manipulated, use_non_tile_height] =
shard_or_reshard_tensor_if_required(
device,
input_tensor,
conv_config,
batch_size,
output_height,
output_width,
in_channels,
out_channels,
mm_conv,
auto_shard,
is_non_tile_mul_width);
if (tensor_manipulated) {
if (conv_config.deallocate_activation) {
ttnn::Tensor input_tensor_ = input_tensor; // TODO: allow in place modification of inputs to the op
Expand Down Expand Up @@ -255,36 +264,28 @@ Result conv2d(
return {conv_output, output_height, output_width, weight_tensor_on_device, bias_tensor_on_device};
} else {
// run conv as matmul
uint32_t num_cores_c = get_num_cores_channels_from_parallel_config(parallel_config);
auto matmul_program_config = determine_matmul_op_config_from_conv_op_config(
opt_conv_op_parallel_config,
opt_conv_op_block_config,
parallel_config.shard_scheme == TensorMemoryLayout::HEIGHT_SHARDED,
conv_config.activation,
parallel_config.shard_orientation == ShardOrientation::COL_MAJOR,
num_cores_c);
Tensor matmul_input = input_tensor_post_tm;
if (stride[0] > 1) {
// run downsample
matmul_input = ttnn::operations::downsample::downsample(
input_tensor_post_tm, {batch_size, input_height, input_width, stride[0], stride[1]});
if (conv_config.deallocate_activation) {
ttnn::operations::core::deallocate(input_tensor_post_tm);
}
std::optional<ttnn::operations::matmul::MatmulProgramConfig> program_config = std::nullopt;
std::optional<MemoryConfig> mm_output_memory_config = std::nullopt;
if (input_tensor_post_tm.is_sharded()) {
uint32_t num_cores_c = get_num_cores_channels_from_parallel_config(parallel_config);
program_config = determine_matmul_op_config_from_conv_op_config(
opt_conv_op_parallel_config,
opt_conv_op_block_config,
parallel_config.shard_scheme == TensorMemoryLayout::HEIGHT_SHARDED,
conv_config.activation,
parallel_config.shard_orientation == ShardOrientation::COL_MAJOR,
num_cores_c);
mm_output_memory_config = conv_out_memory_config;
}
auto matmul_output = ttnn::operations::matmul::matmul(
matmul_input,
Tensor matmul_output = ttnn::linear(
input_tensor_post_tm,
weight_tensor_on_device,
bias_tensor_on_device,
ttnn::operations::matmul::Matmul{
matmul_program_config,
/*bcast_batch=*/std::nullopt,
conv_out_memory_config,
conv_config.dtype,
compute_config});
if (conv_config.deallocate_activation) {
ttnn::operations::core::deallocate(matmul_input);
}
false,
false,
mm_output_memory_config,
std::nullopt,
program_config);

if (memory_config.has_value() && memory_config.value() != matmul_output.memory_config()) {
matmul_output = ttnn::to_memory_config(matmul_output, memory_config.value(), std::nullopt);
Expand Down
15 changes: 2 additions & 13 deletions ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,11 @@

#pragma once
#include <optional>
#include <unordered_set>

#include "conv2d_utils.hpp"
#include "ttnn/core.hpp"
#include "ttnn/operations/core/core.hpp"
#include "ttnn/operations/matmul/matmul.hpp"
#include "ttnn/operations/matmul/device/matmul_op.hpp"
#include "ttnn/types.hpp"
#include "ttnn/tensor/tensor_utils.hpp"
#include "tt_metal/impl/dispatch/command_queue.hpp"
#include "tt_metal/common/math.hpp"
#include "ttnn/operations/data_movement/pad/pad.hpp"
#include "ttnn/operations/conv/conv2d/device/conv2d_op.hpp"
#include "ttnn/tensor/tensor.hpp"
#include "ttnn/operations/sliding_window/sliding_window.hpp"
#include "ttnn/operations/sliding_window/halo/halo.hpp"
#include "ttnn/decorators.hpp"
#include "ttnn/operations/conv/conv2d/conv2d_utils.hpp"

namespace ttnn {

Expand Down
Loading

0 comments on commit 0631201

Please sign in to comment.