Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conv2D with Input & Output in DRAM #13229

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 150 additions & 0 deletions tests/ttnn/unit_tests/operations/test_new_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,156 @@ def run_conv_with_split(


@skip_for_grayskull()
@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True)
@pytest.mark.parametrize("stride", [1])
@pytest.mark.parametrize(
"output_channels, input_channels, input_height, input_width, filter_height, filter_width, pad_h, pad_w, act_block_w_div",
(
(64, 32, 130, 130, 3, 3, 0, 0, 1),
(64, 32, 128, 128, 3, 3, 1, 1, 1),
(64, 32, 1024, 1024, 3, 3, 1, 1, 1),
),
)
@pytest.mark.parametrize(
"has_bias",
[True],
)
@pytest.mark.parametrize(
"weights_dtype",
[ttnn.bfloat16],
)
@pytest.mark.parametrize(
"activations_dtype",
[ttnn.bfloat16],
)
def test_conv_dram(
device,
use_program_cache,
output_channels,
input_channels,
input_height,
input_width,
filter_height,
filter_width,
pad_h,
pad_w,
act_block_w_div,
stride,
has_bias,
weights_dtype,
activations_dtype,
):
if is_grayskull():
if input_channels >= 2048:
pytest.skip("Skipping on grayskull due to insufficient L1")
if input_channels >= 768 and input_height >= 10:
pytest.skip("Skipping on grayskull due to insufficient L1")

stride_h = stride
stride_w = stride
batch_size = 2
fp32_accum = False
packer_l1_acc = False
deallocate_activation = False
debug = False
groups = 1

torch.manual_seed(0)
conv_input_shape = [batch_size, input_channels, input_height, input_width]
conv_weight_shape = [output_channels, input_channels // groups, filter_height, filter_width]
conv_bias_shape = [1, 1, 1, output_channels]

torch_input_tensor_nchw = torch.randn(conv_input_shape, dtype=torch.bfloat16).float()
torch_input_tensor_nchw = torch_input_tensor_nchw.broadcast_to(conv_input_shape).float()
torch_input_tensor = torch.permute(torch_input_tensor_nchw, (0, 2, 3, 1))

torch_weight_tensor = torch.randn(conv_weight_shape, dtype=torch.bfloat16).float()

tt_bias_tensor = None
torch_bias_tensor = None
if has_bias:
torch_bias_tensor = torch.randn(conv_bias_shape, dtype=torch.bfloat16).float() * 50
tt_bias_tensor = ttnn.from_torch(
torch_bias_tensor, weights_dtype if weights_dtype != ttnn.bfloat8_b else ttnn.float32
)
torch_bias_tensor = torch_bias_tensor.reshape(-1)
ref = torch.nn.functional.conv2d(
torch_input_tensor_nchw,
torch_weight_tensor,
bias=torch_bias_tensor,
stride=(stride_h, stride_w),
padding=(pad_h, pad_w),
groups=groups,
)
output_shape_nhwc = [
ref.shape[0],
ref.shape[2],
ref.shape[3],
ref.shape[1],
]

reader_patterns_cache = {}
tt_weight_tensor = ttnn.from_torch(
torch_weight_tensor, weights_dtype if weights_dtype != ttnn.bfloat8_b else ttnn.float32
)

tt_input_tensor = ttnn.from_torch(torch_input_tensor, device=device, dtype=ttnn.bfloat16)

conv_config = ttnn.Conv2dConfig(
dtype=activations_dtype,
weights_dtype=weights_dtype,
math_fidelity=ttnn.MathFidelity.HiFi4,
shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED,
input_channels_alignment=32,
deallocate_activation=deallocate_activation,
fp32_dest_acc_enabled=fp32_accum,
packer_l1_accum_enabled=packer_l1_acc,
enable_act_double_buffer=False,
enable_split_reader=False,
enable_subblock_padding=False,
reshard_if_not_optimal=True,
act_block_w_div=act_block_w_div,
output_height_in_l1=64,
act_block_h_override=64,
)
[tt_output_tensor_on_device, out_height, out_width, weights_device, bias_device] = ttnn.conv2d(
input_tensor=tt_input_tensor,
weight_tensor=tt_weight_tensor,
in_channels=input_channels,
out_channels=output_channels,
device=device,
bias_tensor=tt_bias_tensor,
kernel_size=(filter_height, filter_width),
stride=(stride_h, stride_w),
padding=(pad_h, pad_w),
batch_size=batch_size,
input_height=input_height,
input_width=input_width,
conv_config=conv_config,
conv_op_cache=reader_patterns_cache,
debug=debug,
groups=groups,
)

tt_output_tensor = ttnn.from_device(tt_output_tensor_on_device)
out = tt_output_tensor.cpu().to_torch()

# out is in row major layout and NHWC shape
# NHWC to NCHW
# ref = torch.permute(ref, (0, 2, 3, 1))
out = out.reshape(batch_size, out_height, out_width, output_channels)

out = torch.permute(out, (0, 3, 1, 2))
reader_patterns_cache.clear()

pcc = 0.94
passing, pcc_msg = check_with_pcc_without_tensor_printout(out, ref, pcc=pcc)
logger.info(f"PCC = {pcc_msg}. Threshold = {pcc}")
if not passing:
logger.error("Fails with PCC ", pcc_msg)
assert passing


@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True)
@pytest.mark.parametrize("stride", [1, 2])
@pytest.mark.parametrize(
Expand Down
104 changes: 103 additions & 1 deletion ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,20 @@
#include "conv2d.hpp"
#include <sys/types.h>
#include <cstdint>
#include <optional>

#include "ttnn/common/constants.hpp"
#include "ttnn/operations/core/core.hpp"
#include "ttnn/operations/pool/downsample/device/downsample_op.hpp"
#include "tt_metal/detail/reports/memory_reporter.hpp"
#include "ttnn/operations/core/to_dtype/to_dtype_op.hpp"
#include "tt_metal/common/work_split.hpp"
#include "ttnn/operations/eltwise/unary/common/unary_op_utils.hpp"
#include "ttnn/operations/data_movement/slice/slice.hpp"
#include "ttnn/operations/data_movement/concat/concat.hpp"
#include "ttnn/tensor/tensor.hpp"
#include "ttnn/tensor/tensor_utils.hpp"
#include "ttnn/tensor/types.hpp"

using namespace tt;
namespace ttnn {
Expand Down Expand Up @@ -671,9 +679,103 @@ std::tuple<ttnn::Tensor, uint32_t, uint32_t, ttnn::Tensor, std::optional<ttnn::T
std::optional<const ttnn::Tensor> bias_tensor,
std::optional<const Conv2dConfig> conv_config_) {

Conv2dConfig conv_config = conv_config_.value_or(Conv2dConfig());
uint32_t output_height = ((input_height - kernel_size[0] - ((kernel_size[0] - 1 ) * (dilation[0] - 1)) + 2 * padding[0]) / stride[0]) + 1;
uint32_t output_width = ((input_width - kernel_size[1] - ((kernel_size[0] - 1 ) * (dilation[0] - 1)) + 2 * padding[1]) / stride[1]) + 1;

Conv2dConfig conv_config = conv_config_.value_or(Conv2dConfig());
if(conv_config.output_height_in_l1 > 0) {
TT_FATAL((conv_config.output_height_in_l1 % 32) == 0, "Output height in L1 must be a multiple of 32", conv_config.output_height_in_l1);
ttnn::Tensor input_tensor_on_device;
if( !is_tensor_on_device_or_multidevice(input_tensor)) {
input_tensor_on_device = ttnn::operations::core::to_device(input_tensor, device, std::nullopt);
}
else
{
input_tensor_on_device = input_tensor;
}
ttnn::Tensor weight_tensor_on_device;
std::optional<ttnn::Tensor> bias_tensor_on_device;
if(input_tensor_on_device.memory_config().is_dram()) {
Tensor dram_output_tensor;
bool first_run = true;
for(int batch_index = 0; batch_index < batch_size; batch_index++) {
for(uint32_t output_slice_height_start = 0; output_slice_height_start < output_height; output_slice_height_start+=conv_config.output_height_in_l1) {
uint32_t output_slice_height_end = std::min(output_height, output_slice_height_start + conv_config.output_height_in_l1);
uint32_t output_slice_height = output_slice_height_end - output_slice_height_start;

if(output_slice_height == 0) {
continue;
}

int input_slice_height_start = output_slice_height_start * stride[0] - padding[0];
int input_slice_height_end = (output_slice_height_end - 1) * stride[0] - padding[0] + (kernel_size[0] - 1) * (dilation[0] - 1) + kernel_size[0];
int pad_top = std::max(0, -input_slice_height_start);
int pad_bottom = std::max<int>(0, input_slice_height_end - input_height);
input_slice_height_start = std::max(0, input_slice_height_start);
input_slice_height_end = std::min<int>(input_height, input_slice_height_end);
uint32_t input_slice_height = input_slice_height_end - input_slice_height_start;
log_debug(tt::LogOp, "Output Slice : {}->{}", output_slice_height_start, output_slice_height_end);
log_debug(tt::LogOp, "Input Slice : {}->{}", input_slice_height_start, input_slice_height_end);
log_debug(tt::LogOp, "Padding : {}->{}", pad_top, pad_bottom);

if(input_slice_height_start < input_slice_height_end) {
auto sliced_input_tensor = ttnn::slice(input_tensor,
std::array<uint32_t, 4>{batch_index, input_slice_height_start, 0, 0}, //Start
std::array<uint32_t, 4>{batch_index, input_slice_height_end - 1, input_width - 1,in_channels - 1}, //End - Inclusive
std::array<uint32_t, 4>{1, 1, 1, 1} //Step
);
log_debug(tt::LogOp, "Sliced input tensor shape: {}", sliced_input_tensor.get_shape());
if(pad_top>0 || pad_bottom > 0)
{
auto pad_top_tensor = ttnn::pad(
DefaultQueueId,
sliced_input_tensor,
std::vector<std::pair<uint32_t, uint32_t>>{{0, 0}, {pad_top, pad_bottom}, {0, 0}, {0, 0}},
0, true, std::nullopt);
sliced_input_tensor = pad_top_tensor;
}
log_debug(tt::LogOp, "Padded sliced input tensor shape: {}", sliced_input_tensor.get_shape());
auto conv_config_l1 = conv_config;
conv_config_l1.output_height_in_l1 = 0;
conv_config_l1.reshard_if_not_optimal = true;
ttnn::Tensor sliced_output_tensor;
std::tie(sliced_output_tensor, std::ignore, std::ignore, weight_tensor_on_device, bias_tensor_on_device) = conv2d(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. This function is already too long. You should not add this much new code into it. Please refactor and have new code in a new function.
  2. Do not make this function recursive. Please have multiple levels instead.

sliced_input_tensor,
first_run ? weight_tensor : weight_tensor_on_device,
device,
in_channels,
out_channels,
1, input_slice_height + pad_top + pad_bottom, input_width,
kernel_size, stride, {0,padding[1]}, dilation,
groups,
first_run ? bias_tensor : (std::optional<const ttnn::Tensor>)(bias_tensor_on_device),
conv_config_l1
);
sliced_output_tensor = ttnn::to_layout(sliced_output_tensor, Layout::ROW_MAJOR, std::nullopt,
MemoryConfig {
.memory_layout=TensorMemoryLayout::INTERLEAVED,
.buffer_type=BufferType::L1,
}, device);

if(first_run) {
dram_output_tensor = sliced_output_tensor;
}
else {
dram_output_tensor = ttnn::concat(
std::vector<ttnn::Tensor>{dram_output_tensor, sliced_output_tensor}, 2,
MemoryConfig{
.memory_layout=TensorMemoryLayout::INTERLEAVED,
.buffer_type=BufferType::DRAM,
});
}
log_debug(tt::LogOp, "Dram output tensor shape: {}", dram_output_tensor.get_shape());
first_run = false;
}
}
}
return {dram_output_tensor, output_height, output_width, weight_tensor_on_device, bias_tensor_on_device};
}
}
auto [input_tensor_post_tm, parallel_config, tensor_manipulated] = shard_or_reshard_tensor_if_required(
device, input_tensor, conv_config, batch_size, output_height, output_width, in_channels, out_channels);
if (tensor_manipulated) {
Expand Down
7 changes: 5 additions & 2 deletions ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ struct Conv2dConfig {
bool enable_act_double_buffer = false;
bool enable_split_reader = false;
bool enable_subblock_padding = false;
uint32_t output_height_in_l1 = 0;
static constexpr auto attribute_names = std::make_tuple(
"math_fidelity",
"dtype",
Expand All @@ -68,7 +69,8 @@ struct Conv2dConfig {
"output_layout",
"enable_act_double_buffer",
"enable_split_reader",
"enable_subblock_padding");
"enable_subblock_padding",
"output_height_in_l1");
const auto attribute_values() const {
return std::make_tuple(
std::cref(this->math_fidelity),
Expand All @@ -91,7 +93,8 @@ struct Conv2dConfig {
std::cref(this->output_layout),
std::cref(this->enable_act_double_buffer),
std::cref(this->enable_split_reader),
std::cref(this->enable_subblock_padding));
std::cref(this->enable_subblock_padding),
std::cref(this->output_height_in_l1));
}
};

Expand Down
7 changes: 5 additions & 2 deletions ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ void py_bind_conv2d(py::module& module) {

auto py_conv_config = py::class_<Conv2dConfig>(module, "Conv2dConfig");
py_conv_config.def(
py::init<MathFidelity, DataType, DataType, bool, bool, bool, string, uint32_t, bool, bool, uint32_t, uint32_t, bool, bool, TensorMemoryLayout, std::optional<CoreRangeSet>, bool, Layout, bool, bool, bool>(),
py::init<MathFidelity, DataType, DataType, bool, bool, bool, string, uint32_t, bool, bool, uint32_t, uint32_t, bool, bool, TensorMemoryLayout, std::optional<CoreRangeSet>, bool, Layout, bool, bool, bool, uint32_t>(),
py::kw_only(),
py::arg("math_fidelity") = MathFidelity::HiFi4,
py::arg("dtype") = DataType::BFLOAT16,
Expand All @@ -208,7 +208,8 @@ void py_bind_conv2d(py::module& module) {
py::arg("output_layout") = Layout::TILE,
py::arg("enable_act_double_buffer") = false,
py::arg("enable_split_reader") = false,
py::arg("enable_subblock_padding") = false
py::arg("enable_subblock_padding") = false,
py::arg("output_height_in_l1") = 0
);
py_conv_config.def_readwrite("math_fidelity", &Conv2dConfig::math_fidelity);
py_conv_config.def_readwrite("dtype", &Conv2dConfig::dtype);
Expand All @@ -231,6 +232,8 @@ void py_bind_conv2d(py::module& module) {
py_conv_config.def_readwrite("enable_act_double_buffer", &Conv2dConfig::enable_act_double_buffer);
py_conv_config.def_readwrite("enable_split_reader", &Conv2dConfig::enable_split_reader);
py_conv_config.def_readwrite("enable_subblock_padding", &Conv2dConfig::enable_subblock_padding);
py_conv_config.def_readwrite("output_height_in_l1", &Conv2dConfig::output_height_in_l1);


py::class_<OptimizedConvParallelizationConfig>(module, "OptimizedConvParallelizationConfig")
.def(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ Tensor concat_impl(std::vector<Tensor> &input_tensors, const std::int64_t dim, c
"Current concat implementation requires aligned last dim when concatting on last dim");
}
}
Layout target_layout = Layout::TILE;
Layout target_layout = Layout::ROW_MAJOR;
for (const auto &input_tensor : input_tensors) {
if (input_tensor.get_layout() == Layout::ROW_MAJOR) {
const auto &input_shape = input_tensor.get_legacy_shape();
Expand Down