Skip to content

Commit

Permalink
#6234: Add backward support for rdiv
Browse files Browse the repository at this point in the history
  • Loading branch information
Aswinmcw committed Mar 11, 2024
1 parent 49a6887 commit 3602407
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 0 deletions.
2 changes: 2 additions & 0 deletions docs/source/ttnn/dependencies/tt_lib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -846,6 +846,8 @@ Backward Operations

.. autofunction:: tt_lib.tensor.div_bw

.. autofunction:: tt_lib.tensor.rdiv_bw

.. autofunction:: tt_lib.tensor.sqrt_bw

.. autofunction:: tt_lib.tensor.mul_bw
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import torch
import pytest
import tt_lib
from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import data_gen_pt_tt, compare_results


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 32, 32])),
(torch.Size([1, 1, 320, 384])),
(torch.Size([1, 3, 320, 384])),
),
)
@pytest.mark.parametrize(
"round_mode",
(
"None",
"trunc",
"floor",
),
)
@pytest.mark.parametrize("scalar", [0.05, 1.0, 0.5, 0.12])
def test_bw_rdiv(input_shapes, scalar, round_mode, device):
in_data, input_tensor = data_gen_pt_tt(input_shapes, device, True)
grad_data, grad_tensor = data_gen_pt_tt(input_shapes, device)

tt_output_tensor_on_device = tt_lib.tensor.rdiv_bw(grad_tensor, input_tensor, scalar=scalar, round_mode=round_mode)

in_data.retain_grad()

if round_mode == "None":
round_mode = None
pyt_y = torch.div(scalar, in_data, rounding_mode=round_mode)

pyt_y.backward(gradient=grad_data)

golden_tensor = [in_data.grad]

status = compare_results(tt_output_tensor_on_device, golden_tensor)
assert status
17 changes: 17 additions & 0 deletions tt_eager/tt_dnn/op_library/backward/backward_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,23 @@ std::vector<Tensor> div_bw(const Tensor& grad, const Tensor& input, const Tensor
return operation::decorate_as_composite(__func__, _div_bw)(grad, input, other, output_mem_config);
}

std::vector<Tensor> _rdiv_bw(const Tensor& grad, const Tensor& input, float scalar, string round_mode, const MemoryConfig& output_mem_config) {
std::vector<Tensor> grad_tensor;
if (round_mode=="None"){
Tensor result = mul(neg(grad, output_mem_config) , (mul_unary(recip(square(input, output_mem_config)), scalar, output_mem_config)), std::nullopt, output_mem_config);
grad_tensor.emplace_back(result);
}
else{
Tensor result = zeros_like(grad, output_mem_config);
grad_tensor.emplace_back(result);
}
return grad_tensor;
}
std::vector<Tensor> rdiv_bw(const Tensor& grad, const Tensor& input, float scalar, string round_mode, const MemoryConfig& output_mem_config)
{
return operation::decorate_as_composite(__func__, _rdiv_bw)(grad, input, scalar, round_mode, output_mem_config);
}


std::vector<Tensor> _tanh_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) {
std::vector<Tensor> grad_tensor;
Expand Down
2 changes: 2 additions & 0 deletions tt_eager/tt_dnn/op_library/backward/backward_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ std::vector<Tensor> unary_div_bw(const Tensor& grad, const Tensor& input, float

std::vector<Tensor> div_bw(const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);

std::vector<Tensor> rdiv_bw(const Tensor& grad, const Tensor& input, float scalar, string round_mode, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);

std::vector<Tensor> max_bw(const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);

std::vector<Tensor> min_bw(const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);
Expand Down
17 changes: 17 additions & 0 deletions tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_backward_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,23 @@ namespace tt::tt_metal::detail{
"output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No"
)doc");

m_tensor.def("rdiv_bw", &tt::tt_metal::rdiv_bw,
py::arg("grad").noconvert(), py::arg("input").noconvert(), py::arg("scalar") = 1.0f, py::arg("round_mode") = "None", py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc(
Performs backward operations for division for ``input`` tensor and ``scalar`` with given ``grad``.
Input tensors must have BFLOAT16 data type.
Output tensor will have BFLOAT16 data type.
.. csv-table::
:header: "Argument", "Description", "Data type", "Valid range", "Required"
"grad", "Gradient tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes"
"input", "Tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes"
"scalar", "Scalar value", "float", "default to 1.0f", "No"
"output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No"
)doc");

m_tensor.def("min_bw", &tt::tt_metal::min_bw,
py::arg("grad").noconvert(), py::arg("input_a").noconvert(), py::arg("input_b").noconvert(), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc(
Performs backward operations for minimum of ``input_b`` with given ``grad``.
Expand Down

0 comments on commit 3602407

Please sign in to comment.