Skip to content

Commit

Permalink
add priorsense related functions
Browse files Browse the repository at this point in the history
  • Loading branch information
aloctavodia authored and OriolAbril committed Oct 5, 2024
1 parent a0ff6f5 commit b882593
Show file tree
Hide file tree
Showing 7 changed files with 362 additions and 11 deletions.
3 changes: 3 additions & 0 deletions src/arviz_stats/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,8 @@
try:
from arviz_stats.utils import *
from arviz_stats.accessors import *

except ModuleNotFoundError:
pass

from arviz_stats.psens import *
12 changes: 12 additions & 0 deletions src/arviz_stats/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ def pareto_min_ss(self, dims=None):
"""Compute the minimum effective sample size on the DataArray."""
return get_function("pareto_min_ss")(self._obj, dims=dims)

def power_scale_lw(self, alpha=1, dims=None):
"""Compute log weights for power-scaling of the DataTree."""
return get_function("power_scale_lw")(self._obj, alpha=alpha, dims=dims)


@xr.register_dataset_accessor("azstats")
class AzStatsDsAccessor(_BaseAccessor):
Expand Down Expand Up @@ -227,3 +231,11 @@ def thin(self, dims=None, group="posterior", **kwargs):
def pareto_min_ss(self, dims=None, group="posterior"):
"""Compute the min sample size for all variables in a group of the DataTree."""
return self._apply("pareto_min_ss", dims=dims, group=group)

def power_scale_lw(self, dims=None, group="log_likelihood", **kwargs):
"""Compute log weights for power-scaling of the DataTree."""
return self._apply("power_scale_lw", dims=dims, group=group, **kwargs)

def power_scale_sens(self, dims=None, group="posterior", **kwargs):
"""Compute power-scaling sensitivity."""
return self._apply("power_scale_sens", dims=dims, group=group, **kwargs)
23 changes: 22 additions & 1 deletion src/arviz_stats/base/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,26 @@ def pareto_min_ss(self, ary, chain_axis=-2, draw_axis=-1):
pms_array = make_ufunc(self._pareto_min_ss, n_output=1, n_input=1, n_dims=2, ravel=False)
return pms_array(ary)

def power_scale_lw(self, ary, alpha=0, axes=-1):
"""Compute ranks of MCMC samples."""
ary, axes = process_ary_axes(ary, axes)
psl_ufunc = make_ufunc(
self._power_scale_lw,
n_output=1,
n_input=1,
n_dims=len(axes),
)
return psl_ufunc(ary, out_shape=(ary.shape[i] for i in axes), alpha=alpha)

def power_scale_sens(self, ary, lower_w, upper_w, delta, chain_axis=-2, draw_axis=-1):
"""Compute power-scaling sensitivity."""
if chain_axis is None:
ary = np.expand_dims(ary, axis=0)
chain_axis = 0
ary, _ = process_ary_axes(ary, [chain_axis, draw_axis])
pss_array = make_ufunc(self._power_scale_sens, n_output=1, n_input=1, n_dims=2, ravel=False)
return pss_array(ary, lower_w=lower_w, upper_w=upper_w, delta=delta)

def compute_ranks(self, ary, axes=-1, relative=False):
"""Compute ranks of MCMC samples."""
ary, axes = process_ary_axes(ary, axes)
Expand Down Expand Up @@ -270,7 +290,7 @@ def histogram(self, ary, bins=None, range=None, weights=None, axes=-1, density=N
)
return histogram_ufunc(ary, bins, range, shape_from_1st=True)

def kde(self, ary, axes=-1, circular=False, grid_len=512, **kwargs):
def kde(self, ary, axes=-1, circular=False, grid_len=512, weights=None, **kwargs):
"""Compute of kde on array-like inputs."""
ary, axes = process_ary_axes(ary, axes)
kde_ufunc = make_ufunc(
Expand All @@ -284,6 +304,7 @@ def kde(self, ary, axes=-1, circular=False, grid_len=512, **kwargs):
out_shape=((grid_len,), (grid_len,), ()),
grid_len=grid_len,
circular=circular,
weights=weights,
**kwargs,
)

Expand Down
29 changes: 28 additions & 1 deletion src/arviz_stats/base/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def histogram(self, da, dims=None, bins=None, range=None, weights=None, density=
)
return out

def kde(self, da, dims=None, circular=False, grid_len=512, **kwargs):
def kde(self, da, dims=None, circular=False, grid_len=512, weights=None, **kwargs):
"""Compute kde on DataArray input."""
dims = validate_dims(dims)
grid, pdf, bw = apply_ufunc(
Expand All @@ -188,6 +188,7 @@ def kde(self, da, dims=None, circular=False, grid_len=512, **kwargs):
kwargs={
"circular": circular,
"grid_len": grid_len,
"weights": weights,
"axes": np.arange(-len(dims), 0, 1),
**kwargs,
},
Expand Down Expand Up @@ -240,5 +241,31 @@ def pareto_min_ss(self, da, dims=None):
kwargs={"chain_axis": chain_axis, "draw_axis": draw_axis},
)

def power_scale_lw(self, da, alpha=0, dims=None):
"""Compute log weights for power-scaling component by alpha."""
if dims is None:
dims = rcParams["data.sample_dims"]
return apply_ufunc(
self.array_class.power_scale_lw,
da,
alpha,
input_core_dims=[dims, []],
output_core_dims=[dims],
)

def power_scale_sens(self, da, lower_w, upper_w, delta, dims=None):
"""Compute power-scaling sensitivity."""
if dims is None:
dims = rcParams["data.sample_dims"]
return apply_ufunc(
self.array_class.power_scale_sens,
da,
lower_w,
upper_w,
delta,
input_core_dims=[dims, [], [], []],
output_core_dims=[[]],
)


dataarray_stats = BaseDataArray(array_class=array_stats)
3 changes: 2 additions & 1 deletion src/arviz_stats/base/density.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,7 @@ def kde_linear(
bw_fct=1,
custom_lims=None,
cumulative=False,
weights=None,
grid_len=512,
**kwargs, # pylint: disable=unused-argument
):
Expand Down Expand Up @@ -456,7 +457,7 @@ def kde_linear(
x_min, x_max, x_std, extend_fct, grid_len, custom_lims, extend, bound_correction
)
grid_counts, grid_edges = self._histogram(
x, bins=grid_len, range=(grid_min, grid_max), density=False
x, bins=grid_len, weights=weights, range=(grid_min, grid_max), density=False
)

# Bandwidth estimation
Expand Down
90 changes: 82 additions & 8 deletions src/arviz_stats/base/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,20 @@ def _pareto_khat(self, ary, r_eff=1, tail="both", log_weights=False):

n_draws = len(ary)

n_draws_tail = self._get_ps_tails(n_draws, r_eff, tail=tail)

if tail == "both":
khat = max(
self._ps_tail(ary, n_draws, n_draws_tail, smooth_draws=False, tail=t)[1]
for t in ("left", "right")
)
else:
_, khat = self._ps_tail(ary, n_draws, n_draws_tail, smooth_draws=False, tail=tail)

return khat

@staticmethod
def _get_ps_tails(n_draws, r_eff, tail):
if n_draws > 255:
n_draws_tail = np.ceil(3 * (n_draws / r_eff) ** 0.5).astype(int)
else:
Expand All @@ -388,14 +402,7 @@ def _pareto_khat(self, ary, r_eff=1, tail="both", log_weights=False):
warnings.warn("Number of tail draws cannot be less than 5. Changing to 5")
n_draws_tail = 5

khat = max(
self._ps_tail(ary, n_draws, n_draws_tail, smooth_draws=False, tail=t)[1]
for t in ("left", "right")
)
else:
_, khat = self._ps_tail(ary, n_draws, n_draws_tail, smooth_draws=False, tail=tail)

return khat
return n_draws_tail

def _ps_tail(
self, ary, n_draws, n_draws_tail, smooth_draws=False, tail="both", log_weights=False
Expand Down Expand Up @@ -542,3 +549,70 @@ def _gpinv(probs, kappa, sigma, mu):
q = mu + sigma * np.expm1(-kappa * np.log1p(-probs)) / kappa

return q

def _power_scale_sens(self, ary, lower_w=None, upper_w=None, delta=0.01):
"""Compute power-scaling sensitivity by finite difference second derivative of CJS."""
ary = np.ravel(ary)
lower_w = np.ravel(lower_w)
upper_w = np.ravel(upper_w)
lower_cjs = max(self._cjs_dist(ary, lower_w), self._cjs_dist(-1 * ary, lower_w))
upper_cjs = max(self._cjs_dist(ary, upper_w), self._cjs_dist(-1 * ary, upper_w))
grad = (lower_cjs + upper_cjs) / (2 * np.log2(1 + delta))
return grad

def _power_scale_lw(self, ary, alpha):
"""Compute log weights for power-scaling component by alpha."""
ary = np.ravel(ary)
log_weights = (alpha - 1) * ary
n_draws = len(log_weights)
r_eff = self._ess_tail(ary, relative=True)
n_draws_tail = self._get_ps_tails(n_draws, r_eff, tail="both")
log_weights, _ = self._ps_tail(
log_weights,
n_draws,
n_draws_tail,
smooth_draws=False,
log_weights=True,
)

return log_weights

@staticmethod
def _cjs_dist(ary, weights):
"""Calculate the cumulative Jensen-Shannon distance between original and weighted draws."""
# sort draws and weights
order = np.argsort(ary)
ary = ary[order]
weights = weights[order]

binwidth = np.diff(ary)

# ecdfs
cdf_p = np.linspace(1 / len(ary), 1 - 1 / len(ary), len(ary) - 1)
cdf_q = np.cumsum(weights / np.sum(weights))[:-1]

# integrals of ecdfs
cdf_p_int = np.dot(cdf_p, binwidth)
cdf_q_int = np.dot(cdf_q, binwidth)

# cjs calculation
pq_numer = np.log2(cdf_p, out=np.zeros_like(cdf_p), where=cdf_p != 0)
qp_numer = np.log2(cdf_q, out=np.zeros_like(cdf_q), where=cdf_q != 0)

denom = 0.5 * (cdf_p + cdf_q)
denom = np.log2(denom, out=np.zeros_like(denom), where=denom != 0)

cjs_pq = np.sum(binwidth * (cdf_p * (pq_numer - denom))) + 0.5 / np.log(2) * (
cdf_q_int - cdf_p_int
)

cjs_qp = np.sum(binwidth * (cdf_q * (qp_numer - denom))) + 0.5 / np.log(2) * (
cdf_p_int - cdf_q_int
)

cjs_pq = max(0, cjs_pq)
cjs_qp = max(0, cjs_qp)

bound = cdf_p_int + cdf_q_int

return np.sqrt((cjs_pq + cjs_qp) / bound)
Loading

0 comments on commit b882593

Please sign in to comment.