From 195d0fc9569dc46789650355047404a370bb8f6a Mon Sep 17 00:00:00 2001 From: Bharane AB Date: Mon, 11 Mar 2024 11:30:56 +0000 Subject: [PATCH] #6123: Add support for backward mvlgamma --- docs/source/ttnn/dependencies/tt_lib.rst | 2 + .../backward_ops/test_backward_mvlgamma.py | 39 +++++++++++++++++++ .../op_library/backward/backward_ops.cpp | 21 ++++++++++ .../op_library/backward/backward_ops.hpp | 1 + .../tt_lib_bindings_tensor_backward_ops.cpp | 20 ++++++++++ .../tt_lib_bindings_tensor_composite_ops.cpp | 2 +- 6 files changed, 84 insertions(+), 1 deletion(-) create mode 100644 tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_mvlgamma.py diff --git a/docs/source/ttnn/dependencies/tt_lib.rst b/docs/source/ttnn/dependencies/tt_lib.rst index 928299511648..f08f1344fe68 100644 --- a/docs/source/ttnn/dependencies/tt_lib.rst +++ b/docs/source/ttnn/dependencies/tt_lib.rst @@ -1052,6 +1052,8 @@ Backward Operations .. autofunction:: tt_lib.tensor.unary_remainder_bw +.. autofunction:: tt_lib.tensor.multigammaln_bw + Loss Functions ============== diff --git a/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_mvlgamma.py b/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_mvlgamma.py new file mode 100644 index 000000000000..b8cdc619d175 --- /dev/null +++ b/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_mvlgamma.py @@ -0,0 +1,39 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import pytest +import tt_lib +from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import compare_results, data_gen_pt_tt + + +@pytest.mark.parametrize( + "input_shapes", + ( + (torch.Size([1, 1, 32, 32])), + (torch.Size([1, 1, 320, 384])), + (torch.Size([1, 3, 320, 384])), + ), +) +def test_bw_multigammaln(input_shapes, device): + grad_data, grad_tensor = data_gen_pt_tt(input_shapes, device) + + in_data = torch.Tensor(size=input_shapes).uniform_(3, 10) + in_data.requires_grad = True + input_tensor = ( + tt_lib.tensor.Tensor(in_data, tt_lib.tensor.DataType.BFLOAT16).to(tt_lib.tensor.Layout.TILE).to(device) + ) + + pyt_y = torch.mvlgamma(in_data, 4) + + tt_output_tensor_on_device = tt_lib.tensor.multigammaln_bw(grad_tensor, input_tensor) + + in_data.retain_grad() + + pyt_y.backward(gradient=grad_data) + + golden_tensor = [in_data.grad] + comp_pass = compare_results(tt_output_tensor_on_device, golden_tensor) + + assert comp_pass diff --git a/tt_eager/tt_dnn/op_library/backward/backward_ops.cpp b/tt_eager/tt_dnn/op_library/backward/backward_ops.cpp index 7be866b749d6..15c98a4626f0 100644 --- a/tt_eager/tt_dnn/op_library/backward/backward_ops.cpp +++ b/tt_eager/tt_dnn/op_library/backward/backward_ops.cpp @@ -1637,6 +1637,27 @@ std::vector unary_remainder_bw(const Tensor& grad, const Tensor& input, return operation::decorate_as_composite(__func__, _unary_remainder_bw)(grad, input, scalar, output_mem_config); } +std::vector _multigammaln_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { + std::vector grad_tensor; + Tensor digamma_result = mul(grad, digamma(input, output_mem_config), std::nullopt, output_mem_config); + Tensor digamma_result_2 = mul(grad, digamma(add_unary(-0.5 , input, output_mem_config), output_mem_config), std::nullopt, output_mem_config); + + Tensor grad_result = add(digamma_result, digamma_result_2, std::nullopt, output_mem_config); + + digamma_result = mul(grad, digamma(add_unary(-1.0 , input, output_mem_config), output_mem_config), std::nullopt, output_mem_config); + grad_result = add(grad_result, digamma_result, std::nullopt, output_mem_config); + + digamma_result = mul(grad, digamma(add_unary(-1.5 , input, output_mem_config), output_mem_config), std::nullopt, output_mem_config); + grad_result = add(grad_result, digamma_result, std::nullopt, output_mem_config); + + grad_tensor.emplace_back(grad_result); + return grad_tensor; +} +std::vector multigammaln_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) +{ + return operation::decorate_as_composite(__func__, _multigammaln_bw)(grad, input, output_mem_config); +} + }//namespace tt_metal }//namespace tt diff --git a/tt_eager/tt_dnn/op_library/backward/backward_ops.hpp b/tt_eager/tt_dnn/op_library/backward/backward_ops.hpp index ae08edad3343..b9c75abdf716 100644 --- a/tt_eager/tt_dnn/op_library/backward/backward_ops.hpp +++ b/tt_eager/tt_dnn/op_library/backward/backward_ops.hpp @@ -249,6 +249,7 @@ std::vector unary_fmod_bw(const Tensor& grad, const Tensor& input, float std::vector unary_remainder_bw(const Tensor& grad, const Tensor& input, float eps=0.0f, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); +std::vector multigammaln_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); } //namespace tt_metal } //namespace tt diff --git a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_backward_ops.cpp b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_backward_ops.cpp index 24f3b8cd5465..52123bc64da4 100644 --- a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_backward_ops.cpp +++ b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_backward_ops.cpp @@ -1911,5 +1911,25 @@ namespace tt::tt_metal::detail{ "scalar", "scalar value", "float", "float", "Yes" "output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No" )doc"); + + m_tensor.def("multigammaln_bw", &tt::tt_metal::multigammaln_bw, + py::arg("grad").noconvert(), py::arg("input").noconvert(), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc( + Performs backward operations for multigammaln of ``input`` tensors with given ``grad`` and value of P is taken as 4. + + mvlgamma is refered as multigammaln. + + Input value must be greater than 2.5f + + Input tensors must have BFLOAT16 data type. + + Output tensors will have BFLOAT16 data type. + + .. csv-table:: + :header: "Argument", "Description", "Data type", "Valid range", "Required" + + "grad", "Gradient tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes" + "input", "Tensor mvlgamma is applied to", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes" + "output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No" + )doc"); } } diff --git a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_composite_ops.cpp b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_composite_ops.cpp index 1e6b175d8798..7b373ac7c025 100644 --- a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_composite_ops.cpp +++ b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_composite_ops.cpp @@ -185,7 +185,7 @@ namespace tt::tt_metal::detail{ ); detail::bind_unary_op(m_tensor, "digamma", &digamma, R"doc(Computes the logarithmic derivative of the gamma function on input tensor ``{0}`` for the input range 1 to inf.)doc"); detail::bind_unary_op(m_tensor, "lgamma", &lgamma, R"doc(Computes the natural logarithm of the absolute value of the gamma function on the ``{0}`` tensor for inputs greater than 0.)doc"); - detail::bind_unary_op(m_tensor, "multigammaln", &multigammaln, R"doc(Computes the multivariate log-gamma function with dimension 4 element-wise on the input tensor ``{0}`` for inputs greater than 1.5f.)doc"); + detail::bind_unary_op(m_tensor, "multigammaln", &multigammaln, R"doc(Computes the multivariate log-gamma function with dimension 4 element-wise on the input tensor ``{0}`` for inputs greater than 1.5f. mvlgamma is refered as multigammaln.)doc"); detail::bind_unary_op_with_param( m_tensor, "softshrink", &softshrink,