Skip to content

Commit

Permalink
update pyproject to use github arviz-base and datatree from xarray
Browse files Browse the repository at this point in the history
  • Loading branch information
OriolAbril committed Nov 22, 2024
1 parent 81d401c commit a8b8cbb
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 7 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ funding = "https://opencollective.com/arviz"

[project.optional-dependencies]
xarray = [
"arviz-base @ git+https://github.com/arviz-devs/arviz-base",
"arviz-base @ git+https://github.com/arviz-devs/arviz-base@xarray_datatree",
"xarray-einstats",
"numba",
]
Expand Down
5 changes: 2 additions & 3 deletions src/arviz_stats/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import numpy as np
import xarray as xr
from arviz_base.utils import _var_names
from datatree import DataTree, register_datatree_accessor
from xarray_einstats.numba import ecdf

from arviz_stats.utils import get_function
Expand Down Expand Up @@ -228,7 +227,7 @@ def power_scale_sense(self, dims=None, **kwargs):
return self._apply("power_scale_sense", dims=dims, **kwargs)


@register_datatree_accessor("azstats")
@xr.register_datatree_accessor("azstats")
class AzStatsDtAccessor(_BaseAccessor):
"""ArviZ stats accessor class for DataTrees."""

Expand All @@ -253,7 +252,7 @@ def _apply(self, func_name, group, **kwargs):
if isinstance(group, Hashable):
group = [group]
hashable_group = True
out_dt = DataTree.from_dict(
out_dt = xr.DataTree.from_dict(
{
group_i: apply_function_to_dataset(
get_function(func_name),
Expand Down
5 changes: 2 additions & 3 deletions src/arviz_stats/numba/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import numpy as np
import scipy
import xarray as xr
from datatree import DataTree
from scipy.fftpack import next_fast_len
from xarray_einstats import stats
from xarray_einstats.einops import raw_rearrange
Expand Down Expand Up @@ -118,7 +117,7 @@ def rhat(ds, group="posterior", method="rank", **kwargs):
if method not in func_map:
raise ValueError("method not recognized")
rhat_func = func_map[method]
if isinstance(ds, DataTree):
if isinstance(ds, xr.DataTree):
ds = ds[group]
return ds.map(rhat_func, **kwargs)
if isinstance(ds, xr.Dataset):
Expand Down Expand Up @@ -283,7 +282,7 @@ def ess(ds, group="posterior", method="bulk", **kwargs):
if method not in func_map:
raise ValueError("method not recognized")
ess_func = func_map[method]
if isinstance(ds, DataTree):
if isinstance(ds, xr.DataTree):
ds = ds[group]
return ds.map(ess_func, **kwargs)
if isinstance(ds, xr.Dataset):
Expand Down

0 comments on commit a8b8cbb

Please sign in to comment.