diff --git a/docs/source/ttnn/ttnn/api.rst b/docs/source/ttnn/ttnn/api.rst index 7291df8e70c7..ffcb12cfee92 100644 --- a/docs/source/ttnn/ttnn/api.rst +++ b/docs/source/ttnn/ttnn/api.rst @@ -194,6 +194,7 @@ Pointwise Unary ttnn/celu_bw ttnn/rpow_bw ttnn/floor_bw + ttnn/round_bw Pointwise Binary ================ diff --git a/docs/source/ttnn/ttnn/dependencies/tt_lib.rst b/docs/source/ttnn/ttnn/dependencies/tt_lib.rst index 54ab9221b0fe..1f9c3f0a6b63 100644 --- a/docs/source/ttnn/ttnn/dependencies/tt_lib.rst +++ b/docs/source/ttnn/ttnn/dependencies/tt_lib.rst @@ -928,8 +928,6 @@ Backward Operations .. autofunction:: tt_lib.tensor.repeat_bw -.. autofunction:: tt_lib.tensor.round_bw - .. autofunction:: tt_lib.tensor.unary_div_no_nan_bw Loss Functions diff --git a/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_round.py b/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_round.py index cc9b7109a9cd..3d5f91223599 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_round.py +++ b/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_round.py @@ -5,6 +5,7 @@ import torch import pytest import tt_lib +import ttnn from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import compare_pcc, data_gen_with_range @@ -21,7 +22,7 @@ def test_bw_round(input_shapes, device): in_data, input_tensor = data_gen_with_range(input_shapes, -200, 201, device, required_grad=True) pyt_y = torch.round(in_data) - tt_output_tensor_on_device = tt_lib.tensor.round_bw(grad_tensor) + tt_output_tensor_on_device = ttnn.round_bw(grad_tensor, input_tensor) in_data.retain_grad() diff --git a/tt_eager/tt_dnn/op_library/backward/backward_ops.cpp b/tt_eager/tt_dnn/op_library/backward/backward_ops.cpp index 329761eb192a..42591a78ace3 100644 --- a/tt_eager/tt_dnn/op_library/backward/backward_ops.cpp +++ b/tt_eager/tt_dnn/op_library/backward/backward_ops.cpp @@ -1748,16 +1748,6 @@ std::vector repeat_bw( return operation::decorate_as_composite(__func__, _repeat_bw)(grad, input, shape, output_mem_config); } -std::vector _round_bw(const Tensor& grad, const MemoryConfig& output_mem_config) { - std::vector grad_tensor; - Tensor t_zero = zeros_like(grad, output_mem_config); - grad_tensor.emplace_back(t_zero); - return grad_tensor; -} -std::vector round_bw(const Tensor& grad, const MemoryConfig& output_mem_config) { - return operation::decorate_as_composite(__func__, _round_bw)(grad, output_mem_config); -} - std::vector _unary_div_no_nan_bw( const Tensor& grad, const Tensor& input, float scalar, const MemoryConfig& output_mem_config) { std::vector grad_tensor; diff --git a/tt_eager/tt_dnn/op_library/backward/backward_ops.hpp b/tt_eager/tt_dnn/op_library/backward/backward_ops.hpp index d8c694030859..2e37d1ce0ad3 100644 --- a/tt_eager/tt_dnn/op_library/backward/backward_ops.hpp +++ b/tt_eager/tt_dnn/op_library/backward/backward_ops.hpp @@ -422,9 +422,6 @@ std::vector complex_sub_bw( std::vector repeat_bw( const Tensor& grad, const Tensor& input, const Shape& shape, const MemoryConfig& output_mem_config); -std::vector round_bw( - const Tensor& grad, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - std::vector unary_div_no_nan_bw( const Tensor& grad, const Tensor& input, diff --git a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_backward_ops.cpp b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_backward_ops.cpp index a4cbf453f9d7..50208801e9ca 100644 --- a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_backward_ops.cpp +++ b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_backward_ops.cpp @@ -1153,21 +1153,6 @@ namespace tt::tt_metal::detail{ "output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No" )doc"); - m_tensor.def("round_bw", &tt::tt_metal::round_bw, - py::arg("grad").noconvert(), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc( - Returns an tensor of zeros like ``grad`` tensor - - Input tensor must have BFLOAT16 data type. - - Output tensor will have BFLOAT16 data type. - - .. csv-table:: - :header: "Argument", "Description", "Data type", "Valid range", "Required" - - "grad", "Gradient Tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes" - "output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No" - )doc"); - m_tensor.def("unary_div_no_nan_bw", &tt::tt_metal::unary_div_no_nan_bw, py::arg("grad").noconvert(), py::arg("input").noconvert(), py::arg("scalar") = 1.0f, py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc( Performs backward operations for division with given ``grad`` and ``scalar`` with no nan. diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.cpp b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.cpp index 89fce03079e4..bc312c32f6f2 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.cpp @@ -264,7 +264,6 @@ std::vector _logit_bw(const Tensor& grad, const Tensor& input, const Mem return grad_tensor; } - std::vector _hardshrink_bw( const Tensor& grad, const Tensor& input_tensor, float lambd, const MemoryConfig& output_mem_config) { std::vector grad_tensor; @@ -358,6 +357,20 @@ std::vector _rpow_bw( } +std::vector _floor_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { + std::vector grad_tensor; + Tensor t_zero = tt::tt_metal::zeros_like(grad, output_mem_config); + grad_tensor.emplace_back(t_zero); + return grad_tensor; +} + +std::vector _round_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { + std::vector grad_tensor; + Tensor t_zero = tt::tt_metal::zeros_like(grad, output_mem_config); + grad_tensor.emplace_back(t_zero); + return grad_tensor; +} + std::function(const Tensor&, const Tensor&, const MemoryConfig&)> UnaryBackwardFunction::get_function_type1(UnaryBackwardOpType OpType){ switch (OpType) { case UnaryBackwardOpType::ASSIGN_BW: @@ -390,6 +403,8 @@ std::function(const Tensor&, const Tensor&, const Memo return _logit_bw; case UnaryBackwardOpType::FLOOR_BW: return _floor_bw; + case UnaryBackwardOpType::ROUND_BW: + return _round_bw; default: TT_ASSERT(false && "Undefined op type"); return 0; diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.hpp index 410fb4bd098e..4fb7556cc23c 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.hpp @@ -40,6 +40,8 @@ enum class UnaryBackwardOpType { ELU_BW, CELU_BW, RPOW_BW, + FLOOR_BW, + ROUND_BW, }; struct UnaryBackwardFunction{ diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.hpp index ca74b38c4853..ef34bd3d231b 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.hpp @@ -93,6 +93,7 @@ constexpr auto elu_bw = ttnn::register_operation>("ttnn::celu_bw"); constexpr auto rpow_bw = ttnn::register_operation>("ttnn::rpow_bw"); constexpr auto floor_bw = ttnn::register_operation>("ttnn::floor_bw"); +constexpr auto round_bw = ttnn::register_operation>("ttnn::round_bw"); } // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward_pybind.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward_pybind.hpp index 9f7caca571d1..23cc56876389 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward_pybind.hpp @@ -302,6 +302,12 @@ void py_module(py::module& module) { module, ttnn::rpow_bw, R"doc(Performs backward operations for rpow on :attr:`input_tensor`, :attr:`exponent` with given :attr:`grad_tensor`.)doc"); + + detail::bind_unary_backward( + module, + ttnn::round_bw, + R"doc(Performs backward operations for round on :attr:`input_tensor` or attr:`input_tensor_a`, attr:`input_tensor_b` with given :attr:`grad_tensor`.)doc"); + } } // namespace binary_backward