Skip to content

Commit

Permalink
#10071: Merge relu6_bw to TTNN
Browse files Browse the repository at this point in the history
  • Loading branch information
mouliraj-mcw committed Jul 10, 2024
1 parent 5e6975c commit 32da697
Show file tree
Hide file tree
Showing 11 changed files with 41 additions and 51 deletions.
1 change: 1 addition & 0 deletions docs/source/ttnn/ttnn/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ Pointwise Unary
ttnn/floor_bw
ttnn/round_bw
ttnn/log_bw
ttnn/relu6_bw

Pointwise Binary
================
Expand Down
2 changes: 0 additions & 2 deletions docs/source/ttnn/ttnn/dependencies/tt_lib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -904,8 +904,6 @@ Backward Operations

.. autofunction:: tt_lib.tensor.reciprocal_bw

.. autofunction:: tt_lib.tensor.relu6_bw

.. autofunction:: tt_lib.tensor.rpow_bw

.. autofunction:: tt_lib.tensor.silu_bw
Expand Down
6 changes: 6 additions & 0 deletions docs/source/ttnn/ttnn/ttnn/relu6_bw.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
.. _ttnn.relu6_bw:

ttnn.relu6_bw
##############

.. autofunction:: ttnn.relu6_bw
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch
import pytest
import tt_lib
import ttnn
from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import compare_pcc, data_gen_with_range


Expand All @@ -22,7 +22,7 @@ def test_bw_relu6(input_shapes, device):

pyt_y = torch.nn.functional.relu6(in_data)

tt_output_tensor_on_device = tt_lib.tensor.relu6_bw(grad_tensor, input_tensor)
tt_output_tensor_on_device = ttnn.relu6_bw(grad_tensor, input_tensor)

in_data.retain_grad()

Expand Down
26 changes: 0 additions & 26 deletions tt_eager/tt_dnn/op_library/backward/backward_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1182,32 +1182,6 @@ std::vector<Tensor> reciprocal_bw(const Tensor& grad, const Tensor& input, const
return operation::decorate_as_composite(__func__, _reciprocal_bw)(grad, input, output_mem_config);
}

std::vector<Tensor> _relu6_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) {
std::vector<Tensor> grad_tensor;
Tensor zero_tensor = zeros_like(input, output_mem_config);
Tensor one_tensor = ones_like(input, output_mem_config);
Tensor six_tensor = full_like(input, 6, output_mem_config);
Tensor grad_result =
where(ttnn::le(input, zero_tensor, std::nullopt, output_mem_config), zero_tensor, six_tensor, output_mem_config);
grad_result = where(
ttnn::logical_and(
gtz(input, output_mem_config),
ttnn::lt(input, six_tensor, std::nullopt, output_mem_config),
std::nullopt,
output_mem_config),
grad,
grad_result,
output_mem_config);
grad_result =
where(ttnn::ge(input, six_tensor, std::nullopt, output_mem_config), zero_tensor, grad_result, output_mem_config);

grad_tensor.emplace_back(grad_result);
return grad_tensor;
}
std::vector<Tensor> relu6_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) {
return operation::decorate_as_composite(__func__, _relu6_bw)(grad, input, output_mem_config);
}

std::vector<Tensor> _rpow_bw(
const Tensor& grad, const Tensor& input, float exponent, const MemoryConfig& output_mem_config) {
std::vector<Tensor> grad_tensor;
Expand Down
5 changes: 0 additions & 5 deletions tt_eager/tt_dnn/op_library/backward/backward_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -354,11 +354,6 @@ std::vector<Tensor> reciprocal_bw(
const Tensor& input,
const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);

std::vector<Tensor> relu6_bw(
const Tensor& grad,
const Tensor& input,
const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);

std::vector<Tensor> rpow_bw(
const Tensor& grad,
const Tensor& input,
Expand Down
16 changes: 0 additions & 16 deletions tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_backward_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1034,22 +1034,6 @@ namespace tt::tt_metal::detail{
"output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No"
)doc");

m_tensor.def("relu6_bw", &tt::tt_metal::relu6_bw,
py::arg("grad").noconvert(), py::arg("input").noconvert(), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc(
Returns an tensor of backward operation of relu6 for ``input`` tensor and ``grad`` tensor.
Input tensors must have BFLOAT16 data type.
Output tensors will have BFLOAT16 data type.
.. csv-table::
:header: "Argument", "Description", "Data type", "Valid range", "Required"
"grad", "Gradient tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes"
"input", "Tensor relu6 is applied to", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes"
"output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No"
)doc");

m_tensor.def("rpow_bw", &tt::tt_metal::rpow_bw,
py::arg("grad").noconvert(), py::arg("input").noconvert(), py::arg("exponent").noconvert(), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc(
Performs backward operations for rpow for the ``input`` and ``exponent`` with given ``grad``
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,29 @@ std::vector<Tensor> _log_bw(const Tensor& grad, const Tensor& input, const Memor
return grad_tensor;
}

std::vector<Tensor> _relu6_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) {
std::vector<Tensor> grad_tensor;
Tensor zero_tensor = tt::tt_metal::zeros_like(input, output_mem_config);
Tensor one_tensor = tt::tt_metal::ones_like(input, output_mem_config);
Tensor six_tensor = tt::tt_metal::full_like(input, 6, output_mem_config);
Tensor grad_result =
where(ttnn::le(input, zero_tensor, std::nullopt, output_mem_config), zero_tensor, six_tensor, output_mem_config);
grad_result = where(
ttnn::logical_and(
ttnn::gtz(input, output_mem_config),
ttnn::lt(input, six_tensor, std::nullopt, output_mem_config),
std::nullopt,
output_mem_config),
grad,
grad_result,
output_mem_config);
grad_result =
where(ttnn::ge(input, six_tensor, std::nullopt, output_mem_config), zero_tensor, grad_result, output_mem_config);

grad_tensor.emplace_back(grad_result);
return grad_tensor;
}

std::function<std::vector<ttnn::Tensor>(const Tensor&, const Tensor&, const MemoryConfig&)> get_function_type1(UnaryBackwardOpType OpType){
>>>>>>> #9874: Merge round_bw to TTNN
switch (OpType) {
Expand All @@ -137,6 +160,8 @@ std::function<std::vector<ttnn::Tensor>(const Tensor&, const Tensor&, const Memo
return _round_bw;
case UnaryBackwardOpType::LOG_BW:
return _log_bw;
case UnaryBackwardOpType::RELU6_BW:
return _relu6_bw;
default:
TT_ASSERT(false && "Undefined op type");
return 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ enum class UnaryBackwardOpType {
FLOOR_BW,
ROUND_BW,
LOG_BW,
RELU6_BW
};

struct UnaryBackwardFunction{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ constexpr auto eq_bw = ttnn::register_operation<operations::unary_backward::Exec
constexpr auto floor_bw = ttnn::register_operation<operations::unary_backward::ExecuteUnaryBackward<operations::unary_backward::UnaryBackwardOpType::FLOOR_BW>>("ttnn::floor_bw");
constexpr auto round_bw = ttnn::register_operation<operations::unary_backward::ExecuteUnaryBackward<operations::unary_backward::UnaryBackwardOpType::ROUND_BW>>("ttnn::round_bw");
constexpr auto log_bw = ttnn::register_operation<operations::unary_backward::ExecuteUnaryBackward<operations::unary_backward::UnaryBackwardOpType::LOG_BW>>("ttnn::log_bw");
constexpr auto relu6_bw = ttnn::register_operation<operations::unary_backward::ExecuteUnaryBackward<operations::unary_backward::UnaryBackwardOpType::RELU6_BW>>("ttnn::relu6_bw");


} // namespace ttnn
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,11 @@ void py_module(py::module& module) {
ttnn::log_bw,
R"doc(Performs backward operations for logarithm on :attr:`input_tensor` with given :attr:`grad_tensor`)doc");

detail::bind_unary_backward(
module,
ttnn::relu6_bw,
R"doc(Performs backward operations for relu6 on :attr:`input_tensor` with given :attr:`grad_tensor`)doc");

}

} // namespace binary_backward
Expand Down

0 comments on commit 32da697

Please sign in to comment.