From c88bb749feeae7584402e9417fb01c72f843ad7f Mon Sep 17 00:00:00 2001 From: Virdhatchani Narayanamoorthy <138196495+VirdhatchaniKN@users.noreply.github.com> Date: Sat, 20 Jul 2024 13:15:08 +0530 Subject: [PATCH] #9628: Remove std::function for BW Binary ops (#10492) * #9628: Remove std::function for atan2_bw * #9628: Remove std::function for addalpha_bw * #9628: Update rsub_bw * #9628: Update embedding_bw --- .../binary_backward/binary_backward.hpp | 79 +++++++++++- .../binary_backward_pybind.hpp | 118 +++++++++++++++++- .../device/binary_backward_op.cpp | 65 ++-------- .../device/binary_backward_op.hpp | 57 ++++++++- .../unary_backward/unary_backward_pybind.hpp | 2 +- 5 files changed, 258 insertions(+), 63 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward.hpp index d55aac7afe6..a9b6c2535c9 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward.hpp @@ -13,6 +13,73 @@ namespace ttnn { namespace operations::binary_backward { +//OpHandler_binary_bw : get_function_binary_bw_type1 +template +struct ExecuteBinaryBackwardType1 { + + static inline std::vector create_async_output_tensors( + const std::vector &input_tensors, const std::vector>& optional_inputs) { + const auto& input_tensor = input_tensors.at(0); + return {Tensor(operation::get_workers_for_op_output({input_tensor})), + Tensor(operation::get_workers_for_op_output({input_tensor}))}; + } + + //Type 1: 1 inputs, 1 grad tensor, 1 float, 1 default string + static std::vector execute_on_main_thread( + const Tensor &grad_tensor_arg, + const Tensor &input_tensor_a_arg, + const Tensor &input_tensor_b_arg, + const std::optional &memory_config = std::nullopt) { + auto op_type = get_function_binary_bw_type1(); + auto output_memory_config = memory_config.value_or(input_tensor_a_arg.memory_config()); + return op_type(grad_tensor_arg, input_tensor_a_arg, input_tensor_b_arg, output_memory_config); + } +}; + +//OpHandler_binary_bw_opt_float_default : get_function_binary_bw_type1_opt_float_default +template +struct ExecuteBinaryBackwardOptionalFloatDefault { + + static inline std::vector create_async_output_tensors( + const std::vector &input_tensors, const std::vector>& optional_inputs) { + const auto& input_tensor = input_tensors.at(0); + return {Tensor(operation::get_workers_for_op_output({input_tensor})), + Tensor(operation::get_workers_for_op_output({input_tensor}))}; + } + + static std::vector> execute_on_main_thread( + uint8_t queue_id, + const Tensor &grad_tensor_arg, + const Tensor &input_tensor_a_arg, + const Tensor &input_tensor_b_arg, + float parameter, + const std::optional &memory_config = std::nullopt, + const std::vector& are_required_outputs = std::vector{true, true}, + std::optional input_a_grad = std::nullopt, + std::optional input_b_grad = std::nullopt) { + + auto output_memory_config = memory_config.value_or(input_tensor_a_arg.memory_config()); + auto op_type = get_function_binary_bw_type1_opt_float_default(); + return op_type(queue_id, grad_tensor_arg, input_tensor_a_arg, input_tensor_b_arg, parameter, output_memory_config, are_required_outputs, input_a_grad, input_b_grad); + } + + static std::vector> execute_on_main_thread( + const Tensor &grad_tensor_arg, + const Tensor &input_tensor_a_arg, + const Tensor &input_tensor_b_arg, + float parameter, + const std::optional &memory_config = std::nullopt, + const std::vector& are_required_outputs = std::vector{true, true}, + std::optional input_a_grad = std::nullopt, + std::optional input_b_grad = std::nullopt) { + + auto output_memory_config = memory_config.value_or(input_tensor_a_arg.memory_config()); + auto op_type = get_function_binary_bw_type1_opt_float_default(); + return op_type(DefaultQueueId, grad_tensor_arg, input_tensor_a_arg, input_tensor_b_arg, parameter, output_memory_config, are_required_outputs, input_a_grad, input_b_grad); + } + +}; + template struct ExecuteBinaryBackward { static inline std::vector create_async_output_tensors( @@ -138,9 +205,15 @@ struct ExecuteBinaryBackward { } // operations::binary +//OpHandler_binary_bw : get_function_binary_bw_type1 +constexpr auto atan2_bw = ttnn::register_operation>("ttnn::atan2_bw"); +constexpr auto rsub_bw = ttnn::register_operation>("ttnn::rsub_bw"); +constexpr auto embedding_bw = ttnn::register_operation>("ttnn::embedding_bw"); + +//OpHandler_binary_bw_opt_float_default : get_function_binary_bw_type1_opt_float_default +constexpr auto addalpha_bw = ttnn::register_operation>("ttnn::addalpha_bw"); + //type 1 -constexpr auto atan2_bw = ttnn::register_operation>("ttnn::atan2_bw"); -constexpr auto embedding_bw = ttnn::register_operation>("ttnn::embedding_bw"); constexpr auto subalpha_bw = ttnn::register_operation>("ttnn::subalpha_bw"); constexpr auto xlogy_bw = ttnn::register_operation>("ttnn::xlogy_bw"); constexpr auto hypot_bw = ttnn::register_operation>("ttnn::hypot_bw"); @@ -149,12 +222,10 @@ constexpr auto logaddexp_bw = ttnn::register_operation>("ttnn::logaddexp2_bw"); constexpr auto squared_difference_bw = ttnn::register_operation>("ttnn::squared_difference_bw"); constexpr auto concat_bw = ttnn::register_operation>("ttnn::concat_bw"); -constexpr auto rsub_bw = ttnn::register_operation>("ttnn::rsub_bw"); constexpr auto min_bw = ttnn::register_operation>("ttnn::min_bw"); constexpr auto max_bw = ttnn::register_operation>("ttnn::max_bw"); constexpr auto lerp_bw = ttnn::register_operation>("ttnn::lerp_bw"); //type 2 -constexpr auto addalpha_bw = ttnn::register_operation>("ttnn::addalpha_bw"); } // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward_pybind.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward_pybind.hpp index 8f52005f367..9c0488e01ef 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward_pybind.hpp @@ -20,6 +20,115 @@ namespace binary_backward { namespace detail { +//OpHandler_binary_bw : get_function_binary_bw_type1 +template +void bind_binary_backward_type_1(py::module& module, const binary_backward_operation_t& operation, const std::string& description) { + auto doc = fmt::format( + R"doc({0}(grad_tensor: ttnn.Tensor, input_tensor_a: ttnn.Tensor, input_tensor_b: ttnn.Tensor, *, memory_config: ttnn.MemoryConfig) -> std::vector + +{2} + +Args: + * :attr:`grad_tensor` + * :attr:`input_tensor_a` + * :attr:`input_tensor_b` + +Keyword args: + * :attr:`memory_config` (Optional[ttnn.MemoryConfig]): memory config for the output tensor + +Example: + + >>> grad_tensor = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device) + >>> tensor1 = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device) + >>> tensor2 = ttnn.to_device(ttnn.from_torch(torch.tensor((0, 1), dtype=torch.bfloat16)), device) + >>> output = {1}(grad_tensor, tensor1, tensor2) +)doc", + operation.base_name(), + operation.python_fully_qualified_name(), + description); + + bind_registered_operation( + module, + operation, + doc, + ttnn::pybind_overload_t{ + [](const binary_backward_operation_t& self, + const ttnn::Tensor& grad_tensor, + const ttnn::Tensor& input_tensor_a, + const ttnn::Tensor& input_tensor_b, + const std::optional& memory_config) -> std::vector { + auto output_memory_config = memory_config.value_or(input_tensor_a.memory_config()); + return self(grad_tensor, input_tensor_a, input_tensor_b, output_memory_config); + }, + py::arg("grad_tensor"), + py::arg("input_tensor_a"), + py::arg("input_tensor_b"), + py::kw_only(), + py::arg("memory_config") = std::nullopt}); +} + +//OpHandler_binary_bw_opt_float_default : get_function_binary_bw_type1_opt_float_default +template +void bind_binary_backward_opt_float_default(py::module& module, const binary_backward_operation_t& operation, const std::string& parameter_name, const std::string& parameter_doc, float parameter_value, const std::string& description) { + auto doc = fmt::format( + R"doc({0}(grad_tensor: ttnn.Tensor, input_tensor_a: ttnn.Tensor, input_tensor_b: ttnn.Tensor, {2}: float, *, memory_config: ttnn.MemoryConfig) -> std::vector + + {5} + + Args: + * :attr:`grad_tensor` + * :attr:`input_tensor_a` + * :attr:`input_tensor_b` + * :attr:`{3}` (float):Default value = {4} + + Keyword args: + * :attr:`memory_config` (Optional[ttnn.MemoryConfig]): memory config for the output tensor + + Example: + + >>> grad_tensor = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device) + >>> tensor1 = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device) + >>> tensor2 = ttnn.to_device(ttnn.from_torch(torch.tensor((0, 1), dtype=torch.bfloat16)), device) + >>> output = {1}(grad_tensor, tensor1, tensor2, float) + )doc", + operation.base_name(), + operation.python_fully_qualified_name(), + parameter_name, + parameter_doc, + parameter_value, + description); + + + bind_registered_operation( + module, + operation, + doc, + ttnn::pybind_overload_t{ + [](const binary_backward_operation_t& self, + const ttnn::Tensor& grad_tensor, + const ttnn::Tensor& input_tensor_a, + const ttnn::Tensor& input_tensor_b, + float parameter, + const std::optional& memory_config, + const std::vector& are_required_outputs, + const std::optional& input_a_grad, + const std::optional& input_b_grad, + const uint8_t& queue_id) -> std::vector> { + return self(queue_id, grad_tensor, input_tensor_a, input_tensor_b, parameter, memory_config, are_required_outputs, input_a_grad, input_b_grad); + }, + py::arg("grad_tensor"), + py::arg("input_tensor_a"), + py::arg("input_tensor_b"), + py::arg(parameter_name.c_str()) = parameter_value, + py::kw_only(), + py::arg("memory_config") = std::nullopt, + py::arg("are_required_outputs") = std::vector{true, true}, + py::arg("input_a_grad") = std::nullopt, + py::arg("input_b_grad") = std::nullopt, + py::arg("queue_id") = 0} + ); +} + template void bind_binary_backward(py::module& module, const binary_backward_operation_t& operation, const std::string& description) { auto doc = fmt::format( @@ -212,12 +321,12 @@ Keyword args: void py_module(py::module& module) { - detail::bind_binary_backward( + detail::bind_binary_backward_type_1( module, ttnn::atan2_bw, R"doc(Performs backward operations for atan2 of :attr:`input_tensor_a` and :attr:`input_tensor_b` with given :attr:`grad_tensor`.)doc"); - detail::bind_binary_backward( + detail::bind_binary_backward_type_1( module, ttnn::embedding_bw, R"doc(Performs backward operations for embedding_bw function and it returns specific indices of the embedding table specified by the :attr:`grad_tensor`. @@ -228,9 +337,10 @@ void py_module(py::module& module) { ttnn::subalpha_bw, R"doc(Performs backward operations for subalpha of :attr:`input_tensor_a` and :attr:`input_tensor_b` with given :attr:`grad_tensor`.)doc"); - detail::bind_binary_backward( + detail::bind_binary_backward_opt_float_default( module, ttnn::addalpha_bw, + "alpha", "Alpha value", 1.0f, R"doc(Performs backward operations for addalpha on :attr:`input_tensor_b` , attr:`input_tensor_a`, attr:`alpha` with given attr:`grad_tensor`.)doc"); detail::bind_binary_backward( @@ -273,7 +383,7 @@ void py_module(py::module& module) { ttnn::concat_bw, R"doc(Performs backward operations for concat on :attr:`input_tensor_a` and :attr:`input_tensor_b` with given attr:`grad_tensor`.)doc"); - detail::bind_binary_backward( + detail::bind_binary_backward_type_1( module, ttnn::rsub_bw, R"doc(Performs backward operations for subraction of :attr:`input_tensor_a` from :attr:`input_tensor_b` with given attr:`grad_tensor` (reversed order of subtraction operator).)doc"); diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_backward/device/binary_backward_op.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_backward/device/binary_backward_op.cpp index 3cb3475d41f..797bec43d5f 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_backward/device/binary_backward_op.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_backward/device/binary_backward_op.cpp @@ -23,8 +23,9 @@ namespace ttnn::operations::binary_backward { std::vector _atan2_bw( - const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config) { + const Tensor& grad, const Tensor& input, const Tensor& other, const std::optional& output_mem_config) { std::vector grad_tensor; + auto output_memory_config = output_mem_config.value_or(input.memory_config()); //TODO: Remove after ternary forward ops migration is completed float t_nan = std::nanf(""); using ttnn::operations::unary::UnaryWithParam; using ttnn::operations::unary::UnaryOpType; @@ -32,13 +33,13 @@ std::vector _atan2_bw( UnaryWithParam {UnaryOpType::SQUARE}, UnaryWithParam {UnaryOpType::RECIP}}; Tensor recip_mul = - ttnn::multiply(grad, ttnn::unary_chain(hypot(input, other), ops_chain, output_mem_config), std::nullopt, output_mem_config); - Tensor grad_a = ttnn::multiply(other, recip_mul, std::nullopt, output_mem_config); - Tensor cond = ttnn::logical_and(ttnn::eqz(input, output_mem_config), ttnn::eqz(other, output_mem_config)); - grad_a = where(cond, t_nan, grad_a, output_mem_config); + ttnn::multiply(grad, ttnn::unary_chain(hypot(input, other), ops_chain, output_memory_config), std::nullopt, output_memory_config); + Tensor grad_a = ttnn::multiply(other, recip_mul, std::nullopt, output_memory_config); + Tensor cond = ttnn::logical_and(ttnn::eqz(input, output_memory_config), ttnn::eqz(other, output_memory_config)); + grad_a = where(cond, t_nan, grad_a, output_memory_config); grad_tensor.emplace_back(grad_a); - Tensor grad_b = ttnn::multiply(ttnn::neg(input), recip_mul, std::nullopt, output_mem_config); - grad_b = where(cond, t_nan, grad_b, output_mem_config); + Tensor grad_b = ttnn::multiply(ttnn::neg(input), recip_mul, std::nullopt, output_memory_config); + grad_b = where(cond, t_nan, grad_b, output_memory_config); recip_mul.deallocate(); cond.deallocate(); grad_tensor.emplace_back(grad_b); @@ -47,7 +48,7 @@ std::vector _atan2_bw( std::vector _embedding_bw( - const Tensor& grad, const Tensor& input, const Tensor& weight, const MemoryConfig& output_mem_config) { + const Tensor& grad, const Tensor& input, const Tensor& weight, const std::optional& output_mem_config) { TT_FATAL(input.get_dtype() == DataType::UINT32, "Input must be UINT32"); TT_FATAL( grad.get_legacy_shape()[0] == 1 && grad.get_legacy_shape()[1] == 1, @@ -68,7 +69,7 @@ std::vector> _addalpha_bw( const Tensor& input, const Tensor& other, float alpha, - const MemoryConfig& output_mem_config, + const std::optional& output_mem_config, const std::vector& are_required_outputs, std::optional input_grad, std::optional other_grad) { @@ -99,38 +100,6 @@ std::vector> _addalpha_bw( } -std::vector _addalpha_bw_inter( - const Tensor& grad, const Tensor& input, const Tensor& other, float alpha, const MemoryConfig& output_mem_config) { - - auto result = _addalpha_bw(0, grad, input, other, alpha, output_mem_config, {true, true}, std::nullopt, std::nullopt); - - std::vector output_tensors; - output_tensors.reserve(result.size()); - - for (const auto& opt_tensor : result) { - if (opt_tensor) { - output_tensors.emplace_back(*opt_tensor); - } else { - output_tensors.emplace_back(); - } - } - return output_tensors; -} - - -std::vector> _addalpha_bw_overload( - const Tensor& grad, - const Tensor& input, - const Tensor& other, - float alpha, - const MemoryConfig& output_mem_config, - const std::vector& are_required_outputs, - std::optional input_grad, - std::optional other_grad) { - uint8_t default_queue_id = 0; - return _addalpha_bw(default_queue_id, grad, input, other, alpha, output_mem_config, are_required_outputs, input_grad, other_grad); -} - std::vector _subalpha_bw( const Tensor& grad, const Tensor& input, const Tensor& other, float alpha, const MemoryConfig& output_mem_config) { std::vector grad_tensor; @@ -413,8 +382,8 @@ std::vector _binary_comp_bw(const Tensor& grad, const Tensor& input, con return grad_tensor; } -std::vector _rsub_bw( const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config) { - std::vector grad_tensor = _subalpha_bw(grad, input, other, 1.0f, output_mem_config); +std::vector _rsub_bw( const Tensor& grad, const Tensor& input, const Tensor& other, const std::optional& output_mem_config) { + std::vector grad_tensor = ttnn::operations::binary_backward::_subalpha_bw(grad, input, other, 1.0f, output_mem_config.value_or(input.memory_config())); std::swap(grad_tensor[0], grad_tensor[1]); return grad_tensor; } @@ -623,8 +592,6 @@ std::vector> _mul_bw_overload( std::function(const Tensor&, const Tensor&, const Tensor&, const MemoryConfig&)> BinaryBackwardFunction::get_function_type1(BinaryBackwardOpType OpType){ switch (OpType) { - case BinaryBackwardOpType::ATAN2_BW: - return _atan2_bw; case BinaryBackwardOpType::EMBEDDING_BW: return _embedding_bw; case BinaryBackwardOpType::SUB_BW: @@ -649,8 +616,6 @@ std::function(const Tensor&, const Tensor&, const Tens return _assign_bw; case BinaryBackwardOpType::LE_BW: return _le_bw; - case BinaryBackwardOpType::RSUB_BW: - return _rsub_bw; case BinaryBackwardOpType::GT_BW: return _gt_bw; case BinaryBackwardOpType::LT_BW: @@ -675,8 +640,6 @@ std::function(const Tensor&, const Tensor&, const Tensor&, f switch (OpType) { case BinaryBackwardOpType::SUBALPHA_BW: return _subalpha_bw; - case BinaryBackwardOpType::ADDALPHA_BW: - return _addalpha_bw_inter; case BinaryBackwardOpType::CONCAT_BW: return _concat_bw; case BinaryBackwardOpType::LERP_BW: @@ -701,8 +664,6 @@ std::function(const Tensor&, const Tensor&, const Tens std::function>(uint8_t , const Tensor&, const Tensor&, const Tensor&, float, const MemoryConfig&, const std::vector&, std::optional, std::optional)> BinaryBackwardFunction::get_function_type2(BinaryBackwardOpType OpType){ switch (OpType) { - case BinaryBackwardOpType::ADDALPHA_BW: - return _addalpha_bw; default: TT_ASSERT(false && "Undefined op type"); return 0; @@ -711,8 +672,6 @@ std::function>(uint8_t , const Tensor&, std::function>(const Tensor&, const Tensor&, const Tensor&, float, const MemoryConfig&, const std::vector&, std::optional, std::optional)> BinaryBackwardFunction::get_function_type2_wo_qid(BinaryBackwardOpType OpType){ switch (OpType) { - case BinaryBackwardOpType::ADDALPHA_BW: - return _addalpha_bw_overload; default: TT_ASSERT(false && "Undefined op type"); return 0; diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_backward/device/binary_backward_op.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary_backward/device/binary_backward_op.hpp index a219f33fd18..32efb0a04b2 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_backward/device/binary_backward_op.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_backward/device/binary_backward_op.hpp @@ -43,7 +43,7 @@ enum class BinaryBackwardOpType { MUL_BW, }; struct BinaryBackwardFunction{ -static std::function(const Tensor&, const Tensor&, const Tensor&, const MemoryConfig&)> get_function_type1(BinaryBackwardOpType OpType); +static std::function(const Tensor&, const Tensor&, const Tensor&, const MemoryConfig&)> get_function_type1(BinaryBackwardOpType OpType); //get_function_binary_bw_type1 static std::function(const Tensor&, const Tensor&, const Tensor&, float, const MemoryConfig&)> get_function_type1_w_float(BinaryBackwardOpType OpType); static std::function(const Tensor&, const Tensor&, const Tensor&, std::string, const MemoryConfig&)> get_function_type1_w_string(BinaryBackwardOpType OpType); static std::function>(uint8_t , const Tensor&, const Tensor&, const Tensor&, float, const MemoryConfig&, const std::vector&, std::optional, std::optional)> get_function_type2(BinaryBackwardOpType OpType); @@ -51,4 +51,59 @@ static std::function>(const Tensor&, con static std::function>(uint8_t , const Tensor&, const Tensor&, const Tensor&, const MemoryConfig&, const std::vector&, std::optional, std::optional)> get_function_type3(BinaryBackwardOpType OpType); static std::function>(const Tensor&, const Tensor&, const Tensor&, const MemoryConfig&, const std::vector&, std::optional, std::optional)> get_function_type3_wo_qid(BinaryBackwardOpType OpType); }; + +//OpHandler_binary_bw : get_function_binary_bw_type1 +std::vector _atan2_bw( const Tensor& grad, const Tensor& input, const Tensor& other, const std::optional& output_mem_config); +std::vector _rsub_bw( const Tensor& grad, const Tensor& input, const Tensor& other, const std::optional& output_mem_config); +std::vector _embedding_bw( const Tensor& grad, const Tensor& input, const Tensor& other, const std::optional& output_mem_config); + +//OpHandler_binary_bw_opt_float_default : get_function_binary_bw_type1_opt_float_default +std::vector> _addalpha_bw( uint8_t queue_id, const Tensor& grad, const Tensor& input, const Tensor& other, float alpha = 1.0f, const std::optional& output_mem_config = std::nullopt, const std::vector& are_required_outputs = std::vector{true, true}, std::optional input_grad = std::nullopt, std::optional other_grad = std::nullopt); + +// OpHandler struct template +template +struct OpHandler_binary_bw; + +template +struct OpHandler_binary_bw_opt_float_default; + +template <> +struct OpHandler_binary_bw { + static std::vector handle( const Tensor& grad, const Tensor& input, const Tensor& other, const std::optional& output_mem_config ) { + return _atan2_bw(grad, input, other, output_mem_config); + } +}; + +template <> +struct OpHandler_binary_bw { + static std::vector handle( const Tensor& grad, const Tensor& input, const Tensor& other, const std::optional& output_mem_config ) { + return _rsub_bw(grad, input, other, output_mem_config); + } +}; + +template <> +struct OpHandler_binary_bw_opt_float_default { + static std::vector> handle( uint8_t queue_id, const Tensor& grad, const Tensor& input, const Tensor& other, float alpha, const MemoryConfig& output_mem_config, const std::vector& are_required_outputs, std::optional input_grad, std::optional other_grad ) { + return _addalpha_bw( queue_id, grad, input, other, alpha, output_mem_config, are_required_outputs, input_grad, other_grad); + } +}; + +template <> +struct OpHandler_binary_bw { + static std::vector handle( const Tensor& grad, const Tensor& input, const Tensor& other, const std::optional& output_mem_config ) { + return _embedding_bw(grad, input, other, output_mem_config); + } +}; + +// Template functions to get the function pointers +template +auto get_function_binary_bw_type1() { + return &OpHandler_binary_bw::handle; +} + +template +auto get_function_binary_bw_type1_opt_float_default() { + return &OpHandler_binary_bw_opt_float_default::handle; +} + } // namespace ttnn::operations::binary_backward diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward_pybind.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward_pybind.hpp index b8e33b82d84..7a5ec582a28 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward_pybind.hpp @@ -192,7 +192,7 @@ void bind_unary_backward_float_string_default(py::module& module, const unary_ba Args: * :attr:`grad_tensor` * :attr:`input_tensor_a` or :attr:`input_tensor` - * :attr:`input_tensor_b` or:attr:`{2}` (float): {3} + * :attr:`input_tensor_b` or :attr:`{2}` (float): {3} Keyword args: * :attr:`{4}` (string): {5} , Default value = {6}