From e1524a33db22e1a74bda645652f8f38a7ef3fda1 Mon Sep 17 00:00:00 2001 From: ruthreshk Date: Fri, 23 Feb 2024 15:07:40 +0000 Subject: [PATCH] #5644: Add backward support for repeat --- docs/source/ttnn/ttnn/dependencies/tt_lib.rst | 2 + .../backward_ops/test_backward_repeat.py | 34 ++++++++++++++++ .../op_library/backward/backward_ops.cpp | 40 +++++++++++++++++++ .../op_library/backward/backward_ops.hpp | 3 ++ .../tt_lib_bindings_tensor_backward_ops.cpp | 16 ++++++++ 5 files changed, 95 insertions(+) create mode 100644 tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_repeat.py diff --git a/docs/source/ttnn/ttnn/dependencies/tt_lib.rst b/docs/source/ttnn/ttnn/dependencies/tt_lib.rst index 5c81255ff90a..4eb2a5e7095c 100644 --- a/docs/source/ttnn/ttnn/dependencies/tt_lib.rst +++ b/docs/source/ttnn/ttnn/dependencies/tt_lib.rst @@ -1058,6 +1058,8 @@ Backward Operations .. autofunction:: tt_lib.tensor.multigammaln_bw +.. autofunction:: tt_lib.tensor.repeat_bw + Loss Functions ============== diff --git a/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_repeat.py b/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_repeat.py new file mode 100644 index 000000000000..86ab542f586c --- /dev/null +++ b/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_repeat.py @@ -0,0 +1,34 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import pytest +import tt_lib +from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import data_gen_pt_tt, compare_pcc + + +@pytest.mark.parametrize( + "input_shapes", + ( + (torch.Size([1, 1, 32, 32])), + (torch.Size([1, 1, 320, 384])), + ), +) +@pytest.mark.parametrize("sizes", [[12, 1, 1, 1], [6, 1, 1, 1], [1, 24, 1, 1], [1, 3, 1, 1]]) +def test_bw_repeat(input_shapes, sizes, device): + in_data, input_tensor = data_gen_pt_tt(input_shapes, device, True) + + pyt_y = in_data.repeat(sizes) + + grad_data, grad_tensor = data_gen_pt_tt(pyt_y.shape, device, True) + + tt_output_tensor_on_device = tt_lib.tensor.repeat_bw(grad_tensor, input_tensor, sizes) + + in_data.retain_grad() + + pyt_y.backward(gradient=grad_data) + + golden_tensor = [in_data.grad] + status = compare_pcc(tt_output_tensor_on_device, golden_tensor) + assert status diff --git a/tt_eager/tt_dnn/op_library/backward/backward_ops.cpp b/tt_eager/tt_dnn/op_library/backward/backward_ops.cpp index deb4320eed52..b73e05448f7f 100644 --- a/tt_eager/tt_dnn/op_library/backward/backward_ops.cpp +++ b/tt_eager/tt_dnn/op_library/backward/backward_ops.cpp @@ -5,6 +5,8 @@ #include "tt_dnn/op_library/composite/composite_ops.hpp" #include "tt_dnn/op_library/backward/backward_ops.hpp" #include "tt_dnn/op_library/reduce/reduce_op.hpp" +#include "tt_dnn/op_library/reshape/reshape_op.hpp" +#include "tt_dnn/op_library/moreh_sum/moreh_sum_op.hpp" #include "tt_dnn/op_library/embeddings/embeddings_op.hpp" #include "tt_numpy/functions.hpp" #include "tt_eager/tensor/tensor_utils.hpp" @@ -1783,6 +1785,44 @@ std::vector multigammaln_bw(const Tensor& grad, const Tensor& input, con return operation::decorate_as_composite(__func__, _multigammaln_bw)(grad, input, output_mem_config); } +// Repeat Backward +std::vector _repeat_bw(const Tensor& grad, const Tensor& input, const Shape& shape, const MemoryConfig& output_mem_config) { + std::vector grad_tensor; + auto shape_wh = input.get_legacy_shape(); + TT_ASSERT( shape_wh[0] == 1 && "input shape[0] should be 1"); + // input.get_legacy_shape()[0] + // If repeat shape has 0's, it returns zeros of given input + if (shape[0] == 0 || shape[1] == 0 || shape[2] == 0 || shape[3] == 0) { + Tensor zero_tensor = zeros_like(input, output_mem_config); + grad_tensor.emplace_back(zero_tensor); + return grad_tensor; + } + else if (shape[0] > 1){ + std::vector dim = {0}; + TT_ASSERT( shape[1] == 1 && shape[2] == 1 && shape[3] == 1 && "repeat[1], [2], [3] should be 1"); + Shape required = {1, shape_wh[1], shape_wh[2], shape_wh[3]}; + Tensor result = tt::operations::primary::moreh_sum(grad, zeros(required, input.get_dtype(), input.get_layout(), input.device(), output_mem_config), dim, output_mem_config); + grad_tensor.emplace_back(result); + return grad_tensor; + } + else if (shape[1] > 1) + { + std::vector dim = {1}; + TT_ASSERT( shape[0] == 1 && shape[2] == 1 && shape[3] == 1 && "repeat[0], [2], [3] should be 1"); + Shape required = {shape_wh[0], 1, shape_wh[2], shape_wh[3]}; + Tensor result = tt::operations::primary::moreh_sum(grad, zeros(required, input.get_dtype(), input.get_layout(), input.device(), output_mem_config), dim, output_mem_config); + grad_tensor.emplace_back(result); + return grad_tensor; + } + return grad_tensor; + +} +std::vector repeat_bw(const Tensor& grad, const Tensor& input, const Shape& shape, const MemoryConfig& output_mem_config) +{ + return operation::decorate_as_composite(__func__, _repeat_bw)(grad, input, shape, output_mem_config); +} + + }//namespace tt_metal }//namespace tt diff --git a/tt_eager/tt_dnn/op_library/backward/backward_ops.hpp b/tt_eager/tt_dnn/op_library/backward/backward_ops.hpp index fa7bf64f6210..7297a2afa6fd 100644 --- a/tt_eager/tt_dnn/op_library/backward/backward_ops.hpp +++ b/tt_eager/tt_dnn/op_library/backward/backward_ops.hpp @@ -255,6 +255,9 @@ std::vector imag_bw(const Tensor& grad, const Tensor& input, const Memor std::vector real_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); std::vector multigammaln_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector repeat_bw(const Tensor& grad, const Tensor& input, const Shape& shape, const MemoryConfig& output_mem_config); + } //namespace tt_metal } //namespace tt diff --git a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_backward_ops.cpp b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_backward_ops.cpp index 7bee0b4dc61e..fd3bd86d9085 100644 --- a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_backward_ops.cpp +++ b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_backward_ops.cpp @@ -504,6 +504,22 @@ namespace tt::tt_metal::detail{ "output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No" )doc"); + m_tensor.def("repeat_bw", &tt::tt_metal::repeat_bw, + py::arg("grad").noconvert(), py::arg("input").noconvert(), py::arg("shape"), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc( + Returns a new tensor filled with repetition of input ``input`` tensor according to number of times specified in ``shape``. The rank of ``shape`` should be same as rank of tensor ``input_a``. + The limitation in our implementation is N and C should be 1 and the repeat is of any number for such dim, other should be 1. + + Output tensor will have BFLOAT16 data type. + + .. csv-table:: + :header: "Argument", "Description", "Data type", "Valid range", "Required" + + "grad", "Gradient tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes" + "input", "Input tensor for which repetition is computed", "Tensor", "Tensor of shape [1, Z, Y, X]", "Yes" + "shape", "Shape value", "Shape", "The number of times to repeat this tensor along each dimension", "Yes" + "output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No" + )doc"); + m_tensor.def("unary_sub_bw", &tt::tt_metal::unary_sub_bw, py::arg("grad").noconvert(), py::arg("input").noconvert(), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc( Performs backward operations for subraction of ``input`` tensors with given ``grad``.