Skip to content

Commit

Permalink
#10071: Merge selu_bw to TTNN
Browse files Browse the repository at this point in the history
  • Loading branch information
mouliraj-mcw committed Jul 11, 2024
1 parent 3481d61 commit f049411
Show file tree
Hide file tree
Showing 11 changed files with 41 additions and 51 deletions.
1 change: 1 addition & 0 deletions docs/source/ttnn/ttnn/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ Pointwise Unary
ttnn/relu6_bw
ttnn/abs_bw
ttnn/silu_bw
ttnn/selu_bw

Pointwise Binary
================
Expand Down
2 changes: 0 additions & 2 deletions docs/source/ttnn/ttnn/dependencies/tt_lib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -880,8 +880,6 @@ Backward Operations

.. autofunction:: tt_lib.tensor.rpow_bw

.. autofunction:: tt_lib.tensor.selu_bw

.. autofunction:: tt_lib.tensor.square_bw

.. autofunction:: tt_lib.tensor.tanhshrink_bw
Expand Down
6 changes: 6 additions & 0 deletions docs/source/ttnn/ttnn/ttnn/selu_bw.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
.. _ttnn.selu_bw:

ttnn.selu_bw
#############

.. autofunction:: ttnn.selu_bw
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch
import pytest
import tt_lib
import ttnn
from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import data_gen_with_range, compare_pcc


Expand All @@ -22,7 +22,7 @@ def test_bw_selu(input_shapes, device):

pyt_y = torch.nn.functional.selu(in_data)

tt_output_tensor_on_device = tt_lib.tensor.selu_bw(grad_tensor, input_tensor)
tt_output_tensor_on_device = ttnn.selu_bw(grad_tensor, input_tensor)

in_data.retain_grad()

Expand Down
21 changes: 0 additions & 21 deletions tt_eager/tt_dnn/op_library/backward/backward_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1001,27 +1001,6 @@ std::vector<Tensor> rpow_bw(
return operation::decorate_as_composite(__func__, _rpow_bw)(grad, input, exponent, output_mem_config);
}

// Selu
// result: torch.where(input > 0, grad * lambd, grad * lambd * alpha * torch.exp(input))
std::vector<Tensor> _selu_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) {
std::vector<Tensor> grad_tensor;
Tensor grad_lambd = ttnn::multiply(grad, 1.0507f, std::nullopt, output_mem_config);
Tensor grad_result = where(
ttnn::gtz(input, output_mem_config),
grad_lambd,
ttnn::multiply(ttnn::multiply(grad_lambd, 1.673260f, std::nullopt, output_mem_config),
ttnn::exp(input, false, output_mem_config),
std::nullopt,
output_mem_config),
output_mem_config);
grad_tensor.emplace_back(grad_result);
return grad_tensor;
}
std::vector<Tensor> selu_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) {
return operation::decorate_as_composite(__func__, _selu_bw)(grad, input, output_mem_config);
}


// Autoformat support
Tensor change_layout_to_tile(const Tensor& temp, const MemoryConfig& output_mem_config) {
auto formatted_input_tensor = temp;
Expand Down
5 changes: 0 additions & 5 deletions tt_eager/tt_dnn/op_library/backward/backward_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -285,11 +285,6 @@ std::vector<Tensor> rpow_bw(
float exponent,
const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);

std::vector<Tensor> selu_bw(
const Tensor& grad,
const Tensor& input,
const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);

std::vector<Tensor> square_bw(
const Tensor& grad,
const Tensor& input,
Expand Down
16 changes: 0 additions & 16 deletions tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_backward_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -808,22 +808,6 @@ namespace tt::tt_metal::detail{
"output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No"
)doc");

m_tensor.def("selu_bw", &tt::tt_metal::selu_bw,
py::arg("grad").noconvert(), py::arg("input").noconvert(), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc(
Performs backward operations for selu sin of ``input`` tensors with given ``grad``.
Input tensors must have BFLOAT16 data type.
Output tensors will have BFLOAT16 data type.
.. csv-table::
:header: "Argument", "Description", "Data type", "Valid range", "Required"
"grad", "Gradient tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes"
"input", "Tensor selu_bw is applied to", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes"
"output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No"
)doc");

m_tensor.def("square_bw", &tt::tt_metal::square_bw,
py::arg("grad").noconvert(), py::arg("input").noconvert(), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc(
Performs backward square operations on ``input`` tensors with given ``grad``.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -422,19 +422,38 @@ std::vector<Tensor> _abs_bw(const Tensor& grad, const Tensor& input, const Memor
// result: grad * sigmoid_result * (1 + input * (1 - sigmoid_result))
std::vector<Tensor> _silu_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) {
std::vector<Tensor> grad_tensor;
Tensor grad_sigmoid = ttnn::multiply(grad, sigmoid(input, output_mem_config), std::nullopt, output_mem_config);
Tensor add_sub = add1(
ttnn::multiply(sub_unary(1.0f, sigmoid(input, output_mem_config), output_mem_config),
Tensor grad_sigmoid = ttnn::multiply(grad, ttnn::sigmoid(input, output_mem_config), std::nullopt, output_mem_config);
Tensor add_sub = ttnn::add(
ttnn::multiply(ttnn::subtract(tt::tt_metal::full_like(input, 1.0f) , ttnn::sigmoid(input, output_mem_config), std::nullopt, output_mem_config),
input,
std::nullopt,
output_mem_config),
1.0f,
std::nullopt,
output_mem_config);
Tensor grad_result = ttnn::multiply(grad_sigmoid, add_sub, std::nullopt, output_mem_config);

grad_tensor.emplace_back(grad_result);
return grad_tensor;
}

// Selu
// result: torch.where(input > 0, grad * lambd, grad * lambd * alpha * torch.exp(input))
std::vector<Tensor> _selu_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) {
std::vector<Tensor> grad_tensor;
Tensor grad_lambd = ttnn::multiply(grad, 1.0507f, std::nullopt, output_mem_config);
Tensor grad_result = where(
ttnn::gtz(input, output_mem_config),
grad_lambd,
ttnn::multiply(ttnn::multiply(grad_lambd, 1.673260f, std::nullopt, output_mem_config),
ttnn::exp(input, false, output_mem_config),
std::nullopt,
output_mem_config),
output_mem_config);
grad_tensor.emplace_back(grad_result);
return grad_tensor;
}

std::function<std::vector<ttnn::Tensor>(const Tensor&, const Tensor&, const MemoryConfig&)> UnaryBackwardFunction::get_function_type1(UnaryBackwardOpType OpType){
switch (OpType) {
case UnaryBackwardOpType::ASSIGN_BW:
Expand Down Expand Up @@ -477,6 +496,8 @@ std::function<std::vector<ttnn::Tensor>(const Tensor&, const Tensor&, const Memo
return _abs_bw;
case UnaryBackwardOpType::SILU_BW:
return _silu_bw;
case UnaryBackwardOpType::SELU_BW:
return _selu_bw;
default:
TT_ASSERT(false && "Undefined op type");
return 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ enum class UnaryBackwardOpType {
LOG_BW,
RELU6_BW,
ABS_BW,
SILU_BW
SILU_BW,
SELU_BW,
};

struct UnaryBackwardFunction{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,6 @@ constexpr auto log_bw = ttnn::register_operation<operations::unary_backward::Exe
constexpr auto relu6_bw = ttnn::register_operation<operations::unary_backward::ExecuteUnaryBackward<operations::unary_backward::UnaryBackwardOpType::RELU6_BW>>("ttnn::relu6_bw");
constexpr auto abs_bw = ttnn::register_operation<operations::unary_backward::ExecuteUnaryBackward<operations::unary_backward::UnaryBackwardOpType::ABS_BW>>("ttnn::abs_bw");
constexpr auto silu_bw = ttnn::register_operation<operations::unary_backward::ExecuteUnaryBackward<operations::unary_backward::UnaryBackwardOpType::SILU_BW>>("ttnn::silu_bw");

constexpr auto selu_bw = ttnn::register_operation<operations::unary_backward::ExecuteUnaryBackward<operations::unary_backward::UnaryBackwardOpType::SELU_BW>>("ttnn::selu_bw");

} // namespace ttnn
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,11 @@ void py_module(py::module& module) {
ttnn::silu_bw,
R"doc(Performs backward operations for silu on :attr:`input_tensor` with given :attr:`grad_tensor`)doc");

detail::bind_unary_backward(
module,
ttnn::selu_bw,
R"doc(Performs backward operations for selu on :attr:`input_tensor` with given :attr:`grad_tensor`)doc");

}

} // namespace binary_backward
Expand Down

0 comments on commit f049411

Please sign in to comment.