Skip to content

Commit

Permalink
#9874: Remove clamp_min_bw
Browse files Browse the repository at this point in the history
  • Loading branch information
VirdhatchaniKN committed Jul 18, 2024
1 parent 291f12d commit bafed35
Show file tree
Hide file tree
Showing 8 changed files with 4 additions and 65 deletions.
1 change: 0 additions & 1 deletion docs/source/ttnn/ttnn/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,6 @@ Pointwise Unary
ttnn/tanhshrink
ttnn/threshold
ttnn/mul_bw
ttnn/clamp_min_bw
ttnn/clamp_bw
ttnn/hardtanh_bw
ttnn/threshold_bw
Expand Down
6 changes: 0 additions & 6 deletions docs/source/ttnn/ttnn/ttnn/clamp_min_bw.rst

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -29,24 +29,16 @@ std::vector<ttnn::Tensor> _mul_bw(
return grad_tensor;
}

std::vector<Tensor> _clamp_min_bw(
const Tensor& grad, const Tensor& input, float min, const MemoryConfig& output_mem_config) {
std::vector<Tensor> grad_tensor;
Tensor minT = ttnn::ge(input, min, std::nullopt, output_mem_config);
Tensor result = ttnn::multiply(grad, minT, std::nullopt, output_mem_config);
grad_tensor.emplace_back(result);
return grad_tensor;
}



std::vector<Tensor> _clamp_bw(
const Tensor& grad, const Tensor& input, std::optional<float> min, std::optional<float> max, const std::optional<MemoryConfig>& output_mem_config) {
std::vector<Tensor> grad_tensor;
auto output_memory_config = output_mem_config.value_or(input.memory_config()); //TODO: Remove after ternary forward ops migration is completed
TT_FATAL((max.has_value() || min.has_value()) && "Only one of 'min' or 'max' can be None. Please provide atleast one value");
if (!max.has_value()) {
return _clamp_min_bw( grad, input, min.value(), output_memory_config);
Tensor minT = ttnn::ge(input, min.value(), std::nullopt, output_mem_config);
Tensor result = ttnn::multiply(grad, minT, std::nullopt, output_mem_config);
grad_tensor.emplace_back(result);
return grad_tensor;
}else if(!min.has_value()) {
Tensor maxT = ttnn::le(input, max.value(), std::nullopt, output_mem_config);
Tensor result = ttnn::multiply(grad, maxT, std::nullopt, output_mem_config);
Expand Down Expand Up @@ -1726,8 +1718,6 @@ std::function<std::vector<ttnn::Tensor>(const Tensor&, const Tensor&, float, con
switch (OpType) {
case UnaryBackwardOpType::MUL_BW:
return _mul_bw;
case UnaryBackwardOpType::CLAMP_MIN_BW:
return _clamp_min_bw;
case UnaryBackwardOpType::ADD_BW:
return _add_bw;
case UnaryBackwardOpType::EQ_BW:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ namespace ttnn::operations::unary_backward {
constexpr uint8_t DefaultQueueId = 0;
enum class UnaryBackwardOpType {
MUL_BW,
CLAMP_MIN_BW,
CLAMP_BW,
HARDTANH_BW,
THRESHOLD_BW,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,6 @@ constexpr auto sqrt_bw = ttnn::register_operation<operations::unary_backward::Ex
constexpr auto prod_bw = ttnn::register_operation<operations::unary_backward::ExecuteUnaryBackwardProdBW<operations::unary_backward::UnaryBackwardOpType::PROD_BW>>("ttnn::prod_bw");

constexpr auto mul_bw = ttnn::register_operation<operations::unary_backward::ExecuteUnaryBackward<operations::unary_backward::UnaryBackwardOpType::MUL_BW>>("ttnn::mul_bw");
constexpr auto clamp_min_bw = ttnn::register_operation<operations::unary_backward::ExecuteUnaryBackward<operations::unary_backward::UnaryBackwardOpType::CLAMP_MIN_BW>>("ttnn::clamp_min_bw");
constexpr auto assign_bw = ttnn::register_operation<operations::unary_backward::ExecuteUnaryBackward<operations::unary_backward::UnaryBackwardOpType::ASSIGN_BW>>("ttnn::assign_bw");
constexpr auto multigammaln_bw = ttnn::register_operation<operations::unary_backward::ExecuteUnaryBackward<operations::unary_backward::UnaryBackwardOpType::MULTIGAMMALN_BW>>("ttnn::multigammaln_bw");
constexpr auto add_bw = ttnn::register_operation<operations::unary_backward::ExecuteUnaryBackward<operations::unary_backward::UnaryBackwardOpType::ADD_BW>>("ttnn::add_bw");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -651,11 +651,6 @@ void py_module(py::module& module) {
ttnn::mul_bw,
R"doc(Performs backward operations for multiply on :attr:`input_tensor`, :attr:`alpha` or attr:`input_tensor_a`, attr:`input_tensor_b`, with given :attr:`grad_tensor`.)doc");

detail::bind_unary_backward(
module,
ttnn::clamp_min_bw,
R"doc(Performs backward operations for clamp min value on :attr:`input_tensor`, :attr:`alpha` with given :attr:`grad_tensor`.)doc");

detail::bind_unary_backward_optional_float_params_with_default(
module,
ttnn::clamp_bw,
Expand Down
2 changes: 0 additions & 2 deletions ttnn/ttnn/operations/unary_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def unary_bw_prod(torch_op, x, grad_data, *args, **kwargs):

name_to_golden_function = {
"mul_bw": lambda x, grad_data: unary_bw_with_float(torch.mul, x, grad_data),
"clamp_min_bw": lambda x, grad_data: unary_bw_with_float(torch.clamp_min, x, grad_data),
"add_bw": lambda x, grad_data: unary_bw_with_float(torch.add, x, grad_data),
"eq_bw": lambda x, grad_data: unary_bw_with_float(torch.eq, x, grad_data),
"gt_bw": lambda x, grad_data: unary_bw_with_float(torch.gt, x, grad_data),
Expand Down Expand Up @@ -136,7 +135,6 @@ def _golden_function(input_tensor: ttnn.Tensor, **_):

TTNN_ELTWISE_UNARY_BACKWARD_CPP_FUNCTIONS = [
ttnn.mul_bw,
ttnn.clamp_min_bw,
ttnn.add_bw,
ttnn.eq_bw,
ttnn.gt_bw,
Expand Down

0 comments on commit bafed35

Please sign in to comment.