Skip to content

Commit

Permalink
sort out dimensions/axes/broadcasting/batching
Browse files Browse the repository at this point in the history
  • Loading branch information
OriolAbril committed Oct 6, 2024
1 parent b882593 commit 11eab3c
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 49 deletions.
3 changes: 1 addition & 2 deletions src/arviz_stats/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
try:
from arviz_stats.utils import *
from arviz_stats.accessors import *
from arviz_stats.psens import psens, psens_summary

except ModuleNotFoundError:
pass

from arviz_stats.psens import *
22 changes: 20 additions & 2 deletions src/arviz_stats/accessors.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""ArviZ stats accessors."""

import warnings
from collections.abc import Hashable

import xarray as xr
from arviz_base.utils import _var_names
Expand Down Expand Up @@ -159,6 +160,14 @@ def pareto_min_ss(self, dims=None):
"""Compute the min sample size for all variables in the dataset."""
return self._apply("pareto_min_ss", dims=dims)

def power_scale_lw(self, dims=None, **kwargs):
"""Compute log weights for power-scaling of the DataTree."""
return self._apply("power_scale_lw", dims=dims, **kwargs)

def power_scale_sens(self, dims=None, **kwargs):
"""Compute power-scaling sensitivity."""
return self._apply("power_scale_sens", dims=dims, **kwargs)


@register_datatree_accessor("azstats")
class AzStatsDtAccessor(_BaseAccessor):
Expand All @@ -176,9 +185,11 @@ def _process_input(self, group, method):
return self._obj

def _apply(self, fun_name, dims, group, **kwargs):
if isinstance(group, str):
hashable_group = False
if isinstance(group, Hashable):
group = [group]
return DataTree.from_dict(
hashable_group = True
out_dt = DataTree.from_dict(
{
group_i: xr.Dataset(
{
Expand All @@ -189,6 +200,13 @@ def _apply(self, fun_name, dims, group, **kwargs):
for group_i in group
}
)
if hashable_group:
# if group was a string/hashable, return a datatree with a single node
# (from the provided group) as the root of the DataTree
return out_dt[group[0]]
# if group was a sequence, return a DataTree with multiple groups in the 1st level,
# as many groups as requested
return out_dt

def filter_vars(self, group="posterior", var_names=None, filter_vars=None):
"""Access and filter variables of the provided group."""
Expand Down
9 changes: 7 additions & 2 deletions src/arviz_stats/base/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,17 +149,22 @@ def power_scale_lw(self, ary, alpha=0, axes=-1):
n_output=1,
n_input=1,
n_dims=len(axes),
ravel=False,
)
return psl_ufunc(ary, out_shape=(ary.shape[i] for i in axes), alpha=alpha)

def power_scale_sens(self, ary, lower_w, upper_w, delta, chain_axis=-2, draw_axis=-1):
"""Compute power-scaling sensitivity."""
if chain_axis is None:
ary = np.expand_dims(ary, axis=0)
lower_w = np.expand_dims(lower_w, axis=0)
upper_w = np.expand_dims(upper_w, axis=0)
chain_axis = 0
ary, _ = process_ary_axes(ary, [chain_axis, draw_axis])
pss_array = make_ufunc(self._power_scale_sens, n_output=1, n_input=1, n_dims=2, ravel=False)
return pss_array(ary, lower_w=lower_w, upper_w=upper_w, delta=delta)
lower_w, _ = process_ary_axes(lower_w, [chain_axis, draw_axis])
upper_w, _ = process_ary_axes(upper_w, [chain_axis, draw_axis])
pss_array = make_ufunc(self._power_scale_sens, n_output=1, n_input=3, n_dims=2, ravel=False)
return pss_array(ary, lower_w, upper_w, delta=delta)

def compute_ranks(self, ary, axes=-1, relative=False):
"""Compute ranks of MCMC samples."""
Expand Down
16 changes: 7 additions & 9 deletions src/arviz_stats/base/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import numpy as np
from arviz_base import rcParams
from xarray import DataArray, apply_ufunc, concat
from xarray import DataArray, apply_ufunc, broadcast, concat
from xarray_einstats.stats import _apply_nonreduce_func

from arviz_stats.base.array import array_stats
Expand Down Expand Up @@ -243,28 +243,26 @@ def pareto_min_ss(self, da, dims=None):

def power_scale_lw(self, da, alpha=0, dims=None):
"""Compute log weights for power-scaling component by alpha."""
if dims is None:
dims = rcParams["data.sample_dims"]
dims = validate_dims(dims)
return apply_ufunc(
self.array_class.power_scale_lw,
da,
alpha,
input_core_dims=[dims, []],
output_core_dims=[dims],
kwargs={"axes": np.arange(-len(dims), 0, 1)},
)

def power_scale_sens(self, da, lower_w, upper_w, delta, dims=None):
"""Compute power-scaling sensitivity."""
if dims is None:
dims = rcParams["data.sample_dims"]
dims, chain_axis, draw_axis = validate_dims_chain_draw_axis(dims)
return apply_ufunc(
self.array_class.power_scale_sens,
da,
lower_w,
upper_w,
*broadcast(da, lower_w, upper_w),
delta,
input_core_dims=[dims, [], [], []],
input_core_dims=[dims, dims, dims, []],
output_core_dims=[[]],
kwargs={"chain_axis": chain_axis, "draw_axis": draw_axis},
)


Expand Down
5 changes: 3 additions & 2 deletions src/arviz_stats/base/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,7 @@ def _gpinv(probs, kappa, sigma, mu):

return q

def _power_scale_sens(self, ary, lower_w=None, upper_w=None, delta=0.01):
def _power_scale_sens(self, ary, lower_w, upper_w, delta=0.01):
"""Compute power-scaling sensitivity by finite difference second derivative of CJS."""
ary = np.ravel(ary)
lower_w = np.ravel(lower_w)
Expand All @@ -562,6 +562,7 @@ def _power_scale_sens(self, ary, lower_w=None, upper_w=None, delta=0.01):

def _power_scale_lw(self, ary, alpha):
"""Compute log weights for power-scaling component by alpha."""
shape = ary.shape
ary = np.ravel(ary)
log_weights = (alpha - 1) * ary
n_draws = len(log_weights)
Expand All @@ -575,7 +576,7 @@ def _power_scale_lw(self, ary, alpha):
log_weights=True,
)

return log_weights
return log_weights.reshape(shape)

@staticmethod
def _cjs_dist(ary, weights):
Expand Down
81 changes: 49 additions & 32 deletions src/arviz_stats/psens.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Power-scaling sensitivity diagnostics."""

import warnings
from collections.abc import Hashable
from typing import cast

import numpy as np
Expand All @@ -10,12 +11,17 @@
from arviz_base.labels import BaseLabeller
from arviz_base.sel_utils import xarray_var_iter

from arviz_stats.validate import validate_dims

labeller = BaseLabeller()

__all__ = ["psens", "psens_summary"]


def psens(
dt,
group="log_prior",
sample_dims=None,
component_var_names=None,
component_coords=None,
var_names=None,
Expand Down Expand Up @@ -75,44 +81,45 @@ def psens(
.. [1] Kallioinen et al, *Detecting and diagnosing prior and likelihood sensitivity with
power-scaling*, 2022, https://arxiv.org/abs/2107.14054
"""
dataset = extract(dt, var_names=var_names, filter_vars=filter_vars, group="posterior")
if coords is None:
dataset = extract(
dt, var_names=var_names, filter_vars=filter_vars, group="posterior", combined=False
)
sample_dims = validate_dims(sample_dims)
if coords is not None:
dataset = dataset.sel(coords)

if group == "log_likelihood":
component_draws = get_log_likelihood(dt, var_name=component_var_names, single_var=False)
component_draws = get_log_likelihood_dataset(dt, var_names=component_var_names)
elif group == "log_prior":
component_draws = get_log_prior(dt, var_names=component_var_names)
else:
raise ValueError("Value for `group` argument not recognized")

component_draws = component_draws.stack(__sample__=("chain", "draw"))
if component_coords is None:
if component_coords is not None:
component_draws = component_draws.sel(component_coords)
if isinstance(component_draws, xr.DataArray):
component_draws = component_draws.to_dataset()
if len(component_draws.dims):
component_draws = component_draws.to_stacked_array(
"latent-obs_var", sample_dims=("__sample__",)
).sum("latent-obs_var")

component_draws = component_draws.unstack()
# we stack the different variables (if any) and dimensions in each variable (if any)
# into a flat dimension "latent-obs_var", over which we sum afterwards.
# Consequently, after this component_draws draws is a dataarray with only sample_dims as dims
component_draws = component_draws.to_stacked_array(
"latent-obs_var", sample_dims=sample_dims
).sum("latent-obs_var")

# calculate lower and upper alpha values
lower_alpha = 1 / (1 + delta)
upper_alpha = 1 + delta

# calculate importance sampling weights for lower and upper alpha power-scaling
lower_w = np.exp(component_draws.azstats.power_scale_lw(alpha=lower_alpha)).values.flatten()
lower_w = lower_w / np.sum(lower_w)
lower_w = np.exp(component_draws.azstats.power_scale_lw(alpha=lower_alpha, dims=sample_dims))
lower_w = lower_w / lower_w.sum(sample_dims)

upper_w = np.exp(component_draws.azstats.power_scale_lw(alpha=upper_alpha)).values.flatten()
upper_w = upper_w / np.sum(upper_w)
upper_w = np.exp(component_draws.azstats.power_scale_lw(alpha=upper_alpha, dims=sample_dims))
upper_w = upper_w / upper_w.sum(sample_dims)

return dt.azstats.power_scale_sens(
return dataset.azstats.power_scale_sens(
lower_w=lower_w,
upper_w=upper_w,
delta=delta,
dims=sample_dims,
)


Expand All @@ -138,8 +145,8 @@ def psens_summary(data, threshold=0.05, round_to=3):
and the likelihood sensitivity is below the threshold
- "-" otherwise
"""
pssdp = psens(data, group="log_prior")["posterior"].to_dataset()
pssdl = psens(data, group="log_likelihood")["posterior"].to_dataset()
pssdp = psens(data, group="log_prior")
pssdl = psens(data, group="log_likelihood")

joined = xr.concat([pssdp, pssdl], dim="component").assign_coords(
component=["prior", "likelihood"]
Expand Down Expand Up @@ -173,7 +180,7 @@ def _diagnose(row):


# get_log_likelihood and get_log_prior functions should be somewhere else
def get_log_likelihood(idata, var_name=None, single_var=True):
def get_log_likelihood_dataset(idata, var_names=None):
"""Retrieve the log likelihood dataarray of a given variable."""
if (
not hasattr(idata, "log_likelihood")
Expand All @@ -184,21 +191,29 @@ def get_log_likelihood(idata, var_name=None, single_var=True):
"Storing the log_likelihood in sample_stats groups has been deprecated",
DeprecationWarning,
)
return idata.sample_stats.log_likelihood
log_lik_ds = idata.sample_stats.ds[["log_likelihood"]]
if not hasattr(idata, "log_likelihood"):
raise TypeError("log likelihood not found in inference data object")
log_lik_ds = idata.log_likelihood.ds
if var_names is None:
return log_lik_ds
if isinstance(var_names, Hashable):
return log_lik_ds[[var_names]]
return log_lik_ds[var_names]


def get_log_likelihood_dataarray(data, var_name=None):
log_lik_ds = get_log_likelihood_dataset(data)
if var_name is None:
var_names = list(idata.log_likelihood.data_vars)
var_names = list(log_lik_ds.data_vars)
if len(var_names) > 1:
if single_var:
raise TypeError(
f"Found several log likelihood arrays {var_names}, var_name cannot be None"
)
return idata.log_likelihood[var_names]
return idata.log_likelihood[var_names[0]]
raise TypeError(
f"Found several log likelihood arrays {var_names}, var_name cannot be None"
)
return log_lik_ds[var_names[0]]

try:
log_likelihood = idata.log_likelihood[var_name]
log_likelihood = log_lik_ds[var_name]
except KeyError as err:
raise TypeError(f"No log likelihood data named {var_name} found") from err
return log_likelihood
Expand All @@ -209,5 +224,7 @@ def get_log_prior(idata, var_names=None):
if not hasattr(idata, "log_prior"):
raise TypeError("log prior not found in inference data object")
if var_names is None:
var_names = list(idata.log_prior.data_vars)
return idata.log_prior.to_dataset()[var_names]
return idata.log_prior.ds
if isinstance(var_names, Hashable):
return idata.log_prior.ds[[var_names]]
return idata.log_prior.ds[var_names]

0 comments on commit 11eab3c

Please sign in to comment.