Skip to content

Commit

Permalink
#9874: Overload Unary mul into binary mul file
Browse files Browse the repository at this point in the history
  • Loading branch information
VirdhatchaniKN committed Jul 4, 2024
1 parent 45de9fe commit 35ed601
Show file tree
Hide file tree
Showing 10 changed files with 55 additions and 19 deletions.
1 change: 0 additions & 1 deletion docs/source/ttnn/ttnn/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,6 @@ Pointwise Unary
ttnn/triu
ttnn/tanhshrink
ttnn/threshold
ttnn/unary_mul_bw
ttnn/clamp_min_bw

Pointwise Binary
Expand Down
6 changes: 0 additions & 6 deletions docs/source/ttnn/ttnn/ttnn/unary_mul_bw.rst

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_bw_unary_mul(input_shapes, scalar, device):
in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -5, 5, device)

tt_output_tensor_on_device = ttnn.unary_mul_bw(grad_tensor, input_tensor, scalar)
tt_output_tensor_on_device = ttnn.mul_bw(grad_tensor, input_tensor, scalar)

in_data.retain_grad()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,22 @@ struct ExecuteBinaryBackward {
return op_type(grad_tensor_arg, input_tensor_a_arg, input_tensor_b_arg, input_tensor_c_arg, memory_config);
}

//unary overload
template <typename... Args>
static auto input_tensors_to_validate(const Tensor &grad_tensor, const Tensor &input_tensor, Args &&...args) {
return std::forward_as_tuple(grad_tensor, input_tensor);
}
static std::vector<ttnn::Tensor> execute_on_worker_thread(
const Tensor &grad_tensor_arg,
const Tensor &input_tensor_arg,
float alpha,
const std::optional<MemoryConfig> &memory_config = std::nullopt) {

auto op_type = utils::get_unary_type1_overload_function(binary_backward_op_type);
auto output_memory_config = memory_config.value_or(input_tensor_arg.memory_config());
return op_type(grad_tensor_arg, input_tensor_arg, alpha, output_memory_config);
}

//Type 1: Type 1 with 1 float
template <typename... Args>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,22 @@ Keyword args:
py::arg("memory_config") = std::nullopt,
py::arg("are_required_outputs") = std::vector<bool>{true, true},
py::arg("input_a_grad") = std::nullopt,
py::arg("input_b_grad") = std::nullopt});
py::arg("input_b_grad") = std::nullopt},

//unary overload
ttnn::pybind_overload_t{
[](const binary_backward_operation_t& self,
const ttnn::Tensor& grad_tensor,
const ttnn::Tensor& input_tensor,
const float alpha,
const std::optional<ttnn::MemoryConfig>& memory_config) -> std::vector<ttnn::Tensor> {
return self(grad_tensor, input_tensor, alpha, memory_config);
},
py::arg("grad_tensor"),
py::arg("input_tensor"),
py::arg("alpha"),
py::kw_only(),
py::arg("memory_config") = std::nullopt});
}

} // namespace detail
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// SPDX-License-Identifier: Apache-2.0

#include "ttnn/operations/eltwise/binary_backward/device/binary_backward_op.hpp"
#include "ttnn/operations/eltwise/unary_backward/device/unary_backward_op.cpp"

#include "third_party/magic_enum/magic_enum.hpp"
#include "tt_eager/tt_dnn/op_library/backward/backward_ops.cpp"
Expand Down Expand Up @@ -752,6 +753,16 @@ std::function<std::vector<ttnn::Tensor>(const Tensor&, const Tensor&, const Tens
}
}

std::function<std::vector<ttnn::Tensor>(const Tensor&, const Tensor&, float, const MemoryConfig&)> get_unary_type1_overload_function(BinaryBackwardOpType OpType){
switch (OpType) {
case BinaryBackwardOpType::MUL_BW:
return ttnn::operations::unary_backward::utils::_mul_bw;
default:
TT_ASSERT(false && "Undefined op type");
return 0;
}
}

}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ namespace ttnn::operations::unary_backward {
namespace utils {


std::vector<ttnn::Tensor> _unary_mul_bw(
std::vector<ttnn::Tensor> _mul_bw(
const Tensor& grad, const Tensor& input, float scalar, const MemoryConfig& output_mem_config) {
std::vector<Tensor> grad_tensor;
Tensor result = mul_unary(grad, scalar, output_mem_config);
Expand Down Expand Up @@ -47,8 +47,8 @@ std::function<std::vector<ttnn::Tensor>(const Tensor&, const Tensor&, const Memo

std::function<std::vector<ttnn::Tensor>(const Tensor&, const Tensor&, float, const MemoryConfig&)> get_function_type1_w_float(UnaryBackwardOpType OpType){
switch (OpType) {
case UnaryBackwardOpType::UNARY_MUL_BW:
return _unary_mul_bw;
// case UnaryBackwardOpType::UNARY_MUL_BW:
// return _unary_mul_bw;
case UnaryBackwardOpType::CLAMP_MIN_BW:
return _clamp_min_bw;
default:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ namespace ttnn::operations::unary_backward {

constexpr uint8_t DefaultQueueId = 0;
enum class UnaryBackwardOpType {
UNARY_MUL_BW,
// UNARY_MUL_BW,
CLAMP_MIN_BW,
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

#pragma once

#include "device/unary_backward_op.cpp"
#include "ttnn/device_operation.hpp"
#include "ttnn/operations/data_movement.hpp"

Expand Down Expand Up @@ -79,7 +78,7 @@ struct ExecuteUnaryBackward {
} // operations::unary

//type 1
constexpr auto unary_mul_bw = ttnn::register_operation<operations::unary_backward::ExecuteUnaryBackward<operations::unary_backward::UnaryBackwardOpType::UNARY_MUL_BW>>("ttnn::unary_mul_bw");
// constexpr auto unary_mul_bw = ttnn::register_operation<operations::unary_backward::ExecuteUnaryBackward<operations::unary_backward::UnaryBackwardOpType::UNARY_MUL_BW>>("ttnn::unary_mul_bw");
constexpr auto clamp_min_bw = ttnn::register_operation<operations::unary_backward::ExecuteUnaryBackward<operations::unary_backward::UnaryBackwardOpType::CLAMP_MIN_BW>>("ttnn::clamp_min_bw");

} // namespace ttnn
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,11 @@ Keyword args:


void py_module(py::module& module) {
detail::bind_unary_backward(
module,
ttnn::unary_mul_bw,
R"doc(Performs backward operations for multiply on :attr:`input_tensor`, :attr:`alpha` with given :attr:`grad_tensor`.)doc");
// detail::bind_unary_backward(
// module,
// ttnn::unary_mul_bw,
// R"doc(Performs backward operations for multiply on :attr:`input_tensor`, :attr:`alpha` with given :attr:`grad_tensor`.)doc");
// }

detail::bind_unary_backward(
module,
Expand All @@ -93,6 +94,7 @@ void py_module(py::module& module) {

}


} // namespace binary_backward
} // namespace operations
} // namespace ttnn

0 comments on commit 35ed601

Please sign in to comment.