Skip to content

Commit

Permalink
#16144: common utils and test scalar case
Browse files Browse the repository at this point in the history
  • Loading branch information
KalaivaniMCW committed Jan 10, 2025
1 parent 42f5619 commit 1fc8c3f
Show file tree
Hide file tree
Showing 28 changed files with 1,199 additions and 259 deletions.
249 changes: 249 additions & 0 deletions tests/ttnn/unit_tests/operations/eltwise/test_binary_bcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,3 +255,252 @@ def test_01_volume_tensors(device, a, b, c_golden, memory_config):

assert c.tolist() == c_golden


@pytest.mark.parametrize(
"input_shapes",
(
# (torch.Size([5, 3, 32, 32]), torch.Size([5, 3, 32, 32])),
# (torch.Size([1, 3, 64, 64]), torch.Size([5, 3, 64, 64])), # batch bcast
# (torch.Size([2, 1, 1, 64]), torch.Size([2, 1, 32, 1])), # rowA colB bcast
# (torch.Size([2, 1, 1, 64]), torch.Size([2, 1, 128, 1])), # rowA colB bcast
(torch.Size([2, 1, 2, 2]), torch.Size([2, 1, 2, 2])), # rowA colB bcast
# (torch.Size([5, 3, 32, 64]), torch.Size([5, 3, 32, 64])),
# (torch.Size([5, 3, 64, 32]), torch.Size([5, 3, 64, 32])),
# (torch.Size([5,3,1,1]), torch.Size([5,3,1,1])), # (torch.Size([5, 1, 64, 1]), torch.Size([1, 3, 1, 128])),
# (torch.Size([5, 3, 64, 32]), torch.Size([5, 3, 1, 32])), # (torch.Size([5, 1, 64, 1]), torch.Size([1, 3, 1, 128])),
),
)
@pytest.mark.parametrize(
"ttnn_fn",
[
ttnn.experimental.sub,
],
)
def test_binary_ng(input_shapes, ttnn_fn, device):
a_shape, b_shape = input_shapes
# a_pt = torch.rand(a_shape).bfloat16()
# b_pt = torch.rand(b_shape).bfloat16()
a_pt = torch.ones(a_shape, dtype=torch.bfloat16) * 1
# b_pt = torch.ones(b_shape, dtype=torch.bfloat16) * 7
b_pt = 0.1111111

a_tt = ttnn.from_torch(a_pt, device=device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG)
# b_tt = ttnn.from_torch(b_pt, device=device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG)
b_tt = 0.1111111
cq_id = 0
out_tt = ttnn_fn(a_tt, b_tt, queue_id=cq_id)
golden_fn = ttnn.get_golden_function(ttnn_fn)
out_pt = golden_fn(a_pt, b_pt)
torch.set_printoptions(linewidth=200, threshold=10000, precision=15, sci_mode=False, edgeitems=17)
print(ttnn.to_torch(out_tt))
print(out_pt)
print(a_pt - b_pt)
# comp_pass = compare_pcc([out_tt], [out_pt])
comp_pass = ttnn.pearson_correlation_coefficient(out_pt, out_tt)
assert comp_pass >= 0.99988


@pytest.mark.parametrize(
"input_shapes",
[
[[2, 1, 2, 2], [2, 1, 2, 2]],
[[2, 1, 1, 1], [2, 1, 2, 2]],
[[1, 1, 4, 1], [1, 1, 1, 4]],
[[1, 1, 4, 4], [1, 1, 1, 4]],
[[1, 1, 4, 4], [1, 1, 1, 1]],
# [[5, 1, 64, 1], [1, 3, 1, 128]],
],
)
@pytest.mark.parametrize(
"ttnn_fn",
[
ttnn.experimental.sub,
],
)
def test_binary_ng_fp32(input_shapes, ttnn_fn, device):
[a_shape, b_shape] = input_shapes
x_torch = torch.ones(a_shape, dtype=torch.float32) * 2
y_torch = torch.ones(b_shape, dtype=torch.float32) * 0.00030171126
# y_torch = 0.00030171126
golden_fn = ttnn.get_golden_function(ttnn_fn)
z_torch = torch.sub(torch.square(x_torch), -y_torch)
x_tt = ttnn.from_torch(x_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
y_tt = ttnn.from_torch(y_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
# y_tt = 0.00030171126
z_tt = ttnn.from_torch(z_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
z_tt_sub = ttnn_fn(x_tt, y_tt, lhs_activations=[ttnn.UnaryOpType.SQUARE], rhs_activations=[ttnn.UnaryOpType.NEG])
# ttnn.set_printoptions(profile="full")
# print("tt ", z_tt_sub)
tt_out = ttnn.to_torch(z_tt_sub)

torch.set_printoptions(linewidth=200, threshold=10000, precision=15, sci_mode=False, edgeitems=17)
print("torch", z_torch)
print("tt ", tt_out)

status = torch.allclose(z_torch, tt_out, atol=1e-10, rtol=1e-5, equal_nan=False)
assert status


@pytest.mark.parametrize(
"input_shapes",
[
[[2, 1, 2, 2], [2, 1, 2, 2]],
[[2, 1, 1, 1], [2, 1, 2, 2]],
[[1, 1, 4, 1], [1, 1, 1, 4]],
[[1, 1, 4, 4], [1, 1, 1, 4]],
[[1, 1, 4, 4], [1, 1, 1, 1]],
# [[5, 1, 64, 1], [1, 3, 1, 128]],
],
)
@pytest.mark.parametrize(
"ttnn_fn",
[
ttnn.experimental.sub,
],
)
def test_binary_ng_fp32sc(input_shapes, ttnn_fn, device):
[a_shape, b_shape] = input_shapes
x_torch = torch.ones(a_shape, dtype=torch.float32) * 2
y_torch = 0.00030171126
golden_fn = ttnn.get_golden_function(ttnn_fn)
z_torch = torch.sub(torch.square(x_torch), -y_torch)
x_tt = ttnn.from_torch(x_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
y_tt = 0.00030171126
z_tt = ttnn.from_torch(z_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
z_tt_sub = ttnn_fn(x_tt, y_tt, lhs_activations=[ttnn.UnaryOpType.SQUARE], rhs_activations=[ttnn.UnaryOpType.NEG])
# ttnn.set_printoptions(profile="full")
# print("tt ", z_tt_sub)
tt_out = ttnn.to_torch(z_tt_sub)

torch.set_printoptions(linewidth=200, threshold=10000, precision=15, sci_mode=False, edgeitems=17)
print("torch", z_torch)
print("tt ", tt_out)

status = torch.allclose(z_torch, tt_out, atol=1e-10, rtol=1e-5, equal_nan=False)
assert status


@pytest.mark.parametrize(
"input_shapes",
((torch.Size([2, 1, 2, 2]), torch.Size([2, 1, 2, 2])),),
# ((torch.Size([2, 1, 1, 2]), torch.Size([2, 1, 2, 2])),),
# ((torch.Size([2, 1, 1, 32]), torch.Size([2, 1, 32, 1])),),
# ((torch.Size([2, 1, 2, 2]), torch.Size([2, 1, 1, 1])),),
# (torch.Size([5, 1, 64, 1]), torch.Size([1, 3, 1, 128])),
)
@pytest.mark.parametrize(
"ttnn_fn",
[
ttnn.experimental.add,
],
)
def test_binary_ng_int32(input_shapes, ttnn_fn, device):
a_shape, b_shape = input_shapes
x_torch = torch.ones(a_shape, dtype=torch.int32) * 2
y_torch = torch.ones(b_shape, dtype=torch.int32) * -10
# y_torch = -10
golden_fn = ttnn.get_golden_function(ttnn_fn)
z_torch = golden_fn(x_torch, y_torch)
# z_torch = torch.add(torch.square(x_torch), y_torch)
x_tt = ttnn.from_torch(x_torch, dtype=ttnn.int32, layout=ttnn.TILE_LAYOUT, device=device)
y_tt = ttnn.from_torch(y_torch, dtype=ttnn.int32, layout=ttnn.TILE_LAYOUT, device=device)
# y_tt = -10
z_tt = ttnn.from_torch(z_torch, dtype=ttnn.int32, layout=ttnn.TILE_LAYOUT, device=device)
z_tt_sub = ttnn_fn(
x_tt,
y_tt,
) # lhs_activations=[ttnn.UnaryOpType.SQUARE]
tt_out = ttnn.to_torch(z_tt_sub)

torch.set_printoptions(linewidth=200, threshold=10000, precision=15, sci_mode=False, edgeitems=17)
print("torch", z_torch)
print("tt ", tt_out)

status = torch.allclose(z_torch, tt_out, atol=1e-10, rtol=1e-5, equal_nan=False)
assert status


@pytest.mark.parametrize(
"input_shapes",
[
[[2, 1, 2, 2], [2, 1, 2, 2]],
# [[2, 1, 1, 1], [2, 1, 2, 2]],
# [[1, 1, 4, 1], [1, 1, 1, 4]],
# [[1, 1, 4, 4], [1, 1, 1, 4]],
# [[1, 1, 1, 1], [1, 1, 4, 4]],
# [[5, 1, 64, 1], [1, 3, 1, 128]],
],
)
@pytest.mark.parametrize(
"ttnn_fn",
[
ttnn.experimental.sub,
],
)
def test_binary_ng_fp32_activ(input_shapes, ttnn_fn, device):
a_shape, b_shape = input_shapes
x_torch = torch.ones(a_shape, dtype=torch.float32) * 2
# y_torch = torch.ones(b_shape, dtype=torch.float32) * 0.00030171126
y_torch = 0.00030171122
golden_fn = ttnn.get_golden_function(ttnn_fn)
# z_torch = golden_fn(x_torch, y_torch)
z_torch = torch.sub(torch.square(x_torch), -y_torch)
x_tt = ttnn.from_torch(x_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
# y_tt = ttnn.from_torch(y_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
y_tt = 0.00030171122
z_tt = ttnn.from_torch(z_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
# z_tt_sub = ttnn_fn(x_tt, y_tt,)
z_tt_sub = ttnn_fn(x_tt, y_tt, lhs_activations=[ttnn.UnaryOpType.SQUARE], rhs_activations=[ttnn.UnaryOpType.NEG])
# ttnn.set_printoptions(profile="full")
# print("tt ", z_tt_sub)
tt_out = ttnn.to_torch(z_tt_sub)

torch.set_printoptions(linewidth=200, threshold=10000, precision=15, sci_mode=False, edgeitems=17)
print("torch", z_torch)
print("tt ", tt_out)

status = torch.allclose(z_torch, tt_out, atol=1e-10, rtol=1e-5, equal_nan=False)
assert status


@pytest.mark.parametrize(
"input_shapes",
[
[[2, 1, 2, 2], [2, 1, 2, 2]],
# [[2, 1, 1, 1], [2, 1, 2, 2]],
# [[1, 1, 4, 1], [1, 1, 1, 4]],
# [[1, 1, 4, 4], [1, 1, 1, 4]],
# [[1, 1, 1, 1], [1, 1, 4, 4]],
# [[5, 1, 64, 1], [1, 3, 1, 128]],
],
)
@pytest.mark.parametrize(
"ttnn_fn",
[
ttnn.experimental.sub,
],
)
def test_binary_ng_bf16_activ(input_shapes, ttnn_fn, device):
a_shape, b_shape = input_shapes
x_torch = torch.ones(a_shape, dtype=torch.bfloat16) * 2
# y_torch = torch.ones(b_shape, dtype=torch.bfloat16) * 0.00030171126
y_torch = 0.00030171122
golden_fn = ttnn.get_golden_function(ttnn_fn)
# z_torch = golden_fn(x_torch, y_torch)
z_torch = torch.sub(torch.square(x_torch), -y_torch)
x_tt = ttnn.from_torch(x_torch, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
# y_tt = ttnn.from_torch(y_torch, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
y_tt = 0.00030171122
z_tt = ttnn.from_torch(z_torch, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
# z_tt_sub = ttnn_fn(x_tt, y_tt,)
z_tt_sub = ttnn_fn(x_tt, y_tt, lhs_activations=[ttnn.UnaryOpType.SQUARE], rhs_activations=[ttnn.UnaryOpType.NEG])
# ttnn.set_printoptions(profile="full")
# print("tt ", z_tt_sub)
tt_out = ttnn.to_torch(z_tt_sub)

torch.set_printoptions(linewidth=200, threshold=10000, precision=15, sci_mode=False, edgeitems=17)
print("torch", z_torch)
print("tt ", tt_out)

status = torch.allclose(z_torch, tt_out, atol=1e-10, rtol=1e-5, equal_nan=False)
assert status
15 changes: 12 additions & 3 deletions ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,13 @@ Tensor BinaryNg<binary_op_type>::invoke(
tt::stl::Span<const ttnn::operations::unary::UnaryOpType> lhs_activations,
tt::stl::Span<const ttnn::operations::unary::UnaryOpType> rhs_activations,
tt::stl::Span<const ttnn::operations::unary::UnaryOpType> post_activations) {
auto input_a = typecast_to(DataType::BFLOAT16, input_tensor_a);
auto input_b = typecast_to(DataType::BFLOAT16, input_tensor_b);
Tensor input_a = input_tensor_a;
Tensor input_b = input_tensor_b;
if (input_tensor_a.get_dtype() == DataType::BFLOAT8_B || input_tensor_b.get_dtype() == DataType::BFLOAT8_B ||
input_tensor_a.get_dtype() == DataType::BFLOAT4_B || input_tensor_b.get_dtype() == DataType::BFLOAT4_B) {
input_a = typecast_to(DataType::BFLOAT16, input_tensor_a);
input_b = typecast_to(DataType::BFLOAT16, input_tensor_b);
}

return ttnn::prim::binary_ng(
queue_id,
Expand Down Expand Up @@ -68,7 +73,10 @@ Tensor BinaryNg<binary_op_type>::invoke(
tt::stl::Span<const ttnn::operations::unary::UnaryOpType> lhs_activations,
tt::stl::Span<const ttnn::operations::unary::UnaryOpType> rhs_activations,
tt::stl::Span<const ttnn::operations::unary::UnaryOpType> post_activations) {
auto input_a = typecast_to(DataType::BFLOAT16, input_tensor_a);
Tensor input_a = input_tensor_a;
if (input_tensor_a.get_dtype() == DataType::BFLOAT8_B || input_tensor_a.get_dtype() == DataType::BFLOAT4_B) {
input_a = typecast_to(DataType::BFLOAT16, input_tensor_a);
}

return ttnn::prim::binary_ng(
queue_id,
Expand Down Expand Up @@ -109,6 +117,7 @@ template struct BinaryNg<BinaryOpType::ADD>;
template struct BinaryNg<BinaryOpType::SUB>;
template struct BinaryNg<BinaryOpType::MUL>;
template struct BinaryNg<BinaryOpType::DIV>;
template struct BinaryNg<BinaryOpType::RSUB>;
template struct BinaryNg<BinaryOpType::GT>;
template struct BinaryNg<BinaryOpType::LT>;
template struct BinaryNg<BinaryOpType::LTE>;
Expand Down
4 changes: 4 additions & 0 deletions ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ constexpr auto sub = ttnn::register_operation_with_auto_launch_op<
"ttnn::experimental::sub",
ttnn::operations::binary_ng::BinaryNg<operations::binary_ng::BinaryOpType::SUB>>();

constexpr auto rsub = ttnn::register_operation_with_auto_launch_op<
"ttnn::experimental::rsub",
ttnn::operations::binary_ng::BinaryNg<operations::binary_ng::BinaryOpType::RSUB>>();

constexpr auto mul = ttnn::register_operation_with_auto_launch_op<
"ttnn::experimental::mul",
ttnn::operations::binary_ng::BinaryNg<operations::binary_ng::BinaryOpType::MUL>>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ void py_module(py::module& module) {
detail::bind_binary_ng_operation(module, ttnn::experimental::sub, "Binary Sub Operation");
detail::bind_binary_ng_operation(module, ttnn::experimental::mul, "Binary Mul Operation");
detail::bind_binary_ng_operation(module, ttnn::experimental::div, "Binary Div Operation");
detail::bind_binary_ng_operation(module, ttnn::experimental::rsub, "Binary Rsub Operation");
detail::bind_binary_ng_operation(module, ttnn::experimental::gt, "Binary Greater Than Operation");
detail::bind_binary_ng_operation(module, ttnn::experimental::lt, "Binary Less Than Operation");
detail::bind_binary_ng_operation(module, ttnn::experimental::lte, "Binary Less Than or Equal To Operation");
Expand Down
Loading

0 comments on commit 1fc8c3f

Please sign in to comment.