Skip to content

Commit

Permalink
move in power_scale_dataset from arviz-plots (#38)
Browse files Browse the repository at this point in the history
  • Loading branch information
aloctavodia authored Nov 7, 2024
1 parent ec68b49 commit a832d35
Showing 1 changed file with 49 additions and 0 deletions.
49 changes: 49 additions & 0 deletions src/arviz_stats/psense.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,55 @@ def _diagnose(row):
return psense_df.round(round_to)


def power_scale_dataset(dt, group, alphas, sample_dims):
"""Resample the dataset with the power scale weights.
Parameters
----------
dt : DataSet
group : str
Group to resample. Either "prior" or "likelihood"
alphas : tuple of float
Lower and upper alpha values for power scaling.
sample_dims : str or sequence of hashable
Dimensions to reduce unless mapped to an aesthetic.
Returns
-------
DataSet with resampled data.
"""
lower_w, upper_w = _get_power_scale_weights(dt, alphas, group=group, sample_dims=sample_dims)
lower_w = lower_w.values.flatten()
upper_w = upper_w.values.flatten()
s_size = len(lower_w)

idxs_to_drop = sample_dims if len(sample_dims) == 1 else ["sample"] + sample_dims
idxs_to_drop = set(idxs_to_drop).union(
[
idx
for idx in dt["posterior"].xindexes
if any(dim in dt["posterior"][idx].dims for dim in sample_dims)
]
)
resampled = [
extract(
dt,
group="posterior",
sample_dims=sample_dims,
num_samples=s_size,
weights=weights,
random_seed=42,
resampling_method="stratified",
).drop_indexes(idxs_to_drop)
for weights in (lower_w, upper_w)
]
resampled.insert(
1, extract(dt, group="posterior", sample_dims=sample_dims).drop_indexes(idxs_to_drop)
)

return xr.concat(resampled, dim="alpha").assign_coords(alpha=[alphas[0], 1, alphas[1]])


def _get_power_scale_weights(
dt, alphas=None, group=None, sample_dims=None, group_var_names=None, group_coords=None
):
Expand Down

0 comments on commit a832d35

Please sign in to comment.