Skip to content

Commit

Permalink
#6123: Add support for backward mvlgamma
Browse files Browse the repository at this point in the history
  • Loading branch information
bharane-ab committed Mar 18, 2024
1 parent 1267b58 commit b1d3e1b
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 1 deletion.
2 changes: 2 additions & 0 deletions docs/source/ttnn/dependencies/tt_lib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1058,6 +1058,8 @@ Backward Operations

.. autofunction:: tt_lib.tensor.real_bw

.. autofunction:: tt_lib.tensor.multigammaln_bw

Loss Functions
==============

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import torch
import pytest
import tt_lib
from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import compare_results, data_gen_pt_tt


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 32, 32])),
(torch.Size([1, 1, 320, 384])),
(torch.Size([1, 3, 320, 384])),
),
)
def test_bw_multigammaln(input_shapes, device):
grad_data, grad_tensor = data_gen_pt_tt(input_shapes, device)

in_data = torch.Tensor(size=input_shapes).uniform_(3, 10)
in_data.requires_grad = True
input_tensor = (
tt_lib.tensor.Tensor(in_data, tt_lib.tensor.DataType.BFLOAT16).to(tt_lib.tensor.Layout.TILE).to(device)
)

pyt_y = torch.mvlgamma(in_data, 4)

tt_output_tensor_on_device = tt_lib.tensor.multigammaln_bw(grad_tensor, input_tensor)

in_data.retain_grad()

pyt_y.backward(gradient=grad_data)

golden_tensor = [in_data.grad]
comp_pass = compare_results(tt_output_tensor_on_device, golden_tensor)

assert comp_pass
21 changes: 21 additions & 0 deletions tt_eager/tt_dnn/op_library/backward/backward_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1688,6 +1688,27 @@ std::vector<Tensor> real_bw(const Tensor& grad, const Tensor& input, const Memor

#undef CHECK_FOR_COMPLEX

std::vector<Tensor> _multigammaln_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) {
std::vector<Tensor> grad_tensor;
Tensor digamma_result = mul(grad, digamma(input, output_mem_config), std::nullopt, output_mem_config);
Tensor digamma_result_2 = mul(grad, digamma(add_unary(-0.5 , input, output_mem_config), output_mem_config), std::nullopt, output_mem_config);

Tensor grad_result = add(digamma_result, digamma_result_2, std::nullopt, output_mem_config);

digamma_result = mul(grad, digamma(add_unary(-1.0 , input, output_mem_config), output_mem_config), std::nullopt, output_mem_config);
grad_result = add(grad_result, digamma_result, std::nullopt, output_mem_config);

digamma_result = mul(grad, digamma(add_unary(-1.5 , input, output_mem_config), output_mem_config), std::nullopt, output_mem_config);
grad_result = add(grad_result, digamma_result, std::nullopt, output_mem_config);

grad_tensor.emplace_back(grad_result);
return grad_tensor;
}
std::vector<Tensor> multigammaln_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config)
{
return operation::decorate_as_composite(__func__, _multigammaln_bw)(grad, input, output_mem_config);
}

}//namespace tt_metal

}//namespace tt
1 change: 1 addition & 0 deletions tt_eager/tt_dnn/op_library/backward/backward_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ std::vector<Tensor> imag_bw(const Tensor& grad, const Tensor& input, const Memor

std::vector<Tensor> real_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);

std::vector<Tensor> multigammaln_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);
} //namespace tt_metal

} //namespace tt
20 changes: 20 additions & 0 deletions tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_backward_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1958,5 +1958,25 @@ namespace tt::tt_metal::detail{
"input", "Input Tensor", "Tensor", "Tensor of complex shape [W, Z, Y, X]", "Yes"
"output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No"
)doc");

m_tensor.def("multigammaln_bw", &tt::tt_metal::multigammaln_bw,
py::arg("grad").noconvert(), py::arg("input").noconvert(), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc(
Performs backward operations for multigammaln of ``input`` tensors with given ``grad`` and value of P is taken as 4.
mvlgamma is refered as multigammaln.
Input value must be greater than 2.5f
Input tensors must have BFLOAT16 data type.
Output tensors will have BFLOAT16 data type.
.. csv-table::
:header: "Argument", "Description", "Data type", "Valid range", "Required"
"grad", "Gradient tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes"
"input", "Tensor mvlgamma is applied to", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes"
"output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No"
)doc");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ namespace tt::tt_metal::detail{
);
detail::bind_unary_op(m_tensor, "digamma", &digamma, R"doc(Computes the logarithmic derivative of the gamma function on input tensor ``{0}`` for the input range 1 to inf.)doc");
detail::bind_unary_op(m_tensor, "lgamma", &lgamma, R"doc(Computes the natural logarithm of the absolute value of the gamma function on the ``{0}`` tensor for inputs greater than 0.)doc");
detail::bind_unary_op(m_tensor, "multigammaln", &multigammaln, R"doc(Computes the multivariate log-gamma function with dimension 4 element-wise on the input tensor ``{0}`` for inputs greater than 1.5f.)doc");
detail::bind_unary_op(m_tensor, "multigammaln", &multigammaln, R"doc(Computes the multivariate log-gamma function with dimension 4 element-wise on the input tensor ``{0}`` for inputs greater than 1.5f. mvlgamma is refered as multigammaln.)doc");

detail::bind_unary_op_with_param(
m_tensor, "softshrink", &softshrink,
Expand Down

0 comments on commit b1d3e1b

Please sign in to comment.