Skip to content

Commit

Permalink
fix n term in sbm mcse method
Browse files Browse the repository at this point in the history
  • Loading branch information
OriolAbril committed Mar 11, 2022
1 parent 1055d59 commit 2a0aabb
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions arviz/stats/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,10 @@ def mcse(data, *, var_names=None, method="mean", prob=None, func=None, dask_kwar
if isinstance(data, np.ndarray):
data = np.atleast_2d(data)
if len(data.shape) < 3:
if data.size < 1000 and method == "func":
warnings.warn(
"Not enough samples for reliable estimate of MCSE for arbitrary functions"
)
if prob is not None:
return mcse_func(data, prob=prob) # pylint: disable=unexpected-keyword-arg
else:
Expand All @@ -429,6 +433,10 @@ def mcse(data, *, var_names=None, method="mean", prob=None, func=None, dask_kwar
raise TypeError(msg)

dataset = convert_to_dataset(data, group="posterior")
if (dataset.dims["chain"] * dataset.dims["draw"]) < 1000 and method == "func":
warnings.warn(
"Not enough samples for reliable estimate of MCSE for arbitrary functions"
)
var_names = _var_names(var_names, dataset)

dataset = dataset if var_names is None else dataset[var_names]
Expand Down Expand Up @@ -833,11 +841,11 @@ def _mcse_mean(ary):
def _mcse_func_sbm(ary, func):
"""Compute the Markov Chain error on an arbitrary function."""
ary = np.asarray(ary)
if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
if _not_valid(ary, shape_kwargs=dict(min_draws=10, min_chains=1)):
return np.nan
ess = _ess_mean(ary)
n = ary.size
func_estimate_sd = _sbm(ary, func)
mcse_func_value = func_estimate_sd / np.sqrt(ess)
mcse_func_value = func_estimate_sd / np.sqrt(n)
return mcse_func_value

def _sbm(ary, func):
Expand Down

0 comments on commit 2a0aabb

Please sign in to comment.