Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update mcse_sd calculation to not use normality assumption. #2167

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
### Maintenance and fixes
- Make `arviz.data.generate_dims_coords` handle `dims` and `default_dims` consistently ([2395](https://github.com/arviz-devs/arviz/pull/2395))
- Only emit a warning for custom groups in `InferenceData` when explicitly requested ([2401](https://github.com/arviz-devs/arviz/pull/2401))
- Update `method="sd"` of `mcse` to not use normality assumption ([2167](https://github.com/arviz-devs/arviz/pull/2167))

### Documentation

Expand Down
32 changes: 18 additions & 14 deletions arviz/stats/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,8 +744,8 @@ def _ess_sd(ary, relative=False):
ary = np.asarray(ary)
if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
return np.nan
ary = _split_chains(ary)
return min(_ess(ary, relative=relative), _ess(ary**2, relative=relative))
ary = (ary - ary.mean()) ** 2
return _ess(_split_chains(ary), relative=relative)


def _ess_quantile(ary, prob, relative=False):
Expand Down Expand Up @@ -838,13 +838,15 @@ def _mcse_sd(ary):
ary = np.asarray(ary)
if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
return np.nan
ess = _ess_sd(ary)
sims_c2 = (ary - ary.mean()) ** 2
ess = _ess_mean(sims_c2)
evar = (sims_c2).mean()
varvar = ((sims_c2**2).mean() - evar**2) / ess
varsd = varvar / evar / 4
if _numba_flag:
sd = float(_sqrt(svar(np.ravel(ary), ddof=1), np.zeros(1)).item())
mcse_sd_value = float(_sqrt(np.ravel(varsd), np.zeros(1)))
else:
sd = np.std(ary, ddof=1)
fac_mcse_sd = np.sqrt(np.exp(1) * (1 - 1 / ess) ** (ess - 1) - 1)
mcse_sd_value = sd * fac_mcse_sd
mcse_sd_value = np.sqrt(varsd)
return mcse_sd_value


Expand Down Expand Up @@ -973,19 +975,21 @@ def _multichain_statistics(ary, focus="mean"):
# ess mean
ess_mean_value = _ess_mean(ary)

# ess sd
ess_sd_value = _ess_sd(ary)

# mcse_mean
sd = np.std(ary, ddof=1)
mcse_mean_value = sd / np.sqrt(ess_mean_value)
sims_c2 = (ary - ary.mean()) ** 2
sims_c2_sum = sims_c2.sum()
var = sims_c2_sum / (sims_c2.size - 1)
mcse_mean_value = np.sqrt(var / ess_mean_value)

# ess bulk
ess_bulk_value = _ess(z_split)

# mcse_sd
fac_mcse_sd = np.sqrt(np.exp(1) * (1 - 1 / ess_sd_value) ** (ess_sd_value - 1) - 1)
mcse_sd_value = sd * fac_mcse_sd
evar = sims_c2_sum / sims_c2.size
ess_mean_sims = _ess_mean(sims_c2)
varvar = ((sims_c2**2).mean() - evar**2) / ess_mean_sims
varsd = varvar / evar / 4
mcse_sd_value = np.sqrt(varsd)

return (
mcse_mean_value,
Expand Down
9 changes: 5 additions & 4 deletions arviz/tests/base_tests/test_diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,11 @@ def test_deterministic(self):
```
Reference file:

Created: 2020-08-31
System: Ubuntu 18.04.5 LTS
R version 4.0.2 (2020-06-22)
posterior 0.1.2
Created: 2024-12-20
System: Ubuntu 24.04.1 LTS
R version 4.4.2 (2024-10-31)
posterior version from https://github.com/stan-dev/posterior/pull/388
(after release 1.6.0 but before the fixes in the PR were released).
"""
# download input files
here = os.path.dirname(os.path.abspath(__file__))
Expand Down
Loading