Skip to content

Commit

Permalink
#15780: gcd , lcm in float 32
Browse files Browse the repository at this point in the history
  • Loading branch information
KalaivaniMCW committed Dec 11, 2024
1 parent 9b74440 commit 22e3d31
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 16 deletions.
69 changes: 59 additions & 10 deletions tests/ttnn/unit_tests/operations/eltwise/test_binary_composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -960,20 +960,32 @@ def test_nei_ttnn(input_shapes, scalar, device):
"input_shapes",
(
(torch.Size([1, 1, 32, 32])),
(torch.Size([1, 1, 64, 64])),
(torch.Size([1, 1, 320, 384])),
(torch.Size([1, 3, 320, 384])),
),
)
@skip_for_grayskull("#ToDo: GS implementation needs to be done for remainder")
def test_binary_gcd_ttnn(input_shapes, device):
in_data1, input_tensor1 = data_gen_with_range_int(input_shapes, -1024, 1024, device)
in_data2, input_tensor2 = data_gen_with_range_int(input_shapes, -1024, 1024, device)
torch.manual_seed(213919)
in_data1 = torch.randint(-1000, 1000, input_shapes, dtype=torch.int32)
in_data2 = torch.randint(-1024, 1024, input_shapes, dtype=torch.int32)
# in_data1 = torch.ones(input_shapes, dtype=torch.int32) * 10
# in_data2 = torch.ones(input_shapes, dtype=torch.int32) * 15
input_tensor1 = ttnn.from_torch(in_data1, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
input_tensor2 = ttnn.from_torch(in_data2, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)

output_tensor = ttnn.gcd(input_tensor1, input_tensor2)
golden_function = ttnn.get_golden_function(ttnn.gcd)
golden_tensor = golden_function(in_data1, in_data2)
# golden_tensor = execute_gcd(in_data1, in_data2)
output_tensor = ttnn.to_torch(output_tensor)
# print("TT***", output_tensor)
# print(golden_tensor)

comp_pass = compare_pcc([output_tensor], [golden_tensor])
assert comp_pass
# print(torch.all(output_tensor == golden_tensor))
pcc = ttnn.pearson_correlation_coefficient(golden_tensor, output_tensor)
assert pcc >= 0.99


@pytest.mark.parametrize(
Expand All @@ -987,16 +999,53 @@ def test_binary_gcd_ttnn(input_shapes, device):
@skip_for_grayskull("#ToDo: GS implementation needs to be done for remainder")
def test_binary_lcm_ttnn(input_shapes, device):
torch.manual_seed(213919)
in_data1 = torch.randint(-100, 100, input_shapes, dtype=torch.int32)
in_data2 = torch.randint(-80, 180, input_shapes, dtype=torch.int32)
input_tensor1 = ttnn.from_torch(in_data1, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
input_tensor2 = ttnn.from_torch(in_data2, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
in_data1 = torch.randint(1, 1000, input_shapes, dtype=torch.int32)
in_data2 = torch.randint(1, 1024, input_shapes, dtype=torch.int32)
# print("TT IN***", in_data1)
# print("TT IN***", in_data2)
# in_data1 = torch.ones(input_shapes, dtype=torch.int32) * 10
# in_data2 = torch.ones(input_shapes, dtype=torch.int32) * 15
input_tensor1 = ttnn.from_torch(in_data1, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
input_tensor2 = ttnn.from_torch(in_data2, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
output_tensor = ttnn.lcm(input_tensor1, input_tensor2)
golden_function = ttnn.get_golden_function(ttnn.lcm)
golden_tensor = golden_function(in_data1, in_data2)
output_tensor = ttnn.to_torch(output_tensor)
# print("TT***", output_tensor)
# print(golden_tensor)
# print("diff " , torch.max(torch.abs(output_tensor - golden_tensor)))
pcc = ttnn.pearson_correlation_coefficient(golden_tensor, output_tensor)
assert pcc >= 0.99

comp_pass = compare_pcc([output_tensor], [golden_tensor])
assert comp_pass

@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 32, 32])),
(torch.Size([1, 1, 320, 384])),
(torch.Size([1, 3, 320, 384])),
),
)
@skip_for_grayskull("#ToDo: GS implementation needs to be done for remainder")
# when both inputs are 0, torch=0, tt=nan so avoid 0s on input ?
def test_binary_lcm_ttnn_neg(input_shapes, device):
torch.manual_seed(213919)
in_data1 = torch.randint(-1000, -1, input_shapes, dtype=torch.int32)
in_data2 = torch.randint(-1024, -1, input_shapes, dtype=torch.int32)

input_tensor1 = ttnn.from_torch(in_data1, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
input_tensor2 = ttnn.from_torch(in_data2, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
# print("TT IN***", input_tensor1)
# print("TT IN***", input_tensor2)
output_tensor = ttnn.lcm(input_tensor1, input_tensor2)
golden_function = ttnn.get_golden_function(ttnn.lcm)
golden_tensor = golden_function(in_data1, in_data2)
output_tensor = ttnn.to_torch(output_tensor)
# print("TT***", output_tensor)
# print(golden_tensor)
# print("diff " , torch.max(torch.abs(output_tensor - golden_tensor)))
pcc = ttnn.pearson_correlation_coefficient(golden_tensor, output_tensor)
assert pcc >= 0.99


@pytest.mark.parametrize(
Expand Down
8 changes: 5 additions & 3 deletions tests/ttnn/unit_tests/operations/eltwise/test_binary_fp32.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def test_mul_fp32(device, ttnn_function):
assert status


@pytest.mark.skip(reason="This test will be enabled after #15780 is resolved")
# @pytest.mark.skip(reason="This test will be enabled after #15780 is resolved")
@skip_for_grayskull("Unsupported dtype for Grayskull")
@pytest.mark.parametrize(
"ttnn_function",
Expand All @@ -131,8 +131,8 @@ def test_mul_fp32(device, ttnn_function):
# Torch num/ 0 = inf and 0/0 nan; TT num/ 0 = inf and 0/0=nan; in fp32 tile
# Torch num/ 0 = inf and 0/0 nan; TT num/ 0 = inf and 0/0=0; in chained (mul * recip) div op
def test_div_fp32(device, ttnn_function):
x_torch = torch.tensor([[1.00030171126, -3, 16, -5, 14, -12, 0, 0, 1]], dtype=torch.float32)
y_torch = torch.tensor([[2, 3, -4, -5, 0, 0, 0, 1, 0]], dtype=torch.float32)
x_torch = torch.tensor([[1.00030171126, -3, 16, -5, 14, -12, 0, 0, 1, 15]], dtype=torch.float32)
y_torch = torch.tensor([[2, 3, -4, -5, 0, 0, 0, 1, 0, 10]], dtype=torch.float32)
golden_fn = ttnn.get_golden_function(ttnn_function)
z_torch = golden_fn(x_torch, y_torch)
x_tt = ttnn.from_torch(x_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
Expand All @@ -141,6 +141,8 @@ def test_div_fp32(device, ttnn_function):
z_tt_div = ttnn.divide(x_tt, y_tt)
tt_out = ttnn.to_torch(z_tt_div)

print("torch out in ttnn", ttnn.to_torch(z_tt))
print("tt out in torch", tt_out)
status = ttnn.pearson_correlation_coefficient(z_torch, tt_out) >= 0.999
assert status

Expand Down
57 changes: 57 additions & 0 deletions tests/ttnn/unit_tests/operations/eltwise/test_div_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import torch
import ttnn

import pytest
from models.utility_functions import skip_for_grayskull


@skip_for_grayskull("Unsupported dtype for Grayskull")
@pytest.mark.parametrize(
"ttnn_function",
[
ttnn.remainder,
],
)
def test_remainder_fp32(device, ttnn_function):
x_torch = torch.tensor([[15]], dtype=torch.float32)
y_torch = torch.tensor([[10]], dtype=torch.float32)
golden_fn = ttnn.get_golden_function(ttnn_function)
z_torch = golden_fn(x_torch, y_torch)
x_tt = ttnn.from_torch(x_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
y_tt = ttnn.from_torch(y_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
z_tt = ttnn.from_torch(z_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
z_tt_div = ttnn.remainder(x_tt, y_tt)
tt_out = ttnn.to_torch(z_tt_div)

# print("torch out in ttnn", ttnn.to_torch(z_tt))
# print("tt out in torch", tt_out)
status = ttnn.pearson_correlation_coefficient(z_torch, tt_out) >= 0.999
assert status


@skip_for_grayskull("Unsupported dtype for Grayskull")
@pytest.mark.parametrize(
"ttnn_function",
[
ttnn.abs,
],
)
def test_abs_fp32(device, ttnn_function):
x_torch = torch.tensor([[0, -1, 1, 1.99]], dtype=torch.float32)
y_torch = torch.tensor([[10]], dtype=torch.float32)
golden_fn = ttnn.get_golden_function(ttnn_function)
z_torch = golden_fn(x_torch)
x_tt = ttnn.from_torch(x_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
y_tt = ttnn.from_torch(y_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
z_tt = ttnn.from_torch(z_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
z_tt_div = ttnn.abs(x_tt)
tt_out = ttnn.to_torch(z_tt_div)

print("torch out in ttnn", ttnn.to_torch(z_tt))
print("tt out in torch", tt_out)
status = ttnn.pearson_correlation_coefficient(z_torch, tt_out) >= 0.999
assert status
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,9 @@ Tensor ExecuteDiv::invoke(
"Incorrect rounding mode (expected None, 'trunc', or 'floor')");
output_tensor = output_tensor.value_or(ttnn::empty_like(input_a));
auto arch = input_a.device()->arch();
if (arch == tt::ARCH::WORMHOLE_B0) {
if (arch != tt::ARCH::GRAYSKULL) {
DataType input_dtype = input_a.get_dtype();

Tensor a = typecast(queue_id, input_a, DataType::FLOAT32);
Tensor b = typecast(queue_id, input_b, DataType::FLOAT32);
Tensor result = ttnn::divide(queue_id, a, b);
Expand All @@ -235,6 +236,10 @@ Tensor ExecuteDiv::invoke(
result = ttnn::floor(queue_id, result);
}

if (input_dtype == DataType::FLOAT32 && input_b.get_dtype() == DataType::FLOAT32) {
return result;
}

if (accurate_mode == false) { // If input_b is non-zero tensor
return typecast(queue_id, result, input_dtype, std::nullopt, output_tensor);
}
Expand Down Expand Up @@ -503,15 +508,19 @@ Tensor ExecuteGCD::invoke(
Tensor min = ttnn::where(a_gt_b, input_b_abs, input_a_abs);
Tensor max = ttnn::where(a_gt_b, input_a_abs, input_b_abs);
a_gt_b.deallocate();

// https://en.wikipedia.org/wiki/Lam%C3%A9%27s_theorem
// While 186 is the theoretical maximum iterations for numbers within the floating point range according to Lame's
// theorem, in practice when evaluating gcd of consecutive Fibonacci numbers coerced to floating point, the
// maximum number of iterations reached is only 14 because the remainder converges to 0 much more quickly. In
// addition, limited precision in bfloat16 format decreases support for input to the range [-1024, 1024]

constexpr std::size_t max_iterations = 14;
for (std::size_t iteration = 0; iteration < max_iterations; ++iteration) {
Tensor isz = ttnn::eqz(min);
Tensor rem = ttnn::remainder(max, ttnn::where(isz, isz, min));
Tensor non_zero_min =
ttnn::where(isz, isz, min); // when isz=1, true_val=1, else min; 0's in min are replaced with 1
Tensor rem = ttnn::remainder(max, non_zero_min);
max = ttnn::where(isz, max, min);
min = rem;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ namespace utils {
case BinaryOpType::ADD: return ((a == DataType::FLOAT32 && b == DataType::FLOAT32) || (a == DataType::INT32 && b == DataType::INT32));
case BinaryOpType::SUB:
case BinaryOpType::MUL:
// case BinaryOpType::DIV_FAST: will be enabled after #15780 is resolved
case BinaryOpType::DIV_FAST:
case BinaryOpType::RSUB:
case BinaryOpType::LOGADDEXP:
case BinaryOpType::LOGADDEXP2:
Expand Down

0 comments on commit 22e3d31

Please sign in to comment.