Skip to content

Commit

Permalink
#0: DRAM Conv w/o padding
Browse files Browse the repository at this point in the history
  • Loading branch information
sankarmanoj-tt authored and mywoodstock committed Sep 27, 2024
1 parent 4afe632 commit 25092f2
Show file tree
Hide file tree
Showing 6 changed files with 282 additions and 25 deletions.
37 changes: 37 additions & 0 deletions CMakeCache.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# This is the CMakeCache file.
# For build in directory: /localdev/smanoj/tt-metal
# It was generated by CMake: /usr/bin/cmake
# You can edit this file to change values found and used by cmake.
# If you do not want to change any of the values, simply exit the editor.
# If you do want to change a value, simply edit, save, and exit the editor.
# The syntax for the file is as follows:
# KEY:TYPE=VALUE
# KEY is the name of a variable in the cache.
# TYPE is a hint to GUIs for the type of VALUE, DO NOT EDIT TYPE!.
# VALUE is the current value for the KEY.

########################
# EXTERNAL cache entries
########################


########################
# INTERNAL cache entries
########################

//This is the directory where this CMakeCache.txt was created
CMAKE_CACHEFILE_DIR:INTERNAL=/localdev/smanoj/tt-metal
//Major version of cmake used to create the current loaded cache
CMAKE_CACHE_MAJOR_VERSION:INTERNAL=3
//Minor version of cmake used to create the current loaded cache
CMAKE_CACHE_MINOR_VERSION:INTERNAL=16
//Patch version of cmake used to create the current loaded cache
CMAKE_CACHE_PATCH_VERSION:INTERNAL=3
//Path to CMake executable.
CMAKE_COMMAND:INTERNAL=/usr/bin/cmake
//Path to cpack program executable.
CMAKE_CPACK_COMMAND:INTERNAL=/usr/bin/cpack
//Path to ctest program executable.
CMAKE_CTEST_COMMAND:INTERNAL=/usr/bin/ctest
//Path to CMake installation.
CMAKE_ROOT:INTERNAL=/usr/share/cmake-3.16
1 change: 1 addition & 0 deletions CMakeFiles/cmake.check_cache
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# This file is generated by cmake for dependency checking of the CMakeCache.txt file
145 changes: 145 additions & 0 deletions tests/ttnn/unit_tests/operations/test_new_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,151 @@ def run_conv_with_split(


@skip_for_grayskull()
@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True)
@pytest.mark.parametrize("stride", [1])
@pytest.mark.parametrize(
"output_channels, input_channels, input_height, input_width, filter_height, filter_width, pad_h, pad_w, act_block_w_div",
((64, 32, 130, 130, 3, 3, 0, 0, 1),),
)
@pytest.mark.parametrize(
"has_bias",
[True],
)
@pytest.mark.parametrize(
"weights_dtype",
[ttnn.bfloat16],
)
@pytest.mark.parametrize(
"activations_dtype",
[ttnn.bfloat16],
)
def test_conv_dram(
device,
use_program_cache,
output_channels,
input_channels,
input_height,
input_width,
filter_height,
filter_width,
pad_h,
pad_w,
act_block_w_div,
stride,
has_bias,
weights_dtype,
activations_dtype,
):
if is_grayskull():
if input_channels >= 2048:
pytest.skip("Skipping on grayskull due to insufficient L1")
if input_channels >= 768 and input_height >= 10:
pytest.skip("Skipping on grayskull due to insufficient L1")

stride_h = stride
stride_w = stride
batch_size = 2
fp32_accum = False
packer_l1_acc = False
deallocate_activation = False
debug = False
groups = 1

torch.manual_seed(0)
conv_input_shape = [batch_size, input_channels, input_height, input_width]
conv_weight_shape = [output_channels, input_channels // groups, filter_height, filter_width]
conv_bias_shape = [1, 1, 1, output_channels]

torch_input_tensor_nchw = torch.randn(conv_input_shape, dtype=torch.bfloat16).float()
torch_input_tensor_nchw = torch_input_tensor_nchw.broadcast_to(conv_input_shape).float()
torch_input_tensor = torch.permute(torch_input_tensor_nchw, (0, 2, 3, 1))

torch_weight_tensor = torch.randn(conv_weight_shape, dtype=torch.bfloat16).float()

tt_bias_tensor = None
torch_bias_tensor = None
if has_bias:
torch_bias_tensor = torch.randn(conv_bias_shape, dtype=torch.bfloat16).float() * 50
tt_bias_tensor = ttnn.from_torch(
torch_bias_tensor, weights_dtype if weights_dtype != ttnn.bfloat8_b else ttnn.float32
)
torch_bias_tensor = torch_bias_tensor.reshape(-1)
ref = torch.nn.functional.conv2d(
torch_input_tensor_nchw,
torch_weight_tensor,
bias=torch_bias_tensor,
stride=(stride_h, stride_w),
padding=(pad_h, pad_w),
groups=groups,
)
output_shape_nhwc = [
ref.shape[0],
ref.shape[2],
ref.shape[3],
ref.shape[1],
]

reader_patterns_cache = {}
tt_weight_tensor = ttnn.from_torch(
torch_weight_tensor, weights_dtype if weights_dtype != ttnn.bfloat8_b else ttnn.float32
)

tt_input_tensor = ttnn.from_torch(torch_input_tensor, device=device, dtype=ttnn.bfloat16)

conv_config = ttnn.Conv2dConfig(
dtype=activations_dtype,
weights_dtype=weights_dtype,
math_fidelity=ttnn.MathFidelity.HiFi4,
shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED,
input_channels_alignment=32,
deallocate_activation=deallocate_activation,
fp32_dest_acc_enabled=fp32_accum,
packer_l1_accum_enabled=packer_l1_acc,
enable_act_double_buffer=False,
enable_split_reader=False,
enable_subblock_padding=False,
reshard_if_not_optimal=True,
act_block_w_div=act_block_w_div,
output_height_in_l1=64,
)
[tt_output_tensor_on_device, out_height, out_width, weights_device, bias_device] = ttnn.conv2d(
input_tensor=tt_input_tensor,
weight_tensor=tt_weight_tensor,
in_channels=input_channels,
out_channels=output_channels,
device=device,
bias_tensor=tt_bias_tensor,
kernel_size=(filter_height, filter_width),
stride=(stride_h, stride_w),
padding=(pad_h, pad_w),
batch_size=batch_size,
input_height=input_height,
input_width=input_width,
conv_config=conv_config,
conv_op_cache=reader_patterns_cache,
debug=debug,
groups=groups,
)

tt_output_tensor = ttnn.from_device(tt_output_tensor_on_device)
out = tt_output_tensor.cpu().to_torch()

# out is in row major layout and NHWC shape
# NHWC to NCHW
# ref = torch.permute(ref, (0, 2, 3, 1))
out = out.reshape(batch_size, out_height, out_width, output_channels)

out = torch.permute(out, (0, 3, 1, 2))
reader_patterns_cache.clear()

pcc = 0.94
passing, pcc_msg = check_with_pcc_without_tensor_printout(out, ref, pcc=pcc)
logger.info(f"PCC = {pcc_msg}. Threshold = {pcc}")
if not passing:
logger.error("Fails with PCC ", pcc_msg)
assert passing


@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True)
@pytest.mark.parametrize("stride", [1, 2])
@pytest.mark.parametrize(
Expand Down
2 changes: 1 addition & 1 deletion tests/ttnn/unit_tests/operations/test_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def slice_test(
)

tt_output_tensor = ttnn.slice(tt_input_tensor, output_tensor_start, output_tensor_end, memory_config=out_mem_config)

print(tt_input_tensor.shape, output_tensor_start, output_tensor_end, tt_output_tensor.shape)
a_pt = ttnn.to_torch(tt_output_tensor)

# Pytorch reference
Expand Down
120 changes: 97 additions & 23 deletions ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,20 @@
#include "conv2d.hpp"
#include <sys/types.h>
#include <cstdint>
#include <optional>

#include "common/assert.hpp"
#include "ttnn/common/constants.hpp"
#include "ttnn/operations/core/core.hpp"
#include "ttnn/operations/pool/downsample/device/downsample_op.hpp"
#include "tt_metal/detail/reports/memory_reporter.hpp"
#include "ttnn/operations/core/to_dtype/to_dtype_op.hpp"
#include "tt_metal/common/work_split.hpp"
#include "ttnn/operations/eltwise/unary/common/unary_op_utils.hpp"
#include "ttnn/operations/data_movement/slice/slice.hpp"
#include "ttnn/operations/data_movement/concat/concat.hpp"
#include "ttnn/tensor/tensor.hpp"
#include "ttnn/tensor/tensor_utils.hpp"
#include "ttnn/tensor/types.hpp"

using namespace tt;
namespace ttnn {
Expand Down Expand Up @@ -676,32 +683,99 @@ std::tuple<ttnn::Tensor, uint32_t, uint32_t, ttnn::Tensor, std::optional<ttnn::T
uint32_t output_width = ((input_width - kernel_size[1] - ((kernel_size[0] - 1 ) * (dilation[0] - 1)) + 2 * padding[1]) / stride[1]) + 1;

Conv2dConfig conv_config = conv_config_.value_or(Conv2dConfig());
if(conv_config.output_height_in_l1 > 0)
{
TT_FATAL(conv_config.output_height_in_l1%32,"Input height in L1 must be a multiple of 32");
if(input_tensor.memory_config().is_dram())
if(conv_config.output_height_in_l1 > 0) {
TT_FATAL((conv_config.output_height_in_l1 % 32) == 0, "Output height in L1 must be a multiple of 32", conv_config.output_height_in_l1);
ttnn::Tensor input_tensor_on_device;
if( !is_tensor_on_device_or_multidevice(input_tensor)) {
input_tensor_on_device = ttnn::operations::core::to_device(input_tensor, device, std::nullopt);
}
else
{
for(uint32_t output_slice_height_start = 0; output_slice_height_start < output_height; output_slice_height_start+=conv_config.output_height_in_l1)
{
uint32_t output_slice_height_end = std::min(output_height, output_slice_height_start + conv_config.output_height_in_l1);
uint32_t output_slice_height = output_slice_height_end - output_slice_height_start;

int input_slice_height_start = output_slice_height_start * stride[0] - padding[0];
int input_slice_height_end = (output_slice_height_end - 1) * stride[0] - padding[0] + (kernel_size[0] - 1) * (dilation[0] - 1) + kernel_size[0];
// auto conv_config_l1 = conv_config;
// conv_config_l1.output_height_in_l1 = conv_config.output_height_in_l1;

// //Slice a part of input tensor along the input height dimension and move it to L1.
// auto input_tensor_l1 = dram_slice_to_l1_sharded(input_tensor, input_height_slice, conv_config.output_height_in_l1);

// [output_tensor_l1, output_height, output_width, weight_tensor_on_device, bias_tensor_on_device] =
// conv2d(input_tensor_l1, weight_tensor, device, in_channels, out_channels, batch_size, input_height, input_width, kernel_size, stride, padding, dilation, groups, bias_tensor, conv_config_l1);

// l1_to_dram_slice(output_tensor_l1, output_slice_height_start);
input_tensor_on_device = input_tensor;
}
ttnn::Tensor weight_tensor_on_device;
std::optional<ttnn::Tensor> bias_tensor_on_device;
if(input_tensor_on_device.memory_config().is_dram()) {
Tensor dram_output_tensor;
bool first_run = true;
for(int batch_index = 0; batch_index < batch_size; batch_index++) {
for(uint32_t output_slice_height_start = 0; output_slice_height_start < output_height; output_slice_height_start+=conv_config.output_height_in_l1) {
uint32_t output_slice_height_end = std::min(output_height, output_slice_height_start + conv_config.output_height_in_l1);
uint32_t output_slice_height = output_slice_height_end - output_slice_height_start;

if(output_slice_height == 0) {
continue;
}

int input_slice_height_start = output_slice_height_start * stride[0] - padding[0];
int input_slice_height_end = (output_slice_height_end - 1) * stride[0] - padding[0] + (kernel_size[0] - 1) * (dilation[0] - 1) + kernel_size[0];
int pad_top = std::max(0, -input_slice_height_start);
int pad_bottom = std::max<int>(0, input_slice_height_end - input_height);
input_slice_height_start = std::max(0, input_slice_height_start);
input_slice_height_end = std::min<int>(input_height, input_slice_height_end);
uint32_t input_slice_height = input_slice_height_end - input_slice_height_start;
log_debug(tt::LogOp, "Output Slice : {}->{}", output_slice_height_start, output_slice_height_end);
log_debug(tt::LogOp, "Input Slice : {}->{}", input_slice_height_start, input_slice_height_end);
log_debug(tt::LogOp, "Padding : {}->{}", pad_top, pad_bottom);

if(input_slice_height_start < input_slice_height_end) {
auto sliced_input_tensor = ttnn::slice(input_tensor,
std::array<uint32_t, 4>{batch_index, input_slice_height_start, 0, 0}, //Start
std::array<uint32_t, 4>{batch_index, input_slice_height_end - 1, input_width - 1,in_channels - 1}, //End - Inclusive
std::array<uint32_t, 4>{1, 1, 1, 1} //Step
);
log_debug(tt::LogOp, "Sliced input tensor shape: {}", sliced_input_tensor.get_shape());
if(pad_top>0)
{
auto pad_top_tensor = ttnn::pad(
sliced_input_tensor,
tt::tt_metal::Array4D({1, input_slice_height + pad_top, input_width, in_channels}),
tt::tt_metal::Array4D({0, 0, 0, 0}),
0);
sliced_input_tensor = pad_top_tensor;
}
log_debug(tt::LogOp, "Padded sliced input tensor shape: {}", sliced_input_tensor.get_shape());
auto conv_config_l1 = conv_config;
conv_config_l1.output_height_in_l1 = 0;
conv_config_l1.reshard_if_not_optimal = true;
ttnn::Tensor sliced_output_tensor;
std::tie(sliced_output_tensor, std::ignore, std::ignore, weight_tensor_on_device, bias_tensor_on_device) = conv2d(
sliced_input_tensor,
first_run ? weight_tensor : weight_tensor_on_device,
device,
in_channels,
out_channels,
1, input_slice_height + pad_top, input_width,
kernel_size, stride, {0,padding[1]}, dilation,
groups,
first_run ? bias_tensor : (std::optional<const ttnn::Tensor>)(bias_tensor_on_device),
conv_config_l1
);
sliced_output_tensor = ttnn::to_layout(sliced_output_tensor, Layout::ROW_MAJOR, std::nullopt,
MemoryConfig {
.memory_layout=TensorMemoryLayout::INTERLEAVED,
.buffer_type=BufferType::L1,
}, device);

if(first_run) {
dram_output_tensor = sliced_output_tensor;
}
else {
dram_output_tensor = ttnn::concat(
std::vector<ttnn::Tensor>{dram_output_tensor, sliced_output_tensor}, 2,
MemoryConfig{
.memory_layout=TensorMemoryLayout::INTERLEAVED,
.buffer_type=BufferType::DRAM,
});
}
log_debug(tt::LogOp, "Dram output tensor shape: {}", dram_output_tensor.get_shape());
first_run = false;
}
}
}
return {dram_output_tensor, output_height, output_width, weight_tensor_on_device, bias_tensor_on_device};
}
}

auto [input_tensor_post_tm, parallel_config, tensor_manipulated] = shard_or_reshard_tensor_if_required(
device, input_tensor, conv_config, batch_size, output_height, output_width, in_channels, out_channels);
if (tensor_manipulated) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ Tensor concat_impl(std::vector<Tensor> &input_tensors, const std::int64_t dim, c
"Current concat implementation requires aligned last dim when concatting on last dim");
}
}
Layout target_layout = Layout::TILE;
Layout target_layout = Layout::ROW_MAJOR;
for (const auto &input_tensor : input_tensors) {
if (input_tensor.get_layout() == Layout::ROW_MAJOR) {
const auto &input_shape = input_tensor.get_legacy_shape();
Expand Down

0 comments on commit 25092f2

Please sign in to comment.