From 0cb25034fbfe3da035b1bd2109aea267c3655835 Mon Sep 17 00:00:00 2001 From: mouliraj-mcw Date: Tue, 9 Jul 2024 17:37:51 +0000 Subject: [PATCH] #10071: Merge relu6_bw to TTNN --- docs/source/ttnn/ttnn/api.rst | 1 + docs/source/ttnn/ttnn/dependencies/tt_lib.rst | 2 +- docs/source/ttnn/ttnn/ttnn/relu6_bw.rst | 6 ++++ .../backward_ops/test_backward_relu6.py | 4 +-- .../op_library/backward/backward_ops.cpp | 36 ++++++++----------- .../op_library/backward/backward_ops.hpp | 3 +- .../tt_lib_bindings_tensor_backward_ops.cpp | 14 ++++---- .../device/unary_backward_op.cpp | 25 +++++++++++++ .../device/unary_backward_op.hpp | 1 + .../eltwise/unary_backward/unary_backward.hpp | 1 + .../unary_backward/unary_backward_pybind.hpp | 5 +++ 11 files changed, 66 insertions(+), 32 deletions(-) create mode 100644 docs/source/ttnn/ttnn/ttnn/relu6_bw.rst diff --git a/docs/source/ttnn/ttnn/api.rst b/docs/source/ttnn/ttnn/api.rst index 2fa518316980..7e7e74621ce2 100644 --- a/docs/source/ttnn/ttnn/api.rst +++ b/docs/source/ttnn/ttnn/api.rst @@ -196,6 +196,7 @@ Pointwise Unary ttnn/floor_bw ttnn/round_bw ttnn/log_bw + ttnn/relu6_bw Pointwise Binary ================ diff --git a/docs/source/ttnn/ttnn/dependencies/tt_lib.rst b/docs/source/ttnn/ttnn/dependencies/tt_lib.rst index c72c5199fa29..322eeffc8a43 100644 --- a/docs/source/ttnn/ttnn/dependencies/tt_lib.rst +++ b/docs/source/ttnn/ttnn/dependencies/tt_lib.rst @@ -880,7 +880,7 @@ Backward Operations .. autofunction:: tt_lib.tensor.reciprocal_bw -.. autofunction:: tt_lib.tensor.relu6_bw +.. autofunction:: tt_lib.tensor.rpow_bw .. autofunction:: tt_lib.tensor.silu_bw diff --git a/docs/source/ttnn/ttnn/ttnn/relu6_bw.rst b/docs/source/ttnn/ttnn/ttnn/relu6_bw.rst new file mode 100644 index 000000000000..1a61f8b3d709 --- /dev/null +++ b/docs/source/ttnn/ttnn/ttnn/relu6_bw.rst @@ -0,0 +1,6 @@ +.. _ttnn.relu6_bw: + + ttnn.relu6_bw + ############## + + .. autofunction:: ttnn.relu6_bw diff --git a/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_relu6.py b/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_relu6.py index 2ee155ac4010..15aff1a6855c 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_relu6.py +++ b/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_relu6.py @@ -4,7 +4,7 @@ import torch import pytest -import tt_lib +import ttnn from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import compare_pcc, data_gen_with_range @@ -22,7 +22,7 @@ def test_bw_relu6(input_shapes, device): pyt_y = torch.nn.functional.relu6(in_data) - tt_output_tensor_on_device = tt_lib.tensor.relu6_bw(grad_tensor, input_tensor) + tt_output_tensor_on_device = ttnn.relu6_bw(grad_tensor, input_tensor) in_data.retain_grad() diff --git a/tt_eager/tt_dnn/op_library/backward/backward_ops.cpp b/tt_eager/tt_dnn/op_library/backward/backward_ops.cpp index c4cc6f28821f..821cf9c54723 100644 --- a/tt_eager/tt_dnn/op_library/backward/backward_ops.cpp +++ b/tt_eager/tt_dnn/op_library/backward/backward_ops.cpp @@ -970,33 +970,27 @@ std::vector reciprocal_bw(const Tensor& grad, const Tensor& input, const return operation::decorate_as_composite(__func__, _reciprocal_bw)(grad, input, output_mem_config); } -std::vector _relu6_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { +std::vector _rpow_bw( + const Tensor& grad, const Tensor& input, float exponent, const MemoryConfig& output_mem_config) { std::vector grad_tensor; - Tensor zero_tensor = zeros_like(input, output_mem_config); - Tensor one_tensor = ones_like(input, output_mem_config); - Tensor six_tensor = full_like(input, 6, output_mem_config); - Tensor grad_result = - where(ttnn::le(input, zero_tensor, std::nullopt, output_mem_config), zero_tensor, six_tensor, output_mem_config); - grad_result = where( - ttnn::logical_and( - ttnn::gtz(input, output_mem_config), - ttnn::lt(input, six_tensor, std::nullopt, output_mem_config), - std::nullopt, - output_mem_config), - grad, - grad_result, - output_mem_config); - grad_result = - where(ttnn::ge(input, six_tensor, std::nullopt, output_mem_config), zero_tensor, grad_result, output_mem_config); - + float t_nan = std::nanf(""); + Tensor grad_result = zeros_like(input, output_mem_config); + if (exponent != 0.0) { + grad_result = + ttnn::multiply(grad, + ttnn::multiply(pow(input, exponent - 1, output_mem_config), exponent, std::nullopt, output_mem_config), + std::nullopt, + output_mem_config); + grad_result = where(ttnn::ltz(input, output_mem_config), t_nan, grad_result, output_mem_config); + } grad_tensor.emplace_back(grad_result); return grad_tensor; } -std::vector relu6_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { - return operation::decorate_as_composite(__func__, _relu6_bw)(grad, input, output_mem_config); +std::vector rpow_bw( + const Tensor& grad, const Tensor& input, float exponent, const MemoryConfig& output_mem_config) { + return operation::decorate_as_composite(__func__, _rpow_bw)(grad, input, exponent, output_mem_config); } - // Silu // result: grad * sigmoid_result * (1 + input * (1 - sigmoid_result)) std::vector _silu_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { diff --git a/tt_eager/tt_dnn/op_library/backward/backward_ops.hpp b/tt_eager/tt_dnn/op_library/backward/backward_ops.hpp index 76b150b1f36e..507400122852 100644 --- a/tt_eager/tt_dnn/op_library/backward/backward_ops.hpp +++ b/tt_eager/tt_dnn/op_library/backward/backward_ops.hpp @@ -284,9 +284,10 @@ std::vector reciprocal_bw( const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); -std::vector relu6_bw( +std::vector rpow_bw( const Tensor& grad, const Tensor& input, + float exponent, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); std::vector silu_bw( diff --git a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_backward_ops.cpp b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_backward_ops.cpp index f661e4e68415..2842e6235e3a 100644 --- a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_backward_ops.cpp +++ b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_backward_ops.cpp @@ -807,22 +807,22 @@ namespace tt::tt_metal::detail{ "output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No" )doc"); - m_tensor.def("relu6_bw", &tt::tt_metal::relu6_bw, - py::arg("grad").noconvert(), py::arg("input").noconvert(), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc( - Returns an tensor of backward operation of relu6 for ``input`` tensor and ``grad`` tensor. + m_tensor.def("rpow_bw", &tt::tt_metal::rpow_bw, + py::arg("grad").noconvert(), py::arg("input").noconvert(), py::arg("exponent").noconvert(), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc( + Performs backward operations for rpow for the ``input`` and ``exponent`` with given ``grad`` Input tensors must have BFLOAT16 data type. - Output tensors will have BFLOAT16 data type. + Output tensor will have BFLOAT16 data type. .. csv-table:: :header: "Argument", "Description", "Data type", "Valid range", "Required" "grad", "Gradient tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes" - "input", "Tensor relu6 is applied to", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes" + "input", "Tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes" + "exponent", "exponent", "float", ">0.0", "Yes" "output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No" - )doc"); - + )doc"); m_tensor.def("silu_bw", &tt::tt_metal::silu_bw, py::arg("grad").noconvert(), py::arg("input").noconvert(), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc( diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.cpp b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.cpp index 3a21990e5acd..183c6343038a 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.cpp @@ -388,6 +388,29 @@ std::vector _log_bw(const Tensor& grad, const Tensor& input, const Memor return grad_tensor; } +std::vector _relu6_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { + std::vector grad_tensor; + Tensor zero_tensor = tt::tt_metal::zeros_like(input, output_mem_config); + Tensor one_tensor = tt::tt_metal::ones_like(input, output_mem_config); + Tensor six_tensor = tt::tt_metal::full_like(input, 6, output_mem_config); + Tensor grad_result = + where(ttnn::le(input, zero_tensor, std::nullopt, output_mem_config), zero_tensor, six_tensor, output_mem_config); + grad_result = where( + ttnn::logical_and( + ttnn::gtz(input, output_mem_config), + ttnn::lt(input, six_tensor, std::nullopt, output_mem_config), + std::nullopt, + output_mem_config), + grad, + grad_result, + output_mem_config); + grad_result = + where(ttnn::ge(input, six_tensor, std::nullopt, output_mem_config), zero_tensor, grad_result, output_mem_config); + + grad_tensor.emplace_back(grad_result); + return grad_tensor; +} + std::function(const Tensor&, const Tensor&, const MemoryConfig&)> UnaryBackwardFunction::get_function_type1(UnaryBackwardOpType OpType){ switch (OpType) { case UnaryBackwardOpType::ASSIGN_BW: @@ -424,6 +447,8 @@ std::function(const Tensor&, const Tensor&, const Memo return _round_bw; case UnaryBackwardOpType::LOG_BW: return _log_bw; + case UnaryBackwardOpType::RELU6_BW: + return _relu6_bw; default: TT_ASSERT(false && "Undefined op type"); return 0; diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.hpp index 5377b3b5fe9b..5cac6f627796 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.hpp @@ -43,6 +43,7 @@ enum class UnaryBackwardOpType { FLOOR_BW, ROUND_BW, LOG_BW, + RELU6_BW }; struct UnaryBackwardFunction{ diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.hpp index 8da1cd7c67cf..6b75af8a8c33 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.hpp @@ -95,6 +95,7 @@ constexpr auto rpow_bw = ttnn::register_operation>("ttnn::floor_bw"); constexpr auto round_bw = ttnn::register_operation>("ttnn::round_bw"); constexpr auto log_bw = ttnn::register_operation>("ttnn::log_bw"); +constexpr auto relu6_bw = ttnn::register_operation>("ttnn::relu6_bw"); } // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward_pybind.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward_pybind.hpp index 6975ee4c5cf0..44ef0602ab69 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward_pybind.hpp @@ -313,6 +313,11 @@ void py_module(py::module& module) { ttnn::log_bw, R"doc(Performs backward operations for logarithm on :attr:`input_tensor` with given :attr:`grad_tensor`)doc"); + detail::bind_unary_backward( + module, + ttnn::relu6_bw, + R"doc(Performs backward operations for relu6 on :attr:`input_tensor` with given :attr:`grad_tensor`)doc"); + } } // namespace binary_backward