Skip to content

Commit

Permalink
#7108: Accuracy enhancement for sigmoid using ttnn
Browse files Browse the repository at this point in the history
  • Loading branch information
ruthreshx committed Apr 16, 2024
1 parent a059360 commit 1a32fed
Show file tree
Hide file tree
Showing 11 changed files with 53 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/source/ttnn/ttnn/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ Pointwise Unary
ttnn/relu6
ttnn/rsqrt
ttnn/sigmoid
ttnn/sigmoid_accurate
ttnn/sign
ttnn/silu
ttnn/sin
Expand Down
2 changes: 2 additions & 0 deletions docs/source/ttnn/ttnn/dependencies/tt_lib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,8 @@ Tensor elementwise operations

.. autofunction:: tt_lib.tensor.sigmoid

.. autofunction:: tt_lib.tensor.sigmoid_accurate

.. autofunction:: tt_lib.tensor.hardsigmoid

.. autofunction:: tt_lib.tensor.swish
Expand Down
6 changes: 6 additions & 0 deletions docs/source/ttnn/ttnn/ttnn/sigmoid_accurate.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
.. _ttnn.sigmoid_accurate:

ttnn.sigmoid_accurate
#####################

.. autofunction:: ttnn.sigmoid_accurate
8 changes: 8 additions & 0 deletions tests/tt_eager/python_api_testing/sweep_tests/op_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,14 @@
"tt_lib_op": ttnn_ops.eltwise_sigmoid,
"pytorch_op": pytorch_ops.sigmoid,
},
"eltwise-sigmoid_accurate": {
"tt_lib_op": tt_lib_ops.eltwise_sigmoid_accurate,
"pytorch_op": pytorch_ops.sigmoid,
},
"ttnn-eltwise-sigmoid_accurate": {
"tt_lib_op": ttnn_ops.eltwise_sigmoid_accurate,
"pytorch_op": pytorch_ops.sigmoid,
},
"eltwise-log_sigmoid": {
"tt_lib_op": tt_lib_ops.eltwise_log_sigmoid,
"pytorch_op": pytorch_ops.log_sigmoid,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2171,6 +2171,7 @@ def unary_op(
eltwise_neg = make_unary_op(ttl.tensor.neg)
eltwise_recip = make_unary_op(ttl.tensor.recip)
eltwise_sigmoid = make_unary_op(ttl.tensor.sigmoid)
eltwise_sigmoid_accurate = make_unary_op(ttl.tensor.sigmoid_accurate)
eltwise_log_sigmoid = make_unary_op(ttl.tensor.log_sigmoid)
eltwise_log = make_unary_op(ttl.tensor.log)
eltwise_log2 = make_unary_op(ttl.tensor.log2)
Expand Down
16 changes: 16 additions & 0 deletions tests/ttnn/python_api_testing/sweep_tests/ttnn_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1265,6 +1265,22 @@ def eltwise_silu(
return ttnn_tensor_to_torch(t1)


def eltwise_sigmoid_accurate(
x,
*args,
device,
dtype,
layout,
input_mem_config,
output_mem_config,
**kwargs,
):
t0 = setup_ttnn_tensor(x, device, layout[0], input_mem_config[0], dtype[0])
t1 = ttnn.sigmoid_accurate(t0, memory_config=memory_config_to_ttnn(output_mem_config))

return ttnn_tensor_to_torch(t1)


def eltwise_sin(
x,
*args,
Expand Down
6 changes: 6 additions & 0 deletions tests/ttnn/unit_tests/operations/test_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ def test_hardtanh(device, h, w):
run_activation_unary_test(device, h, w, ttnn.hardtanh, F.hardtanh)


@pytest.mark.parametrize("h", [64])
@pytest.mark.parametrize("w", [128])
def test_sigmoid_accurate(device, h, w):
run_activation_unary_test(device, h, w, ttnn.sigmoid_accurate, torch.sigmoid)


@pytest.mark.parametrize("h", [64])
@pytest.mark.parametrize("w", [128])
def test_hardswish(device, h, w):
Expand Down
9 changes: 9 additions & 0 deletions tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,15 @@ inline Tensor log_sigmoid(
{UnaryWithParam{.op_type = UnaryOpType::SIGMOID}, UnaryWithParam{.op_type = UnaryOpType::LOG}},
output_mem_config);
}

inline Tensor sigmoid_accurate(
const Tensor& input_tensor, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG) {
return run_eltwise_unary(
input_tensor,
{UnaryWithParam{.op_type = UnaryOpType::NEG}, UnaryWithParam{.op_type = UnaryOpType::EXP, .param = 1.0f}, UnaryWithParam{.op_type = UnaryOpType::ADD_UNARY_SFPU, .param = 1.0f}, UnaryWithParam{.op_type = UnaryOpType::RECIP}},
output_mem_config);
}

inline Tensor unary_chain(
const Tensor& input_tensor,
std::vector<UnaryWithParam> ops_chain,
Expand Down
1 change: 1 addition & 0 deletions tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_xary_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ namespace tt::tt_metal::detail {
py::overload_cast<const Tensor &, const MemoryConfig &>(sqrt),
R"doc(Returns tensor with the square-root of elements of the input tensor ``{0}``.)doc");
detail::bind_unary_op(m_tensor, "sigmoid", sigmoid, R"doc(Applies the sigmoid function to the elements of the input tensor ``{0}``.)doc");
detail::bind_unary_op(m_tensor, "sigmoid_accurate", sigmoid_accurate, R"doc(Applies the sigmoid_accurate function to the elements of the input tensor ``{0}``.)doc");
detail::bind_unary_op(m_tensor, "log", log, R"doc(Returns tensor with the natural logarithm of elements of the input tensor ``{0}``.)doc");
detail::bind_unary_op(m_tensor, "tanh", tanh, R"doc(Returns tensor with the hyperbolic tangent of elements of the input tensor ``{0}``.)doc");
detail::bind_unary_op(m_tensor, "log2", log2, R"doc(Returns tensor with the base 2 logarithm of elements of the input tensor ``{0}``.)doc");
Expand Down
1 change: 1 addition & 0 deletions ttnn/ttnn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,7 @@ def manage_config_attribute(name, value):
prelu,
relu6,
sigmoid,
sigmoid_accurate,
sign,
softshrink,
softsign,
Expand Down
2 changes: 2 additions & 0 deletions ttnn/ttnn/operations/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def _golden_function(input_tensor: ttnn.Tensor, **_):
"mish": lambda _x: F.mish(_x.to(torch.float)),
"relu6": F.relu6,
"sigmoid": torch.sigmoid,
"sigmoid_accurate": torch.sigmoid,
"sign": torch.sign,
"softsign": F.softsign,
"swish": F.hardswish,
Expand Down Expand Up @@ -356,6 +357,7 @@ def activation_function(
("mish", ttl.tensor.mish, "mish"),
("relu6", ttl.tensor.relu6, "relu6"),
("sigmoid", ttl.tensor.sigmoid, "sigmoid"),
("sigmoid_accurate", ttl.tensor.sigmoid_accurate, "sigmoid_accurate"),
("sign", ttl.tensor.sign, "sign"),
("softsign", ttl.tensor.softsign, "softsign"),
("swish", ttl.tensor.swish, "swish"),
Expand Down

0 comments on commit 1a32fed

Please sign in to comment.