Skip to content

Commit

Permalink
get_bins dataarray version (#21)
Browse files Browse the repository at this point in the history
* get_bins dataarray draft

* behaviour fixes

* pre-commit

* remove repeated .items()

* ruff
  • Loading branch information
OriolAbril authored Oct 25, 2024
1 parent feb3084 commit 5b110df
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 33 deletions.
56 changes: 36 additions & 20 deletions src/arviz_stats/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,29 @@ def update_kwargs_with_dims(da, kwargs):
return kwargs


def check_var_name_subset(obj, var_name):
if isinstance(obj, xr.Dataset):
return obj[var_name]
if isinstance(obj, DataTree):
return obj.ds[var_name]
return obj


def apply_function_to_dataset(func, ds, kwargs):
return xr.Dataset(
{
var_name: func(
da,
**{
key: check_var_name_subset(value, var_name)
for key, value in update_kwargs_with_dims(da, kwargs).items()
},
)
for var_name, da in ds.items()
}
)


unset = UnsetDefault()


Expand Down Expand Up @@ -121,12 +144,7 @@ def _apply(self, fun, **kwargs):
"""Apply a function to all variables subsetting dims to existing dimensions."""
if isinstance(fun, str):
fun = get_function(fun)
return xr.Dataset(
{
var_name: fun(da, **update_kwargs_with_dims(da, kwargs))
for var_name, da in self._obj.items()
}
)
return apply_function_to_dataset(fun, self._obj, kwargs=kwargs)

def eti(self, prob=None, dims=None, **kwargs):
"""Compute the equal tail interval of all the variables in the dataset."""
Expand Down Expand Up @@ -154,8 +172,12 @@ def kde(self, dims=None, **kwargs):
"""Compute the KDE for all variables in the dataset."""
return self._apply("kde", dims=dims, **kwargs)

def get_bins(self, dims=None, **kwargs):
"""Compute the histogram bin edges for all variables in the dataset."""
return self._apply(get_function("get_bins"), dims=dims, **kwargs)

def histogram(self, dims=None, **kwargs):
"""Compute the KDE for all variables in the dataset."""
"""Compute the histogram for all variables in the dataset."""
return self._apply("histogram", dims=dims, **kwargs)

def compute_ranks(self, dims=None, relative=False):
Expand Down Expand Up @@ -219,25 +241,19 @@ def _process_input(self, group, method, allow_non_matching=True):
f"and the DataTree itself is named {self._obs.name}"
)

def _apply(self, fun_name, group, **kwargs):
def _apply(self, func_name, group, **kwargs):
hashable_group = False
if isinstance(group, Hashable):
group = [group]
hashable_group = True
out_dt = DataTree.from_dict(
{
group_i: xr.Dataset(
{
var_name: get_function(fun_name)(da, **update_kwargs_with_dims(da, kwargs))
for var_name, da in self._process_input(
# if group is a single str/hashable that doesn't match the group
# name, still allow it and apply the function to the top level of
# the provided datatree
group_i,
fun_name,
allow_non_matching=hashable_group,
).items()
}
group_i: apply_function_to_dataset(
get_function(func_name),
# if group is a single str/hashable that doesn't match the group name,
# still allow it and apply the function to the top level of the provided input
self._process_input(group_i, func_name, allow_non_matching=hashable_group),
kwargs=kwargs,
)
for group_i in group
}
Expand Down
7 changes: 4 additions & 3 deletions src/arviz_stats/base/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def compute_ranks(self, ary, axes=-1, relative=False):
)
return compute_ranks_ufunc(ary, out_shape=(ary.shape[i] for i in axes), relative=relative)

def get_bins(self, ary, axes=-1):
def get_bins(self, ary, axes=-1, bins="arviz"):
"""Compute default bins."""
ary, axes = process_ary_axes(ary, axes)
get_bininfo_ufunc = make_ufunc(
Expand All @@ -188,10 +188,11 @@ def get_bins(self, ary, axes=-1):
n_input=1,
n_dims=len(axes),
)
x_min, x_max, width = get_bininfo_ufunc(ary)
# TODO: improve handling of array_like bins
x_min, x_max, width = get_bininfo_ufunc(ary, bins=bins)
n_bins = np.ceil((x_max - x_min) / width)
n_bins = np.ceil(np.mean(n_bins)).astype(int)
return np.moveaxis(np.linspace(x_min, x_max, n_bins), 0, -1)
return np.moveaxis(np.linspace(x_min, x_max, n_bins + 1), 0, -1)

# pylint: disable=redefined-builtin, too-many-return-statements
# noqa: PLR0911
Expand Down
27 changes: 23 additions & 4 deletions src/arviz_stats/base/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,16 +137,28 @@ def _compute_ranks(self, ary, relative=False):
return out / out.size
return out

def _get_bininfo(self, values):
def _get_bininfo(self, values, bins="arviz"):
dtype = values.dtype.kind

if isinstance(bins, str) and bins != "arviz":
bins = np.histogram_bin_edges(values, bins=bins)

if isinstance(bins, np.ndarray):
return bins[0], bins[-1], bins[1] - bins[0]

if dtype == "i":
x_min = values.min().astype(int)
x_max = values.max().astype(int)
else:
x_min = values.min().astype(float)
x_max = values.max().astype(float)

if isinstance(bins, int):
width = (x_max - x_min) / bins
if dtype == "i":
width = max(1, width)
return x_min, x_max, width

# Sturges histogram bin estimator
width_sturges = (x_max - x_min) / (np.log2(values.size) + 1)

Expand All @@ -161,13 +173,20 @@ def _get_bininfo(self, values):

return x_min, x_max, width

def _get_bins(self, values):
def _get_bins(self, values, bins="arviz"):
"""
Automatically compute the number of bins for histograms.
Parameters
----------
values = array_like
values : array_like
bins : int, str or array_like, default "arviz"
If `bins` "arviz", use ArviZ default rule (explained in detail in notes),
if it is a different string it is passed to :func:`numpy.histogram_bin_edges`.
If `bins` is an integer it is interpreted as the number of bins, however,
if `values` holds discrete data, there is an extra check to prevent
the width of the bins to be smaller than ``1``.
If it is an array it is returned as it.
Returns
-------
Expand All @@ -190,7 +209,7 @@ def _get_bins(self, values):
"""
dtype = values.dtype.kind

x_min, x_max, width = self._get_bininfo(values)
x_min, x_max, width = self._get_bininfo(values, bins)

if dtype == "i":
bins = np.arange(x_min, x_max + width + 1, width)
Expand Down
29 changes: 23 additions & 6 deletions src/arviz_stats/base/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,17 +115,31 @@ def mcse(self, da, dims=None, method="mean", prob=None):
},
)

def get_bins(self, da, dims=None, bins="arviz"):
"""Compute bins or align provided ones with DataArray input."""
dims = validate_dims(dims)
return apply_ufunc(
self.array_class.get_bins,
da,
input_core_dims=[dims],
output_core_dims=[["edges_dim" if da.name is None else f"edges_dim_{da.name}"]],
kwargs={
"bins": bins,
"axes": np.arange(-len(dims), 0, 1),
},
)

# pylint: disable=redefined-builtin
def histogram(self, da, dims=None, bins=None, range=None, weights=None, density=None):
"""Compute histogram on DataArray input."""
dims = validate_dims(dims)
edges_dim = "edges_dim"
hist_dim = "hist_dim"
edges_dim = "edges_dim" if da.name is None else f"edges_dim_{da.name}"
hist_dim = "hist_dim" if da.name is None else f"hist_dim_{da.name}"
input_core_dims = [dims]
if isinstance(bins, DataArray):
bins_dims = [dim for dim in bins.dims if dim not in dims + ["plot_axis"]]
assert len(bins_dims) == 1
if "plot_axis" in bins.dims:
bins_dims = [dim for dim in bins.dims if dim not in dims + ["plot_axis"]]
assert len(bins_dims) == 1
hist_dim = bins_dims[0]
bins = (
concat(
Expand All @@ -138,8 +152,11 @@ def histogram(self, da, dims=None, bins=None, range=None, weights=None, density=
.rename({hist_dim: edges_dim})
.drop_vars(edges_dim)
)
else:
edges_dim = bins_dims[0]
elif edges_dim not in bins.dims:
raise ValueError(
"Invalid 'bins' DataArray, it should contain either 'plot_axis' or "
f"'{edges_dim}' dimension"
)
input_core_dims.append([edges_dim])
else:
input_core_dims.append([])
Expand Down

0 comments on commit 5b110df

Please sign in to comment.