From bcff96ded3777acaccb285a942ff5e54b2362ce5 Mon Sep 17 00:00:00 2001 From: Blunde1 Date: Tue, 26 Sep 2023 14:32:26 +0200 Subject: [PATCH] Rewrite trans_derrf to intended behaviour and robust numerics --- src/ert/config/gen_kw_config.py | 34 ++++++++--- .../config/test_transfer_functions.py | 58 +++++++++++++++++++ 2 files changed, 84 insertions(+), 8 deletions(-) diff --git a/src/ert/config/gen_kw_config.py b/src/ert/config/gen_kw_config.py index f77a5d3e0fd..32f4c9bacad 100644 --- a/src/ert/config/gen_kw_config.py +++ b/src/ert/config/gen_kw_config.py @@ -439,15 +439,33 @@ def trans_raw(x: float, _: List[float]) -> float: @staticmethod def trans_derrf(x: float, arg: List[float]) -> float: - '''Observe that the argument of the shift should be \"+\"''' - _steps, _min, _max, _skew, _width = int(arg[0]), arg[1], arg[2], arg[3], arg[4] - y = math.floor( - _steps - * 0.5 - * (1 + math.erf((x + _skew) / (_width * math.sqrt(2.0)))) - / (_steps - 1) + """ + Bin the result of `trans_errf` with `min=0` and `max=1` to closest of `nbins` + linearly spaced values on [0,1]. Finally map [0,1] to [min, max]. + """ + _steps, _min, _max, _skew, _width = ( + int(arg[0]), + arg[1], + arg[2], + arg[3], + arg[4], ) - return _min + y * (_max - _min) + q_values = np.linspace(start=0, stop=1, num=_steps) + q_checks = np.linspace(start=0, stop=1, num=_steps + 1)[1:] + y = TransferFunction.trans_errf(x, [0, 1, _skew, _width]) + y_binned = q_checks[-1] + for i in range(_steps - 1): + if y < q_checks[i]: + y_binned = q_values[i] + break + result = _min + y_binned * (_max - _min) + if result > _max or result < _min: + warnings.warn( + "trans_derff suffered from catastrophic loss of precision, clamping to min,max", + stacklevel=1, + ) + return np.clip(result, _min, _max) + return result @staticmethod def trans_unif(x: float, arg: List[float]) -> float: diff --git a/tests/unit_tests/config/test_transfer_functions.py b/tests/unit_tests/config/test_transfer_functions.py index 9a045d04da5..54d38be6d00 100644 --- a/tests/unit_tests/config/test_transfer_functions.py +++ b/tests/unit_tests/config/test_transfer_functions.py @@ -79,3 +79,61 @@ def test_that_truncated_normal_stretches(x, arg): return result = TransferFunction.trans_truncated_normal(x, arg) assert np.isclose(result, expected) + + +def valid_derrf_parameters(): + """All elements in R, min0""" + steps = st.integers(min_value=2, max_value=1000) + min_max = ( + st.tuples( + st.floats(allow_nan=False, allow_infinity=False), + st.floats(allow_nan=False, allow_infinity=False), + ) + .map(sorted) + .filter(lambda x: x[0] < x[1]) # filter out edge case of equality + ) + skew = st.floats(allow_nan=False, allow_infinity=False) + width = st.floats( + min_value=0.01, max_value=1e6, allow_nan=False, allow_infinity=False + ) + return min_max.flatmap( + lambda min_max: st.tuples( + steps, st.just(min_max[0]), st.just(min_max[1]), skew, width + ) + ) + + +@given(st.floats(allow_nan=False, allow_infinity=False), valid_derrf_parameters()) +def test_that_derrf_is_within_bounds(x, arg): + """The result shold always be between (or equal) min and max""" + result = TransferFunction.trans_derrf(x, arg) + assert arg[1] <= result <= arg[2] + + +@given( + st.lists(st.floats(allow_nan=False, allow_infinity=False), min_size=2), + valid_derrf_parameters(), +) +def test_that_derrf_creates_at_least_steps_number_of_distinct_values(xlist, arg): + """derrf cannot create more than steps distinct values""" + res = [] + for x in xlist: + res.append(TransferFunction.trans_derrf(x, arg)) + assert len(set(res)) <= arg[0] + + +@given( + st.tuples( + st.floats(allow_nan=False, allow_infinity=False), + st.floats(allow_nan=False, allow_infinity=False), + ) + .map(sorted) + .filter(lambda x: x[0] < x[1]), + valid_derrf_parameters(), +) +def test_that_derrf_is_non_strictly_monotone(x_tuple, arg): + """`derrf` is a non-strict monotone function""" + x1, x2 = x_tuple + assert TransferFunction.trans_derrf(x1, arg) <= TransferFunction.trans_derrf( + x2, arg + )