diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward.hpp index b5a78413068..c050d74aa16 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward.hpp @@ -13,7 +13,7 @@ namespace ttnn { namespace operations::binary_backward { -//OpHandler_binary_bw : get_function_binary_bw_type1 +//OpHandler_binary_bw : get_function_binary_bw template struct ExecuteBinaryBackwardType1 { @@ -30,13 +30,13 @@ struct ExecuteBinaryBackwardType1 { const Tensor &input_tensor_a_arg, const Tensor &input_tensor_b_arg, const std::optional &memory_config = std::nullopt) { - auto op_type = get_function_binary_bw_type1(); + auto op_type = get_function_binary_bw(); auto output_memory_config = memory_config.value_or(input_tensor_a_arg.memory_config()); return op_type(grad_tensor_arg, input_tensor_a_arg, input_tensor_b_arg, output_memory_config); } }; -//OpHandler_binary_bw_opt_float_default : get_function_binary_bw_type1_opt_float_default +//OpHandler_binary_bw_opt_float_default : get_function_binary_bw_opt_float_default template struct ExecuteBinaryBackwardOptionalFloatDefault { @@ -59,7 +59,7 @@ struct ExecuteBinaryBackwardOptionalFloatDefault { std::optional input_b_grad = std::nullopt) { auto output_memory_config = memory_config.value_or(input_tensor_a_arg.memory_config()); - auto op_type = get_function_binary_bw_type1_opt_float_default(); + auto op_type = get_function_binary_bw_opt_float_default(); return op_type(queue_id, grad_tensor_arg, input_tensor_a_arg, input_tensor_b_arg, parameter, output_memory_config, are_required_outputs, input_a_grad, input_b_grad); } @@ -74,13 +74,13 @@ struct ExecuteBinaryBackwardOptionalFloatDefault { std::optional input_b_grad = std::nullopt) { auto output_memory_config = memory_config.value_or(input_tensor_a_arg.memory_config()); - auto op_type = get_function_binary_bw_type1_opt_float_default(); + auto op_type = get_function_binary_bw_opt_float_default(); return op_type(DefaultQueueId, grad_tensor_arg, input_tensor_a_arg, input_tensor_b_arg, parameter, output_memory_config, are_required_outputs, input_a_grad, input_b_grad); } }; -//OpHandler_binary_bw_float : get_function_binary_bw_type1_float +//OpHandler_binary_bw_float : get_function_binary_bw_float template struct ExecuteBinaryBackwardType1Float { @@ -98,7 +98,7 @@ struct ExecuteBinaryBackwardType1Float { const Tensor &input_tensor_b_arg, float parameter, const std::optional &memory_config = std::nullopt) { - auto op_type = get_function_binary_bw_type1(); + auto op_type = get_function_binary_bw_float(); auto output_memory_config = memory_config.value_or(input_tensor_a_arg.memory_config()); return op_type(grad_tensor_arg, input_tensor_a_arg, input_tensor_b_arg, parameter, output_memory_config); } @@ -193,15 +193,15 @@ struct ExecuteBinaryBackward { } // operations::binary -//OpHandler_binary_bw : get_function_binary_bw_type1 +//OpHandler_binary_bw : get_function_binary_bw constexpr auto atan2_bw = ttnn::register_operation>("ttnn::atan2_bw"); constexpr auto rsub_bw = ttnn::register_operation>("ttnn::rsub_bw"); constexpr auto embedding_bw = ttnn::register_operation>("ttnn::embedding_bw"); -//OpHandler_binary_bw_opt_float_default : get_function_binary_bw_type1_opt_float_default +//OpHandler_binary_bw_opt_float_default : get_function_binary_bw_opt_float_default constexpr auto addalpha_bw = ttnn::register_operation>("ttnn::addalpha_bw"); -//OpHandler_binary_bw_float : get_function_binary_bw_type1_float +//OpHandler_binary_bw_float : get_function_binary_bw_float constexpr auto subalpha_bw = ttnn::register_operation>("ttnn::subalpha_bw"); //type 1 diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward_pybind.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward_pybind.hpp index 2565d867db8..8fbdca735b2 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward_pybind.hpp @@ -20,7 +20,7 @@ namespace binary_backward { namespace detail { -//OpHandler_binary_bw : get_function_binary_bw_type1 +//OpHandler_binary_bw : get_function_binary_bw template void bind_binary_backward_type_1(py::module& module, const binary_backward_operation_t& operation, const std::string& description) { auto doc = fmt::format( @@ -67,7 +67,7 @@ Keyword args: py::arg("memory_config") = std::nullopt}); } -//OpHandler_binary_bw_opt_float_default : get_function_binary_bw_type1_opt_float_default +//OpHandler_binary_bw_opt_float_default : get_function_binary_bw_opt_float_default template void bind_binary_backward_opt_float_default(py::module& module, const binary_backward_operation_t& operation, const std::string& parameter_name, const std::string& parameter_doc, float parameter_value, const std::string& description) { auto doc = fmt::format( @@ -129,7 +129,7 @@ void bind_binary_backward_opt_float_default(py::module& module, const binary_bac ); } -//OpHandler_binary_bw_float : get_function_binary_bw_type1_float +//OpHandler_binary_bw_float : get_function_binary_bw_float template void bind_binary_backward_float_default(py::module& module, const binary_backward_operation_t& operation, const std::string& parameter_name, const std::string& parameter_doc, float parameter_value, const std::string& description) { auto doc = fmt::format( diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_backward/device/binary_backward_op.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary_backward/device/binary_backward_op.hpp index b89594a6aa3..77a598c6bfd 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_backward/device/binary_backward_op.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_backward/device/binary_backward_op.hpp @@ -43,22 +43,22 @@ enum class BinaryBackwardOpType { MUL_BW, }; struct BinaryBackwardFunction{ -static std::function(const Tensor&, const Tensor&, const Tensor&, const MemoryConfig&)> get_function_type1(BinaryBackwardOpType OpType); //get_function_binary_bw_type1 +static std::function(const Tensor&, const Tensor&, const Tensor&, const MemoryConfig&)> get_function_type1(BinaryBackwardOpType OpType); //get_function_binary_bw static std::function(const Tensor&, const Tensor&, const Tensor&, float, const MemoryConfig&)> get_function_type1_w_float(BinaryBackwardOpType OpType); static std::function(const Tensor&, const Tensor&, const Tensor&, std::string, const MemoryConfig&)> get_function_type1_w_string(BinaryBackwardOpType OpType); static std::function>(uint8_t , const Tensor&, const Tensor&, const Tensor&, const MemoryConfig&, const std::vector&, std::optional, std::optional)> get_function_type3(BinaryBackwardOpType OpType); static std::function>(const Tensor&, const Tensor&, const Tensor&, const MemoryConfig&, const std::vector&, std::optional, std::optional)> get_function_type3_wo_qid(BinaryBackwardOpType OpType); }; -//OpHandler_binary_bw : get_function_binary_bw_type1 +//OpHandler_binary_bw : get_function_binary_bw std::vector _atan2_bw( const Tensor& grad, const Tensor& input, const Tensor& other, const std::optional& output_mem_config); std::vector _rsub_bw( const Tensor& grad, const Tensor& input, const Tensor& other, const std::optional& output_mem_config); std::vector _embedding_bw( const Tensor& grad, const Tensor& input, const Tensor& other, const std::optional& output_mem_config); -//OpHandler_binary_bw_float : get_function_binary_bw_type1_float +//OpHandler_binary_bw_float : get_function_binary_bw_float std::vector _subalpha_bw( const Tensor& grad, const Tensor& input, const Tensor& other, float alpha = 1.0f, const std::optional& output_mem_config = std::nullopt); -//OpHandler_binary_bw_opt_float_default : get_function_binary_bw_type1_opt_float_default +//OpHandler_binary_bw_opt_float_default : get_function_binary_bw_opt_float_default std::vector> _addalpha_bw( uint8_t queue_id, const Tensor& grad, const Tensor& input, const Tensor& other, float alpha = 1.0f, const std::optional& output_mem_config = std::nullopt, const std::vector& are_required_outputs = std::vector{true, true}, std::optional input_grad = std::nullopt, std::optional other_grad = std::nullopt); // OpHandler struct template @@ -100,7 +100,7 @@ struct OpHandler_binary_bw { }; template <> -struct OpHandler_binary_bw { +struct OpHandler_binary_bw_float { static std::vector handle( const Tensor& grad, const Tensor& input, const Tensor& other, float alpha, const std::optional& output_mem_config ) { return _subalpha_bw(grad, input, other, alpha, output_mem_config); } @@ -108,17 +108,17 @@ struct OpHandler_binary_bw { // Template functions to get the function pointers template -auto get_function_binary_bw_type1() { +auto get_function_binary_bw() { return &OpHandler_binary_bw::handle; } template -auto get_function_binary_bw_type1_opt_float_default() { +auto get_function_binary_bw_opt_float_default() { return &OpHandler_binary_bw_opt_float_default::handle; } template -auto get_function_binary_bw_type1_float() { +auto get_function_binary_bw_float() { return &OpHandler_binary_bw_float::handle; }