Skip to content

Commit

Permalink
rename and add check for warning in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
OriolAbril committed Oct 25, 2024
1 parent ff42224 commit 4739a01
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 8 deletions.
6 changes: 3 additions & 3 deletions src/arviz_stats/base/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def hdi(
"multimodal": (
self._hdi_multimodal_discrete if is_discrete else self._hdi_multimodal_continuous
),
"multimodal_nearest": (
"multimodal_sample": (
self._hdi_multimodal_discrete if is_discrete else self._hdi_multimodal_continuous
),
}[method]
Expand All @@ -87,8 +87,8 @@ def hdi(
func_kwargs["bw"] = "isj" if not circular else "taylor"
func_kwargs.update(kwargs)

if method == "multimodal_nearest":
func_kwargs["nearest"] = True
if method == "multimodal_sample":
func_kwargs["from_sample"] = True

result = hdi_array(ary, **func_kwargs)
if is_multimodal:
Expand Down
4 changes: 2 additions & 2 deletions src/arviz_stats/base/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,15 +247,15 @@ def _hdi_nearest(self, ary, prob, circular, skipna):
return hdi_interval

def _hdi_multimodal_continuous(
self, ary, prob, skipna, max_modes, circular, nearest=False, **kwargs
self, ary, prob, skipna, max_modes, circular, from_sample=False, **kwargs
):
"""Compute HDI if the distribution is multimodal."""
ary = ary.flatten()
if skipna:
ary = ary[~np.isnan(ary)]

bins, density, _ = self.kde(ary, circular=circular, **kwargs)
if nearest:
if from_sample:
ary_density = np.interp(ary, bins, density)
hdi_intervals, interval_probs = self._hdi_from_point_densities(
ary, ary_density, prob, circular
Expand Down
7 changes: 4 additions & 3 deletions tests/base/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def test_hdi_coords(centered_eight):
@pytest.mark.parametrize("prob", [0.56, 0.83])
@pytest.mark.parametrize("nearest", [True, False])
def test_hdi_multimodal_continuous(prob, nearest):
method = "multimodal_nearest" if nearest else "multimodal"
method = "multimodal_sample" if nearest else "multimodal"
rng = np.random.default_rng(43)
dist1 = norm(loc=-30, scale=0.5)
dist2 = norm(loc=30, scale=0.5)
Expand Down Expand Up @@ -176,7 +176,8 @@ def test_hdi_multimodal_max_modes():
sample = ndarray_to_dataarray(x, "x", sample_dims=["sample"])
intervals = sample.azstats.hdi(dims="sample", method="multimodal", prob=0.9)
assert intervals.sizes["mode"] == 2
intervals2 = sample.azstats.hdi(dims="sample", method="multimodal", prob=0.9, max_modes=1)
with pytest.warns(UserWarning, match="found more modes"):
intervals2 = sample.azstats.hdi(dims="sample", method="multimodal", prob=0.9, max_modes=1)
assert intervals2.sizes["mode"] == 1
assert intervals2.equals(intervals.isel(mode=[1]))

Expand All @@ -194,7 +195,7 @@ def test_hdi_multimodal_circular(nearest):
"x",
sample_dims=["sample"],
)
method = "multimodal_nearest" if nearest else "multimodal"
method = "multimodal_sample" if nearest else "multimodal"
interval = normal_sample.azstats.hdi(circular=True, method=method, prob=0.83, dims="sample")
assert interval.sizes["mode"] == 2
assert interval.sel(mode=0)[0] <= np.pi / 2 <= interval.sel(mode=0)[1]
Expand Down

0 comments on commit 4739a01

Please sign in to comment.