Skip to content

Commit

Permalink
#0: Wrote I/O equations
Browse files Browse the repository at this point in the history
  • Loading branch information
sankarmanoj-tt authored and mywoodstock committed Sep 27, 2024
1 parent fce451f commit 4afe632
Showing 1 changed file with 18 additions and 10 deletions.
28 changes: 18 additions & 10 deletions ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <sys/types.h>
#include <cstdint>

#include "common/assert.hpp"
#include "ttnn/operations/pool/downsample/device/downsample_op.hpp"
#include "tt_metal/detail/reports/memory_reporter.hpp"
#include "ttnn/operations/core/to_dtype/to_dtype_op.hpp"
Expand Down Expand Up @@ -671,29 +672,36 @@ std::tuple<ttnn::Tensor, uint32_t, uint32_t, ttnn::Tensor, std::optional<ttnn::T
std::optional<const ttnn::Tensor> bias_tensor,
std::optional<const Conv2dConfig> conv_config_) {

uint32_t output_height = ((input_height - kernel_size[0] - ((kernel_size[0] - 1 ) * (dilation[0] - 1)) + 2 * padding[0]) / stride[0]) + 1;
uint32_t output_width = ((input_width - kernel_size[1] - ((kernel_size[0] - 1 ) * (dilation[0] - 1)) + 2 * padding[1]) / stride[1]) + 1;

Conv2dConfig conv_config = conv_config_.value_or(Conv2dConfig());
if(conv_config.output_height_in_l1 > 0)
{
TT_FATAL(conv_config.output_height_in_l1%32,"Input height in L1 must be a multiple of 32");
if(input_tensor.memory_config().is_dram())
{
for(uint32_t input_height_slice = 0; input_height_slice < input_height; input_height_slice+=conv_config.output_height_in_l1)
for(uint32_t output_slice_height_start = 0; output_slice_height_start < output_height; output_slice_height_start+=conv_config.output_height_in_l1)
{
auto conv_config_l1 = conv_config;
conv_config_l1.output_height_in_l1 = conv_config.output_height_in_l1;
uint32_t output_slice_height_end = std::min(output_height, output_slice_height_start + conv_config.output_height_in_l1);
uint32_t output_slice_height = output_slice_height_end - output_slice_height_start;

int input_slice_height_start = output_slice_height_start * stride[0] - padding[0];
int input_slice_height_end = (output_slice_height_end - 1) * stride[0] - padding[0] + (kernel_size[0] - 1) * (dilation[0] - 1) + kernel_size[0];
// auto conv_config_l1 = conv_config;
// conv_config_l1.output_height_in_l1 = conv_config.output_height_in_l1;

//Slice a part of input tensor along the input height dimension and move it to L1.
auto input_tensor_l1 = dram_slice_to_l1_sharded(input_tensor, input_height_slice, conv_config.output_height_in_l1);
// //Slice a part of input tensor along the input height dimension and move it to L1.
// auto input_tensor_l1 = dram_slice_to_l1_sharded(input_tensor, input_height_slice, conv_config.output_height_in_l1);

[output_tensor_l1, output_height, output_width, weight_tensor_on_device, bias_tensor_on_device] =
conv2d(input_tensor_l1, weight_tensor, device, in_channels, out_channels, batch_size, input_height, input_width, kernel_size, stride, padding, dilation, groups, bias_tensor, conv_config_l1);
// [output_tensor_l1, output_height, output_width, weight_tensor_on_device, bias_tensor_on_device] =
// conv2d(input_tensor_l1, weight_tensor, device, in_channels, out_channels, batch_size, input_height, input_width, kernel_size, stride, padding, dilation, groups, bias_tensor, conv_config_l1);

l1_to_dram_slice(output_tensor_l1, output_height_slice_start);
// l1_to_dram_slice(output_tensor_l1, output_slice_height_start);
}
}
}
uint32_t output_height = ((input_height - kernel_size[0] - ((kernel_size[0] - 1 ) * (dilation[0] - 1)) + 2 * padding[0]) / stride[0]) + 1;
uint32_t output_width = ((input_width - kernel_size[1] - ((kernel_size[0] - 1 ) * (dilation[0] - 1)) + 2 * padding[1]) / stride[1]) + 1;

auto [input_tensor_post_tm, parallel_config, tensor_manipulated] = shard_or_reshard_tensor_if_required(
device, input_tensor, conv_config, batch_size, output_height, output_width, in_channels, out_channels);
if (tensor_manipulated) {
Expand Down

0 comments on commit 4afe632

Please sign in to comment.