diff --git a/src/arviz_stats/accessors.py b/src/arviz_stats/accessors.py index e22fee9..d3e824c 100644 --- a/src/arviz_stats/accessors.py +++ b/src/arviz_stats/accessors.py @@ -35,6 +35,29 @@ def update_kwargs_with_dims(da, kwargs): return kwargs +def check_var_name_subset(obj, var_name): + if isinstance(obj, xr.Dataset): + return obj[var_name] + if isinstance(obj, DataTree): + return obj.ds[var_name] + return obj + + +def apply_function_to_dataset(func, ds, kwargs): + return xr.Dataset( + { + var_name: func( + da, + **{ + key: check_var_name_subset(value, var_name) + for key, value in update_kwargs_with_dims(da, kwargs).items() + }, + ) + for var_name, da in ds.items() + } + ) + + unset = UnsetDefault() @@ -121,12 +144,7 @@ def _apply(self, fun, **kwargs): """Apply a function to all variables subsetting dims to existing dimensions.""" if isinstance(fun, str): fun = get_function(fun) - return xr.Dataset( - { - var_name: fun(da, **update_kwargs_with_dims(da, kwargs)) - for var_name, da in self._obj.items() - } - ) + return apply_function_to_dataset(fun, self._obj, kwargs=kwargs) def eti(self, prob=None, dims=None, **kwargs): """Compute the equal tail interval of all the variables in the dataset.""" @@ -154,8 +172,12 @@ def kde(self, dims=None, **kwargs): """Compute the KDE for all variables in the dataset.""" return self._apply("kde", dims=dims, **kwargs) + def get_bins(self, dims=None, **kwargs): + """Compute the histogram bin edges for all variables in the dataset.""" + return self._apply(get_function("get_bins"), dims=dims, **kwargs) + def histogram(self, dims=None, **kwargs): - """Compute the KDE for all variables in the dataset.""" + """Compute the histogram for all variables in the dataset.""" return self._apply("histogram", dims=dims, **kwargs) def compute_ranks(self, dims=None, relative=False): @@ -219,25 +241,19 @@ def _process_input(self, group, method, allow_non_matching=True): f"and the DataTree itself is named {self._obs.name}" ) - def _apply(self, fun_name, group, **kwargs): + def _apply(self, func_name, group, **kwargs): hashable_group = False if isinstance(group, Hashable): group = [group] hashable_group = True out_dt = DataTree.from_dict( { - group_i: xr.Dataset( - { - var_name: get_function(fun_name)(da, **update_kwargs_with_dims(da, kwargs)) - for var_name, da in self._process_input( - # if group is a single str/hashable that doesn't match the group - # name, still allow it and apply the function to the top level of - # the provided datatree - group_i, - fun_name, - allow_non_matching=hashable_group, - ).items() - } + group_i: apply_function_to_dataset( + get_function(func_name), + # if group is a single str/hashable that doesn't match the group name, + # still allow it and apply the function to the top level of the provided input + self._process_input(group_i, func_name, allow_non_matching=hashable_group), + kwargs=kwargs, ) for group_i in group } diff --git a/src/arviz_stats/base/array.py b/src/arviz_stats/base/array.py index 3b0f28a..00c7c32 100644 --- a/src/arviz_stats/base/array.py +++ b/src/arviz_stats/base/array.py @@ -179,7 +179,7 @@ def compute_ranks(self, ary, axes=-1, relative=False): ) return compute_ranks_ufunc(ary, out_shape=(ary.shape[i] for i in axes), relative=relative) - def get_bins(self, ary, axes=-1): + def get_bins(self, ary, axes=-1, bins="arviz"): """Compute default bins.""" ary, axes = process_ary_axes(ary, axes) get_bininfo_ufunc = make_ufunc( @@ -188,10 +188,11 @@ def get_bins(self, ary, axes=-1): n_input=1, n_dims=len(axes), ) - x_min, x_max, width = get_bininfo_ufunc(ary) + # TODO: improve handling of array_like bins + x_min, x_max, width = get_bininfo_ufunc(ary, bins=bins) n_bins = np.ceil((x_max - x_min) / width) n_bins = np.ceil(np.mean(n_bins)).astype(int) - return np.moveaxis(np.linspace(x_min, x_max, n_bins), 0, -1) + return np.moveaxis(np.linspace(x_min, x_max, n_bins + 1), 0, -1) # pylint: disable=redefined-builtin, too-many-return-statements # noqa: PLR0911 diff --git a/src/arviz_stats/base/core.py b/src/arviz_stats/base/core.py index 4f316f5..6c83bc2 100644 --- a/src/arviz_stats/base/core.py +++ b/src/arviz_stats/base/core.py @@ -137,9 +137,15 @@ def _compute_ranks(self, ary, relative=False): return out / out.size return out - def _get_bininfo(self, values): + def _get_bininfo(self, values, bins="arviz"): dtype = values.dtype.kind + if isinstance(bins, str) and bins != "arviz": + bins = np.histogram_bin_edges(values, bins=bins) + + if isinstance(bins, np.ndarray): + return bins[0], bins[-1], bins[1] - bins[0] + if dtype == "i": x_min = values.min().astype(int) x_max = values.max().astype(int) @@ -147,6 +153,12 @@ def _get_bininfo(self, values): x_min = values.min().astype(float) x_max = values.max().astype(float) + if isinstance(bins, int): + width = (x_max - x_min) / bins + if dtype == "i": + width = max(1, width) + return x_min, x_max, width + # Sturges histogram bin estimator width_sturges = (x_max - x_min) / (np.log2(values.size) + 1) @@ -161,13 +173,20 @@ def _get_bininfo(self, values): return x_min, x_max, width - def _get_bins(self, values): + def _get_bins(self, values, bins="arviz"): """ Automatically compute the number of bins for histograms. Parameters ---------- - values = array_like + values : array_like + bins : int, str or array_like, default "arviz" + If `bins` "arviz", use ArviZ default rule (explained in detail in notes), + if it is a different string it is passed to :func:`numpy.histogram_bin_edges`. + If `bins` is an integer it is interpreted as the number of bins, however, + if `values` holds discrete data, there is an extra check to prevent + the width of the bins to be smaller than ``1``. + If it is an array it is returned as it. Returns ------- @@ -190,7 +209,7 @@ def _get_bins(self, values): """ dtype = values.dtype.kind - x_min, x_max, width = self._get_bininfo(values) + x_min, x_max, width = self._get_bininfo(values, bins) if dtype == "i": bins = np.arange(x_min, x_max + width + 1, width) diff --git a/src/arviz_stats/base/dataarray.py b/src/arviz_stats/base/dataarray.py index 5008c34..c5b0815 100644 --- a/src/arviz_stats/base/dataarray.py +++ b/src/arviz_stats/base/dataarray.py @@ -115,17 +115,31 @@ def mcse(self, da, dims=None, method="mean", prob=None): }, ) + def get_bins(self, da, dims=None, bins="arviz"): + """Compute bins or align provided ones with DataArray input.""" + dims = validate_dims(dims) + return apply_ufunc( + self.array_class.get_bins, + da, + input_core_dims=[dims], + output_core_dims=[["edges_dim" if da.name is None else f"edges_dim_{da.name}"]], + kwargs={ + "bins": bins, + "axes": np.arange(-len(dims), 0, 1), + }, + ) + # pylint: disable=redefined-builtin def histogram(self, da, dims=None, bins=None, range=None, weights=None, density=None): """Compute histogram on DataArray input.""" dims = validate_dims(dims) - edges_dim = "edges_dim" - hist_dim = "hist_dim" + edges_dim = "edges_dim" if da.name is None else f"edges_dim_{da.name}" + hist_dim = "hist_dim" if da.name is None else f"hist_dim_{da.name}" input_core_dims = [dims] if isinstance(bins, DataArray): - bins_dims = [dim for dim in bins.dims if dim not in dims + ["plot_axis"]] - assert len(bins_dims) == 1 if "plot_axis" in bins.dims: + bins_dims = [dim for dim in bins.dims if dim not in dims + ["plot_axis"]] + assert len(bins_dims) == 1 hist_dim = bins_dims[0] bins = ( concat( @@ -138,8 +152,11 @@ def histogram(self, da, dims=None, bins=None, range=None, weights=None, density= .rename({hist_dim: edges_dim}) .drop_vars(edges_dim) ) - else: - edges_dim = bins_dims[0] + elif edges_dim not in bins.dims: + raise ValueError( + "Invalid 'bins' DataArray, it should contain either 'plot_axis' or " + f"'{edges_dim}' dimension" + ) input_core_dims.append([edges_dim]) else: input_core_dims.append([])