From 4afe6329c313aeba0a3dc9ca72e2ec21ec1b4ea8 Mon Sep 17 00:00:00 2001 From: Sankar Manoj Date: Thu, 12 Sep 2024 10:47:48 +0530 Subject: [PATCH] #0: Wrote I/O equations --- .../ttnn/operations/conv/conv2d/conv2d.cpp | 28 ++++++++++++------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp index f7aa0e8f165..54de724d1ef 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp @@ -6,6 +6,7 @@ #include #include +#include "common/assert.hpp" #include "ttnn/operations/pool/downsample/device/downsample_op.hpp" #include "tt_metal/detail/reports/memory_reporter.hpp" #include "ttnn/operations/core/to_dtype/to_dtype_op.hpp" @@ -671,29 +672,36 @@ std::tuple bias_tensor, std::optional conv_config_) { + uint32_t output_height = ((input_height - kernel_size[0] - ((kernel_size[0] - 1 ) * (dilation[0] - 1)) + 2 * padding[0]) / stride[0]) + 1; + uint32_t output_width = ((input_width - kernel_size[1] - ((kernel_size[0] - 1 ) * (dilation[0] - 1)) + 2 * padding[1]) / stride[1]) + 1; + Conv2dConfig conv_config = conv_config_.value_or(Conv2dConfig()); if(conv_config.output_height_in_l1 > 0) { TT_FATAL(conv_config.output_height_in_l1%32,"Input height in L1 must be a multiple of 32"); if(input_tensor.memory_config().is_dram()) { - for(uint32_t input_height_slice = 0; input_height_slice < input_height; input_height_slice+=conv_config.output_height_in_l1) + for(uint32_t output_slice_height_start = 0; output_slice_height_start < output_height; output_slice_height_start+=conv_config.output_height_in_l1) { - auto conv_config_l1 = conv_config; - conv_config_l1.output_height_in_l1 = conv_config.output_height_in_l1; + uint32_t output_slice_height_end = std::min(output_height, output_slice_height_start + conv_config.output_height_in_l1); + uint32_t output_slice_height = output_slice_height_end - output_slice_height_start; + + int input_slice_height_start = output_slice_height_start * stride[0] - padding[0]; + int input_slice_height_end = (output_slice_height_end - 1) * stride[0] - padding[0] + (kernel_size[0] - 1) * (dilation[0] - 1) + kernel_size[0]; + // auto conv_config_l1 = conv_config; + // conv_config_l1.output_height_in_l1 = conv_config.output_height_in_l1; - //Slice a part of input tensor along the input height dimension and move it to L1. - auto input_tensor_l1 = dram_slice_to_l1_sharded(input_tensor, input_height_slice, conv_config.output_height_in_l1); + // //Slice a part of input tensor along the input height dimension and move it to L1. + // auto input_tensor_l1 = dram_slice_to_l1_sharded(input_tensor, input_height_slice, conv_config.output_height_in_l1); - [output_tensor_l1, output_height, output_width, weight_tensor_on_device, bias_tensor_on_device] = - conv2d(input_tensor_l1, weight_tensor, device, in_channels, out_channels, batch_size, input_height, input_width, kernel_size, stride, padding, dilation, groups, bias_tensor, conv_config_l1); + // [output_tensor_l1, output_height, output_width, weight_tensor_on_device, bias_tensor_on_device] = + // conv2d(input_tensor_l1, weight_tensor, device, in_channels, out_channels, batch_size, input_height, input_width, kernel_size, stride, padding, dilation, groups, bias_tensor, conv_config_l1); - l1_to_dram_slice(output_tensor_l1, output_height_slice_start); + // l1_to_dram_slice(output_tensor_l1, output_slice_height_start); } } } - uint32_t output_height = ((input_height - kernel_size[0] - ((kernel_size[0] - 1 ) * (dilation[0] - 1)) + 2 * padding[0]) / stride[0]) + 1; - uint32_t output_width = ((input_width - kernel_size[1] - ((kernel_size[0] - 1 ) * (dilation[0] - 1)) + 2 * padding[1]) / stride[1]) + 1; + auto [input_tensor_post_tm, parallel_config, tensor_manipulated] = shard_or_reshard_tensor_if_required( device, input_tensor, conv_config, batch_size, output_height, output_width, in_channels, out_channels); if (tensor_manipulated) {