From 98a5a851bf7f92a51f1276c812f156ab7c03b570 Mon Sep 17 00:00:00 2001 From: VirdhatchaniKN Date: Wed, 10 Apr 2024 08:02:04 +0000 Subject: [PATCH] #7280: Add support for binary division --- docs/source/ttnn/ttnn/api.rst | 1 + docs/source/ttnn/ttnn/ttnn/binary_div.rst | 6 ++ .../python_api_testing/sweep_tests/op_map.py | 4 + .../pytests/tt_dnn/test_composite.py | 2 + .../sweep_tests/pytorch_ops.py | 5 ++ .../sweep_tests/tt_lib_ops.py | 1 + tests/ttnn/sweep_tests/sweeps/binary_div.py | 81 +++++++++++++++++++ .../unit_tests/operations/test_math_binary.py | 6 ++ .../op_library/composite/composite_ops.cpp | 20 +++++ .../op_library/composite/composite_ops.hpp | 5 ++ .../tt_lib_bindings_tensor_composite_ops.cpp | 14 ++++ ttnn/ttnn/__init__.py | 1 + ttnn/ttnn/operations/math.py | 2 + 13 files changed, 148 insertions(+) create mode 100644 docs/source/ttnn/ttnn/ttnn/binary_div.rst create mode 100644 tests/ttnn/sweep_tests/sweeps/binary_div.py diff --git a/docs/source/ttnn/ttnn/api.rst b/docs/source/ttnn/ttnn/api.rst index f01e1dccf2ac..8fff1699d665 100644 --- a/docs/source/ttnn/ttnn/api.rst +++ b/docs/source/ttnn/ttnn/api.rst @@ -189,6 +189,7 @@ Pointwise Binary ttnn/nextafter ttnn/maximum ttnn/minimum + ttnn/binary_div Pointwise Ternary ================= diff --git a/docs/source/ttnn/ttnn/ttnn/binary_div.rst b/docs/source/ttnn/ttnn/ttnn/binary_div.rst new file mode 100644 index 000000000000..cb3a6d88f299 --- /dev/null +++ b/docs/source/ttnn/ttnn/ttnn/binary_div.rst @@ -0,0 +1,6 @@ +.. _ttnn.binary_div: + +ttnn.binary_div +############### + +.. autofunction:: ttnn.binary_div diff --git a/tests/tt_eager/python_api_testing/sweep_tests/op_map.py b/tests/tt_eager/python_api_testing/sweep_tests/op_map.py index 29c474531d20..a74b25453e7a 100644 --- a/tests/tt_eager/python_api_testing/sweep_tests/op_map.py +++ b/tests/tt_eager/python_api_testing/sweep_tests/op_map.py @@ -298,6 +298,10 @@ "tt_lib_op": tt_lib_ops.eltwise_elu, "pytorch_op": pytorch_ops.elu, }, + "eltwise-binary_div": { + "tt_lib_op": tt_lib_ops.eltwise_binary_div, + "pytorch_op": pytorch_ops.binary_div, + }, "eltwise-square": { "tt_lib_op": tt_lib_ops.eltwise_square, "pytorch_op": pytorch_ops.square, diff --git a/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_composite.py b/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_composite.py index b6fe77e640d3..202442b841e1 100644 --- a/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_composite.py +++ b/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_composite.py @@ -112,6 +112,7 @@ def custom_compare(*args, **kwargs): "polygamma", "nextafter", "scatter", + "binary_div", ), shapes, ) @@ -190,6 +191,7 @@ def test_run_eltwise_composite_test(fn, input_shapes, device, function_level_def "isclose", "assign_binary", "nextafter", + "binary_div", ]: num_inputs = 2 diff --git a/tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py b/tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py index 4234dfa799cd..116ee1b5daa9 100644 --- a/tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py +++ b/tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py @@ -680,6 +680,11 @@ def silu(x, *args, **kwargs): return torch.nn.functional.silu(x) +def binary_div(x, y, *args, **kwargs): + result = torch.div(x, y) + return result + + def div_unary(x, *args, scalar, **kwargs): result = torch.div(x, scalar) return result diff --git a/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py b/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py index 26b7b26983cd..2f8be2669e58 100644 --- a/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py +++ b/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py @@ -2231,6 +2231,7 @@ def binary_op( eltwise_mul = make_binary_op(ttl.tensor.mul) eltwise_bias_gelu = make_binary_op(ttl.tensor.bias_gelu) eltwise_squared_difference = make_binary_op(ttl.tensor.squared_difference) +eltwise_binary_div = make_binary_op(ttl.tensor.binary_div) eltwise_hypot = make_binary_op(ttl.tensor.hypot) eltwise_scatter = make_binary_op(ttl.tensor.scatter) eltwise_atan2 = make_binary_op(ttl.tensor.atan2) diff --git a/tests/ttnn/sweep_tests/sweeps/binary_div.py b/tests/ttnn/sweep_tests/sweeps/binary_div.py new file mode 100644 index 000000000000..bb6952773d3b --- /dev/null +++ b/tests/ttnn/sweep_tests/sweeps/binary_div.py @@ -0,0 +1,81 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple + +import torch + +import ttnn + +from tests.ttnn.utils_for_testing import check_with_pcc +from models.utility_functions import torch_random + + +parameters = { + "batch_sizes": [(1,)], + "height": [32, 384, 1024], + "width": [32, 1024, 4096], + "input_a_dtype": [ttnn.bfloat16], + "input_b_dtype": [ttnn.bfloat16], + "input_a_layout": [ttnn.TILE_LAYOUT], + "input_b_layout": [ttnn.TILE_LAYOUT], + "input_b_memory_config": [ttnn.DRAM_MEMORY_CONFIG], + "input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG], +} + + +def skip(**_) -> Tuple[bool, Optional[str]]: + return False, None + + +def is_expected_to_fail(**_) -> Tuple[bool, Optional[str]]: + return False, None + + +def torch_binary_div(x, y, *args, **kwargs): + return torch.div(x, y) + + +def run( + batch_sizes, + height, + width, + input_a_dtype, + input_b_dtype, + input_a_layout, + input_b_layout, + input_b_memory_config, + input_a_memory_config, + output_memory_config, + *, + device, +) -> Tuple[bool, Optional[str]]: + input_shape = (*batch_sizes, height, width) + + low = -100 + high = 100 + + torch_input_tensor_a = torch_random(input_shape, low, high, dtype=torch.bfloat16) + torch_input_tensor_b = torch_random(input_shape, low, high, dtype=torch.bfloat16) + torch_output_tensor = torch_binary_div(torch_input_tensor_a, torch_input_tensor_b) + + input_tensor_a = ttnn.from_torch( + torch_input_tensor_a, + dtype=input_a_dtype, + device=device, + layout=input_a_layout, + memory_config=input_a_memory_config, + ) + input_tensor_b = ttnn.from_torch( + torch_input_tensor_b, + dtype=input_b_dtype, + device=device, + layout=input_b_layout, + memory_config=input_b_memory_config, + ) + output_tensor = ttnn.binary_div(input_tensor_a, input_tensor_b, memory_config=output_memory_config) + output_tensor = ttnn.to_torch(output_tensor) + + return check_with_pcc(torch_output_tensor, output_tensor) diff --git a/tests/ttnn/unit_tests/operations/test_math_binary.py b/tests/ttnn/unit_tests/operations/test_math_binary.py index b62016586015..41e2f0b45650 100644 --- a/tests/ttnn/unit_tests/operations/test_math_binary.py +++ b/tests/ttnn/unit_tests/operations/test_math_binary.py @@ -96,6 +96,12 @@ def test_squared_difference(device, h, w): run_math_binary_test_range(device, h, w, ttnn.squared_difference, torch_squared_difference, -100, 100) +@pytest.mark.parametrize("h", [64]) +@pytest.mark.parametrize("w", [128]) +def test_binary_div(device, h, w): + run_math_binary_test_range(device, h, w, ttnn.binary_div, torch.div, -100, 100) + + @pytest.mark.parametrize("h", [64]) @pytest.mark.parametrize("w", [128]) def test_hypot(device, h, w): diff --git a/tt_eager/tt_dnn/op_library/composite/composite_ops.cpp b/tt_eager/tt_dnn/op_library/composite/composite_ops.cpp index c87a6b1f2e62..cd1f44a23081 100644 --- a/tt_eager/tt_dnn/op_library/composite/composite_ops.cpp +++ b/tt_eager/tt_dnn/op_library/composite/composite_ops.cpp @@ -834,6 +834,26 @@ Tensor addcdiv( return operation::decorate_as_composite(__func__, _addcdiv)(input_a, input_b, input_c, value, output_mem_config); } +Tensor _binary_div( + const Tensor& input_a, + const Tensor& input_b, + const MemoryConfig& output_mem_config) { + Tensor result = mul(input_a, recip(input_b, output_mem_config), std::nullopt, output_mem_config); + Tensor t_inf = full_like(input_a, std::numeric_limits::infinity(), output_mem_config); + Tensor t_nan = full_like(input_a, std::nanf(""), output_mem_config); + return where(eqz(input_b, output_mem_config), + where(eqz(input_a, output_mem_config), + t_nan, + mul(t_inf, sign(input_a, output_mem_config), std::nullopt, output_mem_config), output_mem_config), + result, output_mem_config); +} +Tensor binary_div( + const Tensor& input_a, + const Tensor& input_b, + const MemoryConfig& output_mem_config) { + return operation::decorate_as_composite(__func__, _binary_div)(input_a, input_b, output_mem_config); +} + // logit(input, eps)=log(input / 1 - input) Tensor _logit(const Tensor& input_a, float eps, const MemoryConfig& output_mem_config) { Tensor t_eps = full_like(input_a, eps, output_mem_config); diff --git a/tt_eager/tt_dnn/op_library/composite/composite_ops.hpp b/tt_eager/tt_dnn/op_library/composite/composite_ops.hpp index ffd734044ccf..3282a6ec7ca2 100644 --- a/tt_eager/tt_dnn/op_library/composite/composite_ops.hpp +++ b/tt_eager/tt_dnn/op_library/composite/composite_ops.hpp @@ -160,6 +160,11 @@ Tensor addcdiv( float value, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); +Tensor binary_div( + const Tensor& input_a, + const Tensor& input_b, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + // xlogy(x,y))=x*log(y) Tensor xlogy( const Tensor& input_a, diff --git a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_composite_ops.cpp b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_composite_ops.cpp index 78d3551bd3ab..958287521437 100644 --- a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_composite_ops.cpp +++ b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_composite_ops.cpp @@ -775,7 +775,21 @@ namespace tt::tt_metal::detail{ "output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No" )doc"); + m_tensor.def("binary_div", &binary_div, + py::arg("input_a").noconvert(), py::arg("input_b").noconvert(), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc( + Performs the element-wise division of ``input_a`` by ``input_b``. + + Input tensor must have BFLOAT16 data type. + + Output tensor will have BFLOAT16 data type. + .. csv-table:: + :header: "Argument", "Description", "Data type", "Valid range", "Required" + + "input_a", "Numerator Tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes" + "input_b", "Denominator Tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes" + "output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No" + )doc"); m_tensor.def("mac", py::overload_cast(&mac), py::arg("input").noconvert(), py::arg("tensor1").noconvert(), py::arg("tensor2").noconvert(), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc( diff --git a/ttnn/ttnn/__init__.py b/ttnn/ttnn/__init__.py index 9b0d94699665..220fa4c4c4b8 100644 --- a/ttnn/ttnn/__init__.py +++ b/ttnn/ttnn/__init__.py @@ -385,6 +385,7 @@ def manage_config_attribute(name, value): square, tril, triu, + binary_div, ) from ttnn.operations.normalization import ( diff --git a/ttnn/ttnn/operations/math.py b/ttnn/ttnn/operations/math.py index 1fb546237275..c6003745ebf7 100644 --- a/ttnn/ttnn/operations/math.py +++ b/ttnn/ttnn/operations/math.py @@ -159,6 +159,7 @@ def _golden_function_binary(input_tensor_a: ttnn.Tensor, input_tensor_b: ttnn.Te "atan2": torch.atan2, "hypot": torch.hypot, "squared_difference": torch_squared_difference, + "binary_div": torch.div, } torch_function = name_to_golden_function_function[name] return torch_function(input_tensor_a, input_tensor_b) @@ -237,6 +238,7 @@ def math_binary_function( ("atan2", ttl.tensor.atan2, "atan2"), ("hypot", ttl.tensor.hypot, "hypotenuse"), ("squared_difference", ttl.tensor.squared_difference, "squared_difference (input_a - input_b)^2"), + ("binary_div", ttl.tensor.binary_div, "division (input_a / input_b)"), ]