Skip to content

Commit

Permalink
#9874: Merge clamp_min_bw to TTNN
Browse files Browse the repository at this point in the history
  • Loading branch information
VirdhatchaniKN committed Jul 4, 2024
1 parent e23e6fa commit a193ecc
Show file tree
Hide file tree
Showing 11 changed files with 27 additions and 39 deletions.
1 change: 1 addition & 0 deletions docs/source/ttnn/ttnn/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ Pointwise Unary
ttnn/tanhshrink
ttnn/threshold
ttnn/unary_mul_bw
ttnn/clamp_min_bw

Pointwise Binary
================
Expand Down
2 changes: 0 additions & 2 deletions docs/source/ttnn/ttnn/dependencies/tt_lib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -886,8 +886,6 @@ Backward Operations

.. autofunction:: tt_lib.tensor.clamp_bw

.. autofunction:: tt_lib.tensor.clamp_min_bw

.. autofunction:: tt_lib.tensor.clamp_max_bw

.. autofunction:: tt_lib.tensor.gelu_bw
Expand Down
6 changes: 6 additions & 0 deletions docs/source/ttnn/ttnn/ttnn/clamp_min_bw.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
.. _ttnn.clamp_min_bw:

ttnn.clamp_min_bw
#################

.. autofunction:: ttnn.clamp_min_bw
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch
import pytest
import tt_lib
import ttnn
from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import data_gen_with_range, compare_pcc


Expand All @@ -23,7 +24,7 @@ def test_bw_clamp_min(input_shapes, min, device):

pyt_y = torch.clamp(in_data, min=min)

tt_output_tensor_on_device = tt_lib.tensor.clamp_min_bw(grad_tensor, input_tensor, min)
tt_output_tensor_on_device = ttnn.clamp_min_bw(grad_tensor, input_tensor, min)

in_data.retain_grad()

Expand Down
13 changes: 0 additions & 13 deletions tt_eager/tt_dnn/op_library/backward/backward_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -555,19 +555,6 @@ std::vector<Tensor> clamp_bw(
return operation::decorate_as_composite(__func__, _clamp_bw)(grad, input, min, max, output_mem_config);
}

std::vector<Tensor> _clamp_min_bw(
const Tensor& grad, const Tensor& input, float min, const MemoryConfig& output_mem_config) {
std::vector<Tensor> grad_tensor;
Tensor minT = gte_unary(input, min, output_mem_config);
Tensor result = ttnn::multiply(grad, minT, std::nullopt, output_mem_config);
grad_tensor.emplace_back(result);
return grad_tensor;
}
std::vector<Tensor> clamp_min_bw(
const Tensor& grad, const Tensor& input, float min, const MemoryConfig& output_mem_config) {
return operation::decorate_as_composite(__func__, _clamp_min_bw)(grad, input, min, output_mem_config);
}

std::vector<Tensor> _clamp_max_bw(
const Tensor& grad, const Tensor& input, float max, const MemoryConfig& output_mem_config) {
std::vector<Tensor> grad_tensor;
Expand Down
6 changes: 0 additions & 6 deletions tt_eager/tt_dnn/op_library/backward/backward_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,12 +213,6 @@ std::vector<Tensor> clamp_bw(
float max,
const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);

std::vector<Tensor> clamp_min_bw(
const Tensor& grad,
const Tensor& input,
float min,
const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);

std::vector<Tensor> clamp_max_bw(
const Tensor& grad,
const Tensor& input,
Expand Down
17 changes: 0 additions & 17 deletions tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_backward_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -624,23 +624,6 @@ namespace tt::tt_metal::detail{
"output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No"
)doc");

m_tensor.def("clamp_min_bw", &tt::tt_metal::clamp_min_bw,
py::arg("grad").noconvert(), py::arg("input").noconvert(), py::arg("min").noconvert(), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc(
Performs backward operations for clamp min of ``input`` tensors with given ``grad``.
Input tensors must have BFLOAT16 data type.
Output tensor will have BFLOAT16 data type.
.. csv-table::
:header: "Argument", "Description", "Data type", "Valid range", "Required"
"grad", "Gradient tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes"
"input", "Input Tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes"
"min", "Minimum Value", "float", , "Yes"
"output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No"
)doc");

m_tensor.def("clamp_max_bw", &tt::tt_metal::clamp_max_bw,
py::arg("grad").noconvert(), py::arg("input").noconvert(), py::arg("max").noconvert(), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc(
Performs backward operations for clamp max of ``input`` tensors with given ``grad``.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,15 @@ std::vector<ttnn::Tensor> _unary_mul_bw(
return grad_tensor;
}

std::vector<Tensor> _clamp_min_bw(
const Tensor& grad, const Tensor& input, float min, const MemoryConfig& output_mem_config) {
std::vector<Tensor> grad_tensor;
Tensor minT = gte_unary(input, min, output_mem_config);
Tensor result = ttnn::multiply(grad, minT, std::nullopt, output_mem_config);
grad_tensor.emplace_back(result);
return grad_tensor;
}


std::function<std::vector<ttnn::Tensor>(const Tensor&, const Tensor&, const MemoryConfig&)> get_function_type1(UnaryBackwardOpType OpType){
switch (OpType) {
Expand All @@ -40,6 +49,8 @@ std::function<std::vector<ttnn::Tensor>(const Tensor&, const Tensor&, float, con
switch (OpType) {
case UnaryBackwardOpType::UNARY_MUL_BW:
return _unary_mul_bw;
case UnaryBackwardOpType::CLAMP_MIN_BW:
return _clamp_min_bw;
default:
TT_ASSERT(false && "Undefined op type");
return 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ namespace ttnn::operations::unary_backward {
constexpr uint8_t DefaultQueueId = 0;
enum class UnaryBackwardOpType {
UNARY_MUL_BW,
CLAMP_MIN_BW,
};


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,5 +80,6 @@ struct ExecuteUnaryBackward {

//type 1
constexpr auto unary_mul_bw = ttnn::register_operation<operations::unary_backward::ExecuteUnaryBackward<operations::unary_backward::UnaryBackwardOpType::UNARY_MUL_BW>>("ttnn::unary_mul_bw");
constexpr auto clamp_min_bw = ttnn::register_operation<operations::unary_backward::ExecuteUnaryBackward<operations::unary_backward::UnaryBackwardOpType::CLAMP_MIN_BW>>("ttnn::clamp_min_bw");

} // namespace ttnn
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,11 @@ void py_module(py::module& module) {
ttnn::unary_mul_bw,
R"doc(Performs backward operations for multiply on :attr:`input_tensor`, :attr:`alpha` with given :attr:`grad_tensor`.)doc");

detail::bind_unary_backward(
module,
ttnn::clamp_min_bw,
R"doc(Performs backward operations for clamp min value on :attr:`input_tensor`, :attr:`alpha` with given :attr:`grad_tensor`.)doc");

}

} // namespace binary_backward
Expand Down

0 comments on commit a193ecc

Please sign in to comment.