From a8b8cbb97259c2b441b46f19f2ce5411297d04bf Mon Sep 17 00:00:00 2001 From: "Oriol (ProDesk)" Date: Fri, 25 Oct 2024 22:41:42 +0200 Subject: [PATCH] update pyproject to use github arviz-base and datatree from xarray --- pyproject.toml | 2 +- src/arviz_stats/accessors.py | 5 ++--- src/arviz_stats/numba/diagnostics.py | 5 ++--- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e87d5f1..dec8072 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ funding = "https://opencollective.com/arviz" [project.optional-dependencies] xarray = [ - "arviz-base @ git+https://github.com/arviz-devs/arviz-base", + "arviz-base @ git+https://github.com/arviz-devs/arviz-base@xarray_datatree", "xarray-einstats", "numba", ] diff --git a/src/arviz_stats/accessors.py b/src/arviz_stats/accessors.py index a4cb198..f9df7c3 100644 --- a/src/arviz_stats/accessors.py +++ b/src/arviz_stats/accessors.py @@ -6,7 +6,6 @@ import numpy as np import xarray as xr from arviz_base.utils import _var_names -from datatree import DataTree, register_datatree_accessor from xarray_einstats.numba import ecdf from arviz_stats.utils import get_function @@ -228,7 +227,7 @@ def power_scale_sense(self, dims=None, **kwargs): return self._apply("power_scale_sense", dims=dims, **kwargs) -@register_datatree_accessor("azstats") +@xr.register_datatree_accessor("azstats") class AzStatsDtAccessor(_BaseAccessor): """ArviZ stats accessor class for DataTrees.""" @@ -253,7 +252,7 @@ def _apply(self, func_name, group, **kwargs): if isinstance(group, Hashable): group = [group] hashable_group = True - out_dt = DataTree.from_dict( + out_dt = xr.DataTree.from_dict( { group_i: apply_function_to_dataset( get_function(func_name), diff --git a/src/arviz_stats/numba/diagnostics.py b/src/arviz_stats/numba/diagnostics.py index 94b365c..d993f97 100644 --- a/src/arviz_stats/numba/diagnostics.py +++ b/src/arviz_stats/numba/diagnostics.py @@ -4,7 +4,6 @@ import numpy as np import scipy import xarray as xr -from datatree import DataTree from scipy.fftpack import next_fast_len from xarray_einstats import stats from xarray_einstats.einops import raw_rearrange @@ -118,7 +117,7 @@ def rhat(ds, group="posterior", method="rank", **kwargs): if method not in func_map: raise ValueError("method not recognized") rhat_func = func_map[method] - if isinstance(ds, DataTree): + if isinstance(ds, xr.DataTree): ds = ds[group] return ds.map(rhat_func, **kwargs) if isinstance(ds, xr.Dataset): @@ -283,7 +282,7 @@ def ess(ds, group="posterior", method="bulk", **kwargs): if method not in func_map: raise ValueError("method not recognized") ess_func = func_map[method] - if isinstance(ds, DataTree): + if isinstance(ds, xr.DataTree): ds = ds[group] return ds.map(ess_func, **kwargs) if isinstance(ds, xr.Dataset):