diff --git a/src/arviz_stats/__init__.py b/src/arviz_stats/__init__.py index 54c1bcc..b791f63 100644 --- a/src/arviz_stats/__init__.py +++ b/src/arviz_stats/__init__.py @@ -4,8 +4,7 @@ try: from arviz_stats.utils import * from arviz_stats.accessors import * + from arviz_stats.psens import psens, psens_summary except ModuleNotFoundError: pass - -from arviz_stats.psens import * diff --git a/src/arviz_stats/accessors.py b/src/arviz_stats/accessors.py index 89ee83f..3b2f4c9 100644 --- a/src/arviz_stats/accessors.py +++ b/src/arviz_stats/accessors.py @@ -1,6 +1,7 @@ """ArviZ stats accessors.""" import warnings +from collections.abc import Hashable import xarray as xr from arviz_base.utils import _var_names @@ -159,6 +160,14 @@ def pareto_min_ss(self, dims=None): """Compute the min sample size for all variables in the dataset.""" return self._apply("pareto_min_ss", dims=dims) + def power_scale_lw(self, dims=None, **kwargs): + """Compute log weights for power-scaling of the DataTree.""" + return self._apply("power_scale_lw", dims=dims, **kwargs) + + def power_scale_sens(self, dims=None, **kwargs): + """Compute power-scaling sensitivity.""" + return self._apply("power_scale_sens", dims=dims, **kwargs) + @register_datatree_accessor("azstats") class AzStatsDtAccessor(_BaseAccessor): @@ -176,9 +185,11 @@ def _process_input(self, group, method): return self._obj def _apply(self, fun_name, dims, group, **kwargs): - if isinstance(group, str): + hashable_group = False + if isinstance(group, Hashable): group = [group] - return DataTree.from_dict( + hashable_group = True + out_dt = DataTree.from_dict( { group_i: xr.Dataset( { @@ -189,6 +200,13 @@ def _apply(self, fun_name, dims, group, **kwargs): for group_i in group } ) + if hashable_group: + # if group was a string/hashable, return a datatree with a single node + # (from the provided group) as the root of the DataTree + return out_dt[group[0]] + # if group was a sequence, return a DataTree with multiple groups in the 1st level, + # as many groups as requested + return out_dt def filter_vars(self, group="posterior", var_names=None, filter_vars=None): """Access and filter variables of the provided group.""" diff --git a/src/arviz_stats/base/array.py b/src/arviz_stats/base/array.py index 7f2b580..e617c6d 100644 --- a/src/arviz_stats/base/array.py +++ b/src/arviz_stats/base/array.py @@ -149,6 +149,7 @@ def power_scale_lw(self, ary, alpha=0, axes=-1): n_output=1, n_input=1, n_dims=len(axes), + ravel=False, ) return psl_ufunc(ary, out_shape=(ary.shape[i] for i in axes), alpha=alpha) @@ -156,10 +157,14 @@ def power_scale_sens(self, ary, lower_w, upper_w, delta, chain_axis=-2, draw_axi """Compute power-scaling sensitivity.""" if chain_axis is None: ary = np.expand_dims(ary, axis=0) + lower_w = np.expand_dims(lower_w, axis=0) + upper_w = np.expand_dims(upper_w, axis=0) chain_axis = 0 ary, _ = process_ary_axes(ary, [chain_axis, draw_axis]) - pss_array = make_ufunc(self._power_scale_sens, n_output=1, n_input=1, n_dims=2, ravel=False) - return pss_array(ary, lower_w=lower_w, upper_w=upper_w, delta=delta) + lower_w, _ = process_ary_axes(lower_w, [chain_axis, draw_axis]) + upper_w, _ = process_ary_axes(upper_w, [chain_axis, draw_axis]) + pss_array = make_ufunc(self._power_scale_sens, n_output=1, n_input=3, n_dims=2, ravel=False) + return pss_array(ary, lower_w, upper_w, delta=delta) def compute_ranks(self, ary, axes=-1, relative=False): """Compute ranks of MCMC samples.""" diff --git a/src/arviz_stats/base/dataarray.py b/src/arviz_stats/base/dataarray.py index 7fe8cb2..3689088 100644 --- a/src/arviz_stats/base/dataarray.py +++ b/src/arviz_stats/base/dataarray.py @@ -7,7 +7,7 @@ import numpy as np from arviz_base import rcParams -from xarray import DataArray, apply_ufunc, concat +from xarray import DataArray, apply_ufunc, broadcast, concat from xarray_einstats.stats import _apply_nonreduce_func from arviz_stats.base.array import array_stats @@ -243,28 +243,26 @@ def pareto_min_ss(self, da, dims=None): def power_scale_lw(self, da, alpha=0, dims=None): """Compute log weights for power-scaling component by alpha.""" - if dims is None: - dims = rcParams["data.sample_dims"] + dims = validate_dims(dims) return apply_ufunc( self.array_class.power_scale_lw, da, alpha, input_core_dims=[dims, []], output_core_dims=[dims], + kwargs={"axes": np.arange(-len(dims), 0, 1)}, ) def power_scale_sens(self, da, lower_w, upper_w, delta, dims=None): """Compute power-scaling sensitivity.""" - if dims is None: - dims = rcParams["data.sample_dims"] + dims, chain_axis, draw_axis = validate_dims_chain_draw_axis(dims) return apply_ufunc( self.array_class.power_scale_sens, - da, - lower_w, - upper_w, + *broadcast(da, lower_w, upper_w), delta, - input_core_dims=[dims, [], [], []], + input_core_dims=[dims, dims, dims, []], output_core_dims=[[]], + kwargs={"chain_axis": chain_axis, "draw_axis": draw_axis}, ) diff --git a/src/arviz_stats/base/diagnostics.py b/src/arviz_stats/base/diagnostics.py index b4ce603..884a412 100644 --- a/src/arviz_stats/base/diagnostics.py +++ b/src/arviz_stats/base/diagnostics.py @@ -550,7 +550,7 @@ def _gpinv(probs, kappa, sigma, mu): return q - def _power_scale_sens(self, ary, lower_w=None, upper_w=None, delta=0.01): + def _power_scale_sens(self, ary, lower_w, upper_w, delta=0.01): """Compute power-scaling sensitivity by finite difference second derivative of CJS.""" ary = np.ravel(ary) lower_w = np.ravel(lower_w) @@ -562,6 +562,7 @@ def _power_scale_sens(self, ary, lower_w=None, upper_w=None, delta=0.01): def _power_scale_lw(self, ary, alpha): """Compute log weights for power-scaling component by alpha.""" + shape = ary.shape ary = np.ravel(ary) log_weights = (alpha - 1) * ary n_draws = len(log_weights) @@ -575,7 +576,7 @@ def _power_scale_lw(self, ary, alpha): log_weights=True, ) - return log_weights + return log_weights.reshape(shape) @staticmethod def _cjs_dist(ary, weights): diff --git a/src/arviz_stats/psens.py b/src/arviz_stats/psens.py index abb57d7..c9fc01e 100644 --- a/src/arviz_stats/psens.py +++ b/src/arviz_stats/psens.py @@ -1,6 +1,7 @@ """Power-scaling sensitivity diagnostics.""" import warnings +from collections.abc import Hashable from typing import cast import numpy as np @@ -10,12 +11,17 @@ from arviz_base.labels import BaseLabeller from arviz_base.sel_utils import xarray_var_iter +from arviz_stats.validate import validate_dims + labeller = BaseLabeller() +__all__ = ["psens", "psens_summary"] + def psens( dt, group="log_prior", + sample_dims=None, component_var_names=None, component_coords=None, var_names=None, @@ -75,44 +81,45 @@ def psens( .. [1] Kallioinen et al, *Detecting and diagnosing prior and likelihood sensitivity with power-scaling*, 2022, https://arxiv.org/abs/2107.14054 """ - dataset = extract(dt, var_names=var_names, filter_vars=filter_vars, group="posterior") - if coords is None: + dataset = extract( + dt, var_names=var_names, filter_vars=filter_vars, group="posterior", combined=False + ) + sample_dims = validate_dims(sample_dims) + if coords is not None: dataset = dataset.sel(coords) if group == "log_likelihood": - component_draws = get_log_likelihood(dt, var_name=component_var_names, single_var=False) + component_draws = get_log_likelihood_dataset(dt, var_names=component_var_names) elif group == "log_prior": component_draws = get_log_prior(dt, var_names=component_var_names) else: raise ValueError("Value for `group` argument not recognized") - component_draws = component_draws.stack(__sample__=("chain", "draw")) - if component_coords is None: + if component_coords is not None: component_draws = component_draws.sel(component_coords) - if isinstance(component_draws, xr.DataArray): - component_draws = component_draws.to_dataset() - if len(component_draws.dims): - component_draws = component_draws.to_stacked_array( - "latent-obs_var", sample_dims=("__sample__",) - ).sum("latent-obs_var") - - component_draws = component_draws.unstack() + # we stack the different variables (if any) and dimensions in each variable (if any) + # into a flat dimension "latent-obs_var", over which we sum afterwards. + # Consequently, after this component_draws draws is a dataarray with only sample_dims as dims + component_draws = component_draws.to_stacked_array( + "latent-obs_var", sample_dims=sample_dims + ).sum("latent-obs_var") # calculate lower and upper alpha values lower_alpha = 1 / (1 + delta) upper_alpha = 1 + delta # calculate importance sampling weights for lower and upper alpha power-scaling - lower_w = np.exp(component_draws.azstats.power_scale_lw(alpha=lower_alpha)).values.flatten() - lower_w = lower_w / np.sum(lower_w) + lower_w = np.exp(component_draws.azstats.power_scale_lw(alpha=lower_alpha, dims=sample_dims)) + lower_w = lower_w / lower_w.sum(sample_dims) - upper_w = np.exp(component_draws.azstats.power_scale_lw(alpha=upper_alpha)).values.flatten() - upper_w = upper_w / np.sum(upper_w) + upper_w = np.exp(component_draws.azstats.power_scale_lw(alpha=upper_alpha, dims=sample_dims)) + upper_w = upper_w / upper_w.sum(sample_dims) - return dt.azstats.power_scale_sens( + return dataset.azstats.power_scale_sens( lower_w=lower_w, upper_w=upper_w, delta=delta, + dims=sample_dims, ) @@ -138,8 +145,8 @@ def psens_summary(data, threshold=0.05, round_to=3): and the likelihood sensitivity is below the threshold - "-" otherwise """ - pssdp = psens(data, group="log_prior")["posterior"].to_dataset() - pssdl = psens(data, group="log_likelihood")["posterior"].to_dataset() + pssdp = psens(data, group="log_prior") + pssdl = psens(data, group="log_likelihood") joined = xr.concat([pssdp, pssdl], dim="component").assign_coords( component=["prior", "likelihood"] @@ -173,7 +180,7 @@ def _diagnose(row): # get_log_likelihood and get_log_prior functions should be somewhere else -def get_log_likelihood(idata, var_name=None, single_var=True): +def get_log_likelihood_dataset(idata, var_names=None): """Retrieve the log likelihood dataarray of a given variable.""" if ( not hasattr(idata, "log_likelihood") @@ -184,21 +191,29 @@ def get_log_likelihood(idata, var_name=None, single_var=True): "Storing the log_likelihood in sample_stats groups has been deprecated", DeprecationWarning, ) - return idata.sample_stats.log_likelihood + log_lik_ds = idata.sample_stats.ds[["log_likelihood"]] if not hasattr(idata, "log_likelihood"): raise TypeError("log likelihood not found in inference data object") + log_lik_ds = idata.log_likelihood.ds + if var_names is None: + return log_lik_ds + if isinstance(var_names, Hashable): + return log_lik_ds[[var_names]] + return log_lik_ds[var_names] + + +def get_log_likelihood_dataarray(data, var_name=None): + log_lik_ds = get_log_likelihood_dataset(data) if var_name is None: - var_names = list(idata.log_likelihood.data_vars) + var_names = list(log_lik_ds.data_vars) if len(var_names) > 1: - if single_var: - raise TypeError( - f"Found several log likelihood arrays {var_names}, var_name cannot be None" - ) - return idata.log_likelihood[var_names] - return idata.log_likelihood[var_names[0]] + raise TypeError( + f"Found several log likelihood arrays {var_names}, var_name cannot be None" + ) + return log_lik_ds[var_names[0]] try: - log_likelihood = idata.log_likelihood[var_name] + log_likelihood = log_lik_ds[var_name] except KeyError as err: raise TypeError(f"No log likelihood data named {var_name} found") from err return log_likelihood @@ -209,5 +224,7 @@ def get_log_prior(idata, var_names=None): if not hasattr(idata, "log_prior"): raise TypeError("log prior not found in inference data object") if var_names is None: - var_names = list(idata.log_prior.data_vars) - return idata.log_prior.to_dataset()[var_names] + return idata.log_prior.ds + if isinstance(var_names, Hashable): + return idata.log_prior.ds[[var_names]] + return idata.log_prior.ds[var_names]