Skip to content

Commit

Permalink
#13643: Extend binary-ng math support to match all primitive binary ops
Browse files Browse the repository at this point in the history
  • Loading branch information
yan-zaretskiy committed Dec 16, 2024
1 parent ba516bc commit 25c17eb
Show file tree
Hide file tree
Showing 17 changed files with 766 additions and 194 deletions.
81 changes: 76 additions & 5 deletions tests/ttnn/unit_tests/operations/eltwise/test_binary_bcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
import torch
import pytest
import ttnn
import random
from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range, compare_pcc

from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import (
compare_pcc,
)


@pytest.mark.parametrize(
Expand All @@ -17,21 +19,90 @@
(torch.Size([5, 1, 1, 64]), torch.Size([1, 3, 128, 1])),
),
)
def test_binary_scalar_ops(input_shapes, device):
@pytest.mark.parametrize(
"ttnn_fn",
[
ttnn.experimental.gte,
ttnn.experimental.gt,
ttnn.experimental.lte,
ttnn.experimental.lt,
ttnn.experimental.eq,
ttnn.experimental.ne,
ttnn.experimental.logical_and,
ttnn.experimental.logical_or,
ttnn.experimental.logical_xor,
ttnn.experimental.ldexp,
ttnn.experimental.logaddexp,
ttnn.experimental.logaddexp2,
ttnn.experimental.squared_difference,
ttnn.experimental.add,
ttnn.experimental.sub,
ttnn.experimental.mul,
ttnn.experimental.div,
ttnn.experimental.bias_gelu,
],
)
def test_binary_scalar_ops(input_shapes, ttnn_fn, device):
a_shape, b_shape = input_shapes
a_pt = torch.rand(a_shape).bfloat16()
b_pt = torch.rand(b_shape).bfloat16()

a_tt = ttnn.from_torch(a_pt, device=device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG)
b_tt = ttnn.from_torch(b_pt, device=device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG)
cq_id = 0
out_tt = ttnn.experimental.add(a_tt, b_tt, queue_id=cq_id)
out_pt = a_pt + b_pt
out_tt = ttnn_fn(a_tt, b_tt, queue_id=cq_id)
golden_fn = ttnn.get_golden_function(ttnn_fn)
out_pt = golden_fn(a_pt, b_pt)

comp_pass = compare_pcc([out_tt], [out_pt])
assert comp_pass


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 31, 32]), torch.Size([5, 3, 32, 32])),
(torch.Size([5, 2, 64, 1]), torch.Size([1, 3, 1, 128])),
(torch.Size([5, 1, 1, 64]), torch.Size([2, 3, 128, 1])),
),
)
@pytest.mark.parametrize(
"ttnn_fn",
[
ttnn.experimental.gte,
ttnn.experimental.gt,
ttnn.experimental.lte,
ttnn.experimental.lt,
ttnn.experimental.eq,
ttnn.experimental.ne,
ttnn.experimental.logical_and,
ttnn.experimental.logical_or,
ttnn.experimental.logical_xor,
ttnn.experimental.ldexp,
ttnn.experimental.logaddexp,
ttnn.experimental.logaddexp2,
ttnn.experimental.squared_difference,
ttnn.experimental.add,
ttnn.experimental.sub,
ttnn.experimental.mul,
ttnn.experimental.div,
ttnn.experimental.bias_gelu,
],
)
def test_binary_scalar_ops_invalid_bcast(input_shapes, ttnn_fn, device):
a_shape, b_shape = input_shapes
a_pt = torch.rand(a_shape).bfloat16()
b_pt = torch.rand(b_shape).bfloat16()

a_tt = ttnn.from_torch(a_pt, device=device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG)
b_tt = ttnn.from_torch(b_pt, device=device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG)

with pytest.raises(RuntimeError) as e:
cq_id = 0
_ = ttnn_fn(a_tt, b_tt, queue_id=cq_id)
assert "Broadcasting rule violation" in str(e.value)


@pytest.mark.parametrize(
"shapes",
[
Expand Down
1 change: 1 addition & 0 deletions ttnn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ set(ALL_TTNN_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/binary_ng/binary_ng_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_device_operation.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_program_factory.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/binary/binary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/binary/common/binary_op_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp
Expand Down
17 changes: 17 additions & 0 deletions ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,5 +52,22 @@ Tensor BinaryNg<binary_op_type>::invoke(
}

template struct BinaryNg<BinaryOpType::ADD>;
template struct BinaryNg<BinaryOpType::SUB>;
template struct BinaryNg<BinaryOpType::MUL>;
template struct BinaryNg<BinaryOpType::DIV>;
template struct BinaryNg<BinaryOpType::GT>;
template struct BinaryNg<BinaryOpType::LT>;
template struct BinaryNg<BinaryOpType::LTE>;
template struct BinaryNg<BinaryOpType::GTE>;
template struct BinaryNg<BinaryOpType::EQ>;
template struct BinaryNg<BinaryOpType::NE>;
template struct BinaryNg<BinaryOpType::SQUARED_DIFFERENCE>;
template struct BinaryNg<BinaryOpType::BIAS_GELU>;
template struct BinaryNg<BinaryOpType::LOGICAL_AND>;
template struct BinaryNg<BinaryOpType::LOGICAL_OR>;
template struct BinaryNg<BinaryOpType::LOGICAL_XOR>;
template struct BinaryNg<BinaryOpType::LDEXP>;
template struct BinaryNg<BinaryOpType::LOGADDEXP>;
template struct BinaryNg<BinaryOpType::LOGADDEXP2>;

} // namespace ttnn::operations::binary_ng
68 changes: 68 additions & 0 deletions ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,72 @@ namespace ttnn::experimental {
constexpr auto add = ttnn::register_operation_with_auto_launch_op<
"ttnn::experimental::add",
ttnn::operations::binary_ng::BinaryNg<operations::binary_ng::BinaryOpType::ADD>>();

constexpr auto sub = ttnn::register_operation_with_auto_launch_op<
"ttnn::experimental::sub",
ttnn::operations::binary_ng::BinaryNg<operations::binary_ng::BinaryOpType::SUB>>();

constexpr auto mul = ttnn::register_operation_with_auto_launch_op<
"ttnn::experimental::mul",
ttnn::operations::binary_ng::BinaryNg<operations::binary_ng::BinaryOpType::MUL>>();

constexpr auto div = ttnn::register_operation_with_auto_launch_op<
"ttnn::experimental::div",
ttnn::operations::binary_ng::BinaryNg<operations::binary_ng::BinaryOpType::DIV>>();

constexpr auto eq = ttnn::register_operation_with_auto_launch_op<
"ttnn::experimental::eq",
ttnn::operations::binary_ng::BinaryNg<operations::binary_ng::BinaryOpType::EQ>>();

constexpr auto ne = ttnn::register_operation_with_auto_launch_op<
"ttnn::experimental::ne",
ttnn::operations::binary_ng::BinaryNg<operations::binary_ng::BinaryOpType::NE>>();

constexpr auto gt = ttnn::register_operation_with_auto_launch_op<
"ttnn::experimental::gt",
ttnn::operations::binary_ng::BinaryNg<operations::binary_ng::BinaryOpType::GT>>();

constexpr auto gte = ttnn::register_operation_with_auto_launch_op<
"ttnn::experimental::gte",
ttnn::operations::binary_ng::BinaryNg<operations::binary_ng::BinaryOpType::GTE>>();

constexpr auto lt = ttnn::register_operation_with_auto_launch_op<
"ttnn::experimental::lt",
ttnn::operations::binary_ng::BinaryNg<operations::binary_ng::BinaryOpType::LT>>();

constexpr auto lte = ttnn::register_operation_with_auto_launch_op<
"ttnn::experimental::lte",
ttnn::operations::binary_ng::BinaryNg<operations::binary_ng::BinaryOpType::LTE>>();

constexpr auto squared_difference = ttnn::register_operation_with_auto_launch_op<
"ttnn::experimental::squared_difference",
ttnn::operations::binary_ng::BinaryNg<operations::binary_ng::BinaryOpType::SQUARED_DIFFERENCE>>();

constexpr auto bias_gelu = ttnn::register_operation_with_auto_launch_op<
"ttnn::experimental::bias_gelu",
ttnn::operations::binary_ng::BinaryNg<operations::binary_ng::BinaryOpType::BIAS_GELU>>();

constexpr auto logical_and = ttnn::register_operation_with_auto_launch_op<
"ttnn::experimental::logical_and",
ttnn::operations::binary_ng::BinaryNg<operations::binary_ng::BinaryOpType::LOGICAL_AND>>();

constexpr auto logical_or = ttnn::register_operation_with_auto_launch_op<
"ttnn::experimental::logical_or",
ttnn::operations::binary_ng::BinaryNg<operations::binary_ng::BinaryOpType::LOGICAL_OR>>();

constexpr auto logical_xor = ttnn::register_operation_with_auto_launch_op<
"ttnn::experimental::logical_xor",
ttnn::operations::binary_ng::BinaryNg<operations::binary_ng::BinaryOpType::LOGICAL_XOR>>();

constexpr auto ldexp = ttnn::register_operation_with_auto_launch_op<
"ttnn::experimental::ldexp",
ttnn::operations::binary_ng::BinaryNg<operations::binary_ng::BinaryOpType::LDEXP>>();

constexpr auto logaddexp = ttnn::register_operation_with_auto_launch_op<
"ttnn::experimental::logaddexp",
ttnn::operations::binary_ng::BinaryNg<operations::binary_ng::BinaryOpType::LOGADDEXP>>();

constexpr auto logaddexp2 = ttnn::register_operation_with_auto_launch_op<
"ttnn::experimental::logaddexp2",
ttnn::operations::binary_ng::BinaryNg<operations::binary_ng::BinaryOpType::LOGADDEXP2>>();
}
35 changes: 27 additions & 8 deletions ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng_pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,16 @@

namespace ttnn::operations::binary_ng {
namespace detail {
void bind_binary_ng_operation(py::module& module) {
using OperationType = decltype(ttnn::experimental::add);

template <typename T>
void bind_binary_ng_operation(py::module& module, T op, const std::string& docstring) {
bind_registered_operation(
module,
ttnn::experimental::add,
"Binary Add Ng Operation",
op,
docstring,

// tensor and scalar
ttnn::pybind_overload_t{
[](const OperationType& self,
[](const T& self,
const ttnn::Tensor& input_tensor_a,
const float scalar,
const std::optional<const DataType>& dtype,
Expand All @@ -38,7 +37,7 @@ void bind_binary_ng_operation(py::module& module) {

// tensor and tensor
ttnn::pybind_overload_t{
[](const OperationType& self,
[](const T& self,
const ttnn::Tensor& input_tensor_a,
const ttnn::Tensor& input_tensor_b,
const std::optional<const DataType>& dtype,
Expand All @@ -57,5 +56,25 @@ void bind_binary_ng_operation(py::module& module) {
}
} // namespace detail

void py_module(py::module& module) { detail::bind_binary_ng_operation(module); }
void py_module(py::module& module) {
detail::bind_binary_ng_operation(module, ttnn::experimental::add, "Binary Add Operation");
detail::bind_binary_ng_operation(module, ttnn::experimental::sub, "Binary Sub Operation");
detail::bind_binary_ng_operation(module, ttnn::experimental::mul, "Binary Mul Operation");
detail::bind_binary_ng_operation(module, ttnn::experimental::div, "Binary Div Operation");
detail::bind_binary_ng_operation(module, ttnn::experimental::gt, "Binary Greater Than Operation");
detail::bind_binary_ng_operation(module, ttnn::experimental::lt, "Binary Less Than Operation");
detail::bind_binary_ng_operation(module, ttnn::experimental::lte, "Binary Less Than or Equal To Operation");
detail::bind_binary_ng_operation(module, ttnn::experimental::gte, "Binary Greater Than or Equal To Operation");
detail::bind_binary_ng_operation(module, ttnn::experimental::eq, "Binary Equal Operation");
detail::bind_binary_ng_operation(module, ttnn::experimental::ne, "Binary Not Equal Operation");
detail::bind_binary_ng_operation(
module, ttnn::experimental::squared_difference, "Binary Squared Difference Operation");
detail::bind_binary_ng_operation(module, ttnn::experimental::bias_gelu, "Binary Bias GELU Operation");
detail::bind_binary_ng_operation(module, ttnn::experimental::logical_and, "Binary Logical And Operation");
detail::bind_binary_ng_operation(module, ttnn::experimental::logical_or, "Binary Logical Or Operation");
detail::bind_binary_ng_operation(module, ttnn::experimental::logical_xor, "Binary Logical Xor Operation");
detail::bind_binary_ng_operation(module, ttnn::experimental::ldexp, "Binary Ldexp Operation");
detail::bind_binary_ng_operation(module, ttnn::experimental::logaddexp, "Binary Logaddexp Operation");
detail::bind_binary_ng_operation(module, ttnn::experimental::logaddexp2, "Binary Logaddexp2 Operation");
}
} // namespace ttnn::operations::eltwise::binary_ng
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
#pragma once

#include "pybind11/pybind_fwd.hpp"
#include <string>

namespace py = pybind11;

namespace ttnn::operations::binary_ng {
namespace detail {
void bind_binary_ng_operation(py::module& module);
template <typename T>
void bind_binary_ng_operation(py::module& module, T op, const std::string& docstring);
}

void py_module(py::module& module);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,18 +90,18 @@ void BinaryNgDeviceOperation::validate_on_program_cache_hit(
const auto input_shape_b =
tensor_args.input_tensor_b.has_value() ? tensor_args.input_tensor_b->get_logical_shape() : ttnn::Shape{1, 1};

constexpr int max_rank = 4;
if (input_shape_a.rank() > 0 && input_shape_b.rank() > 0) {
for (int i = 1; i <= max_rank; i++) {
auto a_dim = i <= input_shape_a.rank() ? input_shape_a[-i] : 1;
auto b_dim = i <= input_shape_b.rank() ? input_shape_b[-i] : 1;
TT_FATAL(
a_dim == b_dim || a_dim == 1 || b_dim == 1,
"Broadcasting rule violation for rank {}, dim a: {}, dim b: {}",
i,
a_dim,
b_dim);
}
const int rank_a = input_shape_a.rank();
const int rank_b = input_shape_b.rank();
const int larger_rank = std::max(rank_a, rank_b);
for (int i = -1; i >= -larger_rank; --i) {
auto a_dim = (i >= -rank_a) ? input_shape_a[i] : 1;
auto b_dim = (i >= -rank_b) ? input_shape_b[i] : 1;
TT_FATAL(
a_dim == b_dim || a_dim == 1 || b_dim == 1,
"Broadcasting rule violation for rank {}, dim a: {}, dim b: {}",
i,
a_dim,
b_dim);
}
}

Expand Down
Loading

0 comments on commit 25c17eb

Please sign in to comment.