Skip to content

Commit

Permalink
fix: simplify stanh implementation in paddle frontend and avoid nan v…
Browse files Browse the repository at this point in the history
…alues from exp in the test(ivy-llc#28493)
  • Loading branch information
ZJay07 authored May 2, 2024
1 parent 14a6c2f commit a609e40
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 9 deletions.
9 changes: 2 additions & 7 deletions ivy/functional/frontends/paddle/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,13 +629,8 @@ def square(x, name=None):
@with_supported_dtypes({"2.6.0 and below": ("float32", "float64")}, "paddle")
@to_ivy_arrays_and_back
def stanh(x, scale_a=0.67, scale_b=1.7159, name=None):
# TODO this function will be simplified as soon as the ivy.stanh(x,a,b) is added
exp_ax = ivy.exp(ivy.multiply(scale_a, x))
exp_minus_ax = ivy.exp(ivy.multiply(-scale_a, x))
numerator = ivy.subtract(exp_ax, exp_minus_ax)
denominator = ivy.add(exp_ax, exp_minus_ax)
ret = ivy.multiply(scale_b, ivy.divide(numerator, denominator))
return ret
ret = ivy.stanh(x, alpha=scale_b, beta=scale_a, out=name)
return ivy.asarray(ret, dtype=x.dtype)


@with_unsupported_dtypes({"2.6.0 and below": ("float16", "bfloat16")}, "paddle")
Expand Down
16 changes: 14 additions & 2 deletions ivy_tests/test_ivy/test_frontends/test_paddle/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -2495,8 +2495,20 @@ def test_paddle_square(
dtype_and_x=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("float"),
),
scale_a=st.floats(1e-5, 1e5),
scale_b=st.floats(1e-5, 1e5),
scale_a=st.floats(
min_value=-5,
max_value=5,
allow_nan=False,
allow_subnormal=False,
allow_infinity=False,
),
scale_b=st.floats(
min_value=-5,
max_value=5,
allow_nan=False,
allow_subnormal=False,
allow_infinity=False,
),
)
def test_paddle_stanh(
*,
Expand Down

0 comments on commit a609e40

Please sign in to comment.