Skip to content

Commit

Permalink
#10033: Add forward support for gcd and lcm (#10241)
Browse files Browse the repository at this point in the history
* #10033: Add forward support for gcd

* #10033: Add forward support for lcm

* #10033: Update gcd and lcm
  • Loading branch information
mouliraj-mcw authored Sep 26, 2024
1 parent d1a7449 commit 28e3825
Show file tree
Hide file tree
Showing 7 changed files with 141 additions and 4 deletions.
2 changes: 2 additions & 0 deletions docs/source/ttnn/ttnn/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,8 @@ Pointwise Binary
ttnn.floor_div
ttnn.remainder
ttnn.fmod
ttnn.gcd
ttnn.lcm
ttnn.logical_and_
ttnn.logical_or_
ttnn.logical_xor_
Expand Down
13 changes: 13 additions & 0 deletions tests/ttnn/unit_tests/operations/backward/utility_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,19 @@ def data_gen_with_range(input_shapes, low, high, device, required_grad=False, is
return pt_tensor, tt_tensor


def data_gen_with_range_int(input_shapes, low, high, device, required_grad=False, is_row_major=False):
assert high > low, "Incorrect range provided"
torch.manual_seed(213919)
pt_tensor = torch.randint(low, high, input_shapes, dtype=torch.int32, requires_grad=required_grad)

if is_row_major:
tt_tensor = ttnn.Tensor(pt_tensor, ttnn.float32).to(ttnn.ROW_MAJOR_LAYOUT).to(device)
else:
tt_tensor = ttnn.Tensor(pt_tensor, ttnn.float32).to(ttnn.TILE_LAYOUT).to(device)

return pt_tensor, tt_tensor


def data_gen_with_val(input_shapes, device, required_grad=False, val=1, is_row_major=False):
pt_tensor = (torch.ones(input_shapes, requires_grad=required_grad) * val).bfloat16()
if is_row_major:
Expand Down
48 changes: 46 additions & 2 deletions tests/ttnn/unit_tests/operations/test_binary_composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,13 @@
import pytest
import random
import ttnn
from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc, compare_equal
from models.utility_functions import is_grayskull, skip_for_grayskull, skip_for_wormhole_b0
from tests.ttnn.unit_tests.operations.backward.utility_funcs import (
data_gen_with_range,
data_gen_with_range_int,
compare_pcc,
compare_equal,
)
from models.utility_functions import is_grayskull, skip_for_grayskull


@pytest.mark.parametrize(
Expand Down Expand Up @@ -845,4 +850,43 @@ def test_nei_ttnn(input_shapes, scalar, device):
golden_tensor = golden_function(in_data, scalar)

comp_pass = compare_equal([input_tensor], [golden_tensor])


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 32, 32])),
(torch.Size([1, 1, 320, 384])),
(torch.Size([1, 3, 320, 384])),
),
)
@skip_for_grayskull("#ToDo: GS implementation needs to be done for remainder")
def test_binary_gcd_ttnn(input_shapes, device):
in_data1, input_tensor1 = data_gen_with_range_int(input_shapes, -1024, 1024, device)
in_data2, input_tensor2 = data_gen_with_range_int(input_shapes, -1024, 1024, device)
output_tensor = ttnn.gcd(input_tensor1, input_tensor2)
golden_function = ttnn.get_golden_function(ttnn.gcd)
golden_tensor = golden_function(in_data1, in_data2)

comp_pass = compare_pcc([output_tensor], [golden_tensor])
assert comp_pass


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 32, 32])),
(torch.Size([1, 1, 320, 384])),
(torch.Size([1, 3, 320, 384])),
),
)
@skip_for_grayskull("#ToDo: GS implementation needs to be done for remainder")
def test_binary_lcm_ttnn(input_shapes, device):
in_data1, input_tensor1 = data_gen_with_range_int(input_shapes, -1024, 1024, device)
in_data2, input_tensor2 = data_gen_with_range_int(input_shapes, -1024, 1024, device)
output_tensor = ttnn.lcm(input_tensor1, input_tensor2)
golden_function = ttnn.get_golden_function(ttnn.lcm)
golden_tensor = golden_function(in_data1, in_data2)

comp_pass = compare_pcc([output_tensor], [golden_tensor])
assert comp_pass
18 changes: 17 additions & 1 deletion ttnn/cpp/ttnn/operations/eltwise/binary/binary_composite.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,17 @@ struct ExecuteBinaryRemainder
const std::optional<MemoryConfig>& memory_config = std::nullopt);
};

} // namespace binary
#define DEFINE_BINARY_COMPOSITE(op_name) \
struct Execute##op_name { \
static Tensor invoke( \
const Tensor& input_tensor_a, \
const Tensor& input_tensor_b, \
const std::optional<MemoryConfig>& memory_config = std::nullopt); \
};
DEFINE_BINARY_COMPOSITE(LCM)
DEFINE_BINARY_COMPOSITE(GCD)

} // namespace binary
} // namespace operations

constexpr auto hypot = ttnn::register_operation_with_auto_launch_op<
Expand Down Expand Up @@ -264,5 +274,11 @@ constexpr auto outer = ttnn::register_operation_with_auto_launch_op<
constexpr auto polyval = ttnn::register_operation_with_auto_launch_op<
"ttnn::polyval",
operations::binary::ExecuteBinaryCompositeOpsPolyval<operations::binary::BinaryCompositeOpType::POLYVAL>>();
constexpr auto gcd = ttnn::register_operation_with_auto_launch_op<
"ttnn::gcd",
operations::binary::ExecuteGCD>();
constexpr auto lcm = ttnn::register_operation_with_auto_launch_op<
"ttnn::lcm",
operations::binary::ExecuteLCM>();

} // namespace ttnn
16 changes: 16 additions & 0 deletions ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -879,6 +879,22 @@ void py_module(py::module& module) {
ttnn::logical_and_,
R"doc(Compute inplace logical AND of :attr:`input_tensor_a` and :attr:`input_tensor_b` and returns the tensor with the same layout as :attr:`input_tensor_a`)doc");

detail::bind_binary_composite(
module,
ttnn::gcd,
R"doc(Compute Greatest common divisor of :attr:`input_tensor_a` and :attr:`input_tensor_b` and returns the tensor with the same layout as :attr:`input_tensor_a`.
[supported range -1024 to 1024].)doc",
R"doc(\mathrm{output\_tensor}_i = \text{gcd}\left(\mathrm{input\_tensor\_a}_i , \mathrm{input\_tensor\_b}_i\right)
)doc");

detail::bind_binary_composite(
module,
ttnn::lcm,
R"doc(Compute Least common multiple of :attr:`input_tensor_a` and :attr:`input_tensor_b` and returns the tensor with the same layout as :attr:`input_tensor_a`.
[supported range -1024 to 1024].)doc",
R"doc(\mathrm{output\_tensor}_i = \text{lcm}\left(\mathrm{input\_tensor\_a}_i , \mathrm{input\_tensor\_b}_i\right)
)doc");

detail::bind_binary_composite_with_alpha(
module,
ttnn::addalpha,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,6 @@ Tensor ExecuteBinaryRemainder::invoke(const Tensor& input_a, const Tensor& input
return typecast(result, input_dtype);
}


Tensor ExecuteBinaryRemainder::invoke(const Tensor& input, float scalar, const std::optional<MemoryConfig>& output_mem_config) {
return ttnn::unary_remainder(input, scalar);
}
Expand Down Expand Up @@ -400,4 +399,33 @@ Tensor _polyval(const Tensor& input_a, const std::vector<float>& coeffs, const s
return final_tensor;
}

Tensor ExecuteGCD::invoke(const Tensor& input_a, const Tensor& input_b, const std::optional<MemoryConfig>& output_mem_config) {
Tensor input_a_abs = ttnn::abs(input_a);
Tensor input_b_abs = ttnn::abs(input_b);
Tensor a_gt_b = ttnn::gt(input_a_abs, input_b_abs);
Tensor min = ttnn::where(a_gt_b, input_b_abs, input_a_abs);
Tensor max = ttnn::where(a_gt_b, input_a_abs, input_b_abs);
a_gt_b.deallocate();
// https://en.wikipedia.org/wiki/Lam%C3%A9%27s_theorem
// While 186 is the theoretical maximum iterations for numbers within the floating point range according to Lame's
// theorem, in practice when evaluating gcd of consecutive Fibonacci numbers coerced to floating point, the
// maximum number of iterations reached is only 14 because the remainder converges to 0 much more quickly. In
// addition, limited precision in bfloat16 format decreases support for input to the range [-1024, 1024]
constexpr std::size_t max_iterations = 14;
for (std::size_t iteration = 0; iteration < max_iterations; ++iteration) {
Tensor isz = ttnn::eqz(min);
Tensor rem = ttnn::remainder(max, ttnn::where(isz, isz, min));
max = ttnn::where(isz, max, min);
min = rem;
}
return max;
}

Tensor ExecuteLCM::invoke(const Tensor& input_a, const Tensor& input_b, const std::optional<MemoryConfig>& output_mem_config) {
Tensor val = ttnn::multiply(input_a, input_b, std::nullopt, output_mem_config);
Tensor tmp_result = ttnn::gcd(input_a, input_b);
Tensor result = ttnn::div(val, tmp_result, false, "None", output_mem_config);
return ttnn::abs(result);
}

} // namespace ttnn::operations::binary
18 changes: 18 additions & 0 deletions ttnn/ttnn/operations/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,4 +437,22 @@ def _golden_function_ne_(input_tensor_a, input_tensor_b, *args, **kwargs):
ttnn.attach_golden_function(ttnn.ne_, golden_function=_golden_function_ne_)


def _golden_function_gcd(input_tensor_a, input_tensor_b, *args, **kwargs):
import torch

return torch.gcd(input_tensor_a, input_tensor_b)


ttnn.attach_golden_function(ttnn.gcd, golden_function=_golden_function_gcd)


def _golden_function_lcm(input_tensor_a, input_tensor_b, *args, **kwargs):
import torch

return torch.lcm(input_tensor_a, input_tensor_b)


ttnn.attach_golden_function(ttnn.lcm, golden_function=_golden_function_lcm)


__all__ = []

0 comments on commit 28e3825

Please sign in to comment.