Skip to content

Commit

Permalink
#9874: Remove clamp_max_bw
Browse files Browse the repository at this point in the history
  • Loading branch information
VirdhatchaniKN committed Jul 18, 2024
1 parent eb29903 commit 291f12d
Show file tree
Hide file tree
Showing 8 changed files with 6 additions and 63 deletions.
1 change: 0 additions & 1 deletion docs/source/ttnn/ttnn/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,6 @@ Pointwise Unary
ttnn/threshold
ttnn/mul_bw
ttnn/clamp_min_bw
ttnn/clamp_max_bw
ttnn/clamp_bw
ttnn/hardtanh_bw
ttnn/threshold_bw
Expand Down
6 changes: 0 additions & 6 deletions docs/source/ttnn/ttnn/ttnn/clamp_max_bw.rst

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -38,24 +38,20 @@ std::vector<Tensor> _clamp_min_bw(
return grad_tensor;
}

std::vector<Tensor> _clamp_max_bw(
const Tensor& grad, const Tensor& input, float max, const MemoryConfig& output_mem_config) {
std::vector<Tensor> grad_tensor;
Tensor maxT = ttnn::le(input, max, std::nullopt, output_mem_config);
Tensor result = ttnn::multiply(grad, maxT, std::nullopt, output_mem_config);
grad_tensor.emplace_back(result);
return grad_tensor;
}


std::vector<Tensor> _clamp_bw(
const Tensor& grad, const Tensor& input, std::optional<float> min, std::optional<float> max, const std::optional<MemoryConfig>& output_mem_config) {
std::vector<Tensor> grad_tensor;
auto output_memory_config = output_mem_config.value_or(input.memory_config()); //TODO: Remove after ternary forward ops migration is completed
TT_FATAL((max.has_value() || min.has_value()) && "Only one of 'min' or 'max' can be None. Please provided atleast one value");
TT_FATAL((max.has_value() || min.has_value()) && "Only one of 'min' or 'max' can be None. Please provide atleast one value");
if (!max.has_value()) {
return _clamp_min_bw( grad, input, min.value(), output_memory_config);
}else if(!min.has_value()) {
return _clamp_max_bw( grad, input, max.value(), output_memory_config);
Tensor maxT = ttnn::le(input, max.value(), std::nullopt, output_mem_config);
Tensor result = ttnn::multiply(grad, maxT, std::nullopt, output_mem_config);
grad_tensor.emplace_back(result);
return grad_tensor;
}
Tensor minT = ttnn::ge(input, min.value(), std::nullopt, output_memory_config);
Tensor maxT = ttnn::le(input, max.value(), std::nullopt, output_memory_config);
Expand Down Expand Up @@ -1732,8 +1728,6 @@ std::function<std::vector<ttnn::Tensor>(const Tensor&, const Tensor&, float, con
return _mul_bw;
case UnaryBackwardOpType::CLAMP_MIN_BW:
return _clamp_min_bw;
case UnaryBackwardOpType::CLAMP_MAX_BW:
return _clamp_max_bw;
case UnaryBackwardOpType::ADD_BW:
return _add_bw;
case UnaryBackwardOpType::EQ_BW:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ enum class UnaryBackwardOpType {
NEG_BW,
RELU_BW,
LOGIT_BW,
CLAMP_MAX_BW,
HARDSHRINK_BW,
SOFTSHRINK_BW,
LEAKY_RELU_BW,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,6 @@ constexpr auto rsqrt_bw = ttnn::register_operation<operations::unary_backward::E
constexpr auto neg_bw = ttnn::register_operation<operations::unary_backward::ExecuteUnaryBackward<operations::unary_backward::UnaryBackwardOpType::NEG_BW>>("ttnn::neg_bw");
constexpr auto relu_bw = ttnn::register_operation<operations::unary_backward::ExecuteUnaryBackward<operations::unary_backward::UnaryBackwardOpType::RELU_BW>>("ttnn::relu_bw");
constexpr auto logit_bw = ttnn::register_operation<operations::unary_backward::ExecuteUnaryBackward<operations::unary_backward::UnaryBackwardOpType::LOGIT_BW>>("ttnn::logit_bw");
constexpr auto clamp_max_bw = ttnn::register_operation<operations::unary_backward::ExecuteUnaryBackward<operations::unary_backward::UnaryBackwardOpType::CLAMP_MAX_BW>>("ttnn::clamp_max_bw");
constexpr auto hardshrink_bw = ttnn::register_operation<operations::unary_backward::ExecuteUnaryBackward<operations::unary_backward::UnaryBackwardOpType::HARDSHRINK_BW>>("ttnn::hardshrink_bw");
constexpr auto softshrink_bw = ttnn::register_operation<operations::unary_backward::ExecuteUnaryBackward<operations::unary_backward::UnaryBackwardOpType::SOFTSHRINK_BW>>("ttnn::softshrink_bw");
constexpr auto leaky_relu_bw = ttnn::register_operation<operations::unary_backward::ExecuteUnaryBackward<operations::unary_backward::UnaryBackwardOpType::LEAKY_RELU_BW>>("ttnn::leaky_relu_bw");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -901,11 +901,6 @@ void py_module(py::module& module) {
ttnn::floor_bw,
R"doc(Performs backward operations for floor on :attr:`input_tensor` with given :attr:`grad_tensor`)doc");

detail::bind_unary_backward(
module,
ttnn::clamp_max_bw,
R"doc(Performs backward operations for clamp max value on :attr:`input_tensor`, :attr:`max` with given :attr:`grad_tensor`.)doc");

detail::bind_unary_backward(
module,
ttnn::hardshrink_bw,
Expand Down
2 changes: 0 additions & 2 deletions ttnn/ttnn/operations/unary_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def unary_bw_prod(torch_op, x, grad_data, *args, **kwargs):
name_to_golden_function = {
"mul_bw": lambda x, grad_data: unary_bw_with_float(torch.mul, x, grad_data),
"clamp_min_bw": lambda x, grad_data: unary_bw_with_float(torch.clamp_min, x, grad_data),
"clamp_max_bw": lambda x, grad_data: unary_bw_with_float(torch.clamp_max, x, grad_data),
"add_bw": lambda x, grad_data: unary_bw_with_float(torch.add, x, grad_data),
"eq_bw": lambda x, grad_data: unary_bw_with_float(torch.eq, x, grad_data),
"gt_bw": lambda x, grad_data: unary_bw_with_float(torch.gt, x, grad_data),
Expand Down Expand Up @@ -138,7 +137,6 @@ def _golden_function(input_tensor: ttnn.Tensor, **_):
TTNN_ELTWISE_UNARY_BACKWARD_CPP_FUNCTIONS = [
ttnn.mul_bw,
ttnn.clamp_min_bw,
ttnn.clamp_max_bw,
ttnn.add_bw,
ttnn.eq_bw,
ttnn.gt_bw,
Expand Down

0 comments on commit 291f12d

Please sign in to comment.