diff --git a/arviz/stats/diagnostics.py b/arviz/stats/diagnostics.py index cace5a2ec6..5f3bcd4e40 100644 --- a/arviz/stats/diagnostics.py +++ b/arviz/stats/diagnostics.py @@ -341,40 +341,70 @@ def rhat(data, *, var_names=None, method="rank", dask_kwargs=None): ) -def mcse(data, *, var_names=None, method="mean", prob=None, dask_kwargs=None): - """Calculate Markov Chain Standard Error statistic. +def mcse( + data, + *, + var_names=None, + method="mean", + prob=None, + func=None, + mcse_kwargs=None, + func_kwargs=None, + dask_kwargs=None, +): + r"""Calculate Markov Chain Standard Error statistic. Parameters ---------- - data : obj + data : InferenceData-like or 2D array-like Any object that can be converted to an :class:`arviz.InferenceData` object Refer to documentation of :func:`arviz.convert_to_dataset` for details For ndarray: shape = (chain, draw). For n-dimensional ndarray transform first to dataset with ``az.convert_to_dataset``. - var_names : list + var_names : list of str, optional Names of variables to include in the rhat report - method : str - Select mcse method. Valid methods are: + method : {'mean', 'sd', 'median', 'quantile', 'func'}, optional + The method to use when estimating the MCSE. - "mean" - "sd" - "median" - "quantile" + - "func" - prob : float + Methods "mean", "sd", "median" and "quantile" are described in [1]_. + + prob : float, optional Quantile information. + func : callable, optional + Summary function whose MCSE should be calculated. Only used whem + method is "func". + TODO: add call signature info, something like ``func(ary, **func_kwargs)`` + func_kwargs : dict, optional + Keyword arguments passed to *func* when calling it. dask_kwargs : dict, optional Dask related kwargs passed to :func:`~arviz.wrap_xarray_ufunc`. Returns ------- xarray.Dataset - Return the msce dataset + Dataset with the MCSE results + + Other Parameters + ---------------- + mcse_kwargs : dict, optional + Extra keyword arguments passed to the MCSE estimation method. See Also -------- - ess : Compute autocovariance estimates for every lag for the input array. summary : Create a data frame with summary statistics. plot_mcse : Plot quantile or local Monte Carlo Standard Error. + ess : Compute autocovariance estimates for every lag for the input array. + + References + ---------- + .. [1] Vehtari, Aki, et al. "Rank-normalization, folding, and localization: an improved + $\hat{R}$ for assessing convergence of MCMC (with discussion)." + Bayesian analysis 16.2 (2021): 667-718. https://doi.org/10.1214/20-BA1221 Examples -------- @@ -398,6 +428,7 @@ def mcse(data, *, var_names=None, method="mean", prob=None, dask_kwargs=None): "sd": _mcse_sd, "median": _mcse_median, "quantile": _mcse_quantile, + "func": _mcse_func_sbm, } if method not in methods: raise TypeError( @@ -410,32 +441,44 @@ def mcse(data, *, var_names=None, method="mean", prob=None, dask_kwargs=None): if method == "quantile" and prob is None: raise TypeError("Quantile (prob) information needs to be defined.") + if method == "func" and func is None: + raise TypeError("func argument needs to be defined.") + + mcse_kwargs = {} if mcse_kwargs is None else mcse_kwargs + if prob is not None: + mcse_kwargs.setdefault("prob", prob) + elif func is not None: + mcse_kwargs.setdefault("func", func) + mcse_kwargs.setdefault("func_kwargs", func_kwargs) + if isinstance(data, np.ndarray): data = np.atleast_2d(data) if len(data.shape) < 3: - if prob is not None: - return mcse_func(data, prob=prob) # pylint: disable=unexpected-keyword-arg - - return mcse_func(data) - - msg = ( - "Only uni-dimensional ndarray variables are supported." - " Please transform first to dataset with `az.convert_to_dataset`." - ) - raise TypeError(msg) + if data.size < 1000 and method == "func": + warnings.warn( + "Not enough samples for reliable estimate of MCSE for arbitrary functions" + ) + return mcse_func(data, **mcse_kwargs) + else: + msg = ( + "Only uni-dimensional ndarray variables are supported." + " Please transform first to dataset with `az.convert_to_dataset`." + ) + raise TypeError(msg) dataset = convert_to_dataset(data, group="posterior") + if (dataset.dims["chain"] * dataset.dims["draw"]) < 1000 and method == "func": + warnings.warn("Not enough samples for reliable estimate of MCSE for arbitrary functions") var_names = _var_names(var_names, dataset) dataset = dataset if var_names is None else dataset[var_names] ufunc_kwargs = {"ravel": False} - func_kwargs = {} if prob is None else {"prob": prob} return _wrap_xarray_ufunc( mcse_func, dataset, ufunc_kwargs=ufunc_kwargs, - func_kwargs=func_kwargs, + func_kwargs=mcse_kwargs, dask_kwargs=dask_kwargs, ) @@ -813,13 +856,48 @@ def _mcse_mean(ary): return np.nan ess = _ess_mean(ary) if _numba_flag: - sd = _sqrt(svar(np.ravel(ary), ddof=1), np.zeros(1)) + sd = _sqrt(svar(np.ravel(ary), ddof=1), 0) else: sd = np.std(ary, ddof=1) mcse_mean_value = sd / np.sqrt(ess) return mcse_mean_value +def _mcse_func_sbm(ary, func, b=None, var_func=np.var, func_kwargs=None): + """Compute the Markov Chain error on an arbitrary function.""" + ary = np.asarray(ary) + if _not_valid(ary, shape_kwargs=dict(min_draws=10, min_chains=1)): + return np.nan + n = ary.size + if b is None: + b = int(np.sqrt(n)) + if func_kwargs is None: + func_kwargs = {} + func_estimate_sd = _sbm(ary, func, b=b, var_func=var_func, func_kwargs=func_kwargs) + mcse_func_value = func_estimate_sd / np.sqrt(n) + return mcse_func_value + + +def _sbm(ary, func, b, var_func, func_kwargs): + """Subsampling bootstrap method. + + References + ---------- + .. [1] Doss, Charles R., et al. "Markov chain Monte Carlo estimation of quantiles." + *Electronic Journal of Statistics* 8.2 (2014): 2448-2478. + https://doi.org/10.1214/14-EJS957 + + """ + flat_ary = np.ravel(ary) + n = len(flat_ary) + func_estimates = np.empty(n - b) + for i in range(n - b): + sub_ary = flat_ary[i : i + b] + func_estimates[i] = func(sub_ary, **func_kwargs) + func_estimate_sd = np.sqrt(b * var_func(func_estimates)) + return func_estimate_sd + + def _mcse_sd(ary): """Compute the Markov Chain sd error.""" _numba_flag = Numba.numba_flag