Skip to content

Commit

Permalink
Add thinning (#18)
Browse files Browse the repository at this point in the history
* Add thinning

* remove commented code

* use dims

* ignore dims if factor is auto

* fix dims
  • Loading branch information
aloctavodia authored Aug 10, 2024
1 parent 11719dc commit 1fb6c7d
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 0 deletions.
12 changes: 12 additions & 0 deletions src/arviz_stats/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ def kde(self, dims=None, **kwargs):
"""Compute the KDE on the DataArray."""
return get_function("kde")(self._obj, dims=dims, **kwargs)

def thin(self, factor="auto", dims=None, **kwargs):
"""Perform thinning on the DataArray."""
return get_function("thin")(self._obj, factor=factor, dims=dims, **kwargs)


@xr.register_dataset_accessor("azstats")
class AzStatsDsAccessor(_BaseAccessor):
Expand Down Expand Up @@ -135,6 +139,10 @@ def ecdf(self, dims=None, **kwargs):
# TODO: implement ecdf here so it doesn't depend on numba
return self._apply(ecdf, dims=dims, **kwargs).rename(ecdf_axis="plot_axis")

def thin(self, dims=None, factor="auto"):
"""Perform thinning for all the variables in the dataset."""
return self._apply(get_function("thin"), dims=dims, factor=factor)


@register_datatree_accessor("azstats")
class AzStatsDtAccessor(_BaseAccessor):
Expand Down Expand Up @@ -193,3 +201,7 @@ def kde(self, dims=None, group="posterior", **kwargs):
def histogram(self, dims=None, group="posterior", **kwargs):
"""Compute the KDE for all variables in a group of the DataTree."""
return self._apply("histogram", dims=dims, group=group, **kwargs)

def thin(self, dims=None, group="posterior", **kwargs):
"""Perform thinning for all variables in a group of the DataTree."""
return self._apply("thin", dims=dims, group=group, **kwargs)
33 changes: 33 additions & 0 deletions src/arviz_stats/base/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
"dataarray" functions take :class:`xarray.DataArray` as inputs.
"""

import warnings

import numpy as np
from arviz_base import rcParams
from xarray import DataArray, apply_ufunc, concat
Expand Down Expand Up @@ -193,5 +195,36 @@ def kde(self, da, dims=None, circular=False, grid_len=512, **kwargs):
out = concat((grid, pdf), dim=plot_axis)
return out.assign_coords({"bw" if da.name is None else f"bw_{da.name}": bw})

def thin(self, da, factor="auto", dims=None):
"""Perform thinning on DataArray input."""
if factor == "auto" and dims is not None:
warnings.warn("dims are ignored if factor is auto")

if factor == "auto":
n_samples = da.sizes["chain"] * da.sizes["draw"]
ess_ave = np.minimum(
self.ess(da, method="bulk", dims=["chain", "draw"]),
self.ess(da, method="tail", dims=["chain", "draw"]),
).mean()
factor = int(np.ceil(n_samples / ess_ave))
dims = "draw"

elif isinstance(factor, (float | int)):
if dims is None:
dims = rcParams["data.sample_dims"]
if not isinstance(dims, str):
if len(dims) >= 2:
raise ValueError("dims must be of length 1")
if len(dims) == 1:
dims = dims[0]

factor = int(factor)
if factor == 1:
return da
if factor < 1:
raise ValueError("factor must be greater than 1")

return da.sel({dims: slice(None, None, factor)})


dataarray_stats = BaseDataArray(array_class=array_stats)

0 comments on commit 1fb6c7d

Please sign in to comment.