-
-
Notifications
You must be signed in to change notification settings - Fork 410
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] draft subsampling bootstrap for mcse #1974
base: main
Are you sure you want to change the base?
Changes from all commits
f64d942
7373179
f123d25
697d9a5
00bdf7d
13dd989
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -341,40 +341,70 @@ def rhat(data, *, var_names=None, method="rank", dask_kwargs=None): | |
) | ||
|
||
|
||
def mcse(data, *, var_names=None, method="mean", prob=None, dask_kwargs=None): | ||
"""Calculate Markov Chain Standard Error statistic. | ||
def mcse( | ||
data, | ||
*, | ||
var_names=None, | ||
method="mean", | ||
prob=None, | ||
func=None, | ||
mcse_kwargs=None, | ||
func_kwargs=None, | ||
dask_kwargs=None, | ||
): | ||
r"""Calculate Markov Chain Standard Error statistic. | ||
|
||
Parameters | ||
---------- | ||
data : obj | ||
data : InferenceData-like or 2D array-like | ||
Any object that can be converted to an :class:`arviz.InferenceData` object | ||
Refer to documentation of :func:`arviz.convert_to_dataset` for details | ||
For ndarray: shape = (chain, draw). | ||
For n-dimensional ndarray transform first to dataset with ``az.convert_to_dataset``. | ||
var_names : list | ||
var_names : list of str, optional | ||
Names of variables to include in the rhat report | ||
method : str | ||
Select mcse method. Valid methods are: | ||
method : {'mean', 'sd', 'median', 'quantile', 'func'}, optional | ||
The method to use when estimating the MCSE. | ||
- "mean" | ||
- "sd" | ||
- "median" | ||
- "quantile" | ||
- "func" | ||
|
||
prob : float | ||
Methods "mean", "sd", "median" and "quantile" are described in [1]_. | ||
|
||
prob : float, optional | ||
Quantile information. | ||
func : callable, optional | ||
Summary function whose MCSE should be calculated. Only used whem | ||
method is "func". | ||
TODO: add call signature info, something like ``func(ary, **func_kwargs)`` | ||
func_kwargs : dict, optional | ||
Keyword arguments passed to *func* when calling it. | ||
dask_kwargs : dict, optional | ||
Dask related kwargs passed to :func:`~arviz.wrap_xarray_ufunc`. | ||
|
||
Returns | ||
------- | ||
xarray.Dataset | ||
Return the msce dataset | ||
Dataset with the MCSE results | ||
|
||
Other Parameters | ||
---------------- | ||
mcse_kwargs : dict, optional | ||
Extra keyword arguments passed to the MCSE estimation method. | ||
|
||
See Also | ||
-------- | ||
ess : Compute autocovariance estimates for every lag for the input array. | ||
summary : Create a data frame with summary statistics. | ||
plot_mcse : Plot quantile or local Monte Carlo Standard Error. | ||
ess : Compute autocovariance estimates for every lag for the input array. | ||
|
||
References | ||
---------- | ||
.. [1] Vehtari, Aki, et al. "Rank-normalization, folding, and localization: an improved | ||
$\hat{R}$ for assessing convergence of MCMC (with discussion)." | ||
Bayesian analysis 16.2 (2021): 667-718. https://doi.org/10.1214/20-BA1221 | ||
|
||
Examples | ||
-------- | ||
|
@@ -398,6 +428,7 @@ def mcse(data, *, var_names=None, method="mean", prob=None, dask_kwargs=None): | |
"sd": _mcse_sd, | ||
"median": _mcse_median, | ||
"quantile": _mcse_quantile, | ||
"func": _mcse_func_sbm, | ||
} | ||
if method not in methods: | ||
raise TypeError( | ||
|
@@ -410,32 +441,44 @@ def mcse(data, *, var_names=None, method="mean", prob=None, dask_kwargs=None): | |
if method == "quantile" and prob is None: | ||
raise TypeError("Quantile (prob) information needs to be defined.") | ||
|
||
if method == "func" and func is None: | ||
raise TypeError("func argument needs to be defined.") | ||
|
||
mcse_kwargs = {} if mcse_kwargs is None else mcse_kwargs | ||
if prob is not None: | ||
mcse_kwargs.setdefault("prob", prob) | ||
elif func is not None: | ||
mcse_kwargs.setdefault("func", func) | ||
mcse_kwargs.setdefault("func_kwargs", func_kwargs) | ||
|
||
if isinstance(data, np.ndarray): | ||
data = np.atleast_2d(data) | ||
if len(data.shape) < 3: | ||
if prob is not None: | ||
return mcse_func(data, prob=prob) # pylint: disable=unexpected-keyword-arg | ||
|
||
return mcse_func(data) | ||
|
||
msg = ( | ||
"Only uni-dimensional ndarray variables are supported." | ||
" Please transform first to dataset with `az.convert_to_dataset`." | ||
) | ||
raise TypeError(msg) | ||
if data.size < 1000 and method == "func": | ||
warnings.warn( | ||
"Not enough samples for reliable estimate of MCSE for arbitrary functions" | ||
) | ||
return mcse_func(data, **mcse_kwargs) | ||
else: | ||
msg = ( | ||
"Only uni-dimensional ndarray variables are supported." | ||
" Please transform first to dataset with `az.convert_to_dataset`." | ||
) | ||
raise TypeError(msg) | ||
|
||
dataset = convert_to_dataset(data, group="posterior") | ||
if (dataset.dims["chain"] * dataset.dims["draw"]) < 1000 and method == "func": | ||
warnings.warn("Not enough samples for reliable estimate of MCSE for arbitrary functions") | ||
var_names = _var_names(var_names, dataset) | ||
|
||
dataset = dataset if var_names is None else dataset[var_names] | ||
|
||
ufunc_kwargs = {"ravel": False} | ||
func_kwargs = {} if prob is None else {"prob": prob} | ||
return _wrap_xarray_ufunc( | ||
mcse_func, | ||
dataset, | ||
ufunc_kwargs=ufunc_kwargs, | ||
func_kwargs=func_kwargs, | ||
func_kwargs=mcse_kwargs, | ||
dask_kwargs=dask_kwargs, | ||
) | ||
|
||
|
@@ -813,13 +856,48 @@ def _mcse_mean(ary): | |
return np.nan | ||
ess = _ess_mean(ary) | ||
if _numba_flag: | ||
sd = _sqrt(svar(np.ravel(ary), ddof=1), np.zeros(1)) | ||
sd = _sqrt(svar(np.ravel(ary), ddof=1), 0) | ||
else: | ||
sd = np.std(ary, ddof=1) | ||
mcse_mean_value = sd / np.sqrt(ess) | ||
return mcse_mean_value | ||
|
||
|
||
def _mcse_func_sbm(ary, func, b=None, var_func=np.var, func_kwargs=None): | ||
"""Compute the Markov Chain error on an arbitrary function.""" | ||
ary = np.asarray(ary) | ||
if _not_valid(ary, shape_kwargs=dict(min_draws=10, min_chains=1)): | ||
return np.nan | ||
n = ary.size | ||
if b is None: | ||
b = int(np.sqrt(n)) | ||
if func_kwargs is None: | ||
func_kwargs = {} | ||
func_estimate_sd = _sbm(ary, func, b=b, var_func=var_func, func_kwargs=func_kwargs) | ||
mcse_func_value = func_estimate_sd / np.sqrt(n) | ||
return mcse_func_value | ||
|
||
|
||
def _sbm(ary, func, b, var_func, func_kwargs): | ||
"""Subsampling bootstrap method. | ||
|
||
References | ||
---------- | ||
.. [1] Doss, Charles R., et al. "Markov chain Monte Carlo estimation of quantiles." | ||
*Electronic Journal of Statistics* 8.2 (2014): 2448-2478. | ||
https://doi.org/10.1214/14-EJS957 | ||
|
||
""" | ||
flat_ary = np.ravel(ary) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this sensitive for which order the ravel is done? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it technically is, but there should be no difference (hopefully) if the model has converged. It is also not clear to me how should multiple chains be handled when implementing this algorithm, I started with this flatten approach but I can test a couple options. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @OriolAbril have you tested alternative approaches for handling multiple chains yet? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I benchmarked 3 different approaches for estimating the
In all cases SBM underestimates the MCSE; this is particularly severe when autocorrelation is high and sample sizes are low. |
||
n = len(flat_ary) | ||
func_estimates = np.empty(n - b) | ||
for i in range(n - b): | ||
sub_ary = flat_ary[i : i + b] | ||
func_estimates[i] = func(sub_ary, **func_kwargs) | ||
func_estimate_sd = np.sqrt(b * var_func(func_estimates)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we should probably decide API-wise if we want to keep this or instead move to |
||
return func_estimate_sd | ||
|
||
|
||
def _mcse_sd(ary): | ||
"""Compute the Markov Chain sd error.""" | ||
_numba_flag = Numba.numba_flag | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we could also consider allowing some strings here. e.g. using
"circmean"
expands tostats.circmean
asfunc
here and also fills themcse_kwargs
with{"var_func": stats.circvar}