From a58b4f8bfdafe1cd3bd797462d6cec522aac6648 Mon Sep 17 00:00:00 2001 From: VirdhatchaniKN Date: Thu, 4 Jul 2024 08:19:46 +0000 Subject: [PATCH] #9874: Merge clamp_min_bw to TTNN --- docs/source/ttnn/ttnn/api.rst | 1 + docs/source/ttnn/ttnn/dependencies/tt_lib.rst | 2 -- docs/source/ttnn/ttnn/ttnn/clamp_min_bw.rst | 6 ++++++ .../backward_ops/test_backward_clamp_min.py | 3 ++- .../tt_dnn/op_library/backward/backward_ops.cpp | 13 ------------- .../tt_dnn/op_library/backward/backward_ops.hpp | 6 ------ .../tt_lib_bindings_tensor_backward_ops.cpp | 17 ----------------- .../unary_backward/device/unary_backward_op.cpp | 11 +++++++++++ .../unary_backward/device/unary_backward_op.hpp | 1 + .../eltwise/unary_backward/unary_backward.hpp | 1 + .../unary_backward/unary_backward_pybind.hpp | 7 ++++++- 11 files changed, 28 insertions(+), 40 deletions(-) create mode 100644 docs/source/ttnn/ttnn/ttnn/clamp_min_bw.rst diff --git a/docs/source/ttnn/ttnn/api.rst b/docs/source/ttnn/ttnn/api.rst index 78b8c068bec4..09ef2c44cf8d 100644 --- a/docs/source/ttnn/ttnn/api.rst +++ b/docs/source/ttnn/ttnn/api.rst @@ -167,6 +167,7 @@ Pointwise Unary ttnn/tanhshrink ttnn/threshold ttnn/unary_mul_bw + ttnn/clamp_min_bw Pointwise Binary ================ diff --git a/docs/source/ttnn/ttnn/dependencies/tt_lib.rst b/docs/source/ttnn/ttnn/dependencies/tt_lib.rst index b1265383dc8b..ce6c7566ff6b 100644 --- a/docs/source/ttnn/ttnn/dependencies/tt_lib.rst +++ b/docs/source/ttnn/ttnn/dependencies/tt_lib.rst @@ -871,8 +871,6 @@ Backward Operations .. autofunction:: tt_lib.tensor.clamp_bw -.. autofunction:: tt_lib.tensor.clamp_min_bw - .. autofunction:: tt_lib.tensor.clamp_max_bw .. autofunction:: tt_lib.tensor.gelu_bw diff --git a/docs/source/ttnn/ttnn/ttnn/clamp_min_bw.rst b/docs/source/ttnn/ttnn/ttnn/clamp_min_bw.rst new file mode 100644 index 000000000000..34d578c84954 --- /dev/null +++ b/docs/source/ttnn/ttnn/ttnn/clamp_min_bw.rst @@ -0,0 +1,6 @@ +.. _ttnn.clamp_min_bw: + +ttnn.clamp_min_bw +################# + +.. autofunction:: ttnn.clamp_min_bw diff --git a/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_clamp_min.py b/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_clamp_min.py index 9dedf4490eef..65fef6bb6924 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_clamp_min.py +++ b/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_clamp_min.py @@ -5,6 +5,7 @@ import torch import pytest import tt_lib +import ttnn from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import data_gen_with_range, compare_pcc @@ -23,7 +24,7 @@ def test_bw_clamp_min(input_shapes, min, device): pyt_y = torch.clamp(in_data, min=min) - tt_output_tensor_on_device = tt_lib.tensor.clamp_min_bw(grad_tensor, input_tensor, min) + tt_output_tensor_on_device = ttnn.clamp_min_bw(grad_tensor, input_tensor, min) in_data.retain_grad() diff --git a/tt_eager/tt_dnn/op_library/backward/backward_ops.cpp b/tt_eager/tt_dnn/op_library/backward/backward_ops.cpp index 3c9c218d017b..755b08feaf0f 100644 --- a/tt_eager/tt_dnn/op_library/backward/backward_ops.cpp +++ b/tt_eager/tt_dnn/op_library/backward/backward_ops.cpp @@ -429,19 +429,6 @@ std::vector clamp_bw( return operation::decorate_as_composite(__func__, _clamp_bw)(grad, input, min, max, output_mem_config); } -std::vector _clamp_min_bw( - const Tensor& grad, const Tensor& input, float min, const MemoryConfig& output_mem_config) { - std::vector grad_tensor; - Tensor minT = gte_unary(input, min, output_mem_config); - Tensor result = ttnn::multiply(grad, minT, std::nullopt, output_mem_config); - grad_tensor.emplace_back(result); - return grad_tensor; -} -std::vector clamp_min_bw( - const Tensor& grad, const Tensor& input, float min, const MemoryConfig& output_mem_config) { - return operation::decorate_as_composite(__func__, _clamp_min_bw)(grad, input, min, output_mem_config); -} - std::vector _clamp_max_bw( const Tensor& grad, const Tensor& input, float max, const MemoryConfig& output_mem_config) { std::vector grad_tensor; diff --git a/tt_eager/tt_dnn/op_library/backward/backward_ops.hpp b/tt_eager/tt_dnn/op_library/backward/backward_ops.hpp index e3ab268def63..f2da4c40bf60 100644 --- a/tt_eager/tt_dnn/op_library/backward/backward_ops.hpp +++ b/tt_eager/tt_dnn/op_library/backward/backward_ops.hpp @@ -182,12 +182,6 @@ std::vector clamp_bw( float max, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); -std::vector clamp_min_bw( - const Tensor& grad, - const Tensor& input, - float min, - const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - std::vector clamp_max_bw( const Tensor& grad, const Tensor& input, diff --git a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_backward_ops.cpp b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_backward_ops.cpp index 3441bd236f04..66804539ba94 100644 --- a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_backward_ops.cpp +++ b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_backward_ops.cpp @@ -541,23 +541,6 @@ namespace tt::tt_metal::detail{ "output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No" )doc"); - m_tensor.def("clamp_min_bw", &tt::tt_metal::clamp_min_bw, - py::arg("grad").noconvert(), py::arg("input").noconvert(), py::arg("min").noconvert(), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc( - Performs backward operations for clamp min of ``input`` tensors with given ``grad``. - - Input tensors must have BFLOAT16 data type. - - Output tensor will have BFLOAT16 data type. - - .. csv-table:: - :header: "Argument", "Description", "Data type", "Valid range", "Required" - - "grad", "Gradient tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes" - "input", "Input Tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes" - "min", "Minimum Value", "float", , "Yes" - "output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No" - )doc"); - m_tensor.def("clamp_max_bw", &tt::tt_metal::clamp_max_bw, py::arg("grad").noconvert(), py::arg("input").noconvert(), py::arg("max").noconvert(), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc( Performs backward operations for clamp max of ``input`` tensors with given ``grad``. diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.cpp b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.cpp index ea7482bcff86..982a00a8beec 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.cpp @@ -27,6 +27,15 @@ std::vector _unary_mul_bw( return grad_tensor; } +std::vector _clamp_min_bw( + const Tensor& grad, const Tensor& input, float min, const MemoryConfig& output_mem_config) { + std::vector grad_tensor; + Tensor minT = gte_unary(input, min, output_mem_config); + Tensor result = ttnn::multiply(grad, minT, std::nullopt, output_mem_config); + grad_tensor.emplace_back(result); + return grad_tensor; +} + std::function(const Tensor&, const Tensor&, const MemoryConfig&)> get_function_type1(UnaryBackwardOpType OpType){ switch (OpType) { @@ -40,6 +49,8 @@ std::function(const Tensor&, const Tensor&, float, con switch (OpType) { case UnaryBackwardOpType::UNARY_MUL_BW: return _unary_mul_bw; + case UnaryBackwardOpType::CLAMP_MIN_BW: + return _clamp_min_bw; default: TT_ASSERT(false && "Undefined op type"); return 0; diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.hpp index b202708fcf78..091b44b2a39b 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.hpp @@ -14,6 +14,7 @@ namespace ttnn::operations::unary_backward { constexpr uint8_t DefaultQueueId = 0; enum class UnaryBackwardOpType { UNARY_MUL_BW, + CLAMP_MIN_BW, }; diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.hpp index 000b20c97835..344dacfcd69a 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.hpp @@ -80,5 +80,6 @@ struct ExecuteUnaryBackward { //type 1 constexpr auto unary_mul_bw = ttnn::register_operation>("ttnn::unary_mul_bw"); +constexpr auto clamp_min_bw = ttnn::register_operation>("ttnn::clamp_min_bw"); } // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward_pybind.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward_pybind.hpp index d023e054978a..1d431dd72bbc 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward_pybind.hpp @@ -37,7 +37,7 @@ Keyword args: >>> grad_tensor = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device) >>> input = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device) ->>> output = {1}(grad_tensor, input) + >>> output = {1}(grad_tensor, input) )doc", operation.name(), operation.python_fully_qualified_name(), @@ -86,6 +86,11 @@ void py_module(py::module& module) { ttnn::unary_mul_bw, R"doc(Performs backward operations for multiply on :attr:`input_tensor`, :attr:`alpha` with given :attr:`grad_tensor`.)doc"); + detail::bind_unary_backward( + module, + ttnn::clamp_min_bw, + R"doc(Performs backward operations for clamp min value on :attr:`input_tensor`, :attr:`alpha` with given :attr:`grad_tensor`.)doc"); + } } // namespace binary_backward