diff --git a/src/arviz_stats/base/array.py b/src/arviz_stats/base/array.py index 33a43ec..11c74ff 100644 --- a/src/arviz_stats/base/array.py +++ b/src/arviz_stats/base/array.py @@ -62,7 +62,7 @@ def hdi( "multimodal": ( self._hdi_multimodal_discrete if is_discrete else self._hdi_multimodal_continuous ), - "multimodal_nearest": ( + "multimodal_sample": ( self._hdi_multimodal_discrete if is_discrete else self._hdi_multimodal_continuous ), }[method] @@ -87,8 +87,8 @@ def hdi( func_kwargs["bw"] = "isj" if not circular else "taylor" func_kwargs.update(kwargs) - if method == "multimodal_nearest": - func_kwargs["nearest"] = True + if method == "multimodal_sample": + func_kwargs["from_sample"] = True result = hdi_array(ary, **func_kwargs) if is_multimodal: diff --git a/src/arviz_stats/base/core.py b/src/arviz_stats/base/core.py index 24de035..422c297 100644 --- a/src/arviz_stats/base/core.py +++ b/src/arviz_stats/base/core.py @@ -247,7 +247,7 @@ def _hdi_nearest(self, ary, prob, circular, skipna): return hdi_interval def _hdi_multimodal_continuous( - self, ary, prob, skipna, max_modes, circular, nearest=False, **kwargs + self, ary, prob, skipna, max_modes, circular, from_sample=False, **kwargs ): """Compute HDI if the distribution is multimodal.""" ary = ary.flatten() @@ -255,7 +255,7 @@ def _hdi_multimodal_continuous( ary = ary[~np.isnan(ary)] bins, density, _ = self.kde(ary, circular=circular, **kwargs) - if nearest: + if from_sample: ary_density = np.interp(ary, bins, density) hdi_intervals, interval_probs = self._hdi_from_point_densities( ary, ary_density, prob, circular diff --git a/tests/base/test_stats.py b/tests/base/test_stats.py index 00d04d6..925ba1f 100644 --- a/tests/base/test_stats.py +++ b/tests/base/test_stats.py @@ -92,7 +92,7 @@ def test_hdi_coords(centered_eight): @pytest.mark.parametrize("prob", [0.56, 0.83]) @pytest.mark.parametrize("nearest", [True, False]) def test_hdi_multimodal_continuous(prob, nearest): - method = "multimodal_nearest" if nearest else "multimodal" + method = "multimodal_sample" if nearest else "multimodal" rng = np.random.default_rng(43) dist1 = norm(loc=-30, scale=0.5) dist2 = norm(loc=30, scale=0.5) @@ -176,7 +176,8 @@ def test_hdi_multimodal_max_modes(): sample = ndarray_to_dataarray(x, "x", sample_dims=["sample"]) intervals = sample.azstats.hdi(dims="sample", method="multimodal", prob=0.9) assert intervals.sizes["mode"] == 2 - intervals2 = sample.azstats.hdi(dims="sample", method="multimodal", prob=0.9, max_modes=1) + with pytest.warns(UserWarning, match="found more modes"): + intervals2 = sample.azstats.hdi(dims="sample", method="multimodal", prob=0.9, max_modes=1) assert intervals2.sizes["mode"] == 1 assert intervals2.equals(intervals.isel(mode=[1])) @@ -194,7 +195,7 @@ def test_hdi_multimodal_circular(nearest): "x", sample_dims=["sample"], ) - method = "multimodal_nearest" if nearest else "multimodal" + method = "multimodal_sample" if nearest else "multimodal" interval = normal_sample.azstats.hdi(circular=True, method=method, prob=0.83, dims="sample") assert interval.sizes["mode"] == 2 assert interval.sel(mode=0)[0] <= np.pi / 2 <= interval.sel(mode=0)[1]