Skip to content

Commit

Permalink
#7280: Add support for binary division
Browse files Browse the repository at this point in the history
  • Loading branch information
VirdhatchaniKN committed Apr 10, 2024
1 parent 702adca commit 98a5a85
Show file tree
Hide file tree
Showing 13 changed files with 148 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/source/ttnn/ttnn/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ Pointwise Binary
ttnn/nextafter
ttnn/maximum
ttnn/minimum
ttnn/binary_div

Pointwise Ternary
=================
Expand Down
6 changes: 6 additions & 0 deletions docs/source/ttnn/ttnn/ttnn/binary_div.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
.. _ttnn.binary_div:

ttnn.binary_div
###############

.. autofunction:: ttnn.binary_div
4 changes: 4 additions & 0 deletions tests/tt_eager/python_api_testing/sweep_tests/op_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,10 @@
"tt_lib_op": tt_lib_ops.eltwise_elu,
"pytorch_op": pytorch_ops.elu,
},
"eltwise-binary_div": {
"tt_lib_op": tt_lib_ops.eltwise_binary_div,
"pytorch_op": pytorch_ops.binary_div,
},
"eltwise-square": {
"tt_lib_op": tt_lib_ops.eltwise_square,
"pytorch_op": pytorch_ops.square,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def custom_compare(*args, **kwargs):
"polygamma",
"nextafter",
"scatter",
"binary_div",
),
shapes,
)
Expand Down Expand Up @@ -190,6 +191,7 @@ def test_run_eltwise_composite_test(fn, input_shapes, device, function_level_def
"isclose",
"assign_binary",
"nextafter",
"binary_div",
]:
num_inputs = 2

Expand Down
5 changes: 5 additions & 0 deletions tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,11 @@ def silu(x, *args, **kwargs):
return torch.nn.functional.silu(x)


def binary_div(x, y, *args, **kwargs):
result = torch.div(x, y)
return result


def div_unary(x, *args, scalar, **kwargs):
result = torch.div(x, scalar)
return result
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2231,6 +2231,7 @@ def binary_op(
eltwise_mul = make_binary_op(ttl.tensor.mul)
eltwise_bias_gelu = make_binary_op(ttl.tensor.bias_gelu)
eltwise_squared_difference = make_binary_op(ttl.tensor.squared_difference)
eltwise_binary_div = make_binary_op(ttl.tensor.binary_div)
eltwise_hypot = make_binary_op(ttl.tensor.hypot)
eltwise_scatter = make_binary_op(ttl.tensor.scatter)
eltwise_atan2 = make_binary_op(ttl.tensor.atan2)
Expand Down
81 changes: 81 additions & 0 deletions tests/ttnn/sweep_tests/sweeps/binary_div.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

from typing import Optional, Tuple

import torch

import ttnn

from tests.ttnn.utils_for_testing import check_with_pcc
from models.utility_functions import torch_random


parameters = {
"batch_sizes": [(1,)],
"height": [32, 384, 1024],
"width": [32, 1024, 4096],
"input_a_dtype": [ttnn.bfloat16],
"input_b_dtype": [ttnn.bfloat16],
"input_a_layout": [ttnn.TILE_LAYOUT],
"input_b_layout": [ttnn.TILE_LAYOUT],
"input_b_memory_config": [ttnn.DRAM_MEMORY_CONFIG],
"input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG],
"output_memory_config": [ttnn.DRAM_MEMORY_CONFIG],
}


def skip(**_) -> Tuple[bool, Optional[str]]:
return False, None


def is_expected_to_fail(**_) -> Tuple[bool, Optional[str]]:
return False, None


def torch_binary_div(x, y, *args, **kwargs):
return torch.div(x, y)


def run(
batch_sizes,
height,
width,
input_a_dtype,
input_b_dtype,
input_a_layout,
input_b_layout,
input_b_memory_config,
input_a_memory_config,
output_memory_config,
*,
device,
) -> Tuple[bool, Optional[str]]:
input_shape = (*batch_sizes, height, width)

low = -100
high = 100

torch_input_tensor_a = torch_random(input_shape, low, high, dtype=torch.bfloat16)
torch_input_tensor_b = torch_random(input_shape, low, high, dtype=torch.bfloat16)
torch_output_tensor = torch_binary_div(torch_input_tensor_a, torch_input_tensor_b)

input_tensor_a = ttnn.from_torch(
torch_input_tensor_a,
dtype=input_a_dtype,
device=device,
layout=input_a_layout,
memory_config=input_a_memory_config,
)
input_tensor_b = ttnn.from_torch(
torch_input_tensor_b,
dtype=input_b_dtype,
device=device,
layout=input_b_layout,
memory_config=input_b_memory_config,
)
output_tensor = ttnn.binary_div(input_tensor_a, input_tensor_b, memory_config=output_memory_config)
output_tensor = ttnn.to_torch(output_tensor)

return check_with_pcc(torch_output_tensor, output_tensor)
6 changes: 6 additions & 0 deletions tests/ttnn/unit_tests/operations/test_math_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,12 @@ def test_squared_difference(device, h, w):
run_math_binary_test_range(device, h, w, ttnn.squared_difference, torch_squared_difference, -100, 100)


@pytest.mark.parametrize("h", [64])
@pytest.mark.parametrize("w", [128])
def test_binary_div(device, h, w):
run_math_binary_test_range(device, h, w, ttnn.binary_div, torch.div, -100, 100)


@pytest.mark.parametrize("h", [64])
@pytest.mark.parametrize("w", [128])
def test_hypot(device, h, w):
Expand Down
20 changes: 20 additions & 0 deletions tt_eager/tt_dnn/op_library/composite/composite_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -834,6 +834,26 @@ Tensor addcdiv(
return operation::decorate_as_composite(__func__, _addcdiv)(input_a, input_b, input_c, value, output_mem_config);
}

Tensor _binary_div(
const Tensor& input_a,
const Tensor& input_b,
const MemoryConfig& output_mem_config) {
Tensor result = mul(input_a, recip(input_b, output_mem_config), std::nullopt, output_mem_config);
Tensor t_inf = full_like(input_a, std::numeric_limits<float>::infinity(), output_mem_config);
Tensor t_nan = full_like(input_a, std::nanf(""), output_mem_config);
return where(eqz(input_b, output_mem_config),
where(eqz(input_a, output_mem_config),
t_nan,
mul(t_inf, sign(input_a, output_mem_config), std::nullopt, output_mem_config), output_mem_config),
result, output_mem_config);
}
Tensor binary_div(
const Tensor& input_a,
const Tensor& input_b,
const MemoryConfig& output_mem_config) {
return operation::decorate_as_composite(__func__, _binary_div)(input_a, input_b, output_mem_config);
}

// logit(input, eps)=log(input / 1 - input)
Tensor _logit(const Tensor& input_a, float eps, const MemoryConfig& output_mem_config) {
Tensor t_eps = full_like(input_a, eps, output_mem_config);
Expand Down
5 changes: 5 additions & 0 deletions tt_eager/tt_dnn/op_library/composite/composite_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,11 @@ Tensor addcdiv(
float value,
const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);

Tensor binary_div(
const Tensor& input_a,
const Tensor& input_b,
const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);

// xlogy(x,y))=x*log(y)
Tensor xlogy(
const Tensor& input_a,
Expand Down
14 changes: 14 additions & 0 deletions tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_composite_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -775,7 +775,21 @@ namespace tt::tt_metal::detail{
"output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No"
)doc");

m_tensor.def("binary_div", &binary_div,
py::arg("input_a").noconvert(), py::arg("input_b").noconvert(), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc(
Performs the element-wise division of ``input_a`` by ``input_b``.
Input tensor must have BFLOAT16 data type.
Output tensor will have BFLOAT16 data type.
.. csv-table::
:header: "Argument", "Description", "Data type", "Valid range", "Required"
"input_a", "Numerator Tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes"
"input_b", "Denominator Tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes"
"output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No"
)doc");

m_tensor.def("mac", py::overload_cast<const Tensor&, const Tensor&, const Tensor&, const MemoryConfig&>(&mac),
py::arg("input").noconvert(), py::arg("tensor1").noconvert(), py::arg("tensor2").noconvert(), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc(
Expand Down
1 change: 1 addition & 0 deletions ttnn/ttnn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,7 @@ def manage_config_attribute(name, value):
square,
tril,
triu,
binary_div,
)

from ttnn.operations.normalization import (
Expand Down
2 changes: 2 additions & 0 deletions ttnn/ttnn/operations/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def _golden_function_binary(input_tensor_a: ttnn.Tensor, input_tensor_b: ttnn.Te
"atan2": torch.atan2,
"hypot": torch.hypot,
"squared_difference": torch_squared_difference,
"binary_div": torch.div,
}
torch_function = name_to_golden_function_function[name]
return torch_function(input_tensor_a, input_tensor_b)
Expand Down Expand Up @@ -237,6 +238,7 @@ def math_binary_function(
("atan2", ttl.tensor.atan2, "atan2"),
("hypot", ttl.tensor.hypot, "hypotenuse"),
("squared_difference", ttl.tensor.squared_difference, "squared_difference (input_a - input_b)^2"),
("binary_div", ttl.tensor.binary_div, "division (input_a / input_b)"),
]


Expand Down

0 comments on commit 98a5a85

Please sign in to comment.