Skip to content

Commit

Permalink
#9874: Merge unary_mul_bw to TTNN
Browse files Browse the repository at this point in the history
  • Loading branch information
VirdhatchaniKN authored and Aswinmcw committed Jul 6, 2024
1 parent 9281a1c commit dff3292
Show file tree
Hide file tree
Showing 11 changed files with 261 additions and 32 deletions.
1 change: 1 addition & 0 deletions docs/source/ttnn/ttnn/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ Pointwise Unary
ttnn/triu
ttnn/tanhshrink
ttnn/threshold
ttnn/unary_mul_bw

Pointwise Binary
================
Expand Down
2 changes: 0 additions & 2 deletions docs/source/ttnn/ttnn/dependencies/tt_lib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -823,8 +823,6 @@ Backward Operations

.. autofunction:: tt_lib.tensor.conj_bw

.. autofunction:: tt_lib.tensor.unary_mul_bw

.. autofunction:: tt_lib.tensor.unary_add_bw

.. autofunction:: tt_lib.tensor.unary_assign_bw
Expand Down
6 changes: 6 additions & 0 deletions docs/source/ttnn/ttnn/ttnn/unary_mul_bw.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
.. _ttnn.unary_mul_bw:

ttnn.unary_mul_bw
#################

.. autofunction:: ttnn.unary_mul_bw
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch
import pytest
import tt_lib
import ttnn
from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import data_gen_with_range, compare_pcc


Expand All @@ -21,7 +22,7 @@ def test_bw_unary_mul(input_shapes, scalar, device):
in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -5, 5, device)

tt_output_tensor_on_device = tt_lib.tensor.unary_mul_bw(grad_tensor, input_tensor, scalar=scalar)
tt_output_tensor_on_device = ttnn.unary_mul_bw(grad_tensor, input_tensor, scalar)

in_data.retain_grad()

Expand Down
12 changes: 0 additions & 12 deletions tt_eager/tt_dnn/op_library/backward/backward_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,6 @@ namespace tt {

namespace tt_metal {

std::vector<Tensor> _unary_mul_bw(
const Tensor& grad, const Tensor& input, float scalar, const MemoryConfig& output_mem_config) {
std::vector<Tensor> grad_tensor;
Tensor result = mul_unary(grad, scalar, output_mem_config);
grad_tensor.emplace_back(result);
return grad_tensor;
}
std::vector<Tensor> unary_mul_bw(
const Tensor& grad, const Tensor& input, float scalar, const MemoryConfig& output_mem_config) {
return operation::decorate_as_composite(__func__, _unary_mul_bw)(grad, input, scalar, output_mem_config);
}

// unary_pow:
// grad_input = grad * exponent * torch.pow(input, exponent - 1)
std::vector<std::optional<Tensor>> _unary_pow_bw(uint8_t queue_id, const Tensor& grad, const Tensor& input, float exponent, const MemoryConfig& output_mem_config, const std::vector<bool>& are_required_outputs, std::optional<Tensor> input_grad) {
Expand Down
17 changes: 0 additions & 17 deletions tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_backward_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,23 +41,6 @@ namespace tt::tt_metal::detail{
"output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No"
)doc");

m_tensor.def("unary_mul_bw", &tt::tt_metal::unary_mul_bw,
py::arg("grad").noconvert(), py::arg("input").noconvert(), py::arg("scalar") = 1.0f, py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc(
Performs backward operations for multiplication with given ``grad`` and ``scalar``
Input tensors must have BFLOAT16 data type.
Output tensor will have BFLOAT16 data type.
.. csv-table::
:header: "Argument", "Description", "Data type", "Valid range", "Required"
"grad", "Gradient tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes"
"input", "Tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes"
"scalar", "Scalar value", "float", "default to 1.0f", "No"
"output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No"
)doc");


m_tensor.def("exp_bw",
[](const Tensor& grad,
Expand Down
4 changes: 4 additions & 0 deletions ttnn/cpp/pybind11/operations/__init__.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "ttnn/operations/examples/examples_pybind.hpp"
#include "ttnn/operations/reduction/reduction_pybind.hpp"
#include "ttnn/operations/eltwise/ternary_backward/ternary_backward_pybind.hpp"
#include "ttnn/operations/eltwise/unary_backward/unary_backward_pybind.hpp"
#include "ttnn/operations/data_movement/data_movement_pybind.hpp"
#include "ttnn/operations/embedding/embedding_ops_pybind.hpp"

Expand All @@ -50,6 +51,9 @@ void py_module(py::module& module) {

auto m_ternary_backward = module.def_submodule("ternary_backward", "ternary_backward operations");
ternary_backward::py_module(m_ternary_backward);

auto m_unary_backward = module.def_submodule("unary_backward", "unary_backward operations");
unary_backward::py_module(m_unary_backward);

auto m_ternary = module.def_submodule("ternary", "ternary operations");
ternary::py_module(m_ternary);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include "ttnn/operations/eltwise/unary_backward/device/unary_backward_op.hpp"

#include "third_party/magic_enum/magic_enum.hpp"
#include "tt_eager/tt_dnn/op_library/bcast/bcast_op.hpp"
#include "tt_eager/tt_dnn/op_library/composite/composite_ops.hpp"
#include "tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp"
#include "tt_eager/tt_dnn/op_library/unpad/unpad_op.hpp"
#include "tt_metal/common/constants.hpp"
#include "tt_metal/host_api.hpp"
#include "tt_metal/tools/profiler/op_profiler.hpp"
#include "ttnn/operations/eltwise/unary/unary.hpp"

namespace ttnn::operations::unary_backward {

namespace utils {


std::vector<ttnn::Tensor> _unary_mul_bw(
const Tensor& grad, const Tensor& input, float scalar, const MemoryConfig& output_mem_config) {
std::vector<Tensor> grad_tensor;
Tensor result = mul_unary(grad, scalar, output_mem_config);
grad_tensor.emplace_back(result);
return grad_tensor;
}


std::function<std::vector<ttnn::Tensor>(const Tensor&, const Tensor&, const MemoryConfig&)> get_function_type1(UnaryBackwardOpType OpType){
switch (OpType) {
default:
TT_ASSERT(false && "Undefined op type");
return 0;
}
}

std::function<std::vector<ttnn::Tensor>(const Tensor&, const Tensor&, float, const MemoryConfig&)> get_function_type1_w_float(UnaryBackwardOpType OpType){
switch (OpType) {
case UnaryBackwardOpType::UNARY_MUL_BW:
return _unary_mul_bw;
default:
TT_ASSERT(false && "Undefined op type");
return 0;
}
}

}

} // namespace ttnn::operations::unary
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include <functional>
#include <optional>

#include "third_party/magic_enum/magic_enum.hpp"

namespace ttnn::operations::unary_backward {

constexpr uint8_t DefaultQueueId = 0;
enum class UnaryBackwardOpType {
UNARY_MUL_BW,
};


} // namespace ttnn::operations::unary
84 changes: 84 additions & 0 deletions ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@

// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "device/unary_backward_op.cpp"
#include "ttnn/device_operation.hpp"
#include "ttnn/operations/data_movement.hpp"

namespace ttnn {

namespace operations::unary_backward {

template <UnaryBackwardOpType unary_backward_op_type>
struct ExecuteUnaryBackward {

static inline const std::array<TensorSchema, 2> input_tensor_schemas() {
return {
ttnn::TensorSchema{
2,
4,
{ttnn::bfloat16, ttnn::bfloat8_b, ttnn::bfloat4_b, ttnn::uint16},
{ttnn::TILE_LAYOUT},
true,
false,
false,
false},
ttnn::TensorSchema{
2,
4,
{ttnn::bfloat16, ttnn::bfloat8_b, ttnn::bfloat4_b, ttnn::uint16},
{ttnn::TILE_LAYOUT},
true,
false,
false,
false}};
}

static inline std::vector<ttnn::Tensor> create_async_output_tensors(
const std::vector<Tensor> &input_tensors, const std::vector<std::optional<const Tensor>>& optional_inputs) {
const auto& input_tensor = input_tensors.at(0);
return {Tensor(operation::get_workers_for_op_output({input_tensor}))};
}

//Type 1: 2 inputs, 1 grad tensor
template <typename... Args>
static auto input_tensors_to_validate(const Tensor &grad_tensor, const Tensor &input_tensor_a, const Tensor &input_tensor_b, Args &&...args) {
return std::forward_as_tuple(grad_tensor, input_tensor_a, input_tensor_b);
}

static std::vector<ttnn::Tensor> execute_on_worker_thread(
const Tensor &grad_tensor_arg,
const Tensor &input_tensor_arg,
const std::optional<MemoryConfig> &memory_config = std::nullopt) {

auto op_type = utils::get_function_type1(unary_backward_op_type);
auto output_memory_config = memory_config.value_or(input_tensor_arg.memory_config());
return op_type(grad_tensor_arg, input_tensor_arg, output_memory_config);
}

//Type 1: Type 1 with 1 float
template <typename... Args>

static std::vector<ttnn::Tensor> execute_on_worker_thread(
const Tensor &grad_tensor_arg,
const Tensor &input_tensor_arg,
float alpha,
const std::optional<MemoryConfig> &memory_config = std::nullopt) {

auto op_type = utils::get_function_type1_w_float(unary_backward_op_type);
auto output_memory_config = memory_config.value_or(input_tensor_arg.memory_config());
return op_type(grad_tensor_arg, input_tensor_arg, alpha, output_memory_config);
}

};

} // operations::unary

//type 1
constexpr auto unary_mul_bw = ttnn::register_operation<operations::unary_backward::ExecuteUnaryBackward<operations::unary_backward::UnaryBackwardOpType::UNARY_MUL_BW>>("ttnn::unary_mul_bw");

} // namespace ttnn
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include "ttnn/cpp/pybind11/decorators.hpp"
#include "ttnn/operations/eltwise/unary_backward/unary_backward.hpp"
#include "ttnn/types.hpp"

namespace py = pybind11;

namespace ttnn {
namespace operations {
namespace unary_backward {

namespace detail {

template <typename unary_backward_operation_t>
void bind_unary_backward(py::module& module, const unary_backward_operation_t& operation, const std::string& description) {
auto doc = fmt::format(
R"doc({0}(grad_tensor: ttnn.Tensor, input_tensor: ttnn.Tensor *, memory_config: ttnn.MemoryConfig) -> std::vector<Tensor>
{2}
Args:
* :attr:`grad_tensor`
* :attr:`input_tensor`
Keyword args:
* :attr:`memory_config` (Optional[ttnn.MemoryConfig]): memory config for the output tensor
Example:
>>> grad_tensor = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device)
>>> input = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device)
>>> output = {1}(grad_tensor, input)
)doc",
operation.name(),
operation.python_fully_qualified_name(),
description);

bind_registered_operation(
module,
operation,
doc,
ttnn::pybind_overload_t{
[](const unary_backward_operation_t& self,
const ttnn::Tensor& grad_tensor,
const ttnn::Tensor& input_tensor,
const std::optional<ttnn::MemoryConfig>& memory_config) -> std::vector<ttnn::Tensor> {
auto output_memory_config = memory_config.value_or(input_tensor.memory_config());
return self(grad_tensor, input_tensor, output_memory_config);
},
py::arg("grad_tensor"),
py::arg("input_tensor"),
py::kw_only(),
py::arg("memory_config") = std::nullopt},


ttnn::pybind_overload_t{
[](const unary_backward_operation_t& self,
const ttnn::Tensor& grad_tensor,
const ttnn::Tensor& input_tensor,
const float alpha,
const std::optional<ttnn::MemoryConfig>& memory_config) -> std::vector<ttnn::Tensor> {
return self(grad_tensor, input_tensor, alpha, memory_config);
},
py::arg("grad_tensor"),
py::arg("input_tensor"),
py::arg("alpha"),
py::kw_only(),
py::arg("memory_config") = std::nullopt});

}

} // namespace detail


void py_module(py::module& module) {
detail::bind_unary_backward(
module,
ttnn::unary_mul_bw,
R"doc(Performs backward operations for multiply on :attr:`input_tensor`, :attr:`alpha` with given :attr:`grad_tensor`.)doc");

}

} // namespace binary_backward
} // namespace operations
} // namespace ttnn

0 comments on commit dff3292

Please sign in to comment.