Skip to content

Commit

Permalink
Revert "#16012: Revert conv2d changes because of perf regressions, pc…
Browse files Browse the repository at this point in the history
…c regressions, and increase in runtime (#16019)"

This reverts commit 44ea114.
  • Loading branch information
Pavle Josipovic committed Dec 14, 2024
1 parent a3801c4 commit be9426c
Show file tree
Hide file tree
Showing 12 changed files with 248 additions and 273 deletions.
2 changes: 1 addition & 1 deletion models/demos/yolov4/tests/test_perf_yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@


def get_expected_times():
return (40, 16.2)
return (40, 16.9)


@pytest.mark.models_performance_bare_metal
Expand Down
31 changes: 0 additions & 31 deletions tests/sweep_framework/sweeps/conv2d/short/conv2d_short_sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -1623,48 +1623,17 @@ def test_conv2d_localrun(device, input_spec):
[1, 1056, 1056, 96, 96, 3, 3, 2, 2, 1, 1, 4, False, 1], # 127
[1, 528, 528, 192, 192, 3, 3, 2, 2, 1, 1, 2, False, 1], # 220
[1, 2904, 2904, 48, 48, 3, 3, 2, 2, 1, 1, 11, False, 1], # 294
[1, 1024, 1024, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], # 1407
[1, 256, 1024, 128, 128, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1408
[1, 1056, 1056, 48, 48, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1417
[1, 2904, 1056, 48, 48, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1418
[1, 3024, 1232, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1420
[1, 3024, 1232, 14, 14, 1, 1, 2, 2, 0, 0, 1, False, 1], # 1421
[1, 2520, 1344, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1423
[1, 3712, 1392, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1425
[1, 1024, 1440, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1427
[1, 448, 1632, 12, 12, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1431
[1, 2520, 2520, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1442
[1, 819, 256, 100, 136, 3, 3, 1, 1, 1, 1, 1, True, 1], # 1443
[1, 819, 256, 50, 68, 3, 3, 1, 1, 1, 1, 1, True, 1], # 1447
[1, 2904, 264, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], # 1451
[1, 264, 2904, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], # 1452
[1, 726, 2904, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], # 1453
[1, 7392, 2904, 24, 24, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1455
[1, 1024, 3, 224, 224, 32, 32, 32, 32, 0, 0, 1, True, 1], # 1458
[1, 768, 3, 224, 224, 32, 32, 32, 32, 0, 0, 1, False, 1], # 1460
[1, 768, 3, 224, 224, 32, 32, 32, 32, 0, 0, 1, True, 1], # 1461
[1, 768, 3, 384, 512, 32, 32, 32, 32, 0, 0, 1, True, 1], # 1464
[1, 64, 3, 720, 1280, 7, 7, 2, 2, 3, 3, 1, False, 1], # 1471
[1, 64, 3, 800, 1088, 7, 7, 2, 2, 3, 3, 1, False, 1], # 1472
[1, 308, 3024, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], # 1473
[1, 3024, 3024, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1474
[1, 3024, 308, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], # 1475
[1, 3712, 348, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], # 1479
[1, 348, 3712, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], # 1480
[1, 3712, 3712, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1481
[1, 1056, 528, 96, 96, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1491
[1, 1, 64, 480, 640, 3, 3, 1, 1, 1, 1, 1, True, 1], # 1495
[1, 64, 64, 480, 640, 3, 3, 1, 1, 1, 1, 1, True, 1], # 1496
[1, 1392, 696, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1497
[1, 1920, 720, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1499
[1, 2904, 726, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], # 1501
[1, 7392, 726, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], # 1502
[1, 1024, 728, 19, 19, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1503
[1, 726, 7392, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], # 1506
[1, 7392, 7392, 12, 12, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1507
[1, 1024, 782, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1508
[1, 912, 912, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1509
[1, 1280, 960, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], # 1510
[1, 640, 1920, 32, 32, 3, 3, 1, 1, 1, 1, 1, True, 1], # 1522
[1, 320, 320, 64, 64, 3, 3, 1, 1, 1, 1, 1, True, 1], # 1530
[1, 320, 640, 64, 64, 3, 3, 1, 1, 1, 1, 1, True, 1], # 1540
Expand Down
82 changes: 73 additions & 9 deletions tests/ttnn/unit_tests/operations/test_new_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -2738,17 +2738,15 @@ def test_shallow_conv_with_tiled_input(device):
pad = (1, 1)

torch_kernel = torch.randn(kernel_shape, dtype=torch.bfloat16)
torch_input = torch.randn(input_shape, dtype=torch.bfloat16)

tt_kernel = ttnn.from_torch(torch_kernel)
tt_input = ttnn.to_device(ttnn.from_torch(torch_input), device)

tt_input = ttnn.to_layout(tt_input, ttnn.TILE_LAYOUT)
torch_input = torch.randn(input_shape, dtype=torch.bfloat16)
tt_input = ttnn.from_torch(torch_input, device=device)
tt_input = ttnn.permute(tt_input, (0, 2, 3, 1))

tt_input = ttnn.reshape(tt_input, (1, 1, batch_size * img_h * img_w, in_channels))
tt_input = ttnn.to_layout(tt_input, ttnn.TILE_LAYOUT)

[tt_out, [out_height, out_width], [weights_device, bias_device]] = ttnn.conv2d(
[tt_out, [out_height, out_width], [_, _]] = ttnn.conv2d(
input_tensor=tt_input,
weight_tensor=tt_kernel,
in_channels=in_channels,
Expand All @@ -2763,9 +2761,6 @@ def test_shallow_conv_with_tiled_input(device):
input_height=img_h,
input_width=img_w,
groups=1,
compute_config=ttnn.init_device_compute_kernel_config(
device.arch(),
),
memory_config=ttnn.DRAM_MEMORY_CONFIG,
return_output_dim=True,
return_weights_and_bias=True,
Expand All @@ -2788,3 +2783,72 @@ def test_shallow_conv_with_tiled_input(device):
passing, pcc_msg = check_with_pcc_without_tensor_printout(torch_output_tensor, torch_out_golden_tensor, pcc=0.99)
logger.info(f"PCC = {pcc_msg}. Threshold = 0.99")
assert passing


# Tests running conv2d which maps to matmul w/o sharding the input tensor.
# Output tensor is in DRAM.
@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True)
@pytest.mark.parametrize("tiled_input", [True, False])
@pytest.mark.parametrize("input_on_device", [True, False])
def test_dram_input_mm_conv(device, tiled_input, input_on_device):
batch_size = 1
out_channels, in_channels = 256, 1024
img_h, img_w = 128, 128
input_shape = (batch_size, in_channels, img_h, img_w)

# Params which map conv2d to matmul op.
kernel_h, kernel_w = 1, 1
stride = (1, 1)
dilation = (1, 1)
pad = (0, 0)

kernel_shape = (out_channels, in_channels, kernel_h, kernel_w)
torch_kernel = torch.randn(kernel_shape, dtype=torch.bfloat16)
tt_kernel = ttnn.from_torch(torch_kernel)

torch_input = torch.randn(input_shape, dtype=torch.bfloat16)
if input_on_device:
tt_input = ttnn.from_torch(torch_input, device=device)
tt_input = ttnn.permute(tt_input, (0, 2, 3, 1))
tt_input = ttnn.reshape(tt_input, (1, 1, batch_size * img_h * img_w, in_channels))
else:
torch_input_nhwc = torch.permute(torch_input, (0, 2, 3, 1))
tt_input = ttnn.from_torch(torch_input_nhwc)

if tiled_input:
tt_input = ttnn.to_layout(tt_input, ttnn.TILE_LAYOUT)

tt_out = ttnn.conv2d(
input_tensor=tt_input,
weight_tensor=tt_kernel,
in_channels=in_channels,
out_channels=out_channels,
device=device,
kernel_size=(kernel_h, kernel_w),
stride=stride,
padding=pad,
dilation=dilation,
batch_size=batch_size,
input_height=img_h,
input_width=img_w,
)

assert tt_out.memory_config().memory_layout == ttnn.TensorMemoryLayout.INTERLEAVED

tt_output_tensor = ttnn.from_device(tt_out)
torch_output_tensor = ttnn.to_torch(tt_output_tensor)

# torch_output_tensor is in row major layout and NHWC shape
# NHWC to NCHW
torch_output_tensor = torch_output_tensor.reshape(batch_size, img_h, img_w, torch_output_tensor.shape[-1])
torch_output_tensor = torch_output_tensor[:, :, :, :out_channels]

torch_output_tensor = torch.permute(torch_output_tensor, (0, 3, 1, 2))

torch_out_golden_tensor = torch.nn.functional.conv2d(
torch_input, torch_kernel, bias=None, stride=stride, padding=pad, dilation=dilation, groups=1
)

passing, pcc_msg = check_with_pcc_without_tensor_printout(torch_output_tensor, torch_out_golden_tensor, pcc=0.99)
logger.info(f"PCC = {pcc_msg}. Threshold = 0.99")
assert passing
99 changes: 50 additions & 49 deletions ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,25 @@
//
// SPDX-License-Identifier: Apache-2.0

#include "conv2d.hpp"
#include "conv2d_utils.hpp"
#include "prepare_conv2d_weights.hpp"
#include <sys/types.h>
#include <cstdint>
#include <optional>
#include <utility>

#include "common/constants.hpp"
#include "common/math.hpp"
#include "impl/buffers/buffer_constants.hpp"
#include "ttnn/operations/conv/conv2d/device/conv2d_op.hpp"
#include "ttnn/operations/core/compute_kernel/compute_kernel_config.hpp"
#include "ttnn/operations/core/core.hpp"
#include "ttnn/operations/pool/downsample/device/downsample_op.hpp"
#include "ttnn/operations/sliding_window/sliding_window.hpp"

#include "tt_metal/impl/buffers/buffer_constants.hpp"

#include "ttnn/tensor/tensor.hpp"
#include "ttnn/tensor/types.hpp"
#include "ttnn/operations/core/core.hpp"
#include "ttnn/operations/core/compute_kernel/compute_kernel_config.hpp"
#include "ttnn/operations/conv/conv2d/conv2d.hpp"
#include "ttnn/operations/conv/conv2d/conv2d_utils.hpp"
#include "ttnn/operations/conv/conv2d/prepare_conv2d_weights.hpp"
#include "ttnn/operations/conv/conv2d/device/conv2d_op.hpp"
#include "ttnn/operations/matmul/matmul.hpp"
#include "ttnn/operations/sliding_window/halo/halo.hpp"
#include "ttnn/operations/sliding_window/sliding_window.hpp"

using namespace tt;
namespace ttnn {
Expand Down Expand Up @@ -60,6 +61,7 @@ Result conv2d(
Conv2dConfig conv_config = conv_config_.value_or(Conv2dConfig());
const auto compute_grid_size = device->compute_with_storage_grid_size();

bool auto_shard = false;
if (!input_tensor.is_sharded() && !conv_config.shard_layout.has_value()) {
// In this case we deduce the shard layout.
adjust_conv_op_config_for_auto_shard_if_necessary(
Expand All @@ -74,7 +76,9 @@ Result conv2d(
compute_grid_size,
conv_config,
input_tensor.layout(),
ttnn::is_tensor_on_device_or_multidevice(input_tensor) ? std::make_optional(input_tensor.memory_config()) : std::nullopt);
ttnn::is_tensor_on_device_or_multidevice(input_tensor) ? std::make_optional(input_tensor.memory_config())
: std::nullopt);
auto_shard = true;
}

ShardOrientation shard_orientation =
Expand All @@ -86,16 +90,21 @@ Result conv2d(
(conv_config.weights_dtype == DataType::BFLOAT8_B || conv_config.weights_dtype == DataType::BFLOAT16) &&
conv_config.output_layout == Layout::ROW_MAJOR && ((elem_size * in_channels) % (16 * num_cores_c)) == 0;

DeviceComputeKernelConfig compute_config = compute_config_.value_or(init_device_compute_kernel_config(
device->arch(),
std::nullopt,
MathFidelity::HiFi4,
true,
false,
false
));
auto [input_tensor_post_tm, parallel_config, output_parallel_config, tensor_manipulated, use_non_tile_height] = shard_or_reshard_tensor_if_required(
device, input_tensor, conv_config, batch_size, output_height, output_width, in_channels, out_channels, mm_conv, is_non_tile_mul_width);
DeviceComputeKernelConfig compute_config = compute_config_.value_or(
init_device_compute_kernel_config(device->arch(), std::nullopt, MathFidelity::HiFi4, true, false, false));
auto [input_tensor_post_tm, parallel_config, output_parallel_config, tensor_manipulated, use_non_tile_height] =
shard_or_reshard_tensor_if_required(
device,
input_tensor,
conv_config,
batch_size,
output_height,
output_width,
in_channels,
out_channels,
mm_conv,
auto_shard,
is_non_tile_mul_width);
if (tensor_manipulated) {
if (conv_config.deallocate_activation) {
ttnn::Tensor input_tensor_ = input_tensor; // TODO: allow in place modification of inputs to the op
Expand Down Expand Up @@ -255,36 +264,28 @@ Result conv2d(
return {conv_output, output_height, output_width, weight_tensor_on_device, bias_tensor_on_device};
} else {
// run conv as matmul
uint32_t num_cores_c = get_num_cores_channels_from_parallel_config(parallel_config);
auto matmul_program_config = determine_matmul_op_config_from_conv_op_config(
opt_conv_op_parallel_config,
opt_conv_op_block_config,
parallel_config.shard_scheme == TensorMemoryLayout::HEIGHT_SHARDED,
conv_config.activation,
parallel_config.shard_orientation == ShardOrientation::COL_MAJOR,
num_cores_c);
Tensor matmul_input = input_tensor_post_tm;
if (stride[0] > 1) {
// run downsample
matmul_input = ttnn::operations::downsample::downsample(
input_tensor_post_tm, {batch_size, input_height, input_width, stride[0], stride[1]});
if (conv_config.deallocate_activation) {
ttnn::operations::core::deallocate(input_tensor_post_tm);
}
std::optional<ttnn::operations::matmul::MatmulProgramConfig> program_config = std::nullopt;
std::optional<MemoryConfig> mm_output_memory_config = std::nullopt;
if (input_tensor_post_tm.is_sharded()) {
uint32_t num_cores_c = get_num_cores_channels_from_parallel_config(parallel_config);
program_config = determine_matmul_op_config_from_conv_op_config(
opt_conv_op_parallel_config,
opt_conv_op_block_config,
parallel_config.shard_scheme == TensorMemoryLayout::HEIGHT_SHARDED,
conv_config.activation,
parallel_config.shard_orientation == ShardOrientation::COL_MAJOR,
num_cores_c);
mm_output_memory_config = conv_out_memory_config;
}
auto matmul_output = ttnn::operations::matmul::matmul(
matmul_input,
Tensor matmul_output = ttnn::linear(
input_tensor_post_tm,
weight_tensor_on_device,
bias_tensor_on_device,
ttnn::operations::matmul::Matmul{
matmul_program_config,
/*bcast_batch=*/std::nullopt,
conv_out_memory_config,
conv_config.dtype,
compute_config});
if (conv_config.deallocate_activation) {
ttnn::operations::core::deallocate(matmul_input);
}
false,
false,
mm_output_memory_config,
std::nullopt,
program_config);

if (memory_config.has_value() && memory_config.value() != matmul_output.memory_config()) {
matmul_output = ttnn::to_memory_config(matmul_output, memory_config.value(), std::nullopt);
Expand Down
15 changes: 2 additions & 13 deletions ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,11 @@

#pragma once
#include <optional>
#include <unordered_set>

#include "conv2d_utils.hpp"
#include "ttnn/core.hpp"
#include "ttnn/operations/core/core.hpp"
#include "ttnn/operations/matmul/matmul.hpp"
#include "ttnn/operations/matmul/device/matmul_op.hpp"
#include "ttnn/types.hpp"
#include "ttnn/tensor/tensor_utils.hpp"
#include "tt_metal/impl/dispatch/command_queue.hpp"
#include "tt_metal/common/math.hpp"
#include "ttnn/operations/data_movement/pad/pad.hpp"
#include "ttnn/operations/conv/conv2d/device/conv2d_op.hpp"
#include "ttnn/tensor/tensor.hpp"
#include "ttnn/operations/sliding_window/sliding_window.hpp"
#include "ttnn/operations/sliding_window/halo/halo.hpp"
#include "ttnn/decorators.hpp"
#include "ttnn/operations/conv/conv2d/conv2d_utils.hpp"

namespace ttnn {

Expand Down
Loading

0 comments on commit be9426c

Please sign in to comment.