Skip to content

Commit

Permalink
#9874: Merge log_bw to TTNN
Browse files Browse the repository at this point in the history
  • Loading branch information
mouliraj-mcw committed Jul 9, 2024
1 parent a32ee62 commit 40c5a1f
Show file tree
Hide file tree
Showing 15 changed files with 51 additions and 50 deletions.
1 change: 1 addition & 0 deletions docs/source/ttnn/ttnn/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ Pointwise Unary
ttnn/eq_bw
ttnn/floor_bw
ttnn/round_bw
ttnn/log_bw

Pointwise Binary
================
Expand Down
2 changes: 0 additions & 2 deletions docs/source/ttnn/ttnn/dependencies/tt_lib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -843,8 +843,6 @@ Backward Operations

.. autofunction:: tt_lib.tensor.unary_sub_bw

.. autofunction:: tt_lib.tensor.log_bw

.. autofunction:: tt_lib.tensor.abs_bw

.. autofunction:: tt_lib.tensor.complex_abs_bw
Expand Down
6 changes: 6 additions & 0 deletions docs/source/ttnn/ttnn/ttnn/floor_bw.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
.. _ttnn.floor_bw:

ttnn.floor_bw
##############

.. autofunction:: ttnn.floor_bw
6 changes: 6 additions & 0 deletions docs/source/ttnn/ttnn/ttnn/log_bw.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
.. _ttnn.log_bw:

ttnn.log_bw
############

.. autofunction:: ttnn.log_bw
6 changes: 6 additions & 0 deletions docs/source/ttnn/ttnn/ttnn/round_bw.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
.. _ttnn.round_bw:

ttnn.round_bw
##############

.. autofunction:: ttnn.round_bw
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import torch
import pytest
import tt_lib
import ttnn
from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import compare_pcc, data_gen_with_range

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch
import pytest
import tt_lib
import ttnn
from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import (
data_gen_with_val,
compare_pcc,
Expand All @@ -23,7 +23,7 @@
def test_bw_log_0(input_shapes, device):
in_data, input_tensor = data_gen_with_val(input_shapes, device, True, val=0)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -1, 1, device)
tt_output_tensor_on_device = tt_lib.tensor.log_bw(grad_tensor, input_tensor)
tt_output_tensor_on_device = ttnn.log_bw(grad_tensor, input_tensor)

in_data.retain_grad()

Expand All @@ -47,7 +47,7 @@ def test_bw_log_0(input_shapes, device):
def test_bw_log(input_shapes, device):
in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -100, 100, device)
tt_output_tensor_on_device = tt_lib.tensor.log_bw(grad_tensor, input_tensor)
tt_output_tensor_on_device = ttnn.log_bw(grad_tensor, input_tensor)

in_data.retain_grad()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import torch
import pytest
import tt_lib
import ttnn
from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import compare_pcc, data_gen_with_range

Expand Down
20 changes: 0 additions & 20 deletions tt_eager/tt_dnn/op_library/backward/backward_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -343,26 +343,6 @@ std::vector<Tensor> ne_bw(const Tensor& grad, const MemoryConfig& output_mem_con
return operation::decorate_as_composite(__func__, _ne_bw)(grad, output_mem_config);
}

std::vector<Tensor> _log_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) {
std::vector<Tensor> grad_tensor;
Tensor grad_a = ttnn::multiply(grad, recip(input, output_mem_config), std::nullopt, output_mem_config);
Tensor t_inf = full_like(input, std::numeric_limits<float>::infinity(), output_mem_config);
Tensor t_nan = full_like(input, std::nanf(""), output_mem_config);
grad_tensor.emplace_back(where(
eqz(input, output_mem_config),
where(
eqz(grad, output_mem_config),
t_nan,
ttnn::multiply(t_inf, sign(grad, output_mem_config), std::nullopt, output_mem_config),
output_mem_config),
grad_a,
output_mem_config));
return grad_tensor;
}
std::vector<Tensor> log_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) {
return operation::decorate_as_composite(__func__, _log_bw)(grad, input, output_mem_config);
}

std::vector<Tensor> _abs_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) {
std::vector<Tensor> grad_tensor;
Tensor result = ttnn::multiply(grad, sign(input, output_mem_config), std::nullopt, output_mem_config);
Expand Down
5 changes: 0 additions & 5 deletions tt_eager/tt_dnn/op_library/backward/backward_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,6 @@ std::vector<Tensor> unary_sub_bw(
const Tensor& input,
const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);

std::vector<Tensor> log_bw(
const Tensor& grad,
const Tensor& input,
const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);

std::vector<Tensor> binary_le_bw(
const Tensor& grad,
const Tensor& input,
Expand Down
16 changes: 0 additions & 16 deletions tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_backward_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -425,22 +425,6 @@ namespace tt::tt_metal::detail{
"output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No"
)doc");

m_tensor.def("log_bw", &tt::tt_metal::log_bw,
py::arg("grad").noconvert(), py::arg("input").noconvert(), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc(
Performs backward operations for logarithm of ``input`` tensors with given ``grad``.
Input tensors must have BFLOAT16 data type.
Output tensors will have BFLOAT16 data type.
.. csv-table::
:header: "Argument", "Description", "Data type", "Valid range", "Required"
"grad", "Gradient tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes"
"input", "Tensor add is applied to", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes"
"output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No"
)doc");

m_tensor.def("abs_bw", &tt::tt_metal::abs_bw,
py::arg("grad").noconvert(), py::arg("input").noconvert(), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc(
Performs backward operations for abs of ``input`` tensors with given ``grad``.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,31 @@ std::vector<Tensor> _floor_bw(const Tensor& grad, const Tensor& input, const Mem
grad_tensor.emplace_back(t_zero);
return grad_tensor;
}

std::vector<Tensor> _round_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) {
std::vector<Tensor> grad_tensor;
Tensor t_zero = tt::tt_metal::zeros_like(grad, output_mem_config);
grad_tensor.emplace_back(t_zero);
return grad_tensor;
}

std::vector<Tensor> _log_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) {
std::vector<Tensor> grad_tensor;
Tensor grad_a = ttnn::multiply(grad, ttnn::reciprocal(input, output_mem_config), std::nullopt, output_mem_config);
Tensor t_inf = tt::tt_metal::full_like(input, std::numeric_limits<float>::infinity(), output_mem_config);
Tensor t_nan = tt::tt_metal::full_like(input, std::nanf(""), output_mem_config);
grad_tensor.emplace_back(where(
ttnn::eqz(input, output_mem_config),
where(
ttnn::eqz(grad, output_mem_config),
t_nan,
ttnn::multiply(t_inf, ttnn::sign(grad, output_mem_config), std::nullopt, output_mem_config),
output_mem_config),
grad_a,
output_mem_config));
return grad_tensor;
}

std::function<std::vector<ttnn::Tensor>(const Tensor&, const Tensor&, const MemoryConfig&)> get_function_type1(UnaryBackwardOpType OpType){
switch (OpType) {
case UnaryBackwardOpType::ASSIGN_BW:
Expand All @@ -116,6 +134,8 @@ std::function<std::vector<ttnn::Tensor>(const Tensor&, const Tensor&, const Memo
return _floor_bw;
case UnaryBackwardOpType::ROUND_BW:
return _round_bw;
case UnaryBackwardOpType::LOG_BW:
return _log_bw;
default:
TT_ASSERT(false && "Undefined op type");
return 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ enum class UnaryBackwardOpType {
EQ_BW,
FLOOR_BW,
ROUND_BW,
LOG_BW,
};


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ constexpr auto add_bw = ttnn::register_operation<operations::unary_backward::Exe
constexpr auto eq_bw = ttnn::register_operation<operations::unary_backward::ExecuteUnaryBackward<operations::unary_backward::UnaryBackwardOpType::EQ_BW>>("ttnn::eq_bw");
constexpr auto floor_bw = ttnn::register_operation<operations::unary_backward::ExecuteUnaryBackward<operations::unary_backward::UnaryBackwardOpType::FLOOR_BW>>("ttnn::floor_bw");
constexpr auto round_bw = ttnn::register_operation<operations::unary_backward::ExecuteUnaryBackward<operations::unary_backward::UnaryBackwardOpType::ROUND_BW>>("ttnn::round_bw");
constexpr auto log_bw = ttnn::register_operation<operations::unary_backward::ExecuteUnaryBackward<operations::unary_backward::UnaryBackwardOpType::LOG_BW>>("ttnn::log_bw");


} // namespace ttnn
Original file line number Diff line number Diff line change
Expand Up @@ -198,12 +198,17 @@ void py_module(py::module& module) {
detail::bind_unary_backward(
module,
ttnn::floor_bw,
R"doc(Performs backward operations for floor on :attr:`input_tensor` or attr:`input_tensor_a`, attr:`input_tensor_b` with given :attr:`grad_tensor`.)doc");
R"doc(Performs backward operations for floor on :attr:`input_tensor` with given :attr:`grad_tensor`)doc");

detail::bind_unary_backward(
module,
ttnn::round_bw,
R"doc(Performs backward operations for round on :attr:`input_tensor` or attr:`input_tensor_a`, attr:`input_tensor_b` with given :attr:`grad_tensor`.)doc");
R"doc(Performs backward operations for round on :attr:`input_tensor` with given :attr:`grad_tensor`.)doc");

detail::bind_unary_backward(
module,
ttnn::log_bw,
R"doc(Performs backward operations for logarithm on :attr:`input_tensor` with given :attr:`grad_tensor`)doc");

}

Expand Down

0 comments on commit 40c5a1f

Please sign in to comment.