diff --git a/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_unary_mul.py b/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_unary_mul.py index bf5217cad7da..9d4085b0ffbf 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_unary_mul.py +++ b/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_unary_mul.py @@ -22,7 +22,7 @@ def test_bw_unary_mul(input_shapes, scalar, device): in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True) grad_data, grad_tensor = data_gen_with_range(input_shapes, -5, 5, device) - tt_output_tensor_on_device = ttnn.unary_mul_bw(grad_tensor, input_tensor, scalar) + tt_output_tensor_on_device = ttnn.mul_bw(grad_tensor, input_tensor, scalar) in_data.retain_grad() diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward.hpp index b62a0b5ed88a..ed2afd5623cb 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward.hpp @@ -180,6 +180,4 @@ constexpr auto addalpha_bw = ttnn::register_operation>("ttnn::add_bw"); constexpr auto binary_eq_bw = ttnn::register_operation>("ttnn::binary_eq_bw"); -constexpr auto mul_bw = ttnn::register_operation>("ttnn::mul_bw"); - } // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward_pybind.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward_pybind.hpp index a205b00a631c..030f9a4c29e6 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward_pybind.hpp @@ -346,11 +346,6 @@ void py_module(py::module& module) { ttnn::div_bw, R"doc(Performs backward operations for divide of :attr:`input_tensor_a` and :attr:`input_tensor_b` with given :attr:`grad_tensor`.)doc"); - detail::bind_binary_backward( - module, - ttnn::mul_bw, - R"doc(Performs backward operations for multiply on :attr:`input_tensor_b` , attr:`input_tensor_a` with given attr:`grad_tensor`.)doc"); - } } // namespace binary_backward diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.cpp b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.cpp index 982a00a8beec..893e842c8b55 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.cpp @@ -47,7 +47,7 @@ std::function(const Tensor&, const Tensor&, const Memo std::function(const Tensor&, const Tensor&, float, const MemoryConfig&)> get_function_type1_w_float(UnaryBackwardOpType OpType){ switch (OpType) { - case UnaryBackwardOpType::UNARY_MUL_BW: + case UnaryBackwardOpType::MUL_BW: return _unary_mul_bw; case UnaryBackwardOpType::CLAMP_MIN_BW: return _clamp_min_bw; diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.hpp index 091b44b2a39b..ff25b4d33c38 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.hpp @@ -13,7 +13,7 @@ namespace ttnn::operations::unary_backward { constexpr uint8_t DefaultQueueId = 0; enum class UnaryBackwardOpType { - UNARY_MUL_BW, + MUL_BW, CLAMP_MIN_BW, }; diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.hpp index 344dacfcd69a..7ce21f9af212 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.hpp @@ -16,28 +16,6 @@ namespace operations::unary_backward { template struct ExecuteUnaryBackward { - static inline const std::array input_tensor_schemas() { - return { - ttnn::TensorSchema{ - 2, - 4, - {ttnn::bfloat16, ttnn::bfloat8_b, ttnn::bfloat4_b, ttnn::uint16}, - {ttnn::TILE_LAYOUT}, - true, - false, - false, - false}, - ttnn::TensorSchema{ - 2, - 4, - {ttnn::bfloat16, ttnn::bfloat8_b, ttnn::bfloat4_b, ttnn::uint16}, - {ttnn::TILE_LAYOUT}, - true, - false, - false, - false}}; - } - static inline std::vector create_async_output_tensors( const std::vector &input_tensors, const std::vector>& optional_inputs) { const auto& input_tensor = input_tensors.at(0); @@ -45,10 +23,6 @@ struct ExecuteUnaryBackward { } //Type 1: 2 inputs, 1 grad tensor - template - static auto input_tensors_to_validate(const Tensor &grad_tensor, const Tensor &input_tensor_a, const Tensor &input_tensor_b, Args &&...args) { - return std::forward_as_tuple(grad_tensor, input_tensor_a, input_tensor_b); - } static std::vector execute_on_worker_thread( const Tensor &grad_tensor_arg, @@ -61,7 +35,6 @@ struct ExecuteUnaryBackward { } //Type 1: Type 1 with 1 float - template static std::vector execute_on_worker_thread( const Tensor &grad_tensor_arg, @@ -79,7 +52,7 @@ struct ExecuteUnaryBackward { } // operations::unary //type 1 -constexpr auto unary_mul_bw = ttnn::register_operation>("ttnn::unary_mul_bw"); +constexpr auto mul_bw = ttnn::register_operation>("ttnn::mul_bw"); constexpr auto clamp_min_bw = ttnn::register_operation>("ttnn::clamp_min_bw"); } // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward_pybind.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward_pybind.hpp index 1d431dd72bbc..89f2044ed2ef 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward_pybind.hpp @@ -9,6 +9,7 @@ #include "ttnn/cpp/pybind11/decorators.hpp" #include "ttnn/operations/eltwise/unary_backward/unary_backward.hpp" +#include "ttnn/operations/eltwise/binary_backward/binary_backward.hpp" #include "ttnn/types.hpp" namespace py = pybind11; @@ -39,7 +40,7 @@ Keyword args: >>> input = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device) >>> output = {1}(grad_tensor, input) )doc", - operation.name(), + operation.base_name(), operation.python_fully_qualified_name(), description); @@ -47,6 +48,53 @@ Keyword args: module, operation, doc, + ttnn::pybind_overload_t{ + [operation](const unary_backward_operation_t& self, + const ttnn::Tensor& grad_tensor, + const ttnn::Tensor& input_tensor_a, + const ttnn::Tensor& input_tensor_b, + const std::optional& memory_config) -> std::vector { + auto output_memory_config = memory_config.value_or(input_tensor_a.memory_config()); + + using BinaryBackwardOp = ttnn::operations::binary_backward::ExecuteBinaryBackward; + if(operation.base_name()=="mul_bw"){ + using BinaryBackwardOp = ttnn::operations::binary_backward::ExecuteBinaryBackward; + } + + return BinaryBackwardOp::execute_on_worker_thread(grad_tensor, input_tensor_a, output_memory_config, input_tensor_b); + }, + py::arg("grad_tensor"), + py::arg("input_tensor_a"), + py::arg("input_tensor_b"), + py::kw_only(), + py::arg("memory_config") = std::nullopt}, + + ttnn::pybind_overload_t{ + [](const unary_backward_operation_t& self, + const ttnn::Tensor& grad_tensor, + const ttnn::Tensor& input_tensor_a, + const ttnn::Tensor& input_tensor_b, + const std::optional& memory_config, + const std::vector& are_required_outputs, + const std::optional& input_a_grad, + const std::optional& input_b_grad, + const uint8_t& queue_id) -> std::vector> { + using BinaryBackwardOp = ttnn::operations::binary_backward::ExecuteBinaryBackward; + if(operation.base_name()=="mul_bw"){ + using BinaryBackwardOp = ttnn::operations::binary_backward::ExecuteBinaryBackward; + } + return BinaryBackwardOp::execute_on_main_thread(queue_id, grad_tensor, input_tensor_a, input_tensor_b, memory_config, are_required_outputs, input_a_grad, input_b_grad); + }, + py::arg("grad_tensor"), + py::arg("input_tensor_a"), + py::arg("input_tensor_b"), + py::kw_only(), + py::arg("memory_config") = std::nullopt, + py::arg("are_required_outputs") = std::vector{true, true}, + py::arg("input_a_grad") = std::nullopt, + py::arg("input_b_grad") = std::nullopt, + py::arg("queue_id") = 0}, + ttnn::pybind_overload_t{ [](const unary_backward_operation_t& self, const ttnn::Tensor& grad_tensor, @@ -83,7 +131,7 @@ Keyword args: void py_module(py::module& module) { detail::bind_unary_backward( module, - ttnn::unary_mul_bw, + ttnn::mul_bw, R"doc(Performs backward operations for multiply on :attr:`input_tensor`, :attr:`alpha` with given :attr:`grad_tensor`.)doc"); detail::bind_unary_backward(