From d8706ff89f23600f7dc0b1033d0da3f299bfd06e Mon Sep 17 00:00:00 2001 From: Virdhatchani Narayanamoorthy <138196495+VirdhatchaniKN@users.noreply.github.com> Date: Mon, 30 Sep 2024 07:52:44 +0530 Subject: [PATCH] #13242: Cleanup set-5 unary backward ops (#13243) --- .../device/unary_backward_op.cpp | 24 +-- .../device/unary_backward_op.hpp | 109 ----------- .../eltwise/unary_backward/unary_backward.hpp | 177 +++++++++--------- 3 files changed, 103 insertions(+), 207 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.cpp b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.cpp index 53764850a3e..7eb1db95cae 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.cpp @@ -298,13 +298,13 @@ std::vector _sub_bw(const Tensor& grad, const Tensor& input, float alpha return grad_tensor; } -std::vector _frac_bw(const Tensor& grad, const Tensor& input, const std::optional& output_mem_config) { +std::vector ExecuteUnaryBackwardFrac::invoke(const Tensor& grad, const Tensor& input, const std::optional& output_mem_config) { std::vector grad_tensor; grad_tensor.emplace_back(grad); return grad_tensor; } -std::vector _trunc_bw(const Tensor& grad, const Tensor& input, const std::optional& output_mem_config) { +std::vector ExecuteUnaryBackwardTrunc::invoke(const Tensor& grad, const Tensor& input, const std::optional& output_mem_config) { std::vector grad_tensor; Tensor grad_result = ttnn::zeros_like(grad, grad.get_dtype(), grad.get_layout(), std::nullopt, output_mem_config); grad_tensor.emplace_back(grad_result); @@ -313,7 +313,7 @@ std::vector _trunc_bw(const Tensor& grad, const Tensor& input, const std // return: grad_output * (max_deriv - sign * (z / (1 + z))) // z = exp(-abs(input)) -std::vector _log_sigmoid_bw(const Tensor& grad, const Tensor& input, const std::optional& output_mem_config) { +std::vector ExecuteUnaryBackwardLogSigmoid::invoke(const Tensor& grad, const Tensor& input, const std::optional& output_mem_config) { std::vector grad_tensor; Tensor max_deriv = ttnn::where(ttnn::ltz(input, output_mem_config), 1, 0, output_mem_config); Tensor in_sign = ttnn::where(ttnn::ltz(input, output_mem_config), 1, -1, output_mem_config); @@ -330,14 +330,14 @@ std::vector _log_sigmoid_bw(const Tensor& grad, const Tensor& input, con return grad_tensor; } -std::vector _fill_zero_bw(const Tensor& grad, const Tensor& input, const std::optional& output_mem_config) { +std::vector ExecuteUnaryBackwardFillZero::invoke(const Tensor& grad, const Tensor& input, const std::optional& output_mem_config) { std::vector grad_tensor; Tensor result = ttnn::zeros_like(grad, grad.get_dtype(), grad.get_layout(), std::nullopt, output_mem_config); grad_tensor.emplace_back(result); return grad_tensor; } -std::vector _i0_bw(const Tensor& grad, const Tensor& input, const std::optional& output_mem_config) { +std::vector ExecuteUnaryBackwardI0::invoke(const Tensor& grad, const Tensor& input, const std::optional& output_mem_config) { std::vector grad_tensor; float t_inf = std::numeric_limits::infinity(); Tensor value = ttnn::multiply( @@ -367,7 +367,7 @@ std::vector _i0_bw(const Tensor& grad, const Tensor& input, const std::o return grad_tensor; } -std::vector _tan_bw(const Tensor& grad, const Tensor& input, const std::optional& output_mem_config) { +std::vector ExecuteUnaryBackwardTan::invoke(const Tensor& grad, const Tensor& input, const std::optional& output_mem_config) { std::vector grad_tensor; Tensor tan_result = ttnn::tan(input, output_mem_config); Tensor result = @@ -506,7 +506,7 @@ std::vector ExecuteUnaryBackwardAcosh::invoke(const Tensor& grad, const // # - name: acos(Tensor self) -> Tensor // # self: grad * -((-self * self + 1).rsqrt()) -std::vector _acos_bw(const Tensor& grad, const Tensor& input, const std::optional& output_mem_config) { +std::vector ExecuteUnaryBackwardAcos::invoke(const Tensor& grad, const Tensor& input, const std::optional& output_mem_config) { std::vector grad_tensor; Tensor neg_in = ttnn::neg(input, output_mem_config); Tensor in_rsqrt = @@ -534,7 +534,7 @@ std::vector _acos_bw(const Tensor& grad, const Tensor& input, const std: return grad_tensor; } -std::vector _atan_bw(const Tensor& grad, const Tensor& input, const std::optional& output_mem_config) { +std::vector ExecuteUnaryBackwardAtan::invoke(const Tensor& grad, const Tensor& input, const std::optional& output_mem_config) { std::vector grad_tensor; using ttnn::operations::unary::UnaryWithParam; using ttnn::operations::unary::UnaryOpType; @@ -547,7 +547,7 @@ std::vector _atan_bw(const Tensor& grad, const Tensor& input, const std: return grad_tensor; } -std::vector _rad2deg_bw(const Tensor& grad, const Tensor& input, const std::optional& output_mem_config) { +std::vector ExecuteUnaryBackwardRad2deg::invoke(const Tensor& grad, const Tensor& input, const std::optional& output_mem_config) { std::vector grad_tensor; float M_180_PI = 180 / M_PI; Tensor grad_result = ttnn::multiply(grad, M_180_PI, std::nullopt, output_mem_config); @@ -584,7 +584,7 @@ std::vector ExecuteUnaryBackwardLogit::invoke(const Tensor& grad, const } // square // result: 2 * input * grad_data -std::vector _square_bw(const Tensor& grad, const Tensor& input, const std::optional& output_mem_config) { +std::vector ExecuteUnaryBackwardSquare::invoke(const Tensor& grad, const Tensor& input, const std::optional& output_mem_config) { std::vector grad_tensor; Tensor grad_result = ttnn::multiply(ttnn::multiply(grad, 2.0f, std::nullopt, output_mem_config), input, std::nullopt, output_mem_config); grad_tensor.emplace_back(grad_result); @@ -715,7 +715,7 @@ std::vector ExecuteUnaryBackwardLog::invoke(const Tensor& grad, const Te return grad_tensor; } -std::vector _relu6_bw(const Tensor& grad, const Tensor& input, const std::optional& output_mem_config) { +std::vector ExecuteUnaryBackwardRelu6::invoke(const Tensor& grad, const Tensor& input, const std::optional& output_mem_config) { std::vector grad_tensor; Tensor zero_tensor = ttnn::zeros_like(input); Tensor one_tensor = ttnn::ones_like(input); @@ -772,7 +772,7 @@ std::vector> ExecuteUnaryBackwardSilu::invoke(const Tensor // Selu // result: torch.where(input > 0, grad * lambd, grad * lambd * alpha * torch.exp(input)) -std::vector _selu_bw(const Tensor& grad, const Tensor& input, const std::optional& output_mem_config) { +std::vector ExecuteUnaryBackwardSelu::invoke(const Tensor& grad, const Tensor& input, const std::optional& output_mem_config) { std::vector grad_tensor; Tensor grad_lambd = ttnn::multiply(grad, 1.0507f, std::nullopt, output_mem_config); Tensor grad_result = where( diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.hpp index 403252fa325..430e5bdb28b 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.hpp @@ -17,34 +17,9 @@ enum class UnaryBackwardOpType { ADD_BW, EQ_BW, GT_BW, - ACOS_BW, - ATAN_BW, - RAD2DEG_BW, SUB_BW, - FRAC_BW, - TRUNC_BW, - LOG_SIGMOID_BW, - FILL_ZERO_BW, - I0_BW, - TAN_BW, - RELU6_BW, - SELU_BW, - SQUARE_BW, }; -std::vector _acos_bw( const Tensor& grad, const Tensor& input, const std::optional& output_mem_config); -std::vector _atan_bw( const Tensor& grad, const Tensor& input, const std::optional& output_mem_config); -std::vector _rad2deg_bw( const Tensor& grad, const Tensor& input, const std::optional& output_mem_config); -std::vector _frac_bw( const Tensor& grad, const Tensor& input, const std::optional& output_mem_config); -std::vector _trunc_bw( const Tensor& grad, const Tensor& input, const std::optional& output_mem_config); -std::vector _log_sigmoid_bw( const Tensor& grad, const Tensor& input, const std::optional& output_mem_config); -std::vector _fill_zero_bw( const Tensor& grad, const Tensor& input, const std::optional& output_mem_config); -std::vector _i0_bw( const Tensor& grad, const Tensor& input, const std::optional& output_mem_config); -std::vector _tan_bw( const Tensor& grad, const Tensor& input, const std::optional& output_mem_config); -std::vector _relu6_bw(const Tensor& grad, const Tensor& input, const std::optional& output_mem_config); -std::vector _selu_bw(const Tensor& grad, const Tensor& input, const std::optional& output_mem_config); -std::vector _square_bw(const Tensor& grad, const Tensor& input, const std::optional& output_mem_config); - std::vector _sub_bw( const Tensor& grad, const Tensor& input, float scalar, const std::optional& output_mem_config); std::vector _gt_bw( const Tensor& grad, const Tensor& input, float other, const std::optional& output_mem_config); @@ -57,90 +32,6 @@ Tensor change_layout_to_tile(const Tensor& temp, const MemoryConfig& output_mem_ template struct OpHandler; -template <> -struct OpHandler { - static std::vector handle( const Tensor& grad, const Tensor& input, const std::optional& output_mem_config ) { - return _acos_bw(grad, input, output_mem_config); - } -}; - -template <> -struct OpHandler { - static std::vector handle(const Tensor& grad, const Tensor& input, const std::optional& output_mem_config) { - return _atan_bw(grad, input, output_mem_config); - } -}; - -template <> -struct OpHandler { - static std::vector handle(const Tensor& grad, const Tensor& input, const std::optional& output_mem_config) { - return _rad2deg_bw(grad, input, output_mem_config); - } -}; - -template <> -struct OpHandler { - static std::vector handle(const Tensor& grad, const Tensor& input, const std::optional& output_mem_config) { - return _frac_bw(grad, input, output_mem_config); - } -}; - -template <> -struct OpHandler { - static std::vector handle(const Tensor& grad, const Tensor& input, const std::optional& output_mem_config) { - return _tan_bw(grad, input, output_mem_config); - } -}; - -template <> -struct OpHandler { - static std::vector handle(const Tensor& grad, const Tensor& input, const std::optional& output_mem_config) { - return _trunc_bw(grad, input, output_mem_config); - } -}; - -template <> -struct OpHandler { - static std::vector handle(const Tensor& grad, const Tensor& input, const std::optional& output_mem_config) { - return _log_sigmoid_bw(grad, input, output_mem_config); - } -}; - -template <> -struct OpHandler { - static std::vector handle(const Tensor& grad, const Tensor& input, const std::optional& output_mem_config) { - return _fill_zero_bw(grad, input, output_mem_config); - } -}; - -template <> -struct OpHandler { - static std::vector handle(const Tensor& grad, const Tensor& input, const std::optional& output_mem_config) { - return _i0_bw(grad, input, output_mem_config); - } -}; - -template <> -struct OpHandler { - static std::vector handle(const Tensor& grad, const Tensor& input, const std::optional& output_mem_config) { - return _relu6_bw(grad, input, output_mem_config); - } -}; - -template <> -struct OpHandler { - static std::vector handle(const Tensor& grad, const Tensor& input, const std::optional& output_mem_config) { - return _selu_bw(grad, input, output_mem_config); - } -}; - -template <> -struct OpHandler { - static std::vector handle(const Tensor& grad, const Tensor& input, const std::optional& output_mem_config) { - return _square_bw(grad, input, output_mem_config); - } -}; - template <> struct OpHandler { static std::vector handle( const Tensor& grad, const Tensor& input, float other, const std::optional& output_mem_config ) { diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.hpp index e8482ada8f1..e9c60f2880d 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.hpp @@ -149,14 +149,6 @@ struct ExecuteUnaryBackwardHardtanh { const std::optional &memory_config = std::nullopt); }; -struct ExecuteUnaryBackward { - static std::vector invoke( - const Tensor &grad_tensor_arg, - const Tensor &input_tensor_arg, - float parameter_a, - const std::optional &memory_config = std::nullopt); -}; - struct ExecuteUnaryBackwardHardshrink { static std::vector invoke( const Tensor &grad_tensor_arg, @@ -205,15 +197,88 @@ struct ExecuteUnaryBackwardLogiteps { const std::optional &memory_config = std::nullopt); }; -template -struct ExecuteUnaryBackwardOp { +struct ExecuteUnaryBackwardTan { static std::vector invoke( const Tensor &grad_tensor_arg, const Tensor &input_tensor_arg, - const std::optional &memory_config = std::nullopt) { - auto output_memory_config = memory_config.value_or(input_tensor_arg.memory_config()); - return OpHandler::handle(grad_tensor_arg, input_tensor_arg, output_memory_config); - } + const std::optional &memory_config = std::nullopt); +}; + +struct ExecuteUnaryBackwardSquare { + static std::vector invoke( + const Tensor &grad_tensor_arg, + const Tensor &input_tensor_arg, + const std::optional &memory_config = std::nullopt); +}; + +struct ExecuteUnaryBackwardSelu { + static std::vector invoke( + const Tensor &grad_tensor_arg, + const Tensor &input_tensor_arg, + const std::optional &memory_config = std::nullopt); +}; + +struct ExecuteUnaryBackwardRelu6 { + static std::vector invoke( + const Tensor &grad_tensor_arg, + const Tensor &input_tensor_arg, + const std::optional &memory_config = std::nullopt); +}; + +struct ExecuteUnaryBackwardI0 { + static std::vector invoke( + const Tensor &grad_tensor_arg, + const Tensor &input_tensor_arg, + const std::optional &memory_config = std::nullopt); +}; + +struct ExecuteUnaryBackwardFillZero { + static std::vector invoke( + const Tensor &grad_tensor_arg, + const Tensor &input_tensor_arg, + const std::optional &memory_config = std::nullopt); +}; + +struct ExecuteUnaryBackwardLogSigmoid { + static std::vector invoke( + const Tensor &grad_tensor_arg, + const Tensor &input_tensor_arg, + const std::optional &memory_config = std::nullopt); +}; + +struct ExecuteUnaryBackwardTrunc { + static std::vector invoke( + const Tensor &grad_tensor_arg, + const Tensor &input_tensor_arg, + const std::optional &memory_config = std::nullopt); +}; + +struct ExecuteUnaryBackwardFrac { + static std::vector invoke( + const Tensor &grad_tensor_arg, + const Tensor &input_tensor_arg, + const std::optional &memory_config = std::nullopt); +}; + +struct ExecuteUnaryBackwardRad2deg { + static std::vector invoke( + const Tensor &grad_tensor_arg, + const Tensor &input_tensor_arg, + const std::optional &memory_config = std::nullopt); +}; + +struct ExecuteUnaryBackwardAtan { + static std::vector invoke( + const Tensor &grad_tensor_arg, + const Tensor &input_tensor_arg, + const std::optional &memory_config = std::nullopt); +}; + +struct ExecuteUnaryBackwardAcos { + static std::vector invoke( + const Tensor &grad_tensor_arg, + const Tensor &input_tensor_arg, + const std::optional &memory_config = std::nullopt); }; struct ExecuteUnaryBackwardErfc { @@ -403,18 +468,6 @@ struct ExecuteUnaryBackwardRdiv { const std::optional &memory_config = std::nullopt); }; -template -struct ExecuteUnaryBackwardStringDefault { - static std::vector invoke( - const Tensor &grad_tensor_arg, - const Tensor &input_tensor_arg, - string parameter_a, - const std::optional &memory_config = std::nullopt) { - auto output_memory_config = memory_config.value_or(input_tensor_arg.memory_config()); - return OpHandler::handle(grad_tensor_arg, input_tensor_arg, parameter_a, output_memory_config); - } -}; - struct ExecuteUnaryBackwardRepeat { static std::vector invoke( const Tensor &grad_tensor_arg, @@ -571,66 +624,18 @@ struct ExecuteUnaryBackwardGelu{ } // operations::unary -constexpr auto acos_bw = ttnn::register_operation< - "ttnn::acos_bw", - operations::unary_backward::ExecuteUnaryBackwardOp< - operations::unary_backward::UnaryBackwardOpType::ACOS_BW>>(); - -constexpr auto atan_bw = ttnn::register_operation< - "ttnn::atan_bw", - operations::unary_backward::ExecuteUnaryBackwardOp< - operations::unary_backward::UnaryBackwardOpType::ATAN_BW>>(); - -constexpr auto rad2deg_bw = ttnn::register_operation< - "ttnn::rad2deg_bw", - operations::unary_backward::ExecuteUnaryBackwardOp< - operations::unary_backward::UnaryBackwardOpType::RAD2DEG_BW>>(); - -constexpr auto frac_bw = ttnn::register_operation< - "ttnn::frac_bw", - operations::unary_backward::ExecuteUnaryBackwardOp< - operations::unary_backward::UnaryBackwardOpType::FRAC_BW>>(); - -constexpr auto trunc_bw = ttnn::register_operation< - "ttnn::trunc_bw", - operations::unary_backward::ExecuteUnaryBackwardOp< - operations::unary_backward::UnaryBackwardOpType::TRUNC_BW>>(); - -constexpr auto log_sigmoid_bw = ttnn::register_operation< - "ttnn::log_sigmoid_bw", - operations::unary_backward::ExecuteUnaryBackwardOp< - operations::unary_backward::UnaryBackwardOpType::LOG_SIGMOID_BW>>(); - -constexpr auto fill_zero_bw = ttnn::register_operation< - "ttnn::fill_zero_bw", - operations::unary_backward::ExecuteUnaryBackwardOp< - operations::unary_backward::UnaryBackwardOpType::FILL_ZERO_BW>>(); - -constexpr auto i0_bw = ttnn::register_operation< - "ttnn::i0_bw", - operations::unary_backward::ExecuteUnaryBackwardOp< - operations::unary_backward::UnaryBackwardOpType::I0_BW>>(); - -constexpr auto relu6_bw = ttnn::register_operation< - "ttnn::relu6_bw", - operations::unary_backward::ExecuteUnaryBackwardOp< - operations::unary_backward::UnaryBackwardOpType::RELU6_BW>>(); - -constexpr auto selu_bw = ttnn::register_operation< - "ttnn::selu_bw", - operations::unary_backward::ExecuteUnaryBackwardOp< - operations::unary_backward::UnaryBackwardOpType::SELU_BW>>(); - -constexpr auto square_bw = ttnn::register_operation< - "ttnn::square_bw", - operations::unary_backward::ExecuteUnaryBackwardOp< - operations::unary_backward::UnaryBackwardOpType::SQUARE_BW>>(); - -constexpr auto tan_bw = ttnn::register_operation< - "ttnn::tan_bw", - operations::unary_backward::ExecuteUnaryBackwardOp< - operations::unary_backward::UnaryBackwardOpType::TAN_BW>>(); - +constexpr auto acos_bw = ttnn::register_operation<"ttnn::acos_bw", operations::unary_backward::ExecuteUnaryBackwardAcos>(); +constexpr auto atan_bw = ttnn::register_operation<"ttnn::atan_bw", operations::unary_backward::ExecuteUnaryBackwardAtan>(); +constexpr auto rad2deg_bw = ttnn::register_operation<"ttnn::rad2deg_bw", operations::unary_backward::ExecuteUnaryBackwardRad2deg>(); +constexpr auto frac_bw = ttnn::register_operation<"ttnn::frac_bw", operations::unary_backward::ExecuteUnaryBackwardFrac>(); +constexpr auto trunc_bw = ttnn::register_operation<"ttnn::trunc_bw", operations::unary_backward::ExecuteUnaryBackwardTrunc>(); +constexpr auto log_sigmoid_bw = ttnn::register_operation<"ttnn::log_sigmoid_bw", operations::unary_backward::ExecuteUnaryBackwardLogSigmoid>(); +constexpr auto fill_zero_bw = ttnn::register_operation<"ttnn::fill_zero_bw", operations::unary_backward::ExecuteUnaryBackwardFillZero>(); +constexpr auto i0_bw = ttnn::register_operation<"ttnn::i0_bw", operations::unary_backward::ExecuteUnaryBackwardI0>(); +constexpr auto relu6_bw = ttnn::register_operation<"ttnn::relu6_bw", operations::unary_backward::ExecuteUnaryBackwardRelu6>(); +constexpr auto selu_bw = ttnn::register_operation<"ttnn::selu_bw", operations::unary_backward::ExecuteUnaryBackwardSelu>(); +constexpr auto square_bw = ttnn::register_operation<"ttnn::square_bw", operations::unary_backward::ExecuteUnaryBackwardSquare>(); +constexpr auto tan_bw = ttnn::register_operation<"ttnn::tan_bw", operations::unary_backward::ExecuteUnaryBackwardTan>(); constexpr auto sigmoid_bw = ttnn::register_operation<"ttnn::sigmoid_bw", operations::unary_backward::ExecuteUnaryBackwardSigmoid>(); constexpr auto ceil_bw = ttnn::register_operation<"ttnn::ceil_bw", operations::unary_backward::ExecuteUnaryBackwardCeil>(); constexpr auto softsign_bw = ttnn::register_operation<"ttnn::softsign_bw", operations::unary_backward::ExecuteUnaryBackwardSoftsign>();