Skip to content

Commit

Permalink
Rewrite trans_derrf to intended behaviour and robust numerics
Browse files Browse the repository at this point in the history
  • Loading branch information
Blunde1 committed Sep 26, 2023
1 parent e9dfa15 commit bcff96d
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 8 deletions.
34 changes: 26 additions & 8 deletions src/ert/config/gen_kw_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,15 +439,33 @@ def trans_raw(x: float, _: List[float]) -> float:

@staticmethod
def trans_derrf(x: float, arg: List[float]) -> float:
'''Observe that the argument of the shift should be \"+\"'''
_steps, _min, _max, _skew, _width = int(arg[0]), arg[1], arg[2], arg[3], arg[4]
y = math.floor(
_steps
* 0.5
* (1 + math.erf((x + _skew) / (_width * math.sqrt(2.0))))
/ (_steps - 1)
"""
Bin the result of `trans_errf` with `min=0` and `max=1` to closest of `nbins`
linearly spaced values on [0,1]. Finally map [0,1] to [min, max].
"""
_steps, _min, _max, _skew, _width = (
int(arg[0]),
arg[1],
arg[2],
arg[3],
arg[4],
)
return _min + y * (_max - _min)
q_values = np.linspace(start=0, stop=1, num=_steps)
q_checks = np.linspace(start=0, stop=1, num=_steps + 1)[1:]
y = TransferFunction.trans_errf(x, [0, 1, _skew, _width])
y_binned = q_checks[-1]
for i in range(_steps - 1):
if y < q_checks[i]:
y_binned = q_values[i]
break
result = _min + y_binned * (_max - _min)
if result > _max or result < _min:
warnings.warn(
"trans_derff suffered from catastrophic loss of precision, clamping to min,max",
stacklevel=1,
)
return np.clip(result, _min, _max)
return result

@staticmethod
def trans_unif(x: float, arg: List[float]) -> float:
Expand Down
58 changes: 58 additions & 0 deletions tests/unit_tests/config/test_transfer_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,61 @@ def test_that_truncated_normal_stretches(x, arg):
return
result = TransferFunction.trans_truncated_normal(x, arg)
assert np.isclose(result, expected)


def valid_derrf_parameters():
"""All elements in R, min<max, and width>0"""
steps = st.integers(min_value=2, max_value=1000)
min_max = (
st.tuples(
st.floats(allow_nan=False, allow_infinity=False),
st.floats(allow_nan=False, allow_infinity=False),
)
.map(sorted)
.filter(lambda x: x[0] < x[1]) # filter out edge case of equality
)
skew = st.floats(allow_nan=False, allow_infinity=False)
width = st.floats(
min_value=0.01, max_value=1e6, allow_nan=False, allow_infinity=False
)
return min_max.flatmap(
lambda min_max: st.tuples(
steps, st.just(min_max[0]), st.just(min_max[1]), skew, width
)
)


@given(st.floats(allow_nan=False, allow_infinity=False), valid_derrf_parameters())
def test_that_derrf_is_within_bounds(x, arg):
"""The result shold always be between (or equal) min and max"""
result = TransferFunction.trans_derrf(x, arg)
assert arg[1] <= result <= arg[2]


@given(
st.lists(st.floats(allow_nan=False, allow_infinity=False), min_size=2),
valid_derrf_parameters(),
)
def test_that_derrf_creates_at_least_steps_number_of_distinct_values(xlist, arg):
"""derrf cannot create more than steps distinct values"""
res = []
for x in xlist:
res.append(TransferFunction.trans_derrf(x, arg))
assert len(set(res)) <= arg[0]


@given(
st.tuples(
st.floats(allow_nan=False, allow_infinity=False),
st.floats(allow_nan=False, allow_infinity=False),
)
.map(sorted)
.filter(lambda x: x[0] < x[1]),
valid_derrf_parameters(),
)
def test_that_derrf_is_non_strictly_monotone(x_tuple, arg):
"""`derrf` is a non-strict monotone function"""
x1, x2 = x_tuple
assert TransferFunction.trans_derrf(x1, arg) <= TransferFunction.trans_derrf(
x2, arg
)

0 comments on commit bcff96d

Please sign in to comment.