Skip to content

Commit

Permalink
#10071: Merge relu6_bw to TTNN
Browse files Browse the repository at this point in the history
  • Loading branch information
mouliraj-mcw committed Jul 11, 2024
1 parent 58921c2 commit 0cb2503
Show file tree
Hide file tree
Showing 11 changed files with 66 additions and 32 deletions.
1 change: 1 addition & 0 deletions docs/source/ttnn/ttnn/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ Pointwise Unary
ttnn/floor_bw
ttnn/round_bw
ttnn/log_bw
ttnn/relu6_bw

Pointwise Binary
================
Expand Down
2 changes: 1 addition & 1 deletion docs/source/ttnn/ttnn/dependencies/tt_lib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -880,7 +880,7 @@ Backward Operations

.. autofunction:: tt_lib.tensor.reciprocal_bw

.. autofunction:: tt_lib.tensor.relu6_bw
.. autofunction:: tt_lib.tensor.rpow_bw

.. autofunction:: tt_lib.tensor.silu_bw

Expand Down
6 changes: 6 additions & 0 deletions docs/source/ttnn/ttnn/ttnn/relu6_bw.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
.. _ttnn.relu6_bw:

ttnn.relu6_bw
##############

.. autofunction:: ttnn.relu6_bw
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch
import pytest
import tt_lib
import ttnn
from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import compare_pcc, data_gen_with_range


Expand All @@ -22,7 +22,7 @@ def test_bw_relu6(input_shapes, device):

pyt_y = torch.nn.functional.relu6(in_data)

tt_output_tensor_on_device = tt_lib.tensor.relu6_bw(grad_tensor, input_tensor)
tt_output_tensor_on_device = ttnn.relu6_bw(grad_tensor, input_tensor)

in_data.retain_grad()

Expand Down
36 changes: 15 additions & 21 deletions tt_eager/tt_dnn/op_library/backward/backward_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -970,33 +970,27 @@ std::vector<Tensor> reciprocal_bw(const Tensor& grad, const Tensor& input, const
return operation::decorate_as_composite(__func__, _reciprocal_bw)(grad, input, output_mem_config);
}

std::vector<Tensor> _relu6_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) {
std::vector<Tensor> _rpow_bw(
const Tensor& grad, const Tensor& input, float exponent, const MemoryConfig& output_mem_config) {
std::vector<Tensor> grad_tensor;
Tensor zero_tensor = zeros_like(input, output_mem_config);
Tensor one_tensor = ones_like(input, output_mem_config);
Tensor six_tensor = full_like(input, 6, output_mem_config);
Tensor grad_result =
where(ttnn::le(input, zero_tensor, std::nullopt, output_mem_config), zero_tensor, six_tensor, output_mem_config);
grad_result = where(
ttnn::logical_and(
ttnn::gtz(input, output_mem_config),
ttnn::lt(input, six_tensor, std::nullopt, output_mem_config),
std::nullopt,
output_mem_config),
grad,
grad_result,
output_mem_config);
grad_result =
where(ttnn::ge(input, six_tensor, std::nullopt, output_mem_config), zero_tensor, grad_result, output_mem_config);

float t_nan = std::nanf("");
Tensor grad_result = zeros_like(input, output_mem_config);
if (exponent != 0.0) {
grad_result =
ttnn::multiply(grad,
ttnn::multiply(pow(input, exponent - 1, output_mem_config), exponent, std::nullopt, output_mem_config),
std::nullopt,
output_mem_config);
grad_result = where(ttnn::ltz(input, output_mem_config), t_nan, grad_result, output_mem_config);
}
grad_tensor.emplace_back(grad_result);
return grad_tensor;
}
std::vector<Tensor> relu6_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) {
return operation::decorate_as_composite(__func__, _relu6_bw)(grad, input, output_mem_config);
std::vector<Tensor> rpow_bw(
const Tensor& grad, const Tensor& input, float exponent, const MemoryConfig& output_mem_config) {
return operation::decorate_as_composite(__func__, _rpow_bw)(grad, input, exponent, output_mem_config);
}


// Silu
// result: grad * sigmoid_result * (1 + input * (1 - sigmoid_result))
std::vector<Tensor> _silu_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) {
Expand Down
3 changes: 2 additions & 1 deletion tt_eager/tt_dnn/op_library/backward/backward_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -284,9 +284,10 @@ std::vector<Tensor> reciprocal_bw(
const Tensor& input,
const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);

std::vector<Tensor> relu6_bw(
std::vector<Tensor> rpow_bw(
const Tensor& grad,
const Tensor& input,
float exponent,
const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);

std::vector<Tensor> silu_bw(
Expand Down
14 changes: 7 additions & 7 deletions tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_backward_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -807,22 +807,22 @@ namespace tt::tt_metal::detail{
"output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No"
)doc");

m_tensor.def("relu6_bw", &tt::tt_metal::relu6_bw,
py::arg("grad").noconvert(), py::arg("input").noconvert(), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc(
Returns an tensor of backward operation of relu6 for ``input`` tensor and ``grad`` tensor.
m_tensor.def("rpow_bw", &tt::tt_metal::rpow_bw,
py::arg("grad").noconvert(), py::arg("input").noconvert(), py::arg("exponent").noconvert(), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc(
Performs backward operations for rpow for the ``input`` and ``exponent`` with given ``grad``
Input tensors must have BFLOAT16 data type.
Output tensors will have BFLOAT16 data type.
Output tensor will have BFLOAT16 data type.
.. csv-table::
:header: "Argument", "Description", "Data type", "Valid range", "Required"
"grad", "Gradient tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes"
"input", "Tensor relu6 is applied to", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes"
"input", "Tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes"
"exponent", "exponent", "float", ">0.0", "Yes"
"output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No"
)doc");

)doc");

m_tensor.def("silu_bw", &tt::tt_metal::silu_bw,
py::arg("grad").noconvert(), py::arg("input").noconvert(), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,29 @@ std::vector<Tensor> _log_bw(const Tensor& grad, const Tensor& input, const Memor
return grad_tensor;
}

std::vector<Tensor> _relu6_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) {
std::vector<Tensor> grad_tensor;
Tensor zero_tensor = tt::tt_metal::zeros_like(input, output_mem_config);
Tensor one_tensor = tt::tt_metal::ones_like(input, output_mem_config);
Tensor six_tensor = tt::tt_metal::full_like(input, 6, output_mem_config);
Tensor grad_result =
where(ttnn::le(input, zero_tensor, std::nullopt, output_mem_config), zero_tensor, six_tensor, output_mem_config);
grad_result = where(
ttnn::logical_and(
ttnn::gtz(input, output_mem_config),
ttnn::lt(input, six_tensor, std::nullopt, output_mem_config),
std::nullopt,
output_mem_config),
grad,
grad_result,
output_mem_config);
grad_result =
where(ttnn::ge(input, six_tensor, std::nullopt, output_mem_config), zero_tensor, grad_result, output_mem_config);

grad_tensor.emplace_back(grad_result);
return grad_tensor;
}

std::function<std::vector<ttnn::Tensor>(const Tensor&, const Tensor&, const MemoryConfig&)> UnaryBackwardFunction::get_function_type1(UnaryBackwardOpType OpType){
switch (OpType) {
case UnaryBackwardOpType::ASSIGN_BW:
Expand Down Expand Up @@ -424,6 +447,8 @@ std::function<std::vector<ttnn::Tensor>(const Tensor&, const Tensor&, const Memo
return _round_bw;
case UnaryBackwardOpType::LOG_BW:
return _log_bw;
case UnaryBackwardOpType::RELU6_BW:
return _relu6_bw;
default:
TT_ASSERT(false && "Undefined op type");
return 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ enum class UnaryBackwardOpType {
FLOOR_BW,
ROUND_BW,
LOG_BW,
RELU6_BW
};

struct UnaryBackwardFunction{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ constexpr auto rpow_bw = ttnn::register_operation<operations::unary_backward::Ex
constexpr auto floor_bw = ttnn::register_operation<operations::unary_backward::ExecuteUnaryBackward<operations::unary_backward::UnaryBackwardOpType::FLOOR_BW>>("ttnn::floor_bw");
constexpr auto round_bw = ttnn::register_operation<operations::unary_backward::ExecuteUnaryBackward<operations::unary_backward::UnaryBackwardOpType::ROUND_BW>>("ttnn::round_bw");
constexpr auto log_bw = ttnn::register_operation<operations::unary_backward::ExecuteUnaryBackward<operations::unary_backward::UnaryBackwardOpType::LOG_BW>>("ttnn::log_bw");
constexpr auto relu6_bw = ttnn::register_operation<operations::unary_backward::ExecuteUnaryBackward<operations::unary_backward::UnaryBackwardOpType::RELU6_BW>>("ttnn::relu6_bw");


} // namespace ttnn
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,11 @@ void py_module(py::module& module) {
ttnn::log_bw,
R"doc(Performs backward operations for logarithm on :attr:`input_tensor` with given :attr:`grad_tensor`)doc");

detail::bind_unary_backward(
module,
ttnn::relu6_bw,
R"doc(Performs backward operations for relu6 on :attr:`input_tensor` with given :attr:`grad_tensor`)doc");

}

} // namespace binary_backward
Expand Down

0 comments on commit 0cb2503

Please sign in to comment.