Skip to content

Commit

Permalink
#9874: Overload unary and binary ops
Browse files Browse the repository at this point in the history
  • Loading branch information
Aswinmcw committed Jul 8, 2024
1 parent f8acc2e commit 2ec4d0e
Show file tree
Hide file tree
Showing 7 changed files with 54 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_bw_unary_mul(input_shapes, scalar, device):
in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -5, 5, device)

tt_output_tensor_on_device = ttnn.unary_mul_bw(grad_tensor, input_tensor, scalar)
tt_output_tensor_on_device = ttnn.mul_bw(grad_tensor, input_tensor, scalar)

in_data.retain_grad()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,4 @@ constexpr auto addalpha_bw = ttnn::register_operation<operations::binary_backwar
//type 3
constexpr auto add_bw = ttnn::register_operation<operations::binary_backward::ExecuteBinaryBackward<operations::binary_backward::BinaryBackwardOpType::ADD_BW>>("ttnn::add_bw");
constexpr auto binary_eq_bw = ttnn::register_operation<operations::binary_backward::ExecuteBinaryBackward<operations::binary_backward::BinaryBackwardOpType::BINARY_EQ_BW>>("ttnn::binary_eq_bw");
constexpr auto mul_bw = ttnn::register_operation<operations::binary_backward::ExecuteBinaryBackward<operations::binary_backward::BinaryBackwardOpType::MUL_BW>>("ttnn::mul_bw");

} // namespace ttnn
Original file line number Diff line number Diff line change
Expand Up @@ -346,11 +346,6 @@ void py_module(py::module& module) {
ttnn::div_bw,
R"doc(Performs backward operations for divide of :attr:`input_tensor_a` and :attr:`input_tensor_b` with given :attr:`grad_tensor`.)doc");

detail::bind_binary_backward(
module,
ttnn::mul_bw,
R"doc(Performs backward operations for multiply on :attr:`input_tensor_b` , attr:`input_tensor_a` with given attr:`grad_tensor`.)doc");

}

} // namespace binary_backward
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ std::function<std::vector<ttnn::Tensor>(const Tensor&, const Tensor&, const Memo

std::function<std::vector<ttnn::Tensor>(const Tensor&, const Tensor&, float, const MemoryConfig&)> get_function_type1_w_float(UnaryBackwardOpType OpType){
switch (OpType) {
case UnaryBackwardOpType::UNARY_MUL_BW:
case UnaryBackwardOpType::MUL_BW:
return _unary_mul_bw;
case UnaryBackwardOpType::CLAMP_MIN_BW:
return _clamp_min_bw;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ namespace ttnn::operations::unary_backward {

constexpr uint8_t DefaultQueueId = 0;
enum class UnaryBackwardOpType {
UNARY_MUL_BW,
MUL_BW,
CLAMP_MIN_BW,
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,39 +16,13 @@ namespace operations::unary_backward {
template <UnaryBackwardOpType unary_backward_op_type>
struct ExecuteUnaryBackward {

static inline const std::array<TensorSchema, 2> input_tensor_schemas() {
return {
ttnn::TensorSchema{
2,
4,
{ttnn::bfloat16, ttnn::bfloat8_b, ttnn::bfloat4_b, ttnn::uint16},
{ttnn::TILE_LAYOUT},
true,
false,
false,
false},
ttnn::TensorSchema{
2,
4,
{ttnn::bfloat16, ttnn::bfloat8_b, ttnn::bfloat4_b, ttnn::uint16},
{ttnn::TILE_LAYOUT},
true,
false,
false,
false}};
}

static inline std::vector<ttnn::Tensor> create_async_output_tensors(
const std::vector<Tensor> &input_tensors, const std::vector<std::optional<const Tensor>>& optional_inputs) {
const auto& input_tensor = input_tensors.at(0);
return {Tensor(operation::get_workers_for_op_output({input_tensor}))};
}

//Type 1: 2 inputs, 1 grad tensor
template <typename... Args>
static auto input_tensors_to_validate(const Tensor &grad_tensor, const Tensor &input_tensor_a, const Tensor &input_tensor_b, Args &&...args) {
return std::forward_as_tuple(grad_tensor, input_tensor_a, input_tensor_b);
}

static std::vector<ttnn::Tensor> execute_on_worker_thread(
const Tensor &grad_tensor_arg,
Expand All @@ -61,7 +35,6 @@ struct ExecuteUnaryBackward {
}

//Type 1: Type 1 with 1 float
template <typename... Args>

static std::vector<ttnn::Tensor> execute_on_worker_thread(
const Tensor &grad_tensor_arg,
Expand All @@ -79,7 +52,7 @@ struct ExecuteUnaryBackward {
} // operations::unary

//type 1
constexpr auto unary_mul_bw = ttnn::register_operation<operations::unary_backward::ExecuteUnaryBackward<operations::unary_backward::UnaryBackwardOpType::UNARY_MUL_BW>>("ttnn::unary_mul_bw");
constexpr auto mul_bw = ttnn::register_operation<operations::unary_backward::ExecuteUnaryBackward<operations::unary_backward::UnaryBackwardOpType::MUL_BW>>("ttnn::mul_bw");
constexpr auto clamp_min_bw = ttnn::register_operation<operations::unary_backward::ExecuteUnaryBackward<operations::unary_backward::UnaryBackwardOpType::CLAMP_MIN_BW>>("ttnn::clamp_min_bw");

} // namespace ttnn
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include "ttnn/cpp/pybind11/decorators.hpp"
#include "ttnn/operations/eltwise/unary_backward/unary_backward.hpp"
#include "ttnn/operations/eltwise/binary_backward/binary_backward.hpp"
#include "ttnn/types.hpp"

namespace py = pybind11;
Expand Down Expand Up @@ -39,14 +40,61 @@ Keyword args:
>>> input = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device)
>>> output = {1}(grad_tensor, input)
)doc",
operation.name(),
operation.base_name(),
operation.python_fully_qualified_name(),
description);

bind_registered_operation(
module,
operation,
doc,
ttnn::pybind_overload_t{
[operation](const unary_backward_operation_t& self,
const ttnn::Tensor& grad_tensor,
const ttnn::Tensor& input_tensor_a,
const ttnn::Tensor& input_tensor_b,
const std::optional<ttnn::MemoryConfig>& memory_config) -> std::vector<ttnn::Tensor> {
auto output_memory_config = memory_config.value_or(input_tensor_a.memory_config());

using BinaryBackwardOp = ttnn::operations::binary_backward::ExecuteBinaryBackward<binary_backward::BinaryBackwardOpType::MUL_BW>;
if(operation.base_name()=="mul_bw"){
using BinaryBackwardOp = ttnn::operations::binary_backward::ExecuteBinaryBackward<binary_backward::BinaryBackwardOpType::MUL_BW>;
}

return BinaryBackwardOp::execute_on_worker_thread(grad_tensor, input_tensor_a, output_memory_config, input_tensor_b);
},
py::arg("grad_tensor"),
py::arg("input_tensor_a"),
py::arg("input_tensor_b"),
py::kw_only(),
py::arg("memory_config") = std::nullopt},

ttnn::pybind_overload_t{
[](const unary_backward_operation_t& self,
const ttnn::Tensor& grad_tensor,
const ttnn::Tensor& input_tensor_a,
const ttnn::Tensor& input_tensor_b,
const std::optional<ttnn::MemoryConfig>& memory_config,
const std::vector<bool>& are_required_outputs,
const std::optional<ttnn::Tensor>& input_a_grad,
const std::optional<ttnn::Tensor>& input_b_grad,
const uint8_t& queue_id) -> std::vector<optional<ttnn::Tensor>> {
using BinaryBackwardOp = ttnn::operations::binary_backward::ExecuteBinaryBackward<binary_backward::BinaryBackwardOpType::MUL_BW>;
if(operation.base_name()=="mul_bw"){
using BinaryBackwardOp = ttnn::operations::binary_backward::ExecuteBinaryBackward<binary_backward::BinaryBackwardOpType::MUL_BW>;
}
return BinaryBackwardOp::execute_on_main_thread(queue_id, grad_tensor, input_tensor_a, input_tensor_b, memory_config, are_required_outputs, input_a_grad, input_b_grad);
},
py::arg("grad_tensor"),
py::arg("input_tensor_a"),
py::arg("input_tensor_b"),
py::kw_only(),
py::arg("memory_config") = std::nullopt,
py::arg("are_required_outputs") = std::vector<bool>{true, true},
py::arg("input_a_grad") = std::nullopt,
py::arg("input_b_grad") = std::nullopt,
py::arg("queue_id") = 0},

ttnn::pybind_overload_t{
[](const unary_backward_operation_t& self,
const ttnn::Tensor& grad_tensor,
Expand Down Expand Up @@ -83,7 +131,7 @@ Keyword args:
void py_module(py::module& module) {
detail::bind_unary_backward(
module,
ttnn::unary_mul_bw,
ttnn::mul_bw,
R"doc(Performs backward operations for multiply on :attr:`input_tensor`, :attr:`alpha` with given :attr:`grad_tensor`.)doc");

detail::bind_unary_backward(
Expand Down

0 comments on commit 2ec4d0e

Please sign in to comment.