Skip to content

Commit

Permalink
draft subsampling bootstrap for mcse
Browse files Browse the repository at this point in the history
  • Loading branch information
OriolAbril committed Mar 11, 2022
1 parent 9deaa52 commit 1055d59
Showing 1 changed file with 40 additions and 2 deletions.
42 changes: 40 additions & 2 deletions arviz/stats/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def rhat(data, *, var_names=None, method="rank", dask_kwargs=None):
)


def mcse(data, *, var_names=None, method="mean", prob=None, dask_kwargs=None):
def mcse(data, *, var_names=None, method="mean", prob=None, func=None, dask_kwargs=None):
"""Calculate Markov Chain Standard Error statistic.
Parameters
Expand Down Expand Up @@ -398,6 +398,7 @@ def mcse(data, *, var_names=None, method="mean", prob=None, dask_kwargs=None):
"sd": _mcse_sd,
"median": _mcse_median,
"quantile": _mcse_quantile,
"func": _mcse_func_sbm,
}
if method not in methods:
raise TypeError(
Expand All @@ -410,6 +411,9 @@ def mcse(data, *, var_names=None, method="mean", prob=None, dask_kwargs=None):
if method == "quantile" and prob is None:
raise TypeError("Quantile (prob) information needs to be defined.")

if method == "func" and func is None:
raise TypeError("func argument needs to be defined.")

if isinstance(data, np.ndarray):
data = np.atleast_2d(data)
if len(data.shape) < 3:
Expand All @@ -430,7 +434,11 @@ def mcse(data, *, var_names=None, method="mean", prob=None, dask_kwargs=None):
dataset = dataset if var_names is None else dataset[var_names]

ufunc_kwargs = {"ravel": False}
func_kwargs = {} if prob is None else {"prob": prob}
func_kwargs = {}
if prob is not None:
func_kwargs["prob"] = prob
elif func is not None:
func_kwargs["func"] = func
return _wrap_xarray_ufunc(
mcse_func,
dataset,
Expand Down Expand Up @@ -822,6 +830,36 @@ def _mcse_mean(ary):
return mcse_mean_value


def _mcse_func_sbm(ary, func):
"""Compute the Markov Chain error on an arbitrary function."""
ary = np.asarray(ary)
if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
return np.nan
ess = _ess_mean(ary)
func_estimate_sd = _sbm(ary, func)
mcse_func_value = func_estimate_sd / np.sqrt(ess)
return mcse_func_value

def _sbm(ary, func):
"""Subsampling bootstrap method.
References
----------
.. [1] Doss, Charles R., et al. "Markov chain Monte Carlo estimation of quantiles."
*Electronic Journal of Statistics* 8.2 (2014): 2448-2478.
https://doi.org/10.1214/14-EJS957
"""
flat_ary = np.ravel(ary)
n = len(flat_ary)
b = int(np.sqrt(n))
func_estimates = np.empty(n-b)
for i in range(n-b):
sub_ary = flat_ary[i:i+b]
func_estimates[i] = func(sub_ary)
func_estimate_sd = np.sqrt(b * np.var(func_estimates, ddof=0))
return func_estimate_sd

def _mcse_sd(ary):
"""Compute the Markov Chain sd error."""
_numba_flag = Numba.numba_flag
Expand Down

0 comments on commit 1055d59

Please sign in to comment.