Skip to content

Commit

Permalink
#6123: Add support for backward mvlgamma
Browse files Browse the repository at this point in the history
  • Loading branch information
bharane-ab committed Mar 13, 2024
1 parent 8c9c00a commit 4cfeddb
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 1 deletion.
2 changes: 2 additions & 0 deletions docs/source/ttnn/dependencies/tt_lib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1052,6 +1052,8 @@ Backward Operations

.. autofunction:: tt_lib.tensor.unary_remainder_bw

.. autofunction:: tt_lib.tensor.multigammaln_bw

Loss Functions
==============

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import torch
import pytest
import tt_lib
from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import compare_results, data_gen_pt_tt


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 32, 32])),
(torch.Size([1, 1, 320, 384])),
(torch.Size([1, 3, 320, 384])),
),
)
def test_bw_multigammaln(input_shapes, device):
grad_data, grad_tensor = data_gen_pt_tt(input_shapes, device)

in_data = torch.Tensor(size=input_shapes).uniform_(3, 10)
in_data.requires_grad = True
input_tensor = (
tt_lib.tensor.Tensor(in_data, tt_lib.tensor.DataType.BFLOAT16).to(tt_lib.tensor.Layout.TILE).to(device)
)

pyt_y = torch.mvlgamma(in_data, 4)

tt_output_tensor_on_device = tt_lib.tensor.multigammaln_bw(grad_tensor, input_tensor)

in_data.retain_grad()

pyt_y.backward(gradient=grad_data)

golden_tensor = [in_data.grad]
comp_pass = compare_results(tt_output_tensor_on_device, golden_tensor)

assert comp_pass
21 changes: 21 additions & 0 deletions tt_eager/tt_dnn/op_library/backward/backward_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1637,6 +1637,27 @@ std::vector<Tensor> unary_remainder_bw(const Tensor& grad, const Tensor& input,
return operation::decorate_as_composite(__func__, _unary_remainder_bw)(grad, input, scalar, output_mem_config);
}

std::vector<Tensor> _multigammaln_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) {
std::vector<Tensor> grad_tensor;
Tensor digamma_result = mul(grad, digamma(input, output_mem_config), std::nullopt, output_mem_config);
Tensor digamma_result_2 = mul(grad, digamma(add_unary(-0.5 , input, output_mem_config), output_mem_config), std::nullopt, output_mem_config);

Tensor grad_result = add(digamma_result, digamma_result_2, std::nullopt, output_mem_config);

digamma_result = mul(grad, digamma(add_unary(-1.0 , input, output_mem_config), output_mem_config), std::nullopt, output_mem_config);
grad_result = add(grad_result, digamma_result, std::nullopt, output_mem_config);

digamma_result = mul(grad, digamma(add_unary(-1.5 , input, output_mem_config), output_mem_config), std::nullopt, output_mem_config);
grad_result = add(grad_result, digamma_result, std::nullopt, output_mem_config);

grad_tensor.emplace_back(grad_result);
return grad_tensor;
}
std::vector<Tensor> multigammaln_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config)
{
return operation::decorate_as_composite(__func__, _multigammaln_bw)(grad, input, output_mem_config);
}

}//namespace tt_metal

}//namespace tt
1 change: 1 addition & 0 deletions tt_eager/tt_dnn/op_library/backward/backward_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ std::vector<Tensor> unary_fmod_bw(const Tensor& grad, const Tensor& input, float

std::vector<Tensor> unary_remainder_bw(const Tensor& grad, const Tensor& input, float eps=0.0f, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);

std::vector<Tensor> multigammaln_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);
} //namespace tt_metal

} //namespace tt
18 changes: 18 additions & 0 deletions tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_backward_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1911,5 +1911,23 @@ namespace tt::tt_metal::detail{
"scalar", "scalar value", "float", "float", "Yes"
"output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No"
)doc");

m_tensor.def("multigammaln_bw", &tt::tt_metal::multigammaln_bw,
py::arg("grad").noconvert(), py::arg("input").noconvert(), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc(
Performs backward operations for multigammaln of ``input`` tensors with given ``grad`` and value of P is taken as 4.
mvlgamma is refered as multigammaln.
Inputs greater than 2.5f
Input tensors must have BFLOAT16 data type.
Output tensors will have BFLOAT16 data type.
.. csv-table::
:header: "Argument", "Description", "Data type", "Valid range", "Required"
"grad", "Gradient tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes"
"input", "Tensor mvlgamma is applied to", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes"
"output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No"
)doc");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ namespace tt::tt_metal::detail{
);
detail::bind_unary_op(m_tensor, "digamma", &digamma, R"doc(Computes the logarithmic derivative of the gamma function on input tensor ``{0}`` for the input range 1 to inf.)doc");
detail::bind_unary_op(m_tensor, "lgamma", &lgamma, R"doc(Computes the natural logarithm of the absolute value of the gamma function on the ``{0}`` tensor for inputs greater than 0.)doc");
detail::bind_unary_op(m_tensor, "multigammaln", &multigammaln, R"doc(Computes the multivariate log-gamma function with dimension 4 element-wise on the input tensor ``{0}`` for inputs greater than 1.5f.)doc");
detail::bind_unary_op(m_tensor, "multigammaln", &multigammaln, R"doc(Computes the multivariate log-gamma function with dimension 4 element-wise on the input tensor ``{0}`` for inputs greater than 1.5f. mvlgamma is refered as multigammaln.)doc");

detail::bind_unary_op_with_param(
m_tensor, "softshrink", &softshrink,
Expand Down

0 comments on commit 4cfeddb

Please sign in to comment.