From c06e4f7915bf889a64a84200e2d90cf23044bcae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Fri, 18 Oct 2024 10:29:51 +0200 Subject: [PATCH] BUG: (slow, steady and) correct wins the race: prefer numpy over bottleneck for nanfunctions on float32 arrays --- astropy/stats/nanfunctions.py | 38 +++++++++++++++++++--- astropy/stats/sigma_clipping.py | 3 ++ astropy/stats/tests/test_sigma_clipping.py | 22 ++++++++++++- docs/changes/stats/17204.bugfix.rst | 4 +++ 4 files changed, 62 insertions(+), 5 deletions(-) create mode 100644 docs/changes/stats/17204.bugfix.rst diff --git a/astropy/stats/nanfunctions.py b/astropy/stats/nanfunctions.py index ab80c0d458e0..098aba060cf1 100644 --- a/astropy/stats/nanfunctions.py +++ b/astropy/stats/nanfunctions.py @@ -84,10 +84,40 @@ def _apply_bottleneck( else: return result - nansum = functools.partial(_apply_bottleneck, bottleneck.nansum) - nanmean = functools.partial(_apply_bottleneck, bottleneck.nanmean) - nanmedian = functools.partial(_apply_bottleneck, bottleneck.nanmedian) - nanstd = functools.partial(_apply_bottleneck, bottleneck.nanstd) + bn_funcs = dict( + nansum=functools.partial(_apply_bottleneck, bottleneck.nansum), + nanmean=functools.partial(_apply_bottleneck, bottleneck.nanmean), + nanmedian=functools.partial(_apply_bottleneck, bottleneck.nanmedian), + nanstd=functools.partial(_apply_bottleneck, bottleneck.nanstd), + ) + + np_funcs = dict( + nansum=np.nansum, + nanmean=np.nanmean, + nanmedian=np.nanmedian, + nanstd=np.nanstd, + ) + + def _dtype_dispatch(func_name): + # dispatch to bottleneck or numpy depending on the input array dtype + # this is done to workaround known accuracy bugs in bottleneck + # affecting float32 calculations + # see https://github.com/pydata/bottleneck/issues/379 + # see https://github.com/pydata/bottleneck/issues/462 + # see https://github.com/astropy/astropy/issues/17185 + # see https://github.com/astropy/astropy/issues/11492 + def wrapped(*args, **kwargs): + if args[0].dtype.str[1:] == "f8": + return bn_funcs[func_name](*args, **kwargs) + else: + return np_funcs[func_name](*args, **kwargs) + + return wrapped + + nansum = _dtype_dispatch("nansum") + nanmean = _dtype_dispatch("nanmean") + nanmedian = _dtype_dispatch("nanmedian") + nanstd = _dtype_dispatch("nanstd") else: nansum = np.nansum diff --git a/astropy/stats/sigma_clipping.py b/astropy/stats/sigma_clipping.py index cd2075e56c68..0acd0b0f607d 100644 --- a/astropy/stats/sigma_clipping.py +++ b/astropy/stats/sigma_clipping.py @@ -121,6 +121,7 @@ class SigmaClip: specified as a string. If one of the options is set to a string while the other has a custom callable, you may in some cases see better performance if you have the `bottleneck`_ package installed. + To preserve accuracy, bottleneck is only used for float64 computations. .. _bottleneck: https://github.com/pydata/bottleneck @@ -825,6 +826,7 @@ def sigma_clip( specified as a string. If one of the options is set to a string while the other has a custom callable, you may in some cases see better performance if you have the `bottleneck`_ package installed. + To preserve accuracy, bottleneck is only used for float64 computations. .. _bottleneck: https://github.com/pydata/bottleneck @@ -973,6 +975,7 @@ def sigma_clipped_stats( specified as a string. If one of the options is set to a string while the other has a custom callable, you may in some cases see better performance if you have the `bottleneck`_ package installed. + To preserve accuracy, bottleneck is only used for float64 computations. .. _bottleneck: https://github.com/pydata/bottleneck diff --git a/astropy/stats/tests/test_sigma_clipping.py b/astropy/stats/tests/test_sigma_clipping.py index 774cd31d9314..dcc52beab650 100644 --- a/astropy/stats/tests/test_sigma_clipping.py +++ b/astropy/stats/tests/test_sigma_clipping.py @@ -8,7 +8,8 @@ from astropy.stats import mad_std from astropy.stats.sigma_clipping import SigmaClip, sigma_clip, sigma_clipped_stats from astropy.table import MaskedColumn -from astropy.utils.compat.optional_deps import HAS_SCIPY +from astropy.utils.compat import COPY_IF_NEEDED +from astropy.utils.compat.optional_deps import HAS_BOTTLENECK, HAS_SCIPY from astropy.utils.exceptions import AstropyUserWarning from astropy.utils.misc import NumpyRNGContext @@ -173,6 +174,25 @@ def test_sigma_clipped_stats_masked_col(): sigma_clipped_stats(col) +@pytest.mark.slow +@pytest.mark.skipif( + not HAS_BOTTLENECK, + reason="test a workaround for upstream bug in bottleneck", +) +@pytest.mark.parametrize("shape", [(1024, 1024), (6388, 9576)]) +def test_sigma_clip_large_float32_arrays(shape): + # see https://github.com/astropy/astropy/issues/17185 + rng = np.random.default_rng(0) + + expected = (0.5, 0.5, 0.288) # mean, median, stddev + + arr = rng.random(size=shape, dtype="f4") + for byteorder in (">", "<"): + data = arr.astype(dtype=f"{byteorder}f4", copy=COPY_IF_NEEDED) + res = sigma_clipped_stats(data, sigma=3, maxiters=5) + assert_allclose(res, expected, rtol=3e-3) + + def test_invalid_sigma_clip(): """Test sigma_clip of data containing invalid values.""" diff --git a/docs/changes/stats/17204.bugfix.rst b/docs/changes/stats/17204.bugfix.rst new file mode 100644 index 000000000000..c6ef08b1d396 --- /dev/null +++ b/docs/changes/stats/17204.bugfix.rst @@ -0,0 +1,4 @@ +Fixed accuracy of sigma clipping for large ``float32`` arrays when +``bottleneck`` is installed. Performance may be impacted for computations +involving arrays with dtype other than ``float64``. This change has no impact +for environments that do not have ``bottleneck`` installed.