Skip to content

Commit

Permalink
#8681: Add rfloor_div op
Browse files Browse the repository at this point in the history
  • Loading branch information
mcw-anasuya authored and mouliraj-mcw committed Jul 12, 2024
1 parent c1c52e3 commit dcdaad3
Show file tree
Hide file tree
Showing 8 changed files with 122 additions and 0 deletions.
2 changes: 2 additions & 0 deletions docs/source/ttnn/ttnn/dependencies/tt_lib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,8 @@ Tensor elementwise operations

.. autofunction:: tt_lib.tensor.floor_div

.. autofunction:: tt_lib.tensor.rfloor_div

Tensor relational operations
============================
.. autofunction:: tt_lib.tensor.gtz
Expand Down
4 changes: 4 additions & 0 deletions tests/tt_eager/python_api_testing/sweep_tests/op_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,10 @@
"tt_op": tt_lib_ops.eltwise_unary_floor_div,
"pytorch_op": pytorch_ops.unary_floor_div,
},
"eltwise-_rfloor_div": {
"tt_op": tt_lib_ops.eltwise_rfloor_div,
"pytorch_op": pytorch_ops.rfloor_div,
},
"eltwise-round": {
"tt_op": tt_lib_ops.eltwise_round,
"pytorch_op": pytorch_ops.round,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# SPDX-FileCopyrightText: © 2023-24 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0
import pytest
import torch
import random
from functools import partial
import tt_lib as ttl
from tests.tt_eager.python_api_testing.sweep_tests import (
comparison_funcs,
generation_funcs,
)
from tests.tt_eager.python_api_testing.sweep_tests.run_pytorch_ci_tests import (
run_single_pytorch_test,
)
from models.utility_functions import skip_for_grayskull

mem_configs = [
ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.DRAM),
ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1),
]


@pytest.mark.parametrize(
"input_shapes",
[
[[1, 1, 32, 32], [1, 1, 32, 32]],
[[1, 1, 320, 384], [1, 1, 320, 384]],
[[1, 3, 320, 384], [1, 3, 320, 384]],
],
)
@pytest.mark.parametrize(
"value",
[-5.1, 0.0, 10.9],
)
@pytest.mark.parametrize(
"dst_mem_config",
mem_configs,
)
@skip_for_grayskull("#ToDo: GS implementation needs to be done for floor")
class TestRfloor_div:
def test_run_rfloor_div(
self,
input_shapes,
value,
dst_mem_config,
device,
):
datagen_func = [
generation_funcs.gen_func_with_cast(partial(generation_funcs.gen_rand, low=-100, high=100), torch.bfloat16)
]
test_args = generation_funcs.gen_default_dtype_layout_device(input_shapes)[0]
test_args.update({"value": value})
test_args.update({"output_mem_config": dst_mem_config})
comparison_func = comparison_funcs.comp_pcc
run_single_pytorch_test(
"eltwise-rfloor_div",
input_shapes,
datagen_func,
comparison_func,
device,
test_args,
)
6 changes: 6 additions & 0 deletions tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,12 @@ def unary_floor_div(x, *args, **kwargs):
return result


def rfloor_div(x, *args, **kwargs):
value = kwargs.pop("value")
result = torch.floor_divide(value, x)
return result


def round(x, *args, **kwargs):
decimals = kwargs.pop("decimals")
result = torch.round(x, decimals=decimals)
Expand Down
18 changes: 18 additions & 0 deletions tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1114,6 +1114,24 @@ def eltwise_unary_floor_div(
return tt2torch_tensor(t1)


@setup_host_and_device
def eltwise_rfloor_div(
x,
*args,
value,
device,
dtype,
layout,
input_mem_config,
output_mem_config,
**kwargs,
):
t0 = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0])
t1 = ttl.tensor.rfloor_div(value, t0, output_mem_config=output_mem_config)

return tt2torch_tensor(t1)


@setup_host_and_device
def eltwise_round(
x,
Expand Down
8 changes: 8 additions & 0 deletions tt_eager/tt_dnn/op_library/composite/composite_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1119,6 +1119,14 @@ Tensor floor_div(const Tensor& input_a, float value, const MemoryConfig& output_
return operation::decorate_as_composite(__func__, _floor_div_overload)(input_a, value, output_mem_config);
}

Tensor _rfloor_div(float value, const Tensor& input, const MemoryConfig& output_mem_config) {
Tensor result = div_unary(value, input);
return floor(result, output_mem_config);
}
Tensor rfloor_div(float value, const Tensor& input, const MemoryConfig& output_mem_config) {
return operation::decorate_as_composite(__func__, _rfloor_div)(value, input, output_mem_config);
}

Tensor _div_no_nan(const Tensor& input_a, const Tensor& input_b, const MemoryConfig& output_mem_config) {
Tensor div_result = div(input_a, input_b);
return where(ttnn::eqz(input_b, output_mem_config), 0, div_result);
Expand Down
5 changes: 5 additions & 0 deletions tt_eager/tt_dnn/op_library/composite/composite_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,11 @@ Tensor floor_div(
float value,
const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);

Tensor rfloor_div(
float value,
const Tensor& input,
const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);

// xlogy(x,y))=x*log(y)
Tensor xlogy(
const Tensor& input_a,
Expand Down
16 changes: 16 additions & 0 deletions tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_composite_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1379,6 +1379,22 @@ void TensorModuleCompositeOPs(py::module& m_tensor) {
"output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No"
)doc");

m_tensor.def("rfloor_div", py::overload_cast<float, const Tensor&, const MemoryConfig&>(&rfloor_div),
py::arg("value").noconvert(), py::arg("input").noconvert(), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG,R"doc(
Performs the element-wise floor division of a scalar ``value`` by a tensor ``input``. Support provided only for Wormhole_B0.
Input tensor must have BFLOAT16 data type.
Output tensor will have BFLOAT16 data type.
.. csv-table::
:header: "Argument", "Description", "Data type", "Valid range", "Required"
"value", "Numerator value", "float", "", "Yes"
"input", "Denominator Tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes"
"output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No"
)doc");

m_tensor.def("mac", py::overload_cast<const Tensor&, const Tensor&, const Tensor&, const MemoryConfig&>(&mac),
py::arg("input").noconvert(), py::arg("tensor1").noconvert(), py::arg("tensor2").noconvert(), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc(
Returns tensor with the multiply and accumulation of all of elements of the input tensors ``input, tensor1, tensor2``.
Expand Down

0 comments on commit dcdaad3

Please sign in to comment.