diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py index d55fac4f6c9..4c36adbc121 100644 --- a/xarray/backends/h5netcdf_.py +++ b/xarray/backends/h5netcdf_.py @@ -425,18 +425,22 @@ def open_dataset( # type: ignore[override] # allow LSP violation, not supporti return ds # TODO [MHS, 01/23/2024] This is duplicative of the netcdf4 code in an ugly way. - def open_datatree(self, filename: str, **kwargs) -> DataTree: + def open_datatree( + self, + filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, + **kwargs, + ) -> DataTree: from h5netcdf.legacyapi import Dataset as ncDataset from xarray.backends.api import open_dataset from xarray.datatree_.datatree import DataTree from xarray.datatree_.datatree.treenode import NodePath - ds = open_dataset(filename, **kwargs) + ds = open_dataset(filename_or_obj, **kwargs) tree_root = DataTree.from_dict({"/": ds}) - with ncDataset(filename, mode="r") as ncds: + with ncDataset(filename_or_obj, mode="r") as ncds: for path in _iter_nc_groups(ncds): - subgroup_ds = open_dataset(filename, group=path, **kwargs) + subgroup_ds = open_dataset(filename_or_obj, group=path, **kwargs) # TODO refactor to use __setitem__ once creation of new nodes by assigning Dataset works again node_name = NodePath(path).name diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index d7879a57264..8c3bf0ff77c 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -668,18 +668,22 @@ def open_dataset( # type: ignore[override] # allow LSP violation, not supporti ) return ds - def open_datatree(self, filename: str, **kwargs) -> DataTree: + def open_datatree( + self, + filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, + **kwargs, + ) -> DataTree: from netCDF4 import Dataset as ncDataset from xarray.backends.api import open_dataset from xarray.datatree_.datatree import DataTree from xarray.datatree_.datatree.treenode import NodePath - ds = open_dataset(filename, **kwargs) + ds = open_dataset(filename_or_obj, **kwargs) tree_root = DataTree.from_dict({"/": ds}) - with ncDataset(filename, mode="r") as ncds: + with ncDataset(filename_or_obj, mode="r") as ncds: for path in _iter_nc_groups(ncds): - subgroup_ds = open_dataset(filename, group=path, **kwargs) + subgroup_ds = open_dataset(filename_or_obj, group=path, **kwargs) # TODO refactor to use __setitem__ once creation of new nodes by assigning Dataset works again node_name = NodePath(path).name diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index bb46ab68de5..a276103b18b 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -1036,19 +1036,25 @@ def open_dataset( # type: ignore[override] # allow LSP violation, not supporti ) return ds - def open_datatree(self, store, **kwargs) -> DataTree: - import zarr # type: ignore + def open_datatree( + self, + filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, + **kwargs, + ) -> DataTree: + import zarr from xarray.backends.api import open_dataset from xarray.datatree_.datatree import DataTree from xarray.datatree_.datatree.treenode import NodePath - zds = zarr.open_group(store, mode="r") - ds = open_dataset(store, engine="zarr", **kwargs) + zds = zarr.open_group(filename_or_obj, mode="r") + ds = open_dataset(filename_or_obj, engine="zarr", **kwargs) tree_root = DataTree.from_dict({"/": ds}) for path in _iter_zarr_groups(zds): try: - subgroup_ds = open_dataset(store, engine="zarr", group=path, **kwargs) + subgroup_ds = open_dataset( + filename_or_obj, engine="zarr", group=path, **kwargs + ) except zarr.errors.PathNotFoundError: subgroup_ds = Dataset()