From b8825930641c395825a1a15331e7f6bfd250866e Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Wed, 18 Sep 2024 16:07:53 -0300 Subject: [PATCH] add priorsense related functions --- src/arviz_stats/__init__.py | 3 + src/arviz_stats/accessors.py | 12 ++ src/arviz_stats/base/array.py | 23 ++- src/arviz_stats/base/dataarray.py | 29 +++- src/arviz_stats/base/density.py | 3 +- src/arviz_stats/base/diagnostics.py | 90 ++++++++++-- src/arviz_stats/psens.py | 213 ++++++++++++++++++++++++++++ 7 files changed, 362 insertions(+), 11 deletions(-) create mode 100644 src/arviz_stats/psens.py diff --git a/src/arviz_stats/__init__.py b/src/arviz_stats/__init__.py index 90e6214..54c1bcc 100644 --- a/src/arviz_stats/__init__.py +++ b/src/arviz_stats/__init__.py @@ -4,5 +4,8 @@ try: from arviz_stats.utils import * from arviz_stats.accessors import * + except ModuleNotFoundError: pass + +from arviz_stats.psens import * diff --git a/src/arviz_stats/accessors.py b/src/arviz_stats/accessors.py index e0272b7..89ee83f 100644 --- a/src/arviz_stats/accessors.py +++ b/src/arviz_stats/accessors.py @@ -59,6 +59,10 @@ def pareto_min_ss(self, dims=None): """Compute the minimum effective sample size on the DataArray.""" return get_function("pareto_min_ss")(self._obj, dims=dims) + def power_scale_lw(self, alpha=1, dims=None): + """Compute log weights for power-scaling of the DataTree.""" + return get_function("power_scale_lw")(self._obj, alpha=alpha, dims=dims) + @xr.register_dataset_accessor("azstats") class AzStatsDsAccessor(_BaseAccessor): @@ -227,3 +231,11 @@ def thin(self, dims=None, group="posterior", **kwargs): def pareto_min_ss(self, dims=None, group="posterior"): """Compute the min sample size for all variables in a group of the DataTree.""" return self._apply("pareto_min_ss", dims=dims, group=group) + + def power_scale_lw(self, dims=None, group="log_likelihood", **kwargs): + """Compute log weights for power-scaling of the DataTree.""" + return self._apply("power_scale_lw", dims=dims, group=group, **kwargs) + + def power_scale_sens(self, dims=None, group="posterior", **kwargs): + """Compute power-scaling sensitivity.""" + return self._apply("power_scale_sens", dims=dims, group=group, **kwargs) diff --git a/src/arviz_stats/base/array.py b/src/arviz_stats/base/array.py index 050d322..7f2b580 100644 --- a/src/arviz_stats/base/array.py +++ b/src/arviz_stats/base/array.py @@ -141,6 +141,26 @@ def pareto_min_ss(self, ary, chain_axis=-2, draw_axis=-1): pms_array = make_ufunc(self._pareto_min_ss, n_output=1, n_input=1, n_dims=2, ravel=False) return pms_array(ary) + def power_scale_lw(self, ary, alpha=0, axes=-1): + """Compute ranks of MCMC samples.""" + ary, axes = process_ary_axes(ary, axes) + psl_ufunc = make_ufunc( + self._power_scale_lw, + n_output=1, + n_input=1, + n_dims=len(axes), + ) + return psl_ufunc(ary, out_shape=(ary.shape[i] for i in axes), alpha=alpha) + + def power_scale_sens(self, ary, lower_w, upper_w, delta, chain_axis=-2, draw_axis=-1): + """Compute power-scaling sensitivity.""" + if chain_axis is None: + ary = np.expand_dims(ary, axis=0) + chain_axis = 0 + ary, _ = process_ary_axes(ary, [chain_axis, draw_axis]) + pss_array = make_ufunc(self._power_scale_sens, n_output=1, n_input=1, n_dims=2, ravel=False) + return pss_array(ary, lower_w=lower_w, upper_w=upper_w, delta=delta) + def compute_ranks(self, ary, axes=-1, relative=False): """Compute ranks of MCMC samples.""" ary, axes = process_ary_axes(ary, axes) @@ -270,7 +290,7 @@ def histogram(self, ary, bins=None, range=None, weights=None, axes=-1, density=N ) return histogram_ufunc(ary, bins, range, shape_from_1st=True) - def kde(self, ary, axes=-1, circular=False, grid_len=512, **kwargs): + def kde(self, ary, axes=-1, circular=False, grid_len=512, weights=None, **kwargs): """Compute of kde on array-like inputs.""" ary, axes = process_ary_axes(ary, axes) kde_ufunc = make_ufunc( @@ -284,6 +304,7 @@ def kde(self, ary, axes=-1, circular=False, grid_len=512, **kwargs): out_shape=((grid_len,), (grid_len,), ()), grid_len=grid_len, circular=circular, + weights=weights, **kwargs, ) diff --git a/src/arviz_stats/base/dataarray.py b/src/arviz_stats/base/dataarray.py index 4be1071..7fe8cb2 100644 --- a/src/arviz_stats/base/dataarray.py +++ b/src/arviz_stats/base/dataarray.py @@ -179,7 +179,7 @@ def histogram(self, da, dims=None, bins=None, range=None, weights=None, density= ) return out - def kde(self, da, dims=None, circular=False, grid_len=512, **kwargs): + def kde(self, da, dims=None, circular=False, grid_len=512, weights=None, **kwargs): """Compute kde on DataArray input.""" dims = validate_dims(dims) grid, pdf, bw = apply_ufunc( @@ -188,6 +188,7 @@ def kde(self, da, dims=None, circular=False, grid_len=512, **kwargs): kwargs={ "circular": circular, "grid_len": grid_len, + "weights": weights, "axes": np.arange(-len(dims), 0, 1), **kwargs, }, @@ -240,5 +241,31 @@ def pareto_min_ss(self, da, dims=None): kwargs={"chain_axis": chain_axis, "draw_axis": draw_axis}, ) + def power_scale_lw(self, da, alpha=0, dims=None): + """Compute log weights for power-scaling component by alpha.""" + if dims is None: + dims = rcParams["data.sample_dims"] + return apply_ufunc( + self.array_class.power_scale_lw, + da, + alpha, + input_core_dims=[dims, []], + output_core_dims=[dims], + ) + + def power_scale_sens(self, da, lower_w, upper_w, delta, dims=None): + """Compute power-scaling sensitivity.""" + if dims is None: + dims = rcParams["data.sample_dims"] + return apply_ufunc( + self.array_class.power_scale_sens, + da, + lower_w, + upper_w, + delta, + input_core_dims=[dims, [], [], []], + output_core_dims=[[]], + ) + dataarray_stats = BaseDataArray(array_class=array_stats) diff --git a/src/arviz_stats/base/density.py b/src/arviz_stats/base/density.py index fed1107..a6793ba 100644 --- a/src/arviz_stats/base/density.py +++ b/src/arviz_stats/base/density.py @@ -389,6 +389,7 @@ def kde_linear( bw_fct=1, custom_lims=None, cumulative=False, + weights=None, grid_len=512, **kwargs, # pylint: disable=unused-argument ): @@ -456,7 +457,7 @@ def kde_linear( x_min, x_max, x_std, extend_fct, grid_len, custom_lims, extend, bound_correction ) grid_counts, grid_edges = self._histogram( - x, bins=grid_len, range=(grid_min, grid_max), density=False + x, bins=grid_len, weights=weights, range=(grid_min, grid_max), density=False ) # Bandwidth estimation diff --git a/src/arviz_stats/base/diagnostics.py b/src/arviz_stats/base/diagnostics.py index da71b52..b4ce603 100644 --- a/src/arviz_stats/base/diagnostics.py +++ b/src/arviz_stats/base/diagnostics.py @@ -370,6 +370,20 @@ def _pareto_khat(self, ary, r_eff=1, tail="both", log_weights=False): n_draws = len(ary) + n_draws_tail = self._get_ps_tails(n_draws, r_eff, tail=tail) + + if tail == "both": + khat = max( + self._ps_tail(ary, n_draws, n_draws_tail, smooth_draws=False, tail=t)[1] + for t in ("left", "right") + ) + else: + _, khat = self._ps_tail(ary, n_draws, n_draws_tail, smooth_draws=False, tail=tail) + + return khat + + @staticmethod + def _get_ps_tails(n_draws, r_eff, tail): if n_draws > 255: n_draws_tail = np.ceil(3 * (n_draws / r_eff) ** 0.5).astype(int) else: @@ -388,14 +402,7 @@ def _pareto_khat(self, ary, r_eff=1, tail="both", log_weights=False): warnings.warn("Number of tail draws cannot be less than 5. Changing to 5") n_draws_tail = 5 - khat = max( - self._ps_tail(ary, n_draws, n_draws_tail, smooth_draws=False, tail=t)[1] - for t in ("left", "right") - ) - else: - _, khat = self._ps_tail(ary, n_draws, n_draws_tail, smooth_draws=False, tail=tail) - - return khat + return n_draws_tail def _ps_tail( self, ary, n_draws, n_draws_tail, smooth_draws=False, tail="both", log_weights=False @@ -542,3 +549,70 @@ def _gpinv(probs, kappa, sigma, mu): q = mu + sigma * np.expm1(-kappa * np.log1p(-probs)) / kappa return q + + def _power_scale_sens(self, ary, lower_w=None, upper_w=None, delta=0.01): + """Compute power-scaling sensitivity by finite difference second derivative of CJS.""" + ary = np.ravel(ary) + lower_w = np.ravel(lower_w) + upper_w = np.ravel(upper_w) + lower_cjs = max(self._cjs_dist(ary, lower_w), self._cjs_dist(-1 * ary, lower_w)) + upper_cjs = max(self._cjs_dist(ary, upper_w), self._cjs_dist(-1 * ary, upper_w)) + grad = (lower_cjs + upper_cjs) / (2 * np.log2(1 + delta)) + return grad + + def _power_scale_lw(self, ary, alpha): + """Compute log weights for power-scaling component by alpha.""" + ary = np.ravel(ary) + log_weights = (alpha - 1) * ary + n_draws = len(log_weights) + r_eff = self._ess_tail(ary, relative=True) + n_draws_tail = self._get_ps_tails(n_draws, r_eff, tail="both") + log_weights, _ = self._ps_tail( + log_weights, + n_draws, + n_draws_tail, + smooth_draws=False, + log_weights=True, + ) + + return log_weights + + @staticmethod + def _cjs_dist(ary, weights): + """Calculate the cumulative Jensen-Shannon distance between original and weighted draws.""" + # sort draws and weights + order = np.argsort(ary) + ary = ary[order] + weights = weights[order] + + binwidth = np.diff(ary) + + # ecdfs + cdf_p = np.linspace(1 / len(ary), 1 - 1 / len(ary), len(ary) - 1) + cdf_q = np.cumsum(weights / np.sum(weights))[:-1] + + # integrals of ecdfs + cdf_p_int = np.dot(cdf_p, binwidth) + cdf_q_int = np.dot(cdf_q, binwidth) + + # cjs calculation + pq_numer = np.log2(cdf_p, out=np.zeros_like(cdf_p), where=cdf_p != 0) + qp_numer = np.log2(cdf_q, out=np.zeros_like(cdf_q), where=cdf_q != 0) + + denom = 0.5 * (cdf_p + cdf_q) + denom = np.log2(denom, out=np.zeros_like(denom), where=denom != 0) + + cjs_pq = np.sum(binwidth * (cdf_p * (pq_numer - denom))) + 0.5 / np.log(2) * ( + cdf_q_int - cdf_p_int + ) + + cjs_qp = np.sum(binwidth * (cdf_q * (qp_numer - denom))) + 0.5 / np.log(2) * ( + cdf_p_int - cdf_q_int + ) + + cjs_pq = max(0, cjs_pq) + cjs_qp = max(0, cjs_qp) + + bound = cdf_p_int + cdf_q_int + + return np.sqrt((cjs_pq + cjs_qp) / bound) diff --git a/src/arviz_stats/psens.py b/src/arviz_stats/psens.py new file mode 100644 index 0000000..abb57d7 --- /dev/null +++ b/src/arviz_stats/psens.py @@ -0,0 +1,213 @@ +"""Power-scaling sensitivity diagnostics.""" + +import warnings +from typing import cast + +import numpy as np +import pandas as pd +import xarray as xr +from arviz_base import extract +from arviz_base.labels import BaseLabeller +from arviz_base.sel_utils import xarray_var_iter + +labeller = BaseLabeller() + + +def psens( + dt, + group="log_prior", + component_var_names=None, + component_coords=None, + var_names=None, + coords=None, + filter_vars=None, + delta=0.01, +): + """ + Compute power-scaling sensitivity values. + + Parameters + ---------- + data : obj + Any object that can be converted to an :class:`arviz.InferenceData` object. + Refer to documentation of :func:`arviz.convert_to_dataset` for details. + For ndarray: shape = (chain, draw). + For n-dimensional ndarray transform first to dataset with ``az.convert_to_dataset``. + group : {"log_prior", "log_likelihood"}, default "log_prior" + When `component` is "likelihood", the log likelihood values are retrieved + from the ``log_likelihood`` group as pointwise log likelihood and added + together. With "prior", the log prior values are retrieved from the + ``log_prior`` group. + component_var_names : str, optional + Name of the prior or log likelihood variables to use + component_coords : dict, optional + Coordinates defining a subset over the component element for which to + compute the prior sensitivity diagnostic. + var_names : list of str, optional + Names of posterior variables to include in the power scaling sensitivity diagnostic + coords : dict, optional + Coordinates defining a subset over the posterior. Only these variables will + be used when computing the prior sensitivity. + filter_vars: {None, "like", "regex"}, default None + If ``None`` (default), interpret var_names as the real variables names. + If "like", interpret var_names as substrings of the real variables names. + If "regex", interpret var_names as regular expressions on the real variables names. + delta : float + Value for finite difference derivative calculation. + + + Returns + ------- + xarray.DataTree + Returns dataTree of power-scaling sensitivity diagnostic values. + Higher sensitivity values indicate greater sensitivity. + Prior sensitivity above 0.05 indicates informative prior. + Likelihood sensitivity below 0.05 indicates weak or non-informative likelihood. + + Notes + ----- + The diagnostic is computed by power-scaling the specified component (prior or likelihood) + and determining the degree to which the posterior changes as described in [1]_. + It uses Pareto-smoothed importance sampling to avoid refitting the model. + + References + ---------- + .. [1] Kallioinen et al, *Detecting and diagnosing prior and likelihood sensitivity with + power-scaling*, 2022, https://arxiv.org/abs/2107.14054 + """ + dataset = extract(dt, var_names=var_names, filter_vars=filter_vars, group="posterior") + if coords is None: + dataset = dataset.sel(coords) + + if group == "log_likelihood": + component_draws = get_log_likelihood(dt, var_name=component_var_names, single_var=False) + elif group == "log_prior": + component_draws = get_log_prior(dt, var_names=component_var_names) + else: + raise ValueError("Value for `group` argument not recognized") + + component_draws = component_draws.stack(__sample__=("chain", "draw")) + if component_coords is None: + component_draws = component_draws.sel(component_coords) + if isinstance(component_draws, xr.DataArray): + component_draws = component_draws.to_dataset() + if len(component_draws.dims): + component_draws = component_draws.to_stacked_array( + "latent-obs_var", sample_dims=("__sample__",) + ).sum("latent-obs_var") + + component_draws = component_draws.unstack() + + # calculate lower and upper alpha values + lower_alpha = 1 / (1 + delta) + upper_alpha = 1 + delta + + # calculate importance sampling weights for lower and upper alpha power-scaling + lower_w = np.exp(component_draws.azstats.power_scale_lw(alpha=lower_alpha)).values.flatten() + lower_w = lower_w / np.sum(lower_w) + + upper_w = np.exp(component_draws.azstats.power_scale_lw(alpha=upper_alpha)).values.flatten() + upper_w = upper_w / np.sum(upper_w) + + return dt.azstats.power_scale_sens( + lower_w=lower_w, + upper_w=upper_w, + delta=delta, + ) + + +def psens_summary(data, threshold=0.05, round_to=3): + """ + Compute the prior/likelihood sensitivity based on power-scaling perturbations. + + Parameters + ---------- + data : DataTree + threshold : float, optional + Threshold value to determine the sensitivity diagnosis. Default is 0.05. + round_to : int, optional + Number of decimal places to round the sensitivity values. Default is 3. + + Returns + ------- + psens_df : DataFrame + DataFrame containing the prior and likelihood sensitivity values for each variable + in the data. And a diagnosis column with the following values: + - "prior-data conflict" if both prior and likelihood sensitivity are above threshold + - "strong prior / weak likelihood" if the prior sensitivity is above threshold + and the likelihood sensitivity is below the threshold + - "-" otherwise + """ + pssdp = psens(data, group="log_prior")["posterior"].to_dataset() + pssdl = psens(data, group="log_likelihood")["posterior"].to_dataset() + + joined = xr.concat([pssdp, pssdl], dim="component").assign_coords( + component=["prior", "likelihood"] + ) + + n_vars = np.sum([joined[var].size // 2 for var in joined.data_vars]) + + psens_df = pd.DataFrame( + (np.full((cast(int, n_vars), 2), np.nan)), columns=["prior", "likelihood"] + ) + + indices = [] + for i, (var_name, sel, isel, values) in enumerate( + xarray_var_iter(joined, skip_dims={"component"}) + ): + psens_df.iloc[i] = values + indices.append(labeller.make_label_flat(var_name, sel, isel)) + psens_df.index = indices + + def _diagnose(row): + if row["prior"] >= threshold and row["likelihood"] >= threshold: + return "prior-data conflict" + if row["prior"] > threshold > row["likelihood"]: + return "strong prior / weak likelihood" + + return "-" + + psens_df["diagnosis"] = psens_df.apply(_diagnose, axis=1) + + return psens_df.round(round_to) + + +# get_log_likelihood and get_log_prior functions should be somewhere else +def get_log_likelihood(idata, var_name=None, single_var=True): + """Retrieve the log likelihood dataarray of a given variable.""" + if ( + not hasattr(idata, "log_likelihood") + and hasattr(idata, "sample_stats") + and hasattr(idata.sample_stats, "log_likelihood") + ): + warnings.warn( + "Storing the log_likelihood in sample_stats groups has been deprecated", + DeprecationWarning, + ) + return idata.sample_stats.log_likelihood + if not hasattr(idata, "log_likelihood"): + raise TypeError("log likelihood not found in inference data object") + if var_name is None: + var_names = list(idata.log_likelihood.data_vars) + if len(var_names) > 1: + if single_var: + raise TypeError( + f"Found several log likelihood arrays {var_names}, var_name cannot be None" + ) + return idata.log_likelihood[var_names] + return idata.log_likelihood[var_names[0]] + + try: + log_likelihood = idata.log_likelihood[var_name] + except KeyError as err: + raise TypeError(f"No log likelihood data named {var_name} found") from err + return log_likelihood + + +def get_log_prior(idata, var_names=None): + """Retrieve the log prior dataarray of a given variable.""" + if not hasattr(idata, "log_prior"): + raise TypeError("log prior not found in inference data object") + if var_names is None: + var_names = list(idata.log_prior.data_vars) + return idata.log_prior.to_dataset()[var_names]