Skip to content

Commit

Permalink
#7201: Add support for unary ne
Browse files Browse the repository at this point in the history
  • Loading branch information
VirdhatchaniKN committed Apr 23, 2024
1 parent 5c5c4b6 commit 8ed0b4e
Show file tree
Hide file tree
Showing 19 changed files with 253 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ Compute APIs
gtz_tile
gez_tile
nez_tile
unary_ne_tile

cb_wait_front
cb_pop_front
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
unary_ne_tile
--------------

.. doxygenfunction:: unary_ne_tile_init()
.. doxygenfunction:: unary_ne_tile(uint32_t idst, uint32_t param0)
2 changes: 2 additions & 0 deletions docs/source/ttnn/ttnn/dependencies/tt_lib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,8 @@ Tensor relational operations

.. autofunction:: tt_lib.tensor.ne

.. autofunction:: tt_lib.tensor.unary_ne

Tensor ternary operations
=========================
.. autofunction:: tt_lib.tensor.where
Expand Down
4 changes: 4 additions & 0 deletions tests/tt_eager/python_api_testing/sweep_tests/op_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,10 @@
"tt_lib_op": tt_lib_ops.eltwise_heaviside,
"pytorch_op": pytorch_ops.heaviside,
},
"eltwise-unary_ne": {
"tt_lib_op": tt_lib_ops.eltwise_unary_ne,
"pytorch_op": pytorch_ops.unary_ne,
},
"eltwise-erf": {
"tt_lib_op": tt_lib_ops.eltwise_erf,
"pytorch_op": pytorch_ops.erf,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,39 @@ def test_run_eltwise_heaviside(
test_args,
)

@pytest.mark.parametrize("unary_comp", ["unary_ne"])
@pytest.mark.parametrize("scalar", [0.5, 1.0, -1.0, 0.0])
def test_run_eltwise_unary_comp(
self,
unary_comp,
input_shapes,
scalar,
device,
function_level_defaults,
input_mem_config,
output_mem_config,
):
datagen_func = [
generation_funcs.gen_func_with_cast(partial(generation_funcs.gen_rand, low=-100, high=100), torch.bfloat16)
]
test_args = generation_funcs.gen_default_dtype_layout_device(input_shapes)[0]
test_args.update({"scalar": scalar})
test_args.update(
{
"input_mem_config": [input_mem_config],
"output_mem_config": output_mem_config,
}
)
comparison_func = comparison_funcs.comp_equal
run_single_pytorch_test(
f"eltwise-{unary_comp}",
input_shapes,
datagen_func,
comparison_func,
device,
test_args,
)

@pytest.mark.parametrize("unary_kind", ["add_unary", "sub_unary", "mul_unary", "div_unary"])
@pytest.mark.parametrize("scalar", [-2.0, 1.0, 2.0, 8.0])
def test_run_eltwise_binop_to_unary_ops(
Expand Down
6 changes: 6 additions & 0 deletions tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,12 @@ def heaviside(x, *args, **kwargs):
return result


def unary_ne(x, *args, **kwargs):
value = kwargs.pop("scalar")
result = torch.ne(x, value)
return result


def erf(x, *args, **kwargs):
return torch.erf(x)

Expand Down
18 changes: 18 additions & 0 deletions tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1082,6 +1082,24 @@ def eltwise_heaviside(
return tt2torch_tensor(t1)


@setup_host_and_device
def eltwise_unary_ne(
x,
*args,
scalar,
device,
dtype,
layout,
input_mem_config,
output_mem_config,
**kwargs,
):
t0 = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0])
t1 = ttl.tensor.unary_ne(t0, scalar, output_mem_config=output_mem_config)

return tt2torch_tensor(t1)


@setup_host_and_device
def repeat_interleave(
x,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ std::pair<string, string> get_op_init_and_func_parameterized(UnaryOpType op_type
case UnaryOpType::ADD_UNARY_SFPU: op_init_and_name = {"binop_with_scalar_tile_init();", fmt::format("add_unary_tile({}, {}u);", idst, Converter::to_hex(param0))}; break;
case UnaryOpType::MUL_UNARY_SFPU: op_init_and_name = {"binop_with_scalar_tile_init();", fmt::format("mul_unary_tile({}, {}u);", idst, Converter::to_hex(param0))}; break;
case UnaryOpType::DIV_UNARY_SFPU: op_init_and_name = {"binop_with_scalar_tile_init();", fmt::format("div_unary_tile({}, {}u);", idst, Converter::to_hex(1.0f/param0))}; break;
case UnaryOpType::UNARY_NE: op_init_and_name = {"unary_ne_tile_init();", fmt::format("unary_ne_tile({}, {}u);", idst, Converter::to_hex(param0))}; break;
default:
TT_ASSERT( false && "unexpected parameterized type");
};
Expand Down
7 changes: 5 additions & 2 deletions tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ enum class UnaryOpType {
SUB_UNARY_SFPU = 56,
MUL_UNARY_SFPU = 57,
DIV_UNARY_SFPU = 58,
IDENTITY_UINT32 = 59
IDENTITY_UINT32 = 59,
UNARY_NE = 60
};

template <typename T>
Expand All @@ -100,7 +101,8 @@ bool is_parametrized_type(T val) {
case UnaryOpType::ADD_UNARY_SFPU:
case UnaryOpType::SUB_UNARY_SFPU:
case UnaryOpType::MUL_UNARY_SFPU:
case UnaryOpType::DIV_UNARY_SFPU: return true;
case UnaryOpType::DIV_UNARY_SFPU:
case UnaryOpType::UNARY_NE: return true;
default: return false;
}
return false;
Expand Down Expand Up @@ -308,6 +310,7 @@ constexpr auto leaky_relu = make_eltwise_unary_with_param<UnaryOpType::LEAKY_REL
constexpr auto prelu = leaky_relu;
constexpr auto elu = make_eltwise_unary_with_param<UnaryOpType::ELU>{};
constexpr auto heaviside = make_eltwise_unary_with_param<UnaryOpType::HEAVISIDE>{};
constexpr auto unary_ne = make_eltwise_unary_with_param<UnaryOpType::UNARY_NE>{};
constexpr auto rsub = make_eltwise_unary_with_param<UnaryOpType::RSUB>{};
constexpr auto silu = make_eltwise_unary<UnaryOpType::SILU>{};
constexpr auto identity = make_eltwise_unary<UnaryOpType::IDENTITY>{};
Expand Down
7 changes: 7 additions & 0 deletions tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_xary_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,13 @@ namespace tt::tt_metal::detail {
HEAVISIDE(x) = 0 if x < 0 , 1 if x > 0 , else value.)doc",
R"doc("value", "float", "")doc"

);
detail::bind_unary_op_with_param(
m_tensor, "unary_ne", unary_ne,
py::arg("value"),
R"doc(Perform an eltwise-unary not-equal (``{0} != {1}``) on input tensor.)doc",
R"doc("value", "float", "")doc"

);
detail::bind_unary_op_with_param(
m_tensor, "rdiv", rdiv,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "llk_math_eltwise_unary_sfpu_silu.h"
#include "llk_math_eltwise_unary_sfpu_topk.h"
#include "llk_math_eltwise_unary_sfpu_trigonometry.h"
#include "llk_math_eltwise_unary_sfpu_unary_comp.h"

namespace ckernel {

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "ckernel.h"
#include "ckernel_defs.h"
#include "noc_nonblocking_api.h"
#include "ckernel_sfpu_converter.h"

using namespace sfpi;

namespace ckernel {
namespace sfpu {

template <bool APPROXIMATION_MODE, int ITERATIONS = 4>
inline void calculate_unary_ne(uint value)
{
// SFPU microcode
Converter c_value;
c_value.u = value;
vFloat s = c_value.f;

#pragma GCC unroll 0
for (int d = 0; d < ITERATIONS; d++) {
vFloat v = dst_reg[0];
v_if (v == s) {
v = 0.0f;
}v_else {
v = 1.0f;
}
v_endif;

dst_reg[0] = v;

dst_reg++;
}
}

} // namespace sfpu
} // namespace ckernel
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "llk_math_eltwise_unary_sfpu_init.h"
#include "llk_math_eltwise_unary_sfpu_1_param.h"
#include "ckernel_sfpu_unary_comp.h"

namespace ckernel {

// New LLK SFPU APIs

//Unary Not equal
template <bool APPROXIMATE>
inline void llk_math_eltwise_unary_sfpu_unary_ne_init() {
llk_math_eltwise_unary_sfpu_init<APPROXIMATE>();
}

template <bool APPROXIMATE>
inline void llk_math_eltwise_unary_sfpu_unary_ne(uint dst_index, uint param0, int vector_mode = (int)VectorMode::RC) {
llk_math_eltwise_unary_sfpu_1_param<APPROXIMATE>
(ckernel::sfpu::calculate_unary_ne<APPROXIMATE>,
ckernel::sfpu::calculate_unary_ne<APPROXIMATE>,
dst_index, vector_mode, param0);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -60,5 +60,6 @@ enum SfpuType {
silu,
mask,
negative,
unary_ne,
unused,
};
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,5 @@
#include "llk_math_eltwise_unary_sfpu_square.h"
#include "llk_math_eltwise_unary_sfpu_tanh.h"
#include "llk_math_eltwise_unary_sfpu_topk.h"
#include "llk_math_eltwise_unary_sfpu_unary_comp.h"
#include "llk_math_eltwise_unary_sfpu_trigonometry.h"
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "ckernel.h"
#include "ckernel_defs.h"
#include "noc_nonblocking_api.h"
#include "ckernel_sfpu_converter.h"

using namespace sfpi;

namespace ckernel {
namespace sfpu {

template <bool APPROXIMATION_MODE, int ITERATIONS = 8>
inline void calculate_unary_ne(uint value)
{
// SFPU microcode
Converter c_value;
c_value.u = value;
vFloat s = c_value.f;

#pragma GCC unroll 0
for (int d = 0; d < ITERATIONS; d++) {
vFloat v = dst_reg[0];
v_if (v == s) {
v = 0.0f;
}v_else {
v = 1.0f;
}
v_endif;

dst_reg[0] = v;

dst_reg++;
}
}

} // namespace sfpu
} // namespace ckernel
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "llk_math_eltwise_unary_sfpu_init.h"
#include "llk_math_eltwise_unary_sfpu_1_param.h"
#include "ckernel_sfpu_unary_comp.h"

namespace ckernel {

// New LLK SFPU APIs

//Unary Not equal
template <bool APPROXIMATE>
inline void llk_math_eltwise_unary_sfpu_unary_ne_init() {
llk_math_eltwise_unary_sfpu_init<SfpuType::unary_ne, APPROXIMATE>();
}

template <bool APPROXIMATE>
inline void llk_math_eltwise_unary_sfpu_unary_ne(uint dst_index, uint param0, int vector_mode = (int)VectorMode::RC) {
llk_math_eltwise_unary_sfpu_1_param<APPROXIMATE>
(ckernel::sfpu::calculate_unary_ne<APPROXIMATE>,
ckernel::sfpu::calculate_unary_ne<APPROXIMATE>,
dst_index, vector_mode, param0);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -70,5 +70,6 @@ enum SfpuType {
topk_local_sort,
topk_merge,
topk_rebuild,
unary_ne,
unused,
};
25 changes: 25 additions & 0 deletions tt_metal/include/compute_kernel_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,31 @@ ALWI void heaviside_tile_init() {
MATH(( llk_math_eltwise_unary_sfpu_heaviside_init<APPROX>() ));
}

//unary ne : if x !=value --> 1.0, else 0.0
/**
* Performs element-wise computation of: result = 1 if x!=value , where x is each element of a tile
* in DST register at index tile_index. The value is provided as const param0 The DST register buffer must be in
* acquired state via *acquire_dst* call. This call is blocking and is only
* available on the compute engine.
*
* Return value: None
*
* | Argument | Description | Type | Valid Range | Required |
* |-----------------|----------------------------------------------------------------------------|----------|-------------------------------------------------------|----------|
* | idst | The index of the tile in DST register buffer to perform the computation on | uint32_t | Must be less than the size of the DST register buffer | True |
* | param0 | The value to be compared with the input tensor | uint32_t | | True |
*/
ALWI void unary_ne_tile(uint32_t idst,uint32_t param0) {
MATH(( llk_math_eltwise_unary_sfpu_unary_ne<APPROX>(idst,param0) ));
}

/**
* Please refer to documentation for any_init.
*/
ALWI void unary_ne_tile_init() {
MATH(( llk_math_eltwise_unary_sfpu_unary_ne_init<APPROX>() ));
}

//expm1 : (exp(x) - 1)
/**
* Performs element-wise computation of exp(x) - 1, v where x is each element of a tile
Expand Down

0 comments on commit 8ed0b4e

Please sign in to comment.