Skip to content

Commit

Permalink
#10071: Merge round_bw to TTNN
Browse files Browse the repository at this point in the history
  • Loading branch information
mouliraj-mcw committed Jul 11, 2024
1 parent d9cef87 commit bdcfda6
Show file tree
Hide file tree
Showing 10 changed files with 28 additions and 32 deletions.
1 change: 1 addition & 0 deletions docs/source/ttnn/ttnn/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ Pointwise Unary
ttnn/celu_bw
ttnn/rpow_bw
ttnn/floor_bw
ttnn/round_bw

Pointwise Binary
================
Expand Down
2 changes: 0 additions & 2 deletions docs/source/ttnn/ttnn/dependencies/tt_lib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -928,8 +928,6 @@ Backward Operations

.. autofunction:: tt_lib.tensor.repeat_bw

.. autofunction:: tt_lib.tensor.round_bw

.. autofunction:: tt_lib.tensor.unary_div_no_nan_bw

Loss Functions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch
import pytest
import tt_lib
import ttnn
from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import compare_pcc, data_gen_with_range


Expand All @@ -21,7 +22,7 @@ def test_bw_round(input_shapes, device):
in_data, input_tensor = data_gen_with_range(input_shapes, -200, 201, device, required_grad=True)
pyt_y = torch.round(in_data)

tt_output_tensor_on_device = tt_lib.tensor.round_bw(grad_tensor)
tt_output_tensor_on_device = ttnn.round_bw(grad_tensor, input_tensor)

in_data.retain_grad()

Expand Down
10 changes: 0 additions & 10 deletions tt_eager/tt_dnn/op_library/backward/backward_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1748,16 +1748,6 @@ std::vector<Tensor> repeat_bw(
return operation::decorate_as_composite(__func__, _repeat_bw)(grad, input, shape, output_mem_config);
}

std::vector<Tensor> _round_bw(const Tensor& grad, const MemoryConfig& output_mem_config) {
std::vector<Tensor> grad_tensor;
Tensor t_zero = zeros_like(grad, output_mem_config);
grad_tensor.emplace_back(t_zero);
return grad_tensor;
}
std::vector<Tensor> round_bw(const Tensor& grad, const MemoryConfig& output_mem_config) {
return operation::decorate_as_composite(__func__, _round_bw)(grad, output_mem_config);
}

std::vector<Tensor> _unary_div_no_nan_bw(
const Tensor& grad, const Tensor& input, float scalar, const MemoryConfig& output_mem_config) {
std::vector<Tensor> grad_tensor;
Expand Down
3 changes: 0 additions & 3 deletions tt_eager/tt_dnn/op_library/backward/backward_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -422,9 +422,6 @@ std::vector<Tensor> complex_sub_bw(
std::vector<Tensor> repeat_bw(
const Tensor& grad, const Tensor& input, const Shape& shape, const MemoryConfig& output_mem_config);

std::vector<Tensor> round_bw(
const Tensor& grad, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);

std::vector<Tensor> unary_div_no_nan_bw(
const Tensor& grad,
const Tensor& input,
Expand Down
15 changes: 0 additions & 15 deletions tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_backward_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1153,21 +1153,6 @@ namespace tt::tt_metal::detail{
"output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No"
)doc");

m_tensor.def("round_bw", &tt::tt_metal::round_bw,
py::arg("grad").noconvert(), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc(
Returns an tensor of zeros like ``grad`` tensor
Input tensor must have BFLOAT16 data type.
Output tensor will have BFLOAT16 data type.
.. csv-table::
:header: "Argument", "Description", "Data type", "Valid range", "Required"
"grad", "Gradient Tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes"
"output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No"
)doc");

m_tensor.def("unary_div_no_nan_bw", &tt::tt_metal::unary_div_no_nan_bw,
py::arg("grad").noconvert(), py::arg("input").noconvert(), py::arg("scalar") = 1.0f, py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc(
Performs backward operations for division with given ``grad`` and ``scalar`` with no nan.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,6 @@ std::vector<Tensor> _logit_bw(const Tensor& grad, const Tensor& input, const Mem
return grad_tensor;
}


std::vector<Tensor> _hardshrink_bw(
const Tensor& grad, const Tensor& input_tensor, float lambd, const MemoryConfig& output_mem_config) {
std::vector<Tensor> grad_tensor;
Expand Down Expand Up @@ -358,6 +357,20 @@ std::vector<Tensor> _rpow_bw(
}


std::vector<Tensor> _floor_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) {
std::vector<Tensor> grad_tensor;
Tensor t_zero = tt::tt_metal::zeros_like(grad, output_mem_config);
grad_tensor.emplace_back(t_zero);
return grad_tensor;
}

std::vector<Tensor> _round_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) {
std::vector<Tensor> grad_tensor;
Tensor t_zero = tt::tt_metal::zeros_like(grad, output_mem_config);
grad_tensor.emplace_back(t_zero);
return grad_tensor;
}

std::function<std::vector<ttnn::Tensor>(const Tensor&, const Tensor&, const MemoryConfig&)> UnaryBackwardFunction::get_function_type1(UnaryBackwardOpType OpType){
switch (OpType) {
case UnaryBackwardOpType::ASSIGN_BW:
Expand Down Expand Up @@ -390,6 +403,8 @@ std::function<std::vector<ttnn::Tensor>(const Tensor&, const Tensor&, const Memo
return _logit_bw;
case UnaryBackwardOpType::FLOOR_BW:
return _floor_bw;
case UnaryBackwardOpType::ROUND_BW:
return _round_bw;
default:
TT_ASSERT(false && "Undefined op type");
return 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ enum class UnaryBackwardOpType {
ELU_BW,
CELU_BW,
RPOW_BW,
FLOOR_BW,
ROUND_BW,
};

struct UnaryBackwardFunction{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ constexpr auto elu_bw = ttnn::register_operation<operations::unary_backward::Exe
constexpr auto celu_bw = ttnn::register_operation<operations::unary_backward::ExecuteUnaryBackward<operations::unary_backward::UnaryBackwardOpType::CELU_BW>>("ttnn::celu_bw");
constexpr auto rpow_bw = ttnn::register_operation<operations::unary_backward::ExecuteUnaryBackward<operations::unary_backward::UnaryBackwardOpType::RPOW_BW>>("ttnn::rpow_bw");
constexpr auto floor_bw = ttnn::register_operation<operations::unary_backward::ExecuteUnaryBackward<operations::unary_backward::UnaryBackwardOpType::FLOOR_BW>>("ttnn::floor_bw");
constexpr auto round_bw = ttnn::register_operation<operations::unary_backward::ExecuteUnaryBackward<operations::unary_backward::UnaryBackwardOpType::ROUND_BW>>("ttnn::round_bw");


} // namespace ttnn
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,12 @@ void py_module(py::module& module) {
module,
ttnn::rpow_bw,
R"doc(Performs backward operations for rpow on :attr:`input_tensor`, :attr:`exponent` with given :attr:`grad_tensor`.)doc");

detail::bind_unary_backward(
module,
ttnn::round_bw,
R"doc(Performs backward operations for round on :attr:`input_tensor` or attr:`input_tensor_a`, attr:`input_tensor_b` with given :attr:`grad_tensor`.)doc");

}

} // namespace binary_backward
Expand Down

0 comments on commit bdcfda6

Please sign in to comment.