diff --git a/docs/source/ttnn/ttnn/api.rst b/docs/source/ttnn/ttnn/api.rst index f29b309ef75e..2d866a4c7fbb 100644 --- a/docs/source/ttnn/ttnn/api.rst +++ b/docs/source/ttnn/ttnn/api.rst @@ -198,6 +198,7 @@ Pointwise Unary ttnn/log_bw ttnn/relu6_bw ttnn/abs_bw + ttnn/silu_bw Pointwise Binary ================ diff --git a/docs/source/ttnn/ttnn/dependencies/tt_lib.rst b/docs/source/ttnn/ttnn/dependencies/tt_lib.rst index 5acd9c68ced1..7bbd91c75461 100644 --- a/docs/source/ttnn/ttnn/dependencies/tt_lib.rst +++ b/docs/source/ttnn/ttnn/dependencies/tt_lib.rst @@ -880,8 +880,6 @@ Backward Operations .. autofunction:: tt_lib.tensor.rpow_bw -.. autofunction:: tt_lib.tensor.silu_bw - .. autofunction:: tt_lib.tensor.selu_bw .. autofunction:: tt_lib.tensor.square_bw diff --git a/docs/source/ttnn/ttnn/ttnn/silu_bw.rst b/docs/source/ttnn/ttnn/ttnn/silu_bw.rst new file mode 100644 index 000000000000..16c8ea8efe4f --- /dev/null +++ b/docs/source/ttnn/ttnn/ttnn/silu_bw.rst @@ -0,0 +1,6 @@ +.. _ttnn.silu_bw: + + ttnn.silu_bw + ############# + + .. autofunction:: ttnn.silu_bw diff --git a/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_silu.py b/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_silu.py index 455c87ce1913..c86af349dbe0 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_silu.py +++ b/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_silu.py @@ -4,7 +4,7 @@ import torch import pytest -import tt_lib +import ttnn from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import data_gen_with_range, compare_pcc @@ -22,7 +22,7 @@ def test_bw_silu(input_shapes, device): pyt_y = torch.nn.functional.silu(in_data) - tt_output_tensor_on_device = tt_lib.tensor.silu_bw(grad_tensor, input_tensor) + tt_output_tensor_on_device = ttnn.silu_bw(grad_tensor, input_tensor) in_data.retain_grad() diff --git a/tt_eager/tt_dnn/op_library/backward/backward_ops.cpp b/tt_eager/tt_dnn/op_library/backward/backward_ops.cpp index caee9737bf7f..fde73df19c90 100644 --- a/tt_eager/tt_dnn/op_library/backward/backward_ops.cpp +++ b/tt_eager/tt_dnn/op_library/backward/backward_ops.cpp @@ -1001,28 +1001,6 @@ std::vector rpow_bw( return operation::decorate_as_composite(__func__, _rpow_bw)(grad, input, exponent, output_mem_config); } -// Silu -// result: grad * sigmoid_result * (1 + input * (1 - sigmoid_result)) -std::vector _silu_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { - std::vector grad_tensor; - Tensor grad_sigmoid = ttnn::multiply(grad, ttnn::sigmoid(input, output_mem_config), std::nullopt, output_mem_config); - Tensor add_sub = ttnn::add( - ttnn::multiply(ttnn::subtract(ttnn::full_like(input, 1.0f) , ttnn::sigmoid(input, output_mem_config), std::nullopt, output_mem_config), - input, - std::nullopt, - output_mem_config), - 1.0f, - std::nullopt, - output_mem_config); - Tensor grad_result = ttnn::multiply(grad_sigmoid, add_sub, std::nullopt, output_mem_config); - - grad_tensor.emplace_back(grad_result); - return grad_tensor; -} -std::vector silu_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { - return operation::decorate_as_composite(__func__, _silu_bw)(grad, input, output_mem_config); -} - // Selu // result: torch.where(input > 0, grad * lambd, grad * lambd * alpha * torch.exp(input)) std::vector _selu_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { diff --git a/tt_eager/tt_dnn/op_library/backward/backward_ops.hpp b/tt_eager/tt_dnn/op_library/backward/backward_ops.hpp index 8d6606065d24..c1d60d32ea48 100644 --- a/tt_eager/tt_dnn/op_library/backward/backward_ops.hpp +++ b/tt_eager/tt_dnn/op_library/backward/backward_ops.hpp @@ -285,11 +285,6 @@ std::vector rpow_bw( float exponent, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); -std::vector silu_bw( - const Tensor& grad, - const Tensor& input, - const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - std::vector selu_bw( const Tensor& grad, const Tensor& input, diff --git a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_backward_ops.cpp b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_backward_ops.cpp index 0b39023b0bed..ea2dc4907a15 100644 --- a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_backward_ops.cpp +++ b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_backward_ops.cpp @@ -808,23 +808,6 @@ namespace tt::tt_metal::detail{ "output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No" )doc"); - m_tensor.def("silu_bw", &tt::tt_metal::silu_bw, - py::arg("grad").noconvert(), py::arg("input").noconvert(), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc( - Performs backward operations for silu sin of ``input`` tensors with given ``grad``. - - Input tensors must have BFLOAT16 data type. - - Output tensors will have BFLOAT16 data type. - - .. csv-table:: - :header: "Argument", "Description", "Data type", "Valid range", "Required" - - "grad", "Gradient tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes" - "input", "Tensor silu_bw is applied to", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes" - "output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No" - )doc"); - - m_tensor.def("selu_bw", &tt::tt_metal::selu_bw, py::arg("grad").noconvert(), py::arg("input").noconvert(), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc( Performs backward operations for selu sin of ``input`` tensors with given ``grad``. diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.cpp b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.cpp index 2b95d936e866..252d61eeb02b 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.cpp @@ -418,6 +418,23 @@ std::vector _abs_bw(const Tensor& grad, const Tensor& input, const Memor return grad_tensor; } +// Silu +// result: grad * sigmoid_result * (1 + input * (1 - sigmoid_result)) +std::vector _silu_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { + std::vector grad_tensor; + Tensor grad_sigmoid = ttnn::multiply(grad, sigmoid(input, output_mem_config), std::nullopt, output_mem_config); + Tensor add_sub = add1( + ttnn::multiply(sub_unary(1.0f, sigmoid(input, output_mem_config), output_mem_config), + input, + std::nullopt, + output_mem_config), + output_mem_config); + Tensor grad_result = ttnn::multiply(grad_sigmoid, add_sub, std::nullopt, output_mem_config); + + grad_tensor.emplace_back(grad_result); + return grad_tensor; +} + std::function(const Tensor&, const Tensor&, const MemoryConfig&)> UnaryBackwardFunction::get_function_type1(UnaryBackwardOpType OpType){ switch (OpType) { case UnaryBackwardOpType::ASSIGN_BW: @@ -458,6 +475,8 @@ std::function(const Tensor&, const Tensor&, const Memo return _relu6_bw; case UnaryBackwardOpType::ABS_BW: return _abs_bw; + case UnaryBackwardOpType::SILU_BW: + return _silu_bw; default: TT_ASSERT(false && "Undefined op type"); return 0; diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.hpp index 3a938c7f1c91..7b6fb78fc210 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.hpp @@ -45,6 +45,7 @@ enum class UnaryBackwardOpType { LOG_BW, RELU6_BW, ABS_BW, + SILU_BW }; struct UnaryBackwardFunction{ diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.hpp index c523de91a4c6..f42d61b41972 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.hpp @@ -97,6 +97,7 @@ constexpr auto round_bw = ttnn::register_operation>("ttnn::log_bw"); constexpr auto relu6_bw = ttnn::register_operation>("ttnn::relu6_bw"); constexpr auto abs_bw = ttnn::register_operation>("ttnn::abs_bw"); +constexpr auto silu_bw = ttnn::register_operation>("ttnn::silu_bw"); } // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward_pybind.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward_pybind.hpp index 7144c851be83..047d0eb56b99 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward_pybind.hpp @@ -323,6 +323,11 @@ void py_module(py::module& module) { ttnn::abs_bw, R"doc(Performs backward operations for abs on :attr:`input_tensor` with given :attr:`grad_tensor`)doc"); + detail::bind_unary_backward( + module, + ttnn::silu_bw, + R"doc(Performs backward operations for silu on :attr:`input_tensor` with given :attr:`grad_tensor`)doc"); + } } // namespace binary_backward