diff --git a/src/ert/config/gen_kw_config.py b/src/ert/config/gen_kw_config.py index f77a5d3e0fd..895bb670109 100644 --- a/src/ert/config/gen_kw_config.py +++ b/src/ert/config/gen_kw_config.py @@ -427,6 +427,13 @@ def trans_errf(x: float, arg: List[float]) -> float: """ _min, _max, _skew, _width = arg[0], arg[1], arg[2], arg[3] y = 0.5 * (1 + math.erf((x + _skew) / (_width * math.sqrt(2.0)))) + if np.isnan(y): + raise ValueError( + ( + "Output is nan, likely from triplet (x, skewness, width) " + "leading to low/high-probability in normal CDF." + ) + ) return _min + y * (_max - _min) @staticmethod @@ -439,15 +446,34 @@ def trans_raw(x: float, _: List[float]) -> float: @staticmethod def trans_derrf(x: float, arg: List[float]) -> float: - '''Observe that the argument of the shift should be \"+\"''' - _steps, _min, _max, _skew, _width = int(arg[0]), arg[1], arg[2], arg[3], arg[4] - y = math.floor( - _steps - * 0.5 - * (1 + math.erf((x + _skew) / (_width * math.sqrt(2.0)))) - / (_steps - 1) + """ + Bin the result of `trans_errf` with `min=0` and `max=1` to closest of `nbins` + linearly spaced values on [0,1]. Finally map [0,1] to [min, max]. + """ + _steps, _min, _max, _skew, _width = ( + int(arg[0]), + arg[1], + arg[2], + arg[3], + arg[4], ) - return _min + y * (_max - _min) + q_values = np.linspace(start=0, stop=1, num=_steps) + q_checks = np.linspace(start=0, stop=1, num=_steps + 1)[1:] + y = TransferFunction.trans_errf(x, [0, 1, _skew, _width]) + bin_index = np.digitize(y, q_checks, right=True) + y_binned = q_values[bin_index] + result = _min + y_binned * (_max - _min) + if result > _max or result < _min: + warnings.warn( + "trans_derff suffered from catastrophic loss of precision, clamping to min,max", + stacklevel=1, + ) + return np.clip(result, _min, _max) + if np.isnan(result): + raise ValueError( + "trans_derrf returns nan, check that input arguments are reasonable" + ) + return result @staticmethod def trans_unif(x: float, arg: List[float]) -> float: diff --git a/tests/unit_tests/config/test_gen_kw_config.py b/tests/unit_tests/config/test_gen_kw_config.py index 964960e6b06..c0e0f3f31e5 100644 --- a/tests/unit_tests/config/test_gen_kw_config.py +++ b/tests/unit_tests/config/test_gen_kw_config.py @@ -358,11 +358,6 @@ def test_gen_kw_params_parsing(tmpdir, params, error): ("MYNAME ERRF 1 2 0.1 0.1", 0.3, 1.99996832875816688002), ("MYNAME ERRF 1 2 0.1 0.1", 0.7, 1.99999999999999933387), ("MYNAME ERRF 1 2 0.1 0.1", 1.0, 2.00000000000000000000), - ("MYNAME DERRF 10 1 2 0.1 0.1", -1.0, 1.00000000000000000000), - ("MYNAME DERRF 10 1 2 0.1 0.1", 0.0, 1.00000000000000000000), - ("MYNAME DERRF 10 1 2 0.1 0.1", 0.3, 2.00000000000000000000), - ("MYNAME DERRF 10 1 2 0.1 0.1", 0.7, 2.00000000000000000000), - ("MYNAME DERRF 10 1 2 0.1 0.1", 1.0, 2.00000000000000000000), ], ) def test_gen_kw_trans_func(tmpdir, params, xinput, expected): diff --git a/tests/unit_tests/config/test_transfer_functions.py b/tests/unit_tests/config/test_transfer_functions.py index 9a045d04da5..c366111f83c 100644 --- a/tests/unit_tests/config/test_transfer_functions.py +++ b/tests/unit_tests/config/test_transfer_functions.py @@ -1,6 +1,7 @@ import numpy as np from hypothesis import given from hypothesis import strategies as st +from scipy.stats import norm from ert.config import TransferFunction @@ -79,3 +80,80 @@ def test_that_truncated_normal_stretches(x, arg): return result = TransferFunction.trans_truncated_normal(x, arg) assert np.isclose(result, expected) + + +def valid_derrf_parameters(): + """All elements in R, min0""" + steps = st.integers(min_value=2, max_value=1000) + min_max = ( + st.tuples( + st.floats( + min_value=-1e6, max_value=1e6, allow_nan=False, allow_infinity=False + ), + st.floats( + min_value=-1e6, max_value=1e6, allow_nan=False, allow_infinity=False + ), + ) + .map(sorted) + .filter(lambda x: x[0] < x[1]) # filter out edge case of equality + ) + skew = st.floats(allow_nan=False, allow_infinity=False) + width = st.floats( + min_value=0.01, max_value=1e6, allow_nan=False, allow_infinity=False + ) + return min_max.flatmap( + lambda min_max: st.tuples( + steps, st.just(min_max[0]), st.just(min_max[1]), skew, width + ) + ) + + +@given(st.floats(allow_nan=False, allow_infinity=False), valid_derrf_parameters()) +def test_that_derrf_is_within_bounds(x, arg): + """The result shold always be between (or equal) min and max""" + result = TransferFunction.trans_derrf(x, arg) + assert arg[1] <= result <= arg[2] + + +@given( + st.lists(st.floats(allow_nan=False, allow_infinity=False), min_size=2), + valid_derrf_parameters(), +) +def test_that_derrf_creates_at_least_steps_or_less_distinct_values(xlist, arg): + """derrf cannot create more than steps distinct values""" + res = [TransferFunction.trans_derrf(x, arg) for x in xlist] + assert len(set(res)) <= arg[0] + + +@given(st.floats(allow_nan=False, allow_infinity=False), valid_derrf_parameters()) +def test_that_derrf_corresponds_scaled_binned_normal_cdf(x, arg): + """Check correspondance to normal cdf with -mu=_skew and sd=_width""" + _steps, _min, _max, _skew, _width = arg + q_values = np.linspace(start=0, stop=1, num=_steps) + q_checks = np.linspace(start=0, stop=1, num=_steps + 1)[1:] + p = norm.cdf(x, loc=-_skew, scale=_width) + bin_index = np.digitize(p, q_checks, right=True) + expected = q_values[bin_index] + # scale and ensure ok numerics + expected = _min + expected * (_max - _min) + if expected > _max or expected < _min: + np.clip(expected, _min, _max) + result = TransferFunction.trans_derrf(x, arg) + assert np.isclose(result, expected) + + +@given( + st.tuples( + st.floats(allow_nan=False, allow_infinity=False), + st.floats(allow_nan=False, allow_infinity=False), + ) + .map(sorted) + .filter(lambda x: x[0] < x[1]), + valid_derrf_parameters(), +) +def test_that_derrf_is_non_strictly_monotone(x_tuple, arg): + """`derrf` is a non-strict monotone function""" + x1, x2 = x_tuple + assert TransferFunction.trans_derrf(x1, arg) <= TransferFunction.trans_derrf( + x2, arg + )