Skip to content

Commit

Permalink
#13242: Cleanup set-5 unary backward ops (#13243)
Browse files Browse the repository at this point in the history
  • Loading branch information
VirdhatchaniKN authored Sep 30, 2024
1 parent e3dd4b7 commit d8706ff
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 207 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -298,13 +298,13 @@ std::vector<Tensor> _sub_bw(const Tensor& grad, const Tensor& input, float alpha
return grad_tensor;
}

std::vector<Tensor> _frac_bw(const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config) {
std::vector<Tensor> ExecuteUnaryBackwardFrac::invoke(const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config) {
std::vector<Tensor> grad_tensor;
grad_tensor.emplace_back(grad);
return grad_tensor;
}

std::vector<Tensor> _trunc_bw(const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config) {
std::vector<Tensor> ExecuteUnaryBackwardTrunc::invoke(const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config) {
std::vector<Tensor> grad_tensor;
Tensor grad_result = ttnn::zeros_like(grad, grad.get_dtype(), grad.get_layout(), std::nullopt, output_mem_config);
grad_tensor.emplace_back(grad_result);
Expand All @@ -313,7 +313,7 @@ std::vector<Tensor> _trunc_bw(const Tensor& grad, const Tensor& input, const std

// return: grad_output * (max_deriv - sign * (z / (1 + z)))
// z = exp(-abs(input))
std::vector<Tensor> _log_sigmoid_bw(const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config) {
std::vector<Tensor> ExecuteUnaryBackwardLogSigmoid::invoke(const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config) {
std::vector<Tensor> grad_tensor;
Tensor max_deriv = ttnn::where(ttnn::ltz(input, output_mem_config), 1, 0, output_mem_config);
Tensor in_sign = ttnn::where(ttnn::ltz(input, output_mem_config), 1, -1, output_mem_config);
Expand All @@ -330,14 +330,14 @@ std::vector<Tensor> _log_sigmoid_bw(const Tensor& grad, const Tensor& input, con
return grad_tensor;
}

std::vector<Tensor> _fill_zero_bw(const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config) {
std::vector<Tensor> ExecuteUnaryBackwardFillZero::invoke(const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config) {
std::vector<Tensor> grad_tensor;
Tensor result = ttnn::zeros_like(grad, grad.get_dtype(), grad.get_layout(), std::nullopt, output_mem_config);
grad_tensor.emplace_back(result);
return grad_tensor;
}

std::vector<Tensor> _i0_bw(const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config) {
std::vector<Tensor> ExecuteUnaryBackwardI0::invoke(const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config) {
std::vector<Tensor> grad_tensor;
float t_inf = std::numeric_limits<float>::infinity();
Tensor value = ttnn::multiply(
Expand Down Expand Up @@ -367,7 +367,7 @@ std::vector<Tensor> _i0_bw(const Tensor& grad, const Tensor& input, const std::o
return grad_tensor;
}

std::vector<Tensor> _tan_bw(const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config) {
std::vector<Tensor> ExecuteUnaryBackwardTan::invoke(const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config) {
std::vector<Tensor> grad_tensor;
Tensor tan_result = ttnn::tan(input, output_mem_config);
Tensor result =
Expand Down Expand Up @@ -506,7 +506,7 @@ std::vector<Tensor> ExecuteUnaryBackwardAcosh::invoke(const Tensor& grad, const

// # - name: acos(Tensor self) -> Tensor
// # self: grad * -((-self * self + 1).rsqrt())
std::vector<Tensor> _acos_bw(const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config) {
std::vector<Tensor> ExecuteUnaryBackwardAcos::invoke(const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config) {
std::vector<Tensor> grad_tensor;
Tensor neg_in = ttnn::neg(input, output_mem_config);
Tensor in_rsqrt =
Expand Down Expand Up @@ -534,7 +534,7 @@ std::vector<Tensor> _acos_bw(const Tensor& grad, const Tensor& input, const std:
return grad_tensor;
}

std::vector<Tensor> _atan_bw(const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config) {
std::vector<Tensor> ExecuteUnaryBackwardAtan::invoke(const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config) {
std::vector<Tensor> grad_tensor;
using ttnn::operations::unary::UnaryWithParam;
using ttnn::operations::unary::UnaryOpType;
Expand All @@ -547,7 +547,7 @@ std::vector<Tensor> _atan_bw(const Tensor& grad, const Tensor& input, const std:
return grad_tensor;
}

std::vector<Tensor> _rad2deg_bw(const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config) {
std::vector<Tensor> ExecuteUnaryBackwardRad2deg::invoke(const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config) {
std::vector<Tensor> grad_tensor;
float M_180_PI = 180 / M_PI;
Tensor grad_result = ttnn::multiply(grad, M_180_PI, std::nullopt, output_mem_config);
Expand Down Expand Up @@ -584,7 +584,7 @@ std::vector<Tensor> ExecuteUnaryBackwardLogit::invoke(const Tensor& grad, const
}
// square
// result: 2 * input * grad_data
std::vector<Tensor> _square_bw(const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config) {
std::vector<Tensor> ExecuteUnaryBackwardSquare::invoke(const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config) {
std::vector<Tensor> grad_tensor;
Tensor grad_result = ttnn::multiply(ttnn::multiply(grad, 2.0f, std::nullopt, output_mem_config), input, std::nullopt, output_mem_config);
grad_tensor.emplace_back(grad_result);
Expand Down Expand Up @@ -715,7 +715,7 @@ std::vector<Tensor> ExecuteUnaryBackwardLog::invoke(const Tensor& grad, const Te
return grad_tensor;
}

std::vector<Tensor> _relu6_bw(const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config) {
std::vector<Tensor> ExecuteUnaryBackwardRelu6::invoke(const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config) {
std::vector<Tensor> grad_tensor;
Tensor zero_tensor = ttnn::zeros_like(input);
Tensor one_tensor = ttnn::ones_like(input);
Expand Down Expand Up @@ -772,7 +772,7 @@ std::vector<std::optional<Tensor>> ExecuteUnaryBackwardSilu::invoke(const Tensor

// Selu
// result: torch.where(input > 0, grad * lambd, grad * lambd * alpha * torch.exp(input))
std::vector<Tensor> _selu_bw(const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config) {
std::vector<Tensor> ExecuteUnaryBackwardSelu::invoke(const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config) {
std::vector<Tensor> grad_tensor;
Tensor grad_lambd = ttnn::multiply(grad, 1.0507f, std::nullopt, output_mem_config);
Tensor grad_result = where(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,34 +17,9 @@ enum class UnaryBackwardOpType {
ADD_BW,
EQ_BW,
GT_BW,
ACOS_BW,
ATAN_BW,
RAD2DEG_BW,
SUB_BW,
FRAC_BW,
TRUNC_BW,
LOG_SIGMOID_BW,
FILL_ZERO_BW,
I0_BW,
TAN_BW,
RELU6_BW,
SELU_BW,
SQUARE_BW,
};

std::vector<Tensor> _acos_bw( const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config);
std::vector<Tensor> _atan_bw( const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config);
std::vector<Tensor> _rad2deg_bw( const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config);
std::vector<Tensor> _frac_bw( const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config);
std::vector<Tensor> _trunc_bw( const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config);
std::vector<Tensor> _log_sigmoid_bw( const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config);
std::vector<Tensor> _fill_zero_bw( const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config);
std::vector<Tensor> _i0_bw( const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config);
std::vector<Tensor> _tan_bw( const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config);
std::vector<Tensor> _relu6_bw(const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config);
std::vector<Tensor> _selu_bw(const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config);
std::vector<Tensor> _square_bw(const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config);

std::vector<Tensor> _sub_bw( const Tensor& grad, const Tensor& input, float scalar, const std::optional<MemoryConfig>& output_mem_config);
std::vector<Tensor> _gt_bw( const Tensor& grad, const Tensor& input, float other, const std::optional<MemoryConfig>& output_mem_config);

Expand All @@ -57,90 +32,6 @@ Tensor change_layout_to_tile(const Tensor& temp, const MemoryConfig& output_mem_
template <UnaryBackwardOpType OpType>
struct OpHandler;

template <>
struct OpHandler<UnaryBackwardOpType::ACOS_BW> {
static std::vector<Tensor> handle( const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config ) {
return _acos_bw(grad, input, output_mem_config);
}
};

template <>
struct OpHandler<UnaryBackwardOpType::ATAN_BW> {
static std::vector<Tensor> handle(const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config) {
return _atan_bw(grad, input, output_mem_config);
}
};

template <>
struct OpHandler<UnaryBackwardOpType::RAD2DEG_BW> {
static std::vector<Tensor> handle(const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config) {
return _rad2deg_bw(grad, input, output_mem_config);
}
};

template <>
struct OpHandler<UnaryBackwardOpType::FRAC_BW> {
static std::vector<Tensor> handle(const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config) {
return _frac_bw(grad, input, output_mem_config);
}
};

template <>
struct OpHandler<UnaryBackwardOpType::TAN_BW> {
static std::vector<Tensor> handle(const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config) {
return _tan_bw(grad, input, output_mem_config);
}
};

template <>
struct OpHandler<UnaryBackwardOpType::TRUNC_BW> {
static std::vector<Tensor> handle(const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config) {
return _trunc_bw(grad, input, output_mem_config);
}
};

template <>
struct OpHandler<UnaryBackwardOpType::LOG_SIGMOID_BW> {
static std::vector<Tensor> handle(const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config) {
return _log_sigmoid_bw(grad, input, output_mem_config);
}
};

template <>
struct OpHandler<UnaryBackwardOpType::FILL_ZERO_BW> {
static std::vector<Tensor> handle(const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config) {
return _fill_zero_bw(grad, input, output_mem_config);
}
};

template <>
struct OpHandler<UnaryBackwardOpType::I0_BW> {
static std::vector<Tensor> handle(const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config) {
return _i0_bw(grad, input, output_mem_config);
}
};

template <>
struct OpHandler<UnaryBackwardOpType::RELU6_BW> {
static std::vector<Tensor> handle(const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config) {
return _relu6_bw(grad, input, output_mem_config);
}
};

template <>
struct OpHandler<UnaryBackwardOpType::SELU_BW> {
static std::vector<Tensor> handle(const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config) {
return _selu_bw(grad, input, output_mem_config);
}
};

template <>
struct OpHandler<UnaryBackwardOpType::SQUARE_BW> {
static std::vector<Tensor> handle(const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config) {
return _square_bw(grad, input, output_mem_config);
}
};

template <>
struct OpHandler<UnaryBackwardOpType::GT_BW> {
static std::vector<Tensor> handle( const Tensor& grad, const Tensor& input, float other, const std::optional<MemoryConfig>& output_mem_config ) {
Expand Down
Loading

0 comments on commit d8706ff

Please sign in to comment.