From 1fb6c7d99b06874194a3ba1033cad7824ad7e9ff Mon Sep 17 00:00:00 2001 From: Osvaldo A Martin Date: Sat, 10 Aug 2024 10:10:12 -0300 Subject: [PATCH] Add thinning (#18) * Add thinning * remove commented code * use dims * ignore dims if factor is auto * fix dims --- src/arviz_stats/accessors.py | 12 +++++++++++ src/arviz_stats/base/dataarray.py | 33 +++++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+) diff --git a/src/arviz_stats/accessors.py b/src/arviz_stats/accessors.py index dfafb0f..6c6a037 100644 --- a/src/arviz_stats/accessors.py +++ b/src/arviz_stats/accessors.py @@ -51,6 +51,10 @@ def kde(self, dims=None, **kwargs): """Compute the KDE on the DataArray.""" return get_function("kde")(self._obj, dims=dims, **kwargs) + def thin(self, factor="auto", dims=None, **kwargs): + """Perform thinning on the DataArray.""" + return get_function("thin")(self._obj, factor=factor, dims=dims, **kwargs) + @xr.register_dataset_accessor("azstats") class AzStatsDsAccessor(_BaseAccessor): @@ -135,6 +139,10 @@ def ecdf(self, dims=None, **kwargs): # TODO: implement ecdf here so it doesn't depend on numba return self._apply(ecdf, dims=dims, **kwargs).rename(ecdf_axis="plot_axis") + def thin(self, dims=None, factor="auto"): + """Perform thinning for all the variables in the dataset.""" + return self._apply(get_function("thin"), dims=dims, factor=factor) + @register_datatree_accessor("azstats") class AzStatsDtAccessor(_BaseAccessor): @@ -193,3 +201,7 @@ def kde(self, dims=None, group="posterior", **kwargs): def histogram(self, dims=None, group="posterior", **kwargs): """Compute the KDE for all variables in a group of the DataTree.""" return self._apply("histogram", dims=dims, group=group, **kwargs) + + def thin(self, dims=None, group="posterior", **kwargs): + """Perform thinning for all variables in a group of the DataTree.""" + return self._apply("thin", dims=dims, group=group, **kwargs) diff --git a/src/arviz_stats/base/dataarray.py b/src/arviz_stats/base/dataarray.py index eef6540..858a808 100644 --- a/src/arviz_stats/base/dataarray.py +++ b/src/arviz_stats/base/dataarray.py @@ -3,6 +3,8 @@ "dataarray" functions take :class:`xarray.DataArray` as inputs. """ +import warnings + import numpy as np from arviz_base import rcParams from xarray import DataArray, apply_ufunc, concat @@ -193,5 +195,36 @@ def kde(self, da, dims=None, circular=False, grid_len=512, **kwargs): out = concat((grid, pdf), dim=plot_axis) return out.assign_coords({"bw" if da.name is None else f"bw_{da.name}": bw}) + def thin(self, da, factor="auto", dims=None): + """Perform thinning on DataArray input.""" + if factor == "auto" and dims is not None: + warnings.warn("dims are ignored if factor is auto") + + if factor == "auto": + n_samples = da.sizes["chain"] * da.sizes["draw"] + ess_ave = np.minimum( + self.ess(da, method="bulk", dims=["chain", "draw"]), + self.ess(da, method="tail", dims=["chain", "draw"]), + ).mean() + factor = int(np.ceil(n_samples / ess_ave)) + dims = "draw" + + elif isinstance(factor, (float | int)): + if dims is None: + dims = rcParams["data.sample_dims"] + if not isinstance(dims, str): + if len(dims) >= 2: + raise ValueError("dims must be of length 1") + if len(dims) == 1: + dims = dims[0] + + factor = int(factor) + if factor == 1: + return da + if factor < 1: + raise ValueError("factor must be greater than 1") + + return da.sel({dims: slice(None, None, factor)}) + dataarray_stats = BaseDataArray(array_class=array_stats)