Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update mcse_sd calculation to not use normality assumption. #2167

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
### Maintenance and fixes
- Make `arviz.data.generate_dims_coords` handle `dims` and `default_dims` consistently ([2395](https://github.com/arviz-devs/arviz/pull/2395))
- Only emit a warning for custom groups in `InferenceData` when explicitly requested ([2401](https://github.com/arviz-devs/arviz/pull/2401))
- Update `method="sd"` of `mcse` to not use normality assumption ([2167](https://github.com/arviz-devs/arviz/pull/2167))

### Documentation

Expand Down
16 changes: 9 additions & 7 deletions arviz/stats/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,8 +744,8 @@ def _ess_sd(ary, relative=False):
ary = np.asarray(ary)
if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
return np.nan
ary = _split_chains(ary)
return min(_ess(ary, relative=relative), _ess(ary**2, relative=relative))
ary = np.absolute(ary - ary.mean())
return _ess(_split_chains(ary), relative=relative)


def _ess_quantile(ary, prob, relative=False):
Expand Down Expand Up @@ -838,13 +838,15 @@ def _mcse_sd(ary):
ary = np.asarray(ary)
if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
return np.nan
ess = _ess_sd(ary)
sims_c2 = (ary - ary.mean())**2
ess = _ess_mean(sims_c2)
evar = (sims_c2).mean()
varvar = ((sims_c2**2).mean() - evar**2) / ess
varsd = varvar / evar / 4
if _numba_flag:
sd = float(_sqrt(svar(np.ravel(ary), ddof=1), np.zeros(1)).item())
mcse_sd_value = float(_sqrt(np.ravel(varsd), np.zeros(1)))
else:
sd = np.std(ary, ddof=1)
fac_mcse_sd = np.sqrt(np.exp(1) * (1 - 1 / ess) ** (ess - 1) - 1)
mcse_sd_value = sd * fac_mcse_sd
mcse_sd_value = np.sqrt(varsd)
return mcse_sd_value


Expand Down
8 changes: 4 additions & 4 deletions arviz/tests/base_tests/test_diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,10 @@ def test_deterministic(self):
```
Reference file:

Created: 2020-08-31
System: Ubuntu 18.04.5 LTS
R version 4.0.2 (2020-06-22)
posterior 0.1.2
Created: 2024-12-19
System: Ubuntu 24.04.1 LTS
R version 4.4.2 (2024-10-31)
posterior 1.6.0
"""
# download input files
here = os.path.dirname(os.path.abspath(__file__))
Expand Down
Loading
Loading