diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b341df5..7f81b5b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.6.0 + rev: v5.0.0 hooks: - id: check-added-large-files - id: check-toml @@ -11,7 +11,7 @@ repos: - id: trailing-whitespace - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.4.9 + rev: v0.6.9 hooks: - id: ruff args: [ --fix, --exit-non-zero-on-fix ] diff --git a/src/arviz_stats/accessors.py b/src/arviz_stats/accessors.py index 7745f56..e22fee9 100644 --- a/src/arviz_stats/accessors.py +++ b/src/arviz_stats/accessors.py @@ -1,6 +1,7 @@ """ArviZ stats accessors.""" import warnings +from collections.abc import Hashable import numpy as np import xarray as xr @@ -214,27 +215,40 @@ def _process_input(self, group, method, allow_non_matching=True): ) return self._obj raise ValueError( - f"Group {group} not available in DataTree. Present groups are {self._obj.children}" + f"Group {group} not available in DataTree. Present groups are {self._obj.children} " + f"and the DataTree itself is named {self._obs.name}" ) def _apply(self, fun_name, group, **kwargs): - allow_non_matching = False - if isinstance(group, str): + hashable_group = False + if isinstance(group, Hashable): group = [group] - allow_non_matching = True - return DataTree.from_dict( + hashable_group = True + out_dt = DataTree.from_dict( { group_i: xr.Dataset( { var_name: get_function(fun_name)(da, **update_kwargs_with_dims(da, kwargs)) for var_name, da in self._process_input( - group_i, fun_name, allow_non_matching=allow_non_matching + # if group is a single str/hashable that doesn't match the group + # name, still allow it and apply the function to the top level of + # the provided datatree + group_i, + fun_name, + allow_non_matching=hashable_group, ).items() } ) for group_i in group } ) + if hashable_group: + # if group was a string/hashable, return a datatree with a single node + # (from the provided group) as the root of the DataTree + return out_dt[group[0]] + # if group was a sequence, return a DataTree with multiple groups in the 1st level, + # as many groups as requested + return out_dt def filter_vars(self, group="posterior", var_names=None, filter_vars=None): """Access and filter variables of the provided group.""" diff --git a/tests/base/test_diagnostics.py b/tests/base/test_diagnostics.py index 30683e8..44b2310 100644 --- a/tests/base/test_diagnostics.py +++ b/tests/base/test_diagnostics.py @@ -7,6 +7,7 @@ import pandas as pd import pytest from arviz_base import load_arviz_data, xarray_var_iter + from arviz_stats.base import array_stats # For tests only, recommended value should be closer to 1.01-1.05 diff --git a/tests/base/test_stats_utils.py b/tests/base/test_stats_utils.py index 8e51243..e556069 100644 --- a/tests/base/test_stats_utils.py +++ b/tests/base/test_stats_utils.py @@ -4,11 +4,12 @@ # pylint: disable=no-member,unnecessary-lambda-assignment import numpy as np import pytest -from arviz_stats.base.stats_utils import logsumexp as _logsumexp -from arviz_stats.base.stats_utils import make_ufunc, not_valid from numpy.testing import assert_array_almost_equal from scipy.special import logsumexp +from arviz_stats.base.stats_utils import logsumexp as _logsumexp +from arviz_stats.base.stats_utils import make_ufunc, not_valid + @pytest.mark.parametrize("ary_dtype", [np.float64, np.float32, np.int32, np.int64]) @pytest.mark.parametrize("axis", [None, 0, 1, (-2, -1)]) diff --git a/tests/test_accessors.py b/tests/test_accessors.py new file mode 100644 index 0000000..44a0606 --- /dev/null +++ b/tests/test_accessors.py @@ -0,0 +1,48 @@ +# pylint: disable=redefined-outer-name +"""Test accessors. + +Accessor methods are very short, with the bulk of the computation/processing +handled by private methods. Testing this shared infrastructural methods +is the main goal of this module even if it does so via specific "regular" methods. +""" + +import numpy as np +import pytest +from arviz_base import from_dict +from datatree import DataTree + + +@pytest.fixture(scope="module") +def idata(): + return from_dict( + { + "posterior": { + "a": np.random.normal(size=(4, 100)), + "b": np.random.normal(size=(4, 100, 3)), + }, + "posterior_predictive": { + "y": np.random.normal(size=(4, 100, 7)), + }, + } + ) + + +def test_accessors_available(idata): + assert hasattr(idata, "azstats") + assert hasattr(idata.posterior.ds, "azstats") + assert hasattr(idata.posterior["a"], "azstats") + + +def test_datatree_single_group(idata): + out = idata.azstats.ess(group="posterior") + assert isinstance(out, DataTree) + assert not out.children + assert out.name == "posterior" + + +def test_datatree_multiple_groups(idata): + out = idata.azstats.ess(group=["posterior", "posterior_predictive"]) + assert isinstance(out, DataTree) + assert len(out.children) == 2 + assert "posterior" in out.children + assert "posterior_predictive" in out.children diff --git a/tests/test_psense.py b/tests/test_psense.py index c5d4dec..cb37e92 100644 --- a/tests/test_psense.py +++ b/tests/test_psense.py @@ -1,10 +1,11 @@ import os from arviz_base import convert_to_datatree -from arviz_stats import psense, psense_summary from numpy import isclose from numpy.testing import assert_almost_equal +from arviz_stats import psense, psense_summary + file_path = os.path.join(os.path.dirname(__file__), "univariate_normal.nc") uni_dt = convert_to_datatree(file_path) diff --git a/tests/test_utils.py b/tests/test_utils.py index 998fc1c..859e469 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -5,6 +5,7 @@ import numpy as np import pytest from arviz_base import from_dict, rcParams + from arviz_stats.base.dataarray import dataarray_stats from arviz_stats.utils import ELPDData, get_function, get_log_likelihood diff --git a/tox.ini b/tox.ini index 7243f13..051de9d 100644 --- a/tox.ini +++ b/tox.ini @@ -9,9 +9,10 @@ isolated_build_env = build [gh-actions] python = - 3.10: check, py310 + 3.10: py310 3.11: py311 - 3.12: py312 + 3.12: py312, check + 3.13: py313 [testenv] basepython =