Skip to content

Commit

Permalink
add priorsense related functions
Browse files Browse the repository at this point in the history
  • Loading branch information
aloctavodia committed Sep 20, 2024
1 parent 686f893 commit 86f4a6d
Show file tree
Hide file tree
Showing 6 changed files with 231 additions and 8 deletions.
3 changes: 3 additions & 0 deletions src/arviz_stats/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,8 @@
try:
from arviz_stats.utils import *
from arviz_stats.accessors import *

except ModuleNotFoundError:
pass

from arviz_stats.psens import *
8 changes: 8 additions & 0 deletions src/arviz_stats/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,3 +227,11 @@ def thin(self, dims=None, group="posterior", **kwargs):
def pareto_min_ss(self, dims=None, group="posterior"):
"""Compute the min sample size for all variables in a group of the DataTree."""
return self._apply("pareto_min_ss", dims=dims, group=group)

def power_scale_lw(self, dims=None, group="log_likelihood", **kwargs):
"""Compute log weights for power-scaling of the DataTree."""
return self._apply("power_scale_lw", dims=dims, group=group, **kwargs)

def power_scale_sens(self, dims=None, group="posterior", **kwargs):
"""Compute power-scaling sensitivity."""
return self._apply("power_scale_sens", dims=dims, group=group, **kwargs)
20 changes: 20 additions & 0 deletions src/arviz_stats/base/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,26 @@ def pareto_min_ss(self, ary, chain_axis=-2, draw_axis=-1):
pms_array = make_ufunc(self._pareto_min_ss, n_output=1, n_input=1, n_dims=2, ravel=False)
return pms_array(ary)

def power_scale_lw(self, ary, alpha=0, axes=-1):
"""Compute ranks of MCMC samples."""
ary, axes = process_ary_axes(ary, axes)
psl_ufunc = make_ufunc(
self._power_scale_lw,
n_output=1,
n_input=1,
n_dims=len(axes),
)
return psl_ufunc(ary, out_shape=(ary.shape[i] for i in axes), alpha=alpha)

def power_scale_sens(self, ary, lower_w, upper_w, delta, chain_axis=-2, draw_axis=-1):
"""Compute power-scaling sensitivity."""
if chain_axis is None:
ary = np.expand_dims(ary, axis=0)
chain_axis = 0
ary, _ = process_ary_axes(ary, [chain_axis, draw_axis])
pss_array = make_ufunc(self._power_scale_sens, n_output=1, n_input=1, n_dims=2, ravel=False)
return pss_array(ary, lower_w=lower_w, upper_w=upper_w, delta=delta)

def compute_ranks(self, ary, axes=-1, relative=False):
"""Compute ranks of MCMC samples."""
ary, axes = process_ary_axes(ary, axes)
Expand Down
26 changes: 26 additions & 0 deletions src/arviz_stats/base/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,5 +251,31 @@ def pareto_min_ss(self, da, dims=None):
output_core_dims=[[]],
)

def power_scale_lw(self, da, alpha=0, dims=None):
"""Compute log weights for power-scaling component by alpha."""
if dims is None:
dims = rcParams["data.sample_dims"]
return apply_ufunc(
self.array_class.power_scale_lw,
da,
alpha,
input_core_dims=[dims, []],
output_core_dims=[dims],
)

def power_scale_sens(self, da, lower_w, upper_w, delta, dims=None):
"""Compute power-scaling sensitivity."""
if dims is None:
dims = rcParams["data.sample_dims"]
return apply_ufunc(
self.array_class.power_scale_sens,
da,
lower_w,
upper_w,
delta,
input_core_dims=[dims, [], [], []],
output_core_dims=[[]],
)


dataarray_stats = BaseDataArray(array_class=array_stats)
88 changes: 80 additions & 8 deletions src/arviz_stats/base/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,20 @@ def _pareto_khat(self, ary, r_eff=1, tail="both", log_weights=False):

n_draws = len(ary)

n_draws_tail = self._get_ps_tails(n_draws, r_eff, tail=tail)

if tail == "both":
khat = max(
self._ps_tail(ary, n_draws, n_draws_tail, smooth_draws=False, tail=t)[1]
for t in ("left", "right")
)
else:
_, khat = self._ps_tail(ary, n_draws, n_draws_tail, smooth_draws=False, tail=tail)

return khat

@staticmethod
def _get_ps_tails(n_draws, r_eff, tail):
if n_draws > 255:
n_draws_tail = np.ceil(3 * (n_draws / r_eff) ** 0.5).astype(int)
else:
Expand All @@ -388,14 +402,7 @@ def _pareto_khat(self, ary, r_eff=1, tail="both", log_weights=False):
warnings.warn("Number of tail draws cannot be less than 5. Changing to 5")
n_draws_tail = 5

khat = max(
self._ps_tail(ary, n_draws, n_draws_tail, smooth_draws=False, tail=t)[1]
for t in ("left", "right")
)
else:
_, khat = self._ps_tail(ary, n_draws, n_draws_tail, smooth_draws=False, tail=tail)

return khat
return n_draws_tail

def _ps_tail(
self, ary, n_draws, n_draws_tail, smooth_draws=False, tail="both", log_weights=False
Expand Down Expand Up @@ -542,3 +549,68 @@ def _gpinv(probs, kappa, sigma, mu):
q = mu + sigma * np.expm1(-kappa * np.log1p(-probs)) / kappa

return q

def _power_scale_sens(self, ary, lower_w=None, upper_w=None, delta=0.01):
"""Compute power-scaling sensitivity by finite difference second derivative of CJS."""
ary = np.ravel(ary)
lower_cjs = max(self._cjs_dist(ary, lower_w), self._cjs_dist(-1 * ary, lower_w))
upper_cjs = max(self._cjs_dist(ary, upper_w), self._cjs_dist(-1 * ary, upper_w))
grad = (lower_cjs + upper_cjs) / (2 * np.log2(1 + delta))

return grad

def _power_scale_lw(self, ary, alpha):
"""Compute log weights for power-scaling component by alpha."""
ary = np.ravel(ary)
log_weights = (alpha - 1) * ary
n_draws = len(log_weights)
n_draws_tail = self._get_ps_tails(n_draws, 1, tail="both")
log_weights, _ = self._ps_tail(
log_weights,
n_draws,
n_draws_tail,
smooth_draws=True, # check!
log_weights=True,
)

return log_weights

@staticmethod
def _cjs_dist(ary, weights):
"""Calculate the cumulative Jensen-Shannon distance between original and weighted draws."""
# sort draws and weights
order = np.argsort(ary)
ary = ary[order]
weights = weights[order]

binwidth = np.diff(ary)

# ecdfs
cdf_p = np.linspace(1 / len(ary), 1 - 1 / len(ary), len(ary) - 1)
cdf_q = np.cumsum(weights / np.sum(weights))[:-1]

# integrals of ecdfs
cdf_p_int = np.dot(cdf_p, binwidth)
cdf_q_int = np.dot(cdf_q, binwidth)

# cjs calculation
pq_numer = np.log2(cdf_p, out=np.zeros_like(cdf_p), where=cdf_p != 0)
qp_numer = np.log2(cdf_q, out=np.zeros_like(cdf_q), where=cdf_q != 0)

denom = 0.5 * (cdf_p + cdf_q)
denom = np.log2(denom, out=np.zeros_like(denom), where=denom != 0)

cjs_pq = np.sum(binwidth * (cdf_p * (pq_numer - denom))) + 0.5 / np.log(2) * (
cdf_q_int - cdf_p_int
)

cjs_qp = np.sum(binwidth * (cdf_q * (qp_numer - denom))) + 0.5 / np.log(2) * (
cdf_p_int - cdf_q_int
)

cjs_pq = max(0, cjs_pq)
cjs_qp = max(0, cjs_qp)

bound = cdf_p_int + cdf_q_int

return np.sqrt((cjs_pq + cjs_qp) / bound)
94 changes: 94 additions & 0 deletions src/arviz_stats/psens.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""Power-scaling sensitivity diagnostics."""

from typing import cast

import numpy as np
import pandas as pd
import xarray as xr
from arviz_base.labels import BaseLabeller
from arviz_base.sel_utils import xarray_var_iter

labeller = BaseLabeller()


def psens(dt, group="log_likelihood"):
"""
Compute power-scaling sensitivity values.
dt : DataTree
group : str
"log_likelihood" or "log_prior".
"""
# calculate lower and upper alpha values
delta = 0.1
lower_alpha = 1 / (1 + delta)
upper_alpha = 1 + delta

# calculate importance sampling weights for lower and upper alpha power-scaling
lower_w = np.exp(dt.azstats.power_scale_lw(alpha=lower_alpha, group=group))
lower_w = lower_w / np.sum(lower_w)

upper_w = np.exp(dt.azstats.power_scale_lw(alpha=upper_alpha, group=group))
upper_w = upper_w / np.sum(upper_w)

# calculate the sensitivity diagnostic based on the importance weights and draws
return dt.azstats.power_scale_sens(
lower_w=lower_w[group]["obs"].values.flatten(), # FIXME
upper_w=upper_w[group]["obs"].values.flatten(), # FIXME
delta=delta,
)


def psens_summary(data, threshold=0.05, round_to=3):
"""
Compute the prior/likelihood sensitivity based on power-scaling perturbations.
Parameters
----------
data : DataTree
threshold : float, optional
Threshold value to determine the sensitivity diagnosis. Default is 0.05.
round_to : int, optional
Number of decimal places to round the sensitivity values. Default is 3.
Returns
-------
psens_df : DataFrame
DataFrame containing the prior and likelihood sensitivity values for each variable
in the data. And a diagnosis column with the following values:
- "prior-data conflict" if both prior and likelihood sensitivity are above threshold
- "strong prior / weak likelihood" if the prior sensitivity is above threshold
and the likelihood sensitivity is below the threshold
- "-" otherwise
"""
pssdp = psens(data, group="log_prior")
pssdl = psens(data, group="log_likelihood")

joined = xr.concat([pssdp, pssdl], dim="component").assign_coords(
component=["prior", "likelihood"]
)
n_vars = np.sum([joined[var].size // 2 for var in joined.data_vars])

psens_df = pd.DataFrame(
(np.full((cast(int, n_vars), 2), np.nan)), columns=["prior", "likelihood"]
)

indices = []
for i, (var_name, sel, isel, values) in enumerate(
xarray_var_iter(joined, skip_dims={"component"})
):
psens_df.iloc[i] = values
indices.append(labeller.make_label_flat(var_name, sel, isel))
psens_df.index = indices

def _diagnose(row):
if row["prior"] >= threshold and row["likelihood"] >= threshold:
return "prior-data conflict"
if row["prior"] > threshold > row["likelihood"]:
return "strong prior / weak likelihood"

return "-"

psens_df["diagnosis"] = psens_df.apply(_diagnose, axis=1)

return psens_df.round(round_to)

0 comments on commit 86f4a6d

Please sign in to comment.