Skip to content

Commit

Permalink
Use datatree from xarray (#35)
Browse files Browse the repository at this point in the history
* update pyproject to use github arviz-base and datatree from xarray

* update leftover datatree.DataTree references

* use xarray.DataTree in tests
  • Loading branch information
OriolAbril authored Dec 5, 2024
1 parent 81d401c commit 77e1358
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 11 deletions.
3 changes: 1 addition & 2 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,13 @@
numpydoc_xref_aliases = {
"DataArray": ":class:`xarray.DataArray`",
"Dataset": ":class:`xarray.Dataset`",
"DataTree": ":class:`datatree.DataTree`",
"DataTree": ":class:`xarray.DataTree`",
**{f"{singular}s": f":any:`{singular}s <{singular}>`" for singular in singulars},
}

intersphinx_mapping = {
"arviz_org": ("https://www.arviz.org/en/latest/", None),
"dask": ("https://docs.dask.org/en/latest/", None),
"datatree": ("https://xarray-datatree.readthedocs.io/en/latest/", None),
"numpy": ("https://numpy.org/doc/stable/", None),
"python": ("https://docs.python.org/3/", None),
"xarray": ("https://docs.xarray.dev/en/stable/", None),
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ funding = "https://opencollective.com/arviz"

[project.optional-dependencies]
xarray = [
"arviz-base @ git+https://github.com/arviz-devs/arviz-base",
"arviz-base @ git+https://github.com/arviz-devs/arviz-base@xarray_datatree",
"xarray-einstats",
"numba",
]
Expand Down
7 changes: 3 additions & 4 deletions src/arviz_stats/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import numpy as np
import xarray as xr
from arviz_base.utils import _var_names
from datatree import DataTree, register_datatree_accessor
from xarray_einstats.numba import ecdf

from arviz_stats.utils import get_function
Expand Down Expand Up @@ -38,7 +37,7 @@ def update_kwargs_with_dims(da, kwargs):
def check_var_name_subset(obj, var_name):
if isinstance(obj, xr.Dataset):
return obj[var_name]
if isinstance(obj, DataTree):
if isinstance(obj, xr.DataTree):
return obj.ds[var_name]
return obj

Expand Down Expand Up @@ -228,7 +227,7 @@ def power_scale_sense(self, dims=None, **kwargs):
return self._apply("power_scale_sense", dims=dims, **kwargs)


@register_datatree_accessor("azstats")
@xr.register_datatree_accessor("azstats")
class AzStatsDtAccessor(_BaseAccessor):
"""ArviZ stats accessor class for DataTrees."""

Expand All @@ -253,7 +252,7 @@ def _apply(self, func_name, group, **kwargs):
if isinstance(group, Hashable):
group = [group]
hashable_group = True
out_dt = DataTree.from_dict(
out_dt = xr.DataTree.from_dict(
{
group_i: apply_function_to_dataset(
get_function(func_name),
Expand Down
5 changes: 2 additions & 3 deletions src/arviz_stats/numba/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import numpy as np
import scipy
import xarray as xr
from datatree import DataTree
from scipy.fftpack import next_fast_len
from xarray_einstats import stats
from xarray_einstats.einops import raw_rearrange
Expand Down Expand Up @@ -118,7 +117,7 @@ def rhat(ds, group="posterior", method="rank", **kwargs):
if method not in func_map:
raise ValueError("method not recognized")
rhat_func = func_map[method]
if isinstance(ds, DataTree):
if isinstance(ds, xr.DataTree):
ds = ds[group]
return ds.map(rhat_func, **kwargs)
if isinstance(ds, xr.Dataset):
Expand Down Expand Up @@ -283,7 +282,7 @@ def ess(ds, group="posterior", method="bulk", **kwargs):
if method not in func_map:
raise ValueError("method not recognized")
ess_func = func_map[method]
if isinstance(ds, DataTree):
if isinstance(ds, xr.DataTree):
ds = ds[group]
return ds.map(ess_func, **kwargs)
if isinstance(ds, xr.Dataset):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import numpy as np
import pytest
from arviz_base import from_dict
from datatree import DataTree
from xarray import DataTree


@pytest.fixture(scope="module")
Expand Down

0 comments on commit 77e1358

Please sign in to comment.