From a832d35ff2cdcc3824385efaae82c70370111873 Mon Sep 17 00:00:00 2001 From: Osvaldo A Martin Date: Thu, 7 Nov 2024 10:17:44 -0300 Subject: [PATCH] move in power_scale_dataset from arviz-plots (#38) --- src/arviz_stats/psense.py | 49 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/src/arviz_stats/psense.py b/src/arviz_stats/psense.py index 88479ff..18e347f 100644 --- a/src/arviz_stats/psense.py +++ b/src/arviz_stats/psense.py @@ -227,6 +227,55 @@ def _diagnose(row): return psense_df.round(round_to) +def power_scale_dataset(dt, group, alphas, sample_dims): + """Resample the dataset with the power scale weights. + + Parameters + ---------- + dt : DataSet + group : str + Group to resample. Either "prior" or "likelihood" + alphas : tuple of float + Lower and upper alpha values for power scaling. + sample_dims : str or sequence of hashable + Dimensions to reduce unless mapped to an aesthetic. + + Returns + ------- + DataSet with resampled data. + """ + lower_w, upper_w = _get_power_scale_weights(dt, alphas, group=group, sample_dims=sample_dims) + lower_w = lower_w.values.flatten() + upper_w = upper_w.values.flatten() + s_size = len(lower_w) + + idxs_to_drop = sample_dims if len(sample_dims) == 1 else ["sample"] + sample_dims + idxs_to_drop = set(idxs_to_drop).union( + [ + idx + for idx in dt["posterior"].xindexes + if any(dim in dt["posterior"][idx].dims for dim in sample_dims) + ] + ) + resampled = [ + extract( + dt, + group="posterior", + sample_dims=sample_dims, + num_samples=s_size, + weights=weights, + random_seed=42, + resampling_method="stratified", + ).drop_indexes(idxs_to_drop) + for weights in (lower_w, upper_w) + ] + resampled.insert( + 1, extract(dt, group="posterior", sample_dims=sample_dims).drop_indexes(idxs_to_drop) + ) + + return xr.concat(resampled, dim="alpha").assign_coords(alpha=[alphas[0], 1, alphas[1]]) + + def _get_power_scale_weights( dt, alphas=None, group=None, sample_dims=None, group_var_names=None, group_coords=None ):