Skip to content

Commit

Permalink
Rewrite trans_derrf to intended behaviour
Browse files Browse the repository at this point in the history
  • Loading branch information
Blunde1 committed Sep 26, 2023
1 parent bcff96d commit b2c4c02
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 6 deletions.
7 changes: 7 additions & 0 deletions src/ert/config/gen_kw_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,13 @@ def trans_errf(x: float, arg: List[float]) -> float:
"""
_min, _max, _skew, _width = arg[0], arg[1], arg[2], arg[3]
y = 0.5 * (1 + math.erf((x + _skew) / (_width * math.sqrt(2.0))))
if np.isnan(y):
raise ValueError(
(
"Output is nan, likely from triplet (x, skewness, width) "
"leading to low/high-probability in normal CDF."
)
)
return _min + y * (_max - _min)

@staticmethod
Expand Down
5 changes: 0 additions & 5 deletions tests/unit_tests/config/test_gen_kw_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,11 +358,6 @@ def test_gen_kw_params_parsing(tmpdir, params, error):
("MYNAME ERRF 1 2 0.1 0.1", 0.3, 1.99996832875816688002),
("MYNAME ERRF 1 2 0.1 0.1", 0.7, 1.99999999999999933387),
("MYNAME ERRF 1 2 0.1 0.1", 1.0, 2.00000000000000000000),
("MYNAME DERRF 10 1 2 0.1 0.1", -1.0, 1.00000000000000000000),
("MYNAME DERRF 10 1 2 0.1 0.1", 0.0, 1.00000000000000000000),
("MYNAME DERRF 10 1 2 0.1 0.1", 0.3, 2.00000000000000000000),
("MYNAME DERRF 10 1 2 0.1 0.1", 0.7, 2.00000000000000000000),
("MYNAME DERRF 10 1 2 0.1 0.1", 1.0, 2.00000000000000000000),
],
)
def test_gen_kw_trans_func(tmpdir, params, xinput, expected):
Expand Down
23 changes: 22 additions & 1 deletion tests/unit_tests/config/test_transfer_functions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
from scipy.stats import norm
from hypothesis import given
from hypothesis import strategies as st

Expand Down Expand Up @@ -114,14 +115,34 @@ def test_that_derrf_is_within_bounds(x, arg):
st.lists(st.floats(allow_nan=False, allow_infinity=False), min_size=2),
valid_derrf_parameters(),
)
def test_that_derrf_creates_at_least_steps_number_of_distinct_values(xlist, arg):
def test_that_derrf_creates_at_least_steps_or_less_distinct_values(xlist, arg):
"""derrf cannot create more than steps distinct values"""
res = []
for x in xlist:
res.append(TransferFunction.trans_derrf(x, arg))
assert len(set(res)) <= arg[0]


@given(st.floats(allow_nan=False, allow_infinity=False), valid_derrf_parameters())
def test_that_derrf_corresponds_scaled_binned_normal_cdf(x, arg):
"""Check correspondance to normal cdf with -mu=_skew and sd=_width"""
_steps, _min, _max, _skew, _width = arg
q_values = np.linspace(start=0, stop=1, num=_steps)
q_checks = np.linspace(start=0, stop=1, num=_steps + 1)[1:]
p = norm.cdf(x, loc=-_skew, scale=_width)
expected = q_values[-1] # last
for i in range(_steps - 1):
if p < q_checks[i]:
expected = q_values[i]
break
# scale and ensure ok numerics
expected = _min + expected * (_max - _min)
if expected > _max or expected < _min:
np.clip(expected, _min, _max)
result = TransferFunction.trans_derrf(x, arg)
assert np.isclose(result, expected)


@given(
st.tuples(
st.floats(allow_nan=False, allow_infinity=False),
Expand Down

0 comments on commit b2c4c02

Please sign in to comment.