diff --git a/.binder/environment.yml b/.binder/environment.yml index 053b12dfc86..fee5ed07cf7 100644 --- a/.binder/environment.yml +++ b/.binder/environment.yml @@ -28,7 +28,6 @@ dependencies: - pip - pooch - pydap - - pynio - rasterio - scipy - seaborn diff --git a/MANIFEST.in b/MANIFEST.in index 032b620f433..a119e7df1fd 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1 +1,2 @@ prune xarray/datatree_* +recursive-include xarray/datatree_/datatree *.py diff --git a/ci/requirements/bare-minimum.yml b/ci/requirements/bare-minimum.yml index 56af319f0bb..105e90ce109 100644 --- a/ci/requirements/bare-minimum.yml +++ b/ci/requirements/bare-minimum.yml @@ -12,5 +12,5 @@ dependencies: - pytest-xdist - pytest-timeout - numpy=1.23 - - packaging=22.0 - - pandas=1.5 + - packaging=23.1 + - pandas=2.0 diff --git a/ci/requirements/min-all-deps.yml b/ci/requirements/min-all-deps.yml index d2965fb3fc5..64f4327bbcb 100644 --- a/ci/requirements/min-all-deps.yml +++ b/ci/requirements/min-all-deps.yml @@ -9,13 +9,13 @@ dependencies: # doc/user-guide/installing.rst, doc/user-guide/plotting.rst and setup.py. - python=3.9 - array-api-strict=1.0 # dependency for testing the array api compat - - boto3=1.24 + - boto3=1.26 - bottleneck=1.3 - cartopy=0.21 - cftime=1.6 - coveralls - - dask-core=2022.12 - - distributed=2022.12 + - dask-core=2023.4 + - distributed=2023.4 # Flox > 0.8 has a bug with numbagg versions # It will require numbagg > 0.6 # so we should just skip that series eventually @@ -25,12 +25,12 @@ dependencies: # h5py and hdf5 tend to cause conflicts # for e.g. hdf5 1.12 conflicts with h5py=3.1 # prioritize bumping other packages instead - - h5py=3.7 + - h5py=3.8 - hdf5=1.12 - hypothesis - iris=3.4 - lxml=4.9 # Optional dep of pydap - - matplotlib-base=3.6 + - matplotlib-base=3.7 - nc-time-axis=1.4 # netcdf follows a 1.major.minor[.patch] convention # (see https://github.com/Unidata/netcdf4-python/issues/1090) @@ -38,11 +38,11 @@ dependencies: - numba=0.56 - numbagg=0.2.1 - numpy=1.23 - - packaging=22.0 - - pandas=1.5 + - packaging=23.1 + - pandas=2.0 - pint=0.22 - pip - - pydap=3.3 + - pydap=3.4 - pytest - pytest-cov - pytest-env @@ -51,7 +51,7 @@ dependencies: - rasterio=1.3 - scipy=1.10 - seaborn=0.12 - - sparse=0.13 + - sparse=0.14 - toolz=0.12 - - typing_extensions=4.4 - - zarr=2.13 + - typing_extensions=4.5 + - zarr=2.14 diff --git a/doc/getting-started-guide/installing.rst b/doc/getting-started-guide/installing.rst index f7eaf92f9cf..ca12ae62440 100644 --- a/doc/getting-started-guide/installing.rst +++ b/doc/getting-started-guide/installing.rst @@ -31,9 +31,6 @@ For netCDF and IO - `pydap `__: used as a fallback for accessing OPeNDAP - `h5netcdf `__: an alternative library for reading and writing netCDF4 files that does not use the netCDF-C libraries -- `PyNIO `__: for reading GRIB and other - geoscience specific file formats. Note that PyNIO is not available for Windows and - that the PyNIO backend may be moved outside of xarray in the future. - `zarr `__: for chunked, compressed, N-dimensional arrays. - `cftime `__: recommended if you want to encode/decode datetimes for non-standard calendars or dates before diff --git a/doc/user-guide/io.rst b/doc/user-guide/io.rst index 48751c5f299..63bf8b80d81 100644 --- a/doc/user-guide/io.rst +++ b/doc/user-guide/io.rst @@ -1294,27 +1294,6 @@ We recommend installing cfgrib via conda:: .. _cfgrib: https://github.com/ecmwf/cfgrib -.. _io.pynio: - -Formats supported by PyNIO --------------------------- - -.. warning:: - - The `PyNIO backend is deprecated`_. `PyNIO is no longer maintained`_. - -Xarray can also read GRIB, HDF4 and other file formats supported by PyNIO_, -if PyNIO is installed. To use PyNIO to read such files, supply -``engine='pynio'`` to :py:func:`open_dataset`. - -We recommend installing PyNIO via conda:: - - conda install -c conda-forge pynio - -.. _PyNIO: https://www.pyngl.ucar.edu/Nio.shtml -.. _PyNIO backend is deprecated: https://github.com/pydata/xarray/issues/4491 -.. _PyNIO is no longer maintained: https://github.com/NCAR/pynio/issues/53 - CSV and other formats supported by pandas ----------------------------------------- diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 2332f7f236b..46f9d3bbfc9 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -29,9 +29,31 @@ New Features for example, will retain the object. However, one cannot do operations that are not possible on the `ExtensionArray` then, such as broadcasting. By `Ilan Gold `_. +- Added the option to avoid automatically creating 1D pandas indexes in :py:meth:`Dataset.expand_dims()`, by passing the new kwarg + `create_index=False`. (:pull:`8960`) + By `Tom Nicholas `_. Breaking changes ~~~~~~~~~~~~~~~~ +- The PyNIO backend has been deleted (:issue:`4491`, :pull:`7301`). + By `Deepak Cherian `_. + +- The minimum versions of some dependencies were changed, in particular our minimum supported pandas version is now Pandas 2. + + ===================== ========= ======= + Package Old New + ===================== ========= ======= + dask-core 2022.12 2023.4 + distributed 2022.12 2023.4 + h5py 3.7 3.8 + matplotlib-base 3.6 3.7 + packaging 22.0 23.1 + pandas 1.5 2.0 + pydap 3.3 3.4 + sparse 0.13 0.14 + typing_extensions 4.4 4.5 + zarr 2.13 2.14 + ===================== ========= ======= Bug fixes @@ -40,12 +62,17 @@ Bug fixes Internal Changes ~~~~~~~~~~~~~~~~ -- Migrates ``formatting_html`` functionality for `DataTree` into ``xarray/core`` (:pull: `8930`) +- Migrates ``formatting_html`` functionality for ``DataTree`` into ``xarray/core`` (:pull: `8930`) By `Eni Awowale `_, `Julia Signell `_ and `Tom Nicholas `_. - Migrates ``datatree_mapping`` functionality into ``xarray/core`` (:pull:`8948`) By `Matt Savoie `_ `Owen Littlejohns - ` and `Tom Nicholas `_. + `_ and `Tom Nicholas `_. +- Migrates ``extensions``, ``formatting`` and ``datatree_render`` functionality for + ``DataTree`` into ``xarray/core``. Also migrates ``testing`` functionality into + ``xarray/testing/assertions`` for ``DataTree``. (:pull:`8967`) + By `Owen Littlejohns `_ and + `Tom Nicholas `_. .. _whats-new.2024.03.0: @@ -6806,8 +6833,7 @@ Enhancements datasets with a MultiIndex to a netCDF file. User contributions in this area would be greatly appreciated. -- Support for reading GRIB, HDF4 and other file formats via PyNIO_. See - :ref:`io.pynio` for more details. +- Support for reading GRIB, HDF4 and other file formats via PyNIO_. - Better error message when a variable is supplied with the same name as one of its dimensions. - Plotting: more control on colormap parameters (:issue:`642`). ``vmin`` and diff --git a/pyproject.toml b/pyproject.toml index 8cbd395b2a3..dc1b9d65de6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,8 +24,8 @@ requires-python = ">=3.9" dependencies = [ "numpy>=1.23", - "packaging>=22", - "pandas>=1.5", + "packaging>=23.1", + "pandas>=2.0", ] [project.optional-dependencies] @@ -88,7 +88,7 @@ exclude_lines = ["pragma: no cover", "if TYPE_CHECKING"] enable_error_code = "redundant-self" exclude = [ 'xarray/util/generate_.*\.py', - 'xarray/datatree_/.*\.py', + 'xarray/datatree_/doc/.*\.py', ] files = "xarray" show_error_codes = true @@ -97,11 +97,6 @@ warn_redundant_casts = true warn_unused_configs = true warn_unused_ignores = true -# Ignore mypy errors for modules imported from datatree_. -[[tool.mypy.overrides]] -ignore_errors = true -module = "xarray.datatree_.*" - # Much of the numerical computing stack doesn't have type annotations yet. [[tool.mypy.overrides]] ignore_missing_imports = true diff --git a/xarray/backends/__init__.py b/xarray/backends/__init__.py index 1c8d2d3a659..550b9e29e42 100644 --- a/xarray/backends/__init__.py +++ b/xarray/backends/__init__.py @@ -15,7 +15,6 @@ from xarray.backends.netCDF4_ import NetCDF4BackendEntrypoint, NetCDF4DataStore from xarray.backends.plugins import list_engines, refresh_engines from xarray.backends.pydap_ import PydapBackendEntrypoint, PydapDataStore -from xarray.backends.pynio_ import NioDataStore from xarray.backends.scipy_ import ScipyBackendEntrypoint, ScipyDataStore from xarray.backends.store import StoreBackendEntrypoint from xarray.backends.zarr import ZarrBackendEntrypoint, ZarrStore @@ -30,7 +29,6 @@ "InMemoryDataStore", "NetCDF4DataStore", "PydapDataStore", - "NioDataStore", "ScipyDataStore", "H5NetCDFStore", "ZarrStore", diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 2589ff196f9..62085fe5e2a 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -61,7 +61,7 @@ T_NetcdfEngine = Literal["netcdf4", "scipy", "h5netcdf"] T_Engine = Union[ T_NetcdfEngine, - Literal["pydap", "pynio", "zarr"], + Literal["pydap", "zarr"], type[BackendEntrypoint], str, # no nice typing support for custom backends None, @@ -79,7 +79,6 @@ "scipy": backends.ScipyDataStore, "pydap": backends.PydapDataStore.open, "h5netcdf": backends.H5NetCDFStore.open, - "pynio": backends.NioDataStore, "zarr": backends.ZarrStore.open_group, } @@ -420,8 +419,8 @@ def open_dataset( ends with .gz, in which case the file is gunzipped and opened with scipy.io.netcdf (only netCDF3 supported). Byte-strings or file-like objects are opened by scipy.io.netcdf (netCDF3) or h5py (netCDF4/HDF). - engine : {"netcdf4", "scipy", "pydap", "h5netcdf", "pynio", \ - "zarr", None}, installed backend \ + engine : {"netcdf4", "scipy", "pydap", "h5netcdf", "zarr", None}\ + , installed backend \ or subclass of xarray.backends.BackendEntrypoint, optional Engine to use when reading files. If not provided, the default engine is chosen based on available dependencies, with a preference for @@ -523,7 +522,7 @@ def open_dataset( relevant when using dask or another form of parallelism. By default, appropriate locks are chosen to safely read and write files with the currently active dask scheduler. Supported by "netcdf4", "h5netcdf", - "scipy", "pynio". + "scipy". See engine open function for kwargs accepted by each specific engine. @@ -627,8 +626,8 @@ def open_dataarray( ends with .gz, in which case the file is gunzipped and opened with scipy.io.netcdf (only netCDF3 supported). Byte-strings or file-like objects are opened by scipy.io.netcdf (netCDF3) or h5py (netCDF4/HDF). - engine : {"netcdf4", "scipy", "pydap", "h5netcdf", "pynio", \ - "zarr", None}, installed backend \ + engine : {"netcdf4", "scipy", "pydap", "h5netcdf", "zarr", None}\ + , installed backend \ or subclass of xarray.backends.BackendEntrypoint, optional Engine to use when reading files. If not provided, the default engine is chosen based on available dependencies, with a preference for @@ -728,7 +727,7 @@ def open_dataarray( relevant when using dask or another form of parallelism. By default, appropriate locks are chosen to safely read and write files with the currently active dask scheduler. Supported by "netcdf4", "h5netcdf", - "scipy", "pynio". + "scipy". See engine open function for kwargs accepted by each specific engine. @@ -897,8 +896,8 @@ def open_mfdataset( If provided, call this function on each dataset prior to concatenation. You can find the file-name from which each dataset was loaded in ``ds.encoding["source"]``. - engine : {"netcdf4", "scipy", "pydap", "h5netcdf", "pynio", \ - "zarr", None}, installed backend \ + engine : {"netcdf4", "scipy", "pydap", "h5netcdf", "zarr", None}\ + , installed backend \ or subclass of xarray.backends.BackendEntrypoint, optional Engine to use when reading files. If not provided, the default engine is chosen based on available dependencies, with a preference for diff --git a/xarray/backends/pynio_.py b/xarray/backends/pynio_.py deleted file mode 100644 index 75e96ffdc0a..00000000000 --- a/xarray/backends/pynio_.py +++ /dev/null @@ -1,164 +0,0 @@ -from __future__ import annotations - -import warnings -from collections.abc import Iterable -from typing import TYPE_CHECKING, Any - -import numpy as np - -from xarray.backends.common import ( - BACKEND_ENTRYPOINTS, - AbstractDataStore, - BackendArray, - BackendEntrypoint, - _normalize_path, -) -from xarray.backends.file_manager import CachingFileManager -from xarray.backends.locks import ( - HDF5_LOCK, - NETCDFC_LOCK, - SerializableLock, - combine_locks, - ensure_lock, -) -from xarray.backends.store import StoreBackendEntrypoint -from xarray.core import indexing -from xarray.core.utils import Frozen, FrozenDict, close_on_error -from xarray.core.variable import Variable - -if TYPE_CHECKING: - import os - from io import BufferedIOBase - - from xarray.core.dataset import Dataset - -# PyNIO can invoke netCDF libraries internally -# Add a dedicated lock just in case NCL as well isn't thread-safe. -NCL_LOCK = SerializableLock() -PYNIO_LOCK = combine_locks([HDF5_LOCK, NETCDFC_LOCK, NCL_LOCK]) - - -class NioArrayWrapper(BackendArray): - def __init__(self, variable_name, datastore): - self.datastore = datastore - self.variable_name = variable_name - array = self.get_array() - self.shape = array.shape - self.dtype = np.dtype(array.typecode()) - - def get_array(self, needs_lock=True): - ds = self.datastore._manager.acquire(needs_lock) - return ds.variables[self.variable_name] - - def __getitem__(self, key): - return indexing.explicit_indexing_adapter( - key, self.shape, indexing.IndexingSupport.BASIC, self._getitem - ) - - def _getitem(self, key): - with self.datastore.lock: - array = self.get_array(needs_lock=False) - - if key == () and self.ndim == 0: - return array.get_value() - - return array[key] - - -class NioDataStore(AbstractDataStore): - """Store for accessing datasets via PyNIO""" - - def __init__(self, filename, mode="r", lock=None, **kwargs): - import Nio - - warnings.warn( - "The PyNIO backend is Deprecated and will be removed from Xarray in a future release. " - "See https://github.com/pydata/xarray/issues/4491 for more information", - DeprecationWarning, - ) - - if lock is None: - lock = PYNIO_LOCK - self.lock = ensure_lock(lock) - self._manager = CachingFileManager( - Nio.open_file, filename, lock=lock, mode=mode, kwargs=kwargs - ) - # xarray provides its own support for FillValue, - # so turn off PyNIO's support for the same. - self.ds.set_option("MaskedArrayMode", "MaskedNever") - - @property - def ds(self): - return self._manager.acquire() - - def open_store_variable(self, name, var): - data = indexing.LazilyIndexedArray(NioArrayWrapper(name, self)) - return Variable(var.dimensions, data, var.attributes) - - def get_variables(self): - return FrozenDict( - (k, self.open_store_variable(k, v)) for k, v in self.ds.variables.items() - ) - - def get_attrs(self): - return Frozen(self.ds.attributes) - - def get_dimensions(self): - return Frozen(self.ds.dimensions) - - def get_encoding(self): - return { - "unlimited_dims": {k for k in self.ds.dimensions if self.ds.unlimited(k)} - } - - def close(self): - self._manager.close() - - -class PynioBackendEntrypoint(BackendEntrypoint): - """ - PyNIO backend - - .. deprecated:: 0.20.0 - - Deprecated as PyNIO is no longer supported. See - https://github.com/pydata/xarray/issues/4491 for more information - """ - - def open_dataset( # type: ignore[override] # allow LSP violation, not supporting **kwargs - self, - filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, - *, - mask_and_scale=True, - decode_times=True, - concat_characters=True, - decode_coords=True, - drop_variables: str | Iterable[str] | None = None, - use_cftime=None, - decode_timedelta=None, - mode="r", - lock=None, - ) -> Dataset: - filename_or_obj = _normalize_path(filename_or_obj) - store = NioDataStore( - filename_or_obj, - mode=mode, - lock=lock, - ) - - store_entrypoint = StoreBackendEntrypoint() - with close_on_error(store): - ds = store_entrypoint.open_dataset( - store, - mask_and_scale=mask_and_scale, - decode_times=decode_times, - concat_characters=concat_characters, - decode_coords=decode_coords, - drop_variables=drop_variables, - use_cftime=use_cftime, - decode_timedelta=decode_timedelta, - ) - return ds - - -BACKEND_ENTRYPOINTS["pynio"] = ("Nio", PynioBackendEntrypoint) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 509962ff80d..41c9af1bb10 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -2557,6 +2557,7 @@ def expand_dims( self, dim: None | Hashable | Sequence[Hashable] | Mapping[Any, Any] = None, axis: None | int | Sequence[int] = None, + create_index: bool = True, **dim_kwargs: Any, ) -> Self: """Return a new object with an additional axis (or axes) inserted at @@ -2566,6 +2567,9 @@ def expand_dims( If dim is already a scalar coordinate, it will be promoted to a 1D coordinate consisting of a single value. + The automatic creation of indexes to back new 1D coordinate variables + controlled by the create_index kwarg. + Parameters ---------- dim : Hashable, sequence of Hashable, dict, or None, optional @@ -2581,6 +2585,8 @@ def expand_dims( multiple axes are inserted. In this case, dim arguments should be same length list. If axis=None is passed, all the axes will be inserted to the start of the result array. + create_index : bool, default is True + Whether to create new PandasIndex objects for any new 1D coordinate variables. **dim_kwargs : int or sequence or ndarray The keywords are arbitrary dimensions being inserted and the values are either the lengths of the new dims (if int is given), or their @@ -2644,7 +2650,7 @@ def expand_dims( dim = {dim: 1} dim = either_dict_or_kwargs(dim, dim_kwargs, "expand_dims") - ds = self._to_temp_dataset().expand_dims(dim, axis) + ds = self._to_temp_dataset().expand_dims(dim, axis, create_index=create_index) return self._from_temp_dataset(ds) def set_index( diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 96f3be00995..4f9125a1ab0 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -4497,6 +4497,7 @@ def expand_dims( self, dim: None | Hashable | Sequence[Hashable] | Mapping[Any, Any] = None, axis: None | int | Sequence[int] = None, + create_index: bool = True, **dim_kwargs: Any, ) -> Self: """Return a new object with an additional axis (or axes) inserted at @@ -4506,6 +4507,9 @@ def expand_dims( If dim is already a scalar coordinate, it will be promoted to a 1D coordinate consisting of a single value. + The automatic creation of indexes to back new 1D coordinate variables + controlled by the create_index kwarg. + Parameters ---------- dim : hashable, sequence of hashable, mapping, or None @@ -4521,6 +4525,8 @@ def expand_dims( multiple axes are inserted. In this case, dim arguments should be same length list. If axis=None is passed, all the axes will be inserted to the start of the result array. + create_index : bool, default is True + Whether to create new PandasIndex objects for any new 1D coordinate variables. **dim_kwargs : int or sequence or ndarray The keywords are arbitrary dimensions being inserted and the values are either the lengths of the new dims (if int is given), or their @@ -4640,9 +4646,14 @@ def expand_dims( # save the coordinates to the variables dict, and set the # value within the dim dict to the length of the iterable # for later use. - index = PandasIndex(v, k) - indexes[k] = index - variables.update(index.create_variables()) + + if create_index: + index = PandasIndex(v, k) + indexes[k] = index + name_and_new_1d_var = index.create_variables() + else: + name_and_new_1d_var = {k: Variable(data=v, dims=k)} + variables.update(name_and_new_1d_var) coord_names.add(k) dim[k] = variables[k].size elif isinstance(v, int): @@ -4678,11 +4689,23 @@ def expand_dims( variables[k] = v.set_dims(dict(all_dims)) else: if k not in variables: - # If dims includes a label of a non-dimension coordinate, - # it will be promoted to a 1D coordinate with a single value. - index, index_vars = create_default_index_implicit(v.set_dims(k)) - indexes[k] = index - variables.update(index_vars) + if k in coord_names and create_index: + # If dims includes a label of a non-dimension coordinate, + # it will be promoted to a 1D coordinate with a single value. + index, index_vars = create_default_index_implicit(v.set_dims(k)) + indexes[k] = index + variables.update(index_vars) + else: + if create_index: + warnings.warn( + f"No index created for dimension {k} because variable {k} is not a coordinate. " + f"To create an index for {k}, please first call `.set_coords('{k}')` on this object.", + UserWarning, + ) + + # create 1D variable without creating a new index + new_1d_var = v.set_dims(k) + variables.update({k: new_1d_var}) return self._replace_with_new_dims( variables, coord_names=coord_names, indexes=indexes diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index 57fd7222898..48c714b697c 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -23,6 +23,8 @@ check_isomorphic, map_over_subtree, ) +from xarray.core.datatree_render import RenderDataTree +from xarray.core.formatting import datatree_repr from xarray.core.formatting_html import ( datatree_repr as datatree_repr_html, ) @@ -40,13 +42,11 @@ ) from xarray.core.variable import Variable from xarray.datatree_.datatree.common import TreeAttrAccessMixin -from xarray.datatree_.datatree.formatting import datatree_repr from xarray.datatree_.datatree.ops import ( DataTreeArithmeticMixin, MappedDatasetMethodsMixin, MappedDataWithCoords, ) -from xarray.datatree_.datatree.render import RenderTree try: from xarray.core.variable import calculate_dimensions @@ -1451,7 +1451,7 @@ def pipe( def render(self): """Print tree structure, including any data stored at each node.""" - for pre, fill, node in RenderTree(self): + for pre, fill, node in RenderDataTree(self): print(f"{pre}DataTree('{self.name}')") for ds_line in repr(node.ds)[1:]: print(f"{fill}{ds_line}") diff --git a/xarray/core/datatree_mapping.py b/xarray/core/datatree_mapping.py index 714921d2a90..4da934f2085 100644 --- a/xarray/core/datatree_mapping.py +++ b/xarray/core/datatree_mapping.py @@ -3,11 +3,11 @@ import functools import sys from itertools import repeat -from textwrap import dedent from typing import TYPE_CHECKING, Callable -from xarray import DataArray, Dataset -from xarray.core.iterators import LevelOrderIter +from xarray.core.dataarray import DataArray +from xarray.core.dataset import Dataset +from xarray.core.formatting import diff_treestructure from xarray.core.treenode import NodePath, TreeNode if TYPE_CHECKING: @@ -71,37 +71,6 @@ def check_isomorphic( raise TreeIsomorphismError("DataTree objects are not isomorphic:\n" + diff) -def diff_treestructure(a: DataTree, b: DataTree, require_names_equal: bool) -> str: - """ - Return a summary of why two trees are not isomorphic. - If they are isomorphic return an empty string. - """ - - # Walking nodes in "level-order" fashion means walking down from the root breadth-first. - # Checking for isomorphism by walking in this way implicitly assumes that the tree is an ordered tree - # (which it is so long as children are stored in a tuple or list rather than in a set). - for node_a, node_b in zip(LevelOrderIter(a), LevelOrderIter(b)): - path_a, path_b = node_a.path, node_b.path - - if require_names_equal and node_a.name != node_b.name: - diff = dedent( - f"""\ - Node '{path_a}' in the left object has name '{node_a.name}' - Node '{path_b}' in the right object has name '{node_b.name}'""" - ) - return diff - - if len(node_a.children) != len(node_b.children): - diff = dedent( - f"""\ - Number of children on node '{path_a}' of the left object: {len(node_a.children)} - Number of children on node '{path_b}' of the right object: {len(node_b.children)}""" - ) - return diff - - return "" - - def map_over_subtree(func: Callable) -> Callable: """ Decorator which turns a function which acts on (and returns) Datasets into one which acts on and returns DataTrees. diff --git a/xarray/core/datatree_render.py b/xarray/core/datatree_render.py new file mode 100644 index 00000000000..d069071495e --- /dev/null +++ b/xarray/core/datatree_render.py @@ -0,0 +1,266 @@ +""" +String Tree Rendering. Copied from anytree. + +Minor changes to `RenderDataTree` include accessing `children.values()`, and +type hints. + +""" + +from __future__ import annotations + +from collections import namedtuple +from collections.abc import Iterable, Iterator +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from xarray.core.datatree import DataTree + +Row = namedtuple("Row", ("pre", "fill", "node")) + + +class AbstractStyle: + def __init__(self, vertical: str, cont: str, end: str): + """ + Tree Render Style. + Args: + vertical: Sign for vertical line. + cont: Chars for a continued branch. + end: Chars for the last branch. + """ + super().__init__() + self.vertical = vertical + self.cont = cont + self.end = end + assert ( + len(cont) == len(vertical) == len(end) + ), f"'{vertical}', '{cont}' and '{end}' need to have equal length" + + @property + def empty(self) -> str: + """Empty string as placeholder.""" + return " " * len(self.end) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}()" + + +class ContStyle(AbstractStyle): + def __init__(self): + """ + Continued style, without gaps. + + >>> from xarray.core.datatree import DataTree + >>> from xarray.core.datatree_render import RenderDataTree + >>> root = DataTree(name="root") + >>> s0 = DataTree(name="sub0", parent=root) + >>> s0b = DataTree(name="sub0B", parent=s0) + >>> s0a = DataTree(name="sub0A", parent=s0) + >>> s1 = DataTree(name="sub1", parent=root) + >>> print(RenderDataTree(root)) + DataTree('root', parent=None) + ├── DataTree('sub0') + │ ├── DataTree('sub0B') + │ └── DataTree('sub0A') + └── DataTree('sub1') + """ + super().__init__("\u2502 ", "\u251c\u2500\u2500 ", "\u2514\u2500\u2500 ") + + +class RenderDataTree: + def __init__( + self, + node: DataTree, + style=ContStyle(), + childiter: type = list, + maxlevel: int | None = None, + ): + """ + Render tree starting at `node`. + Keyword Args: + style (AbstractStyle): Render Style. + childiter: Child iterator. Note, due to the use of node.children.values(), + Iterables that change the order of children cannot be used + (e.g., `reversed`). + maxlevel: Limit rendering to this depth. + :any:`RenderDataTree` is an iterator, returning a tuple with 3 items: + `pre` + tree prefix. + `fill` + filling for multiline entries. + `node` + :any:`NodeMixin` object. + It is up to the user to assemble these parts to a whole. + + Examples + -------- + + >>> from xarray import Dataset + >>> from xarray.core.datatree import DataTree + >>> from xarray.core.datatree_render import RenderDataTree + >>> root = DataTree(name="root", data=Dataset({"a": 0, "b": 1})) + >>> s0 = DataTree(name="sub0", parent=root, data=Dataset({"c": 2, "d": 3})) + >>> s0b = DataTree(name="sub0B", parent=s0, data=Dataset({"e": 4})) + >>> s0a = DataTree(name="sub0A", parent=s0, data=Dataset({"f": 5, "g": 6})) + >>> s1 = DataTree(name="sub1", parent=root, data=Dataset({"h": 7})) + + # Simple one line: + + >>> for pre, _, node in RenderDataTree(root): + ... print(f"{pre}{node.name}") + ... + root + ├── sub0 + │ ├── sub0B + │ └── sub0A + └── sub1 + + # Multiline: + + >>> for pre, fill, node in RenderDataTree(root): + ... print(f"{pre}{node.name}") + ... for variable in node.variables: + ... print(f"{fill}{variable}") + ... + root + a + b + ├── sub0 + │ c + │ d + │ ├── sub0B + │ │ e + │ └── sub0A + │ f + │ g + └── sub1 + h + + :any:`by_attr` simplifies attribute rendering and supports multiline: + >>> print(RenderDataTree(root).by_attr()) + root + ├── sub0 + │ ├── sub0B + │ └── sub0A + └── sub1 + + # `maxlevel` limits the depth of the tree: + + >>> print(RenderDataTree(root, maxlevel=2).by_attr("name")) + root + ├── sub0 + └── sub1 + """ + if not isinstance(style, AbstractStyle): + style = style() + self.node = node + self.style = style + self.childiter = childiter + self.maxlevel = maxlevel + + def __iter__(self) -> Iterator[Row]: + return self.__next(self.node, tuple()) + + def __next( + self, node: DataTree, continues: tuple[bool, ...], level: int = 0 + ) -> Iterator[Row]: + yield RenderDataTree.__item(node, continues, self.style) + children = node.children.values() + level += 1 + if children and (self.maxlevel is None or level < self.maxlevel): + children = self.childiter(children) + for child, is_last in _is_last(children): + yield from self.__next(child, continues + (not is_last,), level=level) + + @staticmethod + def __item( + node: DataTree, continues: tuple[bool, ...], style: AbstractStyle + ) -> Row: + if not continues: + return Row("", "", node) + else: + items = [style.vertical if cont else style.empty for cont in continues] + indent = "".join(items[:-1]) + branch = style.cont if continues[-1] else style.end + pre = indent + branch + fill = "".join(items) + return Row(pre, fill, node) + + def __str__(self) -> str: + return str(self.node) + + def __repr__(self) -> str: + classname = self.__class__.__name__ + args = [ + repr(self.node), + f"style={repr(self.style)}", + f"childiter={repr(self.childiter)}", + ] + return f"{classname}({', '.join(args)})" + + def by_attr(self, attrname: str = "name") -> str: + """ + Return rendered tree with node attribute `attrname`. + + Examples + -------- + + >>> from xarray import Dataset + >>> from xarray.core.datatree import DataTree + >>> from xarray.core.datatree_render import RenderDataTree + >>> root = DataTree(name="root") + >>> s0 = DataTree(name="sub0", parent=root) + >>> s0b = DataTree( + ... name="sub0B", parent=s0, data=Dataset({"foo": 4, "bar": 109}) + ... ) + >>> s0a = DataTree(name="sub0A", parent=s0) + >>> s1 = DataTree(name="sub1", parent=root) + >>> s1a = DataTree(name="sub1A", parent=s1) + >>> s1b = DataTree(name="sub1B", parent=s1, data=Dataset({"bar": 8})) + >>> s1c = DataTree(name="sub1C", parent=s1) + >>> s1ca = DataTree(name="sub1Ca", parent=s1c) + >>> print(RenderDataTree(root).by_attr("name")) + root + ├── sub0 + │ ├── sub0B + │ └── sub0A + └── sub1 + ├── sub1A + ├── sub1B + └── sub1C + └── sub1Ca + """ + + def get() -> Iterator[str]: + for pre, fill, node in self: + attr = ( + attrname(node) + if callable(attrname) + else getattr(node, attrname, "") + ) + if isinstance(attr, (list, tuple)): + lines = attr + else: + lines = str(attr).split("\n") + yield f"{pre}{lines[0]}" + for line in lines[1:]: + yield f"{fill}{line}" + + return "\n".join(get()) + + +def _is_last(iterable: Iterable) -> Iterator[tuple[DataTree, bool]]: + iter_ = iter(iterable) + try: + nextitem = next(iter_) + except StopIteration: + pass + else: + item = nextitem + while True: + try: + nextitem = next(iter_) + yield item, False + except StopIteration: + yield nextitem, True + break + item = nextitem diff --git a/xarray/core/extensions.py b/xarray/core/extensions.py index efe00718a79..9ebbd564f4f 100644 --- a/xarray/core/extensions.py +++ b/xarray/core/extensions.py @@ -4,6 +4,7 @@ from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset +from xarray.core.datatree import DataTree class AccessorRegistrationWarning(Warning): @@ -121,3 +122,20 @@ def register_dataset_accessor(name): register_dataarray_accessor """ return _register_accessor(name, Dataset) + + +def register_datatree_accessor(name): + """Register a custom accessor on DataTree objects. + + Parameters + ---------- + name : str + Name under which the accessor should be registered. A warning is issued + if this name conflicts with a preexisting attribute. + + See Also + -------- + xarray.register_dataarray_accessor + xarray.register_dataset_accessor + """ + return _register_accessor(name, DataTree) diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py index 3eed7d02a2e..ad65a44d7d5 100644 --- a/xarray/core/formatting.py +++ b/xarray/core/formatting.py @@ -11,20 +11,24 @@ from datetime import datetime, timedelta from itertools import chain, zip_longest from reprlib import recursive_repr +from textwrap import dedent from typing import TYPE_CHECKING import numpy as np import pandas as pd from pandas.errors import OutOfBoundsDatetime +from xarray.core.datatree_render import RenderDataTree from xarray.core.duck_array_ops import array_equiv, astype from xarray.core.indexing import MemoryCachedArray +from xarray.core.iterators import LevelOrderIter from xarray.core.options import OPTIONS, _get_boolean_with_default from xarray.core.utils import is_duck_array from xarray.namedarray.pycompat import array_type, to_duck_array, to_numpy if TYPE_CHECKING: from xarray.core.coordinates import AbstractCoordinates + from xarray.core.datatree import DataTree UNITS = ("B", "kB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB") @@ -926,6 +930,37 @@ def diff_array_repr(a, b, compat): return "\n".join(summary) +def diff_treestructure(a: DataTree, b: DataTree, require_names_equal: bool) -> str: + """ + Return a summary of why two trees are not isomorphic. + If they are isomorphic return an empty string. + """ + + # Walking nodes in "level-order" fashion means walking down from the root breadth-first. + # Checking for isomorphism by walking in this way implicitly assumes that the tree is an ordered tree + # (which it is so long as children are stored in a tuple or list rather than in a set). + for node_a, node_b in zip(LevelOrderIter(a), LevelOrderIter(b)): + path_a, path_b = node_a.path, node_b.path + + if require_names_equal and node_a.name != node_b.name: + diff = dedent( + f"""\ + Node '{path_a}' in the left object has name '{node_a.name}' + Node '{path_b}' in the right object has name '{node_b.name}'""" + ) + return diff + + if len(node_a.children) != len(node_b.children): + diff = dedent( + f"""\ + Number of children on node '{path_a}' of the left object: {len(node_a.children)} + Number of children on node '{path_b}' of the right object: {len(node_b.children)}""" + ) + return diff + + return "" + + def diff_dataset_repr(a, b, compat): summary = [ f"Left and right {type(a).__name__} objects are not {_compat_to_str(compat)}" @@ -945,6 +980,86 @@ def diff_dataset_repr(a, b, compat): return "\n".join(summary) +def diff_nodewise_summary(a: DataTree, b: DataTree, compat): + """Iterates over all corresponding nodes, recording differences between data at each location.""" + + compat_str = _compat_to_str(compat) + + summary = [] + for node_a, node_b in zip(a.subtree, b.subtree): + a_ds, b_ds = node_a.ds, node_b.ds + + if not a_ds._all_compat(b_ds, compat): + dataset_diff = diff_dataset_repr(a_ds, b_ds, compat_str) + data_diff = "\n".join(dataset_diff.split("\n", 1)[1:]) + + nodediff = ( + f"\nData in nodes at position '{node_a.path}' do not match:" + f"{data_diff}" + ) + summary.append(nodediff) + + return "\n".join(summary) + + +def diff_datatree_repr(a: DataTree, b: DataTree, compat): + summary = [ + f"Left and right {type(a).__name__} objects are not {_compat_to_str(compat)}" + ] + + strict_names = True if compat in ["equals", "identical"] else False + treestructure_diff = diff_treestructure(a, b, strict_names) + + # If the trees structures are different there is no point comparing each node + # TODO we could show any differences in nodes up to the first place that structure differs? + if treestructure_diff or compat == "isomorphic": + summary.append("\n" + treestructure_diff) + else: + nodewise_diff = diff_nodewise_summary(a, b, compat) + summary.append("\n" + nodewise_diff) + + return "\n".join(summary) + + +def _single_node_repr(node: DataTree) -> str: + """Information about this node, not including its relationships to other nodes.""" + node_info = f"DataTree('{node.name}')" + + if node.has_data or node.has_attrs: + ds_info = "\n" + repr(node.ds) + else: + ds_info = "" + return node_info + ds_info + + +def datatree_repr(dt: DataTree): + """A printable representation of the structure of this entire tree.""" + renderer = RenderDataTree(dt) + + lines = [] + for pre, fill, node in renderer: + node_repr = _single_node_repr(node) + + node_line = f"{pre}{node_repr.splitlines()[0]}" + lines.append(node_line) + + if node.has_data or node.has_attrs: + ds_repr = node_repr.splitlines()[2:] + for line in ds_repr: + if len(node.children) > 0: + lines.append(f"{fill}{renderer.style.vertical}{line}") + else: + lines.append(f"{fill}{' ' * len(renderer.style.vertical)}{line}") + + # Tack on info about whether or not root node has a parent at the start + first_line = lines[0] + parent = f'"{dt.parent.name}"' if dt.parent is not None else "None" + first_line_with_parent = first_line[:-1] + f", parent={parent})" + lines[0] = first_line_with_parent + + return "\n".join(lines) + + def shorten_list_repr(items: Sequence, max_items: int) -> str: if len(items) <= max_items: return repr(items) diff --git a/xarray/datatree_/datatree/extensions.py b/xarray/datatree_/datatree/extensions.py deleted file mode 100644 index bf888fc4484..00000000000 --- a/xarray/datatree_/datatree/extensions.py +++ /dev/null @@ -1,20 +0,0 @@ -from xarray.core.extensions import _register_accessor - -from xarray.core.datatree import DataTree - - -def register_datatree_accessor(name): - """Register a custom accessor on DataTree objects. - - Parameters - ---------- - name : str - Name under which the accessor should be registered. A warning is issued - if this name conflicts with a preexisting attribute. - - See Also - -------- - xarray.register_dataarray_accessor - xarray.register_dataset_accessor - """ - return _register_accessor(name, DataTree) diff --git a/xarray/datatree_/datatree/formatting.py b/xarray/datatree_/datatree/formatting.py deleted file mode 100644 index fdd23933ae6..00000000000 --- a/xarray/datatree_/datatree/formatting.py +++ /dev/null @@ -1,91 +0,0 @@ -from typing import TYPE_CHECKING - -from xarray.core.formatting import _compat_to_str, diff_dataset_repr - -from xarray.core.datatree_mapping import diff_treestructure -from xarray.datatree_.datatree.render import RenderTree - -if TYPE_CHECKING: - from xarray.core.datatree import DataTree - - -def diff_nodewise_summary(a, b, compat): - """Iterates over all corresponding nodes, recording differences between data at each location.""" - - compat_str = _compat_to_str(compat) - - summary = [] - for node_a, node_b in zip(a.subtree, b.subtree): - a_ds, b_ds = node_a.ds, node_b.ds - - if not a_ds._all_compat(b_ds, compat): - dataset_diff = diff_dataset_repr(a_ds, b_ds, compat_str) - data_diff = "\n".join(dataset_diff.split("\n", 1)[1:]) - - nodediff = ( - f"\nData in nodes at position '{node_a.path}' do not match:" - f"{data_diff}" - ) - summary.append(nodediff) - - return "\n".join(summary) - - -def diff_tree_repr(a, b, compat): - summary = [ - f"Left and right {type(a).__name__} objects are not {_compat_to_str(compat)}" - ] - - # TODO check root parents? - - strict_names = True if compat in ["equals", "identical"] else False - treestructure_diff = diff_treestructure(a, b, strict_names) - - # If the trees structures are different there is no point comparing each node - # TODO we could show any differences in nodes up to the first place that structure differs? - if treestructure_diff or compat == "isomorphic": - summary.append("\n" + treestructure_diff) - else: - nodewise_diff = diff_nodewise_summary(a, b, compat) - summary.append("\n" + nodewise_diff) - - return "\n".join(summary) - - -def datatree_repr(dt): - """A printable representation of the structure of this entire tree.""" - renderer = RenderTree(dt) - - lines = [] - for pre, fill, node in renderer: - node_repr = _single_node_repr(node) - - node_line = f"{pre}{node_repr.splitlines()[0]}" - lines.append(node_line) - - if node.has_data or node.has_attrs: - ds_repr = node_repr.splitlines()[2:] - for line in ds_repr: - if len(node.children) > 0: - lines.append(f"{fill}{renderer.style.vertical}{line}") - else: - lines.append(f"{fill}{' ' * len(renderer.style.vertical)}{line}") - - # Tack on info about whether or not root node has a parent at the start - first_line = lines[0] - parent = f'"{dt.parent.name}"' if dt.parent is not None else "None" - first_line_with_parent = first_line[:-1] + f", parent={parent})" - lines[0] = first_line_with_parent - - return "\n".join(lines) - - -def _single_node_repr(node: "DataTree") -> str: - """Information about this node, not including its relationships to other nodes.""" - node_info = f"DataTree('{node.name}')" - - if node.has_data or node.has_attrs: - ds_info = "\n" + repr(node.ds) - else: - ds_info = "" - return node_info + ds_info diff --git a/xarray/datatree_/datatree/io.py b/xarray/datatree_/datatree/io.py index 48335ddca70..6c8e9617da3 100644 --- a/xarray/datatree_/datatree/io.py +++ b/xarray/datatree_/datatree/io.py @@ -3,14 +3,14 @@ def _get_nc_dataset_class(engine): if engine == "netcdf4": - from netCDF4 import Dataset # type: ignore + from netCDF4 import Dataset elif engine == "h5netcdf": - from h5netcdf.legacyapi import Dataset # type: ignore + from h5netcdf.legacyapi import Dataset elif engine is None: try: from netCDF4 import Dataset except ImportError: - from h5netcdf.legacyapi import Dataset # type: ignore + from h5netcdf.legacyapi import Dataset else: raise ValueError(f"unsupported engine: {engine}") return Dataset @@ -78,7 +78,7 @@ def _datatree_to_netcdf( def _create_empty_zarr_group(store, group, mode): - import zarr # type: ignore + import zarr root = zarr.open_group(store, mode=mode) root.create_group(group, overwrite=True) @@ -92,7 +92,7 @@ def _datatree_to_zarr( consolidated: bool = True, **kwargs, ): - from zarr.convenience import consolidate_metadata # type: ignore + from zarr.convenience import consolidate_metadata if kwargs.get("group", None) is not None: raise NotImplementedError( diff --git a/xarray/datatree_/datatree/ops.py b/xarray/datatree_/datatree/ops.py index 83b9d1b275a..1ca8a7c1e01 100644 --- a/xarray/datatree_/datatree/ops.py +++ b/xarray/datatree_/datatree/ops.py @@ -1,6 +1,6 @@ import textwrap -from xarray import Dataset +from xarray.core.dataset import Dataset from xarray.core.datatree_mapping import map_over_subtree diff --git a/xarray/datatree_/datatree/render.py b/xarray/datatree_/datatree/render.py deleted file mode 100644 index e6af9c85ee8..00000000000 --- a/xarray/datatree_/datatree/render.py +++ /dev/null @@ -1,271 +0,0 @@ -""" -String Tree Rendering. Copied from anytree. -""" - -import collections -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from xarray.core.datatree import DataTree - -Row = collections.namedtuple("Row", ("pre", "fill", "node")) - - -class AbstractStyle(object): - def __init__(self, vertical, cont, end): - """ - Tree Render Style. - Args: - vertical: Sign for vertical line. - cont: Chars for a continued branch. - end: Chars for the last branch. - """ - super(AbstractStyle, self).__init__() - self.vertical = vertical - self.cont = cont - self.end = end - assert ( - len(cont) == len(vertical) == len(end) - ), f"'{vertical}', '{cont}' and '{end}' need to have equal length" - - @property - def empty(self): - """Empty string as placeholder.""" - return " " * len(self.end) - - def __repr__(self): - return f"{self.__class__.__name__}()" - - -class ContStyle(AbstractStyle): - def __init__(self): - """ - Continued style, without gaps. - - >>> from anytree import Node, RenderTree - >>> root = Node("root") - >>> s0 = Node("sub0", parent=root) - >>> s0b = Node("sub0B", parent=s0) - >>> s0a = Node("sub0A", parent=s0) - >>> s1 = Node("sub1", parent=root) - >>> print(RenderTree(root, style=ContStyle())) - - Node('/root') - ├── Node('/root/sub0') - │ ├── Node('/root/sub0/sub0B') - │ └── Node('/root/sub0/sub0A') - └── Node('/root/sub1') - """ - super(ContStyle, self).__init__( - "\u2502 ", "\u251c\u2500\u2500 ", "\u2514\u2500\u2500 " - ) - - -class RenderTree(object): - def __init__( - self, node: "DataTree", style=ContStyle(), childiter=list, maxlevel=None - ): - """ - Render tree starting at `node`. - Keyword Args: - style (AbstractStyle): Render Style. - childiter: Child iterator. - maxlevel: Limit rendering to this depth. - :any:`RenderTree` is an iterator, returning a tuple with 3 items: - `pre` - tree prefix. - `fill` - filling for multiline entries. - `node` - :any:`NodeMixin` object. - It is up to the user to assemble these parts to a whole. - >>> from anytree import Node, RenderTree - >>> root = Node("root", lines=["c0fe", "c0de"]) - >>> s0 = Node("sub0", parent=root, lines=["ha", "ba"]) - >>> s0b = Node("sub0B", parent=s0, lines=["1", "2", "3"]) - >>> s0a = Node("sub0A", parent=s0, lines=["a", "b"]) - >>> s1 = Node("sub1", parent=root, lines=["Z"]) - Simple one line: - >>> for pre, _, node in RenderTree(root): - ... print("%s%s" % (pre, node.name)) - ... - root - ├── sub0 - │ ├── sub0B - │ └── sub0A - └── sub1 - Multiline: - >>> for pre, fill, node in RenderTree(root): - ... print("%s%s" % (pre, node.lines[0])) - ... for line in node.lines[1:]: - ... print("%s%s" % (fill, line)) - ... - c0fe - c0de - ├── ha - │ ba - │ ├── 1 - │ │ 2 - │ │ 3 - │ └── a - │ b - └── Z - `maxlevel` limits the depth of the tree: - >>> print(RenderTree(root, maxlevel=2)) - Node('/root', lines=['c0fe', 'c0de']) - ├── Node('/root/sub0', lines=['ha', 'ba']) - └── Node('/root/sub1', lines=['Z']) - The `childiter` is responsible for iterating over child nodes at the - same level. An reversed order can be achived by using `reversed`. - >>> for row in RenderTree(root, childiter=reversed): - ... print("%s%s" % (row.pre, row.node.name)) - ... - root - ├── sub1 - └── sub0 - ├── sub0A - └── sub0B - Or writing your own sort function: - >>> def mysort(items): - ... return sorted(items, key=lambda item: item.name) - ... - >>> for row in RenderTree(root, childiter=mysort): - ... print("%s%s" % (row.pre, row.node.name)) - ... - root - ├── sub0 - │ ├── sub0A - │ └── sub0B - └── sub1 - :any:`by_attr` simplifies attribute rendering and supports multiline: - >>> print(RenderTree(root).by_attr()) - root - ├── sub0 - │ ├── sub0B - │ └── sub0A - └── sub1 - >>> print(RenderTree(root).by_attr("lines")) - c0fe - c0de - ├── ha - │ ba - │ ├── 1 - │ │ 2 - │ │ 3 - │ └── a - │ b - └── Z - And can be a function: - >>> print(RenderTree(root).by_attr(lambda n: " ".join(n.lines))) - c0fe c0de - ├── ha ba - │ ├── 1 2 3 - │ └── a b - └── Z - """ - if not isinstance(style, AbstractStyle): - style = style() - self.node = node - self.style = style - self.childiter = childiter - self.maxlevel = maxlevel - - def __iter__(self): - return self.__next(self.node, tuple()) - - def __next(self, node, continues, level=0): - yield RenderTree.__item(node, continues, self.style) - children = node.children.values() - level += 1 - if children and (self.maxlevel is None or level < self.maxlevel): - children = self.childiter(children) - for child, is_last in _is_last(children): - for grandchild in self.__next( - child, continues + (not is_last,), level=level - ): - yield grandchild - - @staticmethod - def __item(node, continues, style): - if not continues: - return Row("", "", node) - else: - items = [style.vertical if cont else style.empty for cont in continues] - indent = "".join(items[:-1]) - branch = style.cont if continues[-1] else style.end - pre = indent + branch - fill = "".join(items) - return Row(pre, fill, node) - - def __str__(self): - lines = ["%s%r" % (pre, node) for pre, _, node in self] - return "\n".join(lines) - - def __repr__(self): - classname = self.__class__.__name__ - args = [ - repr(self.node), - "style=%s" % repr(self.style), - "childiter=%s" % repr(self.childiter), - ] - return "%s(%s)" % (classname, ", ".join(args)) - - def by_attr(self, attrname="name"): - """ - Return rendered tree with node attribute `attrname`. - >>> from anytree import AnyNode, RenderTree - >>> root = AnyNode(id="root") - >>> s0 = AnyNode(id="sub0", parent=root) - >>> s0b = AnyNode(id="sub0B", parent=s0, foo=4, bar=109) - >>> s0a = AnyNode(id="sub0A", parent=s0) - >>> s1 = AnyNode(id="sub1", parent=root) - >>> s1a = AnyNode(id="sub1A", parent=s1) - >>> s1b = AnyNode(id="sub1B", parent=s1, bar=8) - >>> s1c = AnyNode(id="sub1C", parent=s1) - >>> s1ca = AnyNode(id="sub1Ca", parent=s1c) - >>> print(RenderTree(root).by_attr("id")) - root - ├── sub0 - │ ├── sub0B - │ └── sub0A - └── sub1 - ├── sub1A - ├── sub1B - └── sub1C - └── sub1Ca - """ - - def get(): - for pre, fill, node in self: - attr = ( - attrname(node) - if callable(attrname) - else getattr(node, attrname, "") - ) - if isinstance(attr, (list, tuple)): - lines = attr - else: - lines = str(attr).split("\n") - yield "%s%s" % (pre, lines[0]) - for line in lines[1:]: - yield "%s%s" % (fill, line) - - return "\n".join(get()) - - -def _is_last(iterable): - iter_ = iter(iterable) - try: - nextitem = next(iter_) - except StopIteration: - pass - else: - item = nextitem - while True: - try: - nextitem = next(iter_) - yield item, False - except StopIteration: - yield nextitem, True - break - item = nextitem diff --git a/xarray/datatree_/datatree/testing.py b/xarray/datatree_/datatree/testing.py deleted file mode 100644 index bf54116725a..00000000000 --- a/xarray/datatree_/datatree/testing.py +++ /dev/null @@ -1,120 +0,0 @@ -from xarray.testing.assertions import ensure_warnings - -from xarray.core.datatree import DataTree -from .formatting import diff_tree_repr - - -@ensure_warnings -def assert_isomorphic(a: DataTree, b: DataTree, from_root: bool = False): - """ - Two DataTrees are considered isomorphic if every node has the same number of children. - - Nothing about the data in each node is checked. - - Isomorphism is a necessary condition for two trees to be used in a nodewise binary operation, - such as tree1 + tree2. - - By default this function does not check any part of the tree above the given node. - Therefore this function can be used as default to check that two subtrees are isomorphic. - - Parameters - ---------- - a : DataTree - The first object to compare. - b : DataTree - The second object to compare. - from_root : bool, optional, default is False - Whether or not to first traverse to the root of the trees before checking for isomorphism. - If a & b have no parents then this has no effect. - - See Also - -------- - DataTree.isomorphic - assert_equals - assert_identical - """ - __tracebackhide__ = True - assert isinstance(a, type(b)) - - if isinstance(a, DataTree): - if from_root: - a = a.root - b = b.root - - assert a.isomorphic(b, from_root=from_root), diff_tree_repr(a, b, "isomorphic") - else: - raise TypeError(f"{type(a)} not of type DataTree") - - -@ensure_warnings -def assert_equal(a: DataTree, b: DataTree, from_root: bool = True): - """ - Two DataTrees are equal if they have isomorphic node structures, with matching node names, - and if they have matching variables and coordinates, all of which are equal. - - By default this method will check the whole tree above the given node. - - Parameters - ---------- - a : DataTree - The first object to compare. - b : DataTree - The second object to compare. - from_root : bool, optional, default is True - Whether or not to first traverse to the root of the trees before checking for isomorphism. - If a & b have no parents then this has no effect. - - See Also - -------- - DataTree.equals - assert_isomorphic - assert_identical - """ - __tracebackhide__ = True - assert isinstance(a, type(b)) - - if isinstance(a, DataTree): - if from_root: - a = a.root - b = b.root - - assert a.equals(b, from_root=from_root), diff_tree_repr(a, b, "equals") - else: - raise TypeError(f"{type(a)} not of type DataTree") - - -@ensure_warnings -def assert_identical(a: DataTree, b: DataTree, from_root: bool = True): - """ - Like assert_equals, but will also check all dataset attributes and the attributes on - all variables and coordinates. - - By default this method will check the whole tree above the given node. - - Parameters - ---------- - a : xarray.DataTree - The first object to compare. - b : xarray.DataTree - The second object to compare. - from_root : bool, optional, default is True - Whether or not to first traverse to the root of the trees before checking for isomorphism. - If a & b have no parents then this has no effect. - - See Also - -------- - DataTree.identical - assert_isomorphic - assert_equal - """ - - __tracebackhide__ = True - assert isinstance(a, type(b)) - if isinstance(a, DataTree): - if from_root: - a = a.root - b = b.root - - assert a.identical(b, from_root=from_root), diff_tree_repr(a, b, "identical") - else: - raise TypeError(f"{type(a)} not of type DataTree") diff --git a/xarray/datatree_/datatree/tests/__init__.py b/xarray/datatree_/datatree/tests/__init__.py deleted file mode 100644 index 64961158b13..00000000000 --- a/xarray/datatree_/datatree/tests/__init__.py +++ /dev/null @@ -1,29 +0,0 @@ -import importlib - -import pytest -from packaging import version - - -def _importorskip(modname, minversion=None): - try: - mod = importlib.import_module(modname) - has = True - if minversion is not None: - if LooseVersion(mod.__version__) < LooseVersion(minversion): - raise ImportError("Minimum version not satisfied") - except ImportError: - has = False - func = pytest.mark.skipif(not has, reason=f"requires {modname}") - return has, func - - -def LooseVersion(vstring): - # Our development version is something like '0.10.9+aac7bfc' - # This function just ignores the git commit id. - vstring = vstring.split("+")[0] - return version.parse(vstring) - - -has_zarr, requires_zarr = _importorskip("zarr") -has_h5netcdf, requires_h5netcdf = _importorskip("h5netcdf") -has_netCDF4, requires_netCDF4 = _importorskip("netCDF4") diff --git a/xarray/datatree_/datatree/tests/conftest.py b/xarray/datatree_/datatree/tests/conftest.py deleted file mode 100644 index 53a9a72239d..00000000000 --- a/xarray/datatree_/datatree/tests/conftest.py +++ /dev/null @@ -1,65 +0,0 @@ -import pytest -import xarray as xr - -from xarray.core.datatree import DataTree - - -@pytest.fixture(scope="module") -def create_test_datatree(): - """ - Create a test datatree with this structure: - - - |-- set1 - | |-- - | | Dimensions: () - | | Data variables: - | | a int64 0 - | | b int64 1 - | |-- set1 - | |-- set2 - |-- set2 - | |-- - | | Dimensions: (x: 2) - | | Data variables: - | | a (x) int64 2, 3 - | | b (x) int64 0.1, 0.2 - | |-- set1 - |-- set3 - |-- - | Dimensions: (x: 2, y: 3) - | Data variables: - | a (y) int64 6, 7, 8 - | set0 (x) int64 9, 10 - - The structure has deliberately repeated names of tags, variables, and - dimensions in order to better check for bugs caused by name conflicts. - """ - - def _create_test_datatree(modify=lambda ds: ds): - set1_data = modify(xr.Dataset({"a": 0, "b": 1})) - set2_data = modify(xr.Dataset({"a": ("x", [2, 3]), "b": ("x", [0.1, 0.2])})) - root_data = modify(xr.Dataset({"a": ("y", [6, 7, 8]), "set0": ("x", [9, 10])})) - - # Avoid using __init__ so we can independently test it - root = DataTree(data=root_data) - set1 = DataTree(name="set1", parent=root, data=set1_data) - DataTree(name="set1", parent=set1) - DataTree(name="set2", parent=set1) - set2 = DataTree(name="set2", parent=root, data=set2_data) - DataTree(name="set1", parent=set2) - DataTree(name="set3", parent=root) - - return root - - return _create_test_datatree - - -@pytest.fixture(scope="module") -def simple_datatree(create_test_datatree): - """ - Invoke create_test_datatree fixture (callback). - - Returns a DataTree. - """ - return create_test_datatree() diff --git a/xarray/datatree_/datatree/tests/test_dataset_api.py b/xarray/datatree_/datatree/tests/test_dataset_api.py deleted file mode 100644 index 4ca532ebba4..00000000000 --- a/xarray/datatree_/datatree/tests/test_dataset_api.py +++ /dev/null @@ -1,98 +0,0 @@ -import numpy as np -import xarray as xr - -from xarray.core.datatree import DataTree -from xarray.datatree_.datatree.testing import assert_equal - - -class TestDSMethodInheritance: - def test_dataset_method(self): - ds = xr.Dataset({"a": ("x", [1, 2, 3])}) - dt = DataTree(data=ds) - DataTree(name="results", parent=dt, data=ds) - - expected = DataTree(data=ds.isel(x=1)) - DataTree(name="results", parent=expected, data=ds.isel(x=1)) - - result = dt.isel(x=1) - assert_equal(result, expected) - - def test_reduce_method(self): - ds = xr.Dataset({"a": ("x", [False, True, False])}) - dt = DataTree(data=ds) - DataTree(name="results", parent=dt, data=ds) - - expected = DataTree(data=ds.any()) - DataTree(name="results", parent=expected, data=ds.any()) - - result = dt.any() - assert_equal(result, expected) - - def test_nan_reduce_method(self): - ds = xr.Dataset({"a": ("x", [1, 2, 3])}) - dt = DataTree(data=ds) - DataTree(name="results", parent=dt, data=ds) - - expected = DataTree(data=ds.mean()) - DataTree(name="results", parent=expected, data=ds.mean()) - - result = dt.mean() - assert_equal(result, expected) - - def test_cum_method(self): - ds = xr.Dataset({"a": ("x", [1, 2, 3])}) - dt = DataTree(data=ds) - DataTree(name="results", parent=dt, data=ds) - - expected = DataTree(data=ds.cumsum()) - DataTree(name="results", parent=expected, data=ds.cumsum()) - - result = dt.cumsum() - assert_equal(result, expected) - - -class TestOps: - def test_binary_op_on_int(self): - ds1 = xr.Dataset({"a": [5], "b": [3]}) - ds2 = xr.Dataset({"x": [0.1, 0.2], "y": [10, 20]}) - dt = DataTree(data=ds1) - DataTree(name="subnode", data=ds2, parent=dt) - - expected = DataTree(data=ds1 * 5) - DataTree(name="subnode", data=ds2 * 5, parent=expected) - - result = dt * 5 - assert_equal(result, expected) - - def test_binary_op_on_dataset(self): - ds1 = xr.Dataset({"a": [5], "b": [3]}) - ds2 = xr.Dataset({"x": [0.1, 0.2], "y": [10, 20]}) - dt = DataTree(data=ds1) - DataTree(name="subnode", data=ds2, parent=dt) - other_ds = xr.Dataset({"z": ("z", [0.1, 0.2])}) - - expected = DataTree(data=ds1 * other_ds) - DataTree(name="subnode", data=ds2 * other_ds, parent=expected) - - result = dt * other_ds - assert_equal(result, expected) - - def test_binary_op_on_datatree(self): - ds1 = xr.Dataset({"a": [5], "b": [3]}) - ds2 = xr.Dataset({"x": [0.1, 0.2], "y": [10, 20]}) - dt = DataTree(data=ds1) - DataTree(name="subnode", data=ds2, parent=dt) - - expected = DataTree(data=ds1 * ds1) - DataTree(name="subnode", data=ds2 * ds2, parent=expected) - - result = dt * dt - assert_equal(result, expected) - - -class TestUFuncs: - def test_tree(self, create_test_datatree): - dt = create_test_datatree() - expected = create_test_datatree(modify=lambda ds: np.sin(ds)) - result_tree = np.sin(dt) - assert_equal(result_tree, expected) diff --git a/xarray/datatree_/datatree/tests/test_extensions.py b/xarray/datatree_/datatree/tests/test_extensions.py deleted file mode 100644 index fb2e82453ec..00000000000 --- a/xarray/datatree_/datatree/tests/test_extensions.py +++ /dev/null @@ -1,41 +0,0 @@ -import pytest - -from xarray.core.datatree import DataTree -from xarray.datatree_.datatree.extensions import register_datatree_accessor - - -class TestAccessor: - def test_register(self) -> None: - @register_datatree_accessor("demo") - class DemoAccessor: - """Demo accessor.""" - - def __init__(self, xarray_obj): - self._obj = xarray_obj - - @property - def foo(self): - return "bar" - - dt: DataTree = DataTree() - assert dt.demo.foo == "bar" # type: ignore - - # accessor is cached - assert dt.demo is dt.demo # type: ignore - - # check descriptor - assert dt.demo.__doc__ == "Demo accessor." # type: ignore - # TODO: typing doesn't seem to work with accessors - assert DataTree.demo.__doc__ == "Demo accessor." # type: ignore - assert isinstance(dt.demo, DemoAccessor) # type: ignore - assert DataTree.demo is DemoAccessor # type: ignore - - with pytest.warns(Warning, match="overriding a preexisting attribute"): - - @register_datatree_accessor("demo") - class Foo: - pass - - # ensure we can remove it - del DataTree.demo # type: ignore - assert not hasattr(DataTree, "demo") diff --git a/xarray/datatree_/datatree/tests/test_formatting.py b/xarray/datatree_/datatree/tests/test_formatting.py deleted file mode 100644 index 77f8346ae72..00000000000 --- a/xarray/datatree_/datatree/tests/test_formatting.py +++ /dev/null @@ -1,120 +0,0 @@ -from textwrap import dedent - -from xarray import Dataset - -from xarray.core.datatree import DataTree -from xarray.datatree_.datatree.formatting import diff_tree_repr - - -class TestRepr: - def test_print_empty_node(self): - dt = DataTree(name="root") - printout = dt.__str__() - assert printout == "DataTree('root', parent=None)" - - def test_print_empty_node_with_attrs(self): - dat = Dataset(attrs={"note": "has attrs"}) - dt = DataTree(name="root", data=dat) - printout = dt.__str__() - assert printout == dedent( - """\ - DataTree('root', parent=None) - Dimensions: () - Data variables: - *empty* - Attributes: - note: has attrs""" - ) - - def test_print_node_with_data(self): - dat = Dataset({"a": [0, 2]}) - dt = DataTree(name="root", data=dat) - printout = dt.__str__() - expected = [ - "DataTree('root', parent=None)", - "Dimensions", - "Coordinates", - "a", - "Data variables", - "*empty*", - ] - for expected_line, printed_line in zip(expected, printout.splitlines()): - assert expected_line in printed_line - - def test_nested_node(self): - dat = Dataset({"a": [0, 2]}) - root = DataTree(name="root") - DataTree(name="results", data=dat, parent=root) - printout = root.__str__() - assert printout.splitlines()[2].startswith(" ") - - def test_print_datatree(self, simple_datatree): - dt = simple_datatree - print(dt) - - # TODO work out how to test something complex like this - - def test_repr_of_node_with_data(self): - dat = Dataset({"a": [0, 2]}) - dt = DataTree(name="root", data=dat) - assert "Coordinates" in repr(dt) - - -class TestDiffFormatting: - def test_diff_structure(self): - dt_1 = DataTree.from_dict({"a": None, "a/b": None, "a/c": None}) - dt_2 = DataTree.from_dict({"d": None, "d/e": None}) - - expected = dedent( - """\ - Left and right DataTree objects are not isomorphic - - Number of children on node '/a' of the left object: 2 - Number of children on node '/d' of the right object: 1""" - ) - actual = diff_tree_repr(dt_1, dt_2, "isomorphic") - assert actual == expected - - def test_diff_node_names(self): - dt_1 = DataTree.from_dict({"a": None}) - dt_2 = DataTree.from_dict({"b": None}) - - expected = dedent( - """\ - Left and right DataTree objects are not identical - - Node '/a' in the left object has name 'a' - Node '/b' in the right object has name 'b'""" - ) - actual = diff_tree_repr(dt_1, dt_2, "identical") - assert actual == expected - - def test_diff_node_data(self): - import numpy as np - - # casting to int64 explicitly ensures that int64s are created on all architectures - ds1 = Dataset({"u": np.int64(0), "v": np.int64(1)}) - ds3 = Dataset({"w": np.int64(5)}) - dt_1 = DataTree.from_dict({"a": ds1, "a/b": ds3}) - ds2 = Dataset({"u": np.int64(0)}) - ds4 = Dataset({"w": np.int64(6)}) - dt_2 = DataTree.from_dict({"a": ds2, "a/b": ds4}) - - expected = dedent( - """\ - Left and right DataTree objects are not equal - - - Data in nodes at position '/a' do not match: - - Data variables only on the left object: - v int64 8B 1 - - Data in nodes at position '/a/b' do not match: - - Differing data variables: - L w int64 8B 5 - R w int64 8B 6""" - ) - actual = diff_tree_repr(dt_1, dt_2, "equals") - assert actual == expected diff --git a/xarray/datatree_/docs/source/conf.py b/xarray/datatree_/docs/source/conf.py index 8a9224def5b..430dbb5bf6d 100644 --- a/xarray/datatree_/docs/source/conf.py +++ b/xarray/datatree_/docs/source/conf.py @@ -17,9 +17,9 @@ import os import sys -import sphinx_autosummary_accessors +import sphinx_autosummary_accessors # type: ignore -import datatree +import datatree # type: ignore # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the @@ -286,7 +286,7 @@ # -- Options for LaTeX output -------------------------------------------------- -latex_elements = { +latex_elements: dict = { # The paper size ('letterpaper' or 'a4paper'). # 'papersize': 'letterpaper', # The font size ('10pt', '11pt' or '12pt'). diff --git a/xarray/testing/__init__.py b/xarray/testing/__init__.py index ab2f8ba4357..316b0ea5252 100644 --- a/xarray/testing/__init__.py +++ b/xarray/testing/__init__.py @@ -1,3 +1,4 @@ +# TODO: Add assert_isomorphic when making DataTree API public from xarray.testing.assertions import ( # noqa: F401 _assert_dataarray_invariants, _assert_dataset_invariants, diff --git a/xarray/testing/assertions.py b/xarray/testing/assertions.py index 6418eb79b8b..018874c169e 100644 --- a/xarray/testing/assertions.py +++ b/xarray/testing/assertions.py @@ -3,7 +3,7 @@ import functools import warnings from collections.abc import Hashable -from typing import Union +from typing import Union, overload import numpy as np import pandas as pd @@ -12,6 +12,8 @@ from xarray.core.coordinates import Coordinates from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset +from xarray.core.datatree import DataTree +from xarray.core.formatting import diff_datatree_repr from xarray.core.indexes import Index, PandasIndex, PandasMultiIndex, default_indexes from xarray.core.variable import IndexVariable, Variable @@ -50,7 +52,59 @@ def _data_allclose_or_equiv(arr1, arr2, rtol=1e-05, atol=1e-08, decode_bytes=Tru @ensure_warnings -def assert_equal(a, b): +def assert_isomorphic(a: DataTree, b: DataTree, from_root: bool = False): + """ + Two DataTrees are considered isomorphic if every node has the same number of children. + + Nothing about the data or attrs in each node is checked. + + Isomorphism is a necessary condition for two trees to be used in a nodewise binary operation, + such as tree1 + tree2. + + By default this function does not check any part of the tree above the given node. + Therefore this function can be used as default to check that two subtrees are isomorphic. + + Parameters + ---------- + a : DataTree + The first object to compare. + b : DataTree + The second object to compare. + from_root : bool, optional, default is False + Whether or not to first traverse to the root of the trees before checking for isomorphism. + If a & b have no parents then this has no effect. + + See Also + -------- + DataTree.isomorphic + assert_equal + assert_identical + """ + __tracebackhide__ = True + assert isinstance(a, type(b)) + + if isinstance(a, DataTree): + if from_root: + a = a.root + b = b.root + + assert a.isomorphic(b, from_root=from_root), diff_datatree_repr( + a, b, "isomorphic" + ) + else: + raise TypeError(f"{type(a)} not of type DataTree") + + +@overload +def assert_equal(a, b): ... + + +@overload +def assert_equal(a: DataTree, b: DataTree, from_root: bool = True): ... + + +@ensure_warnings +def assert_equal(a, b, from_root=True): """Like :py:func:`numpy.testing.assert_array_equal`, but for xarray objects. @@ -59,12 +113,20 @@ def assert_equal(a, b): (except for Dataset objects for which the variable names must match). Arrays with NaN in the same location are considered equal. + For DataTree objects, assert_equal is mapped over all Datasets on each node, + with the DataTrees being equal if both are isomorphic and the corresponding + Datasets at each node are themselves equal. + Parameters ---------- - a : xarray.Dataset, xarray.DataArray, xarray.Variable or xarray.Coordinates - The first object to compare. - b : xarray.Dataset, xarray.DataArray, xarray.Variable or xarray.Coordinates - The second object to compare. + a : xarray.Dataset, xarray.DataArray, xarray.Variable, xarray.Coordinates + or xarray.core.datatree.DataTree. The first object to compare. + b : xarray.Dataset, xarray.DataArray, xarray.Variable, xarray.Coordinates + or xarray.core.datatree.DataTree. The second object to compare. + from_root : bool, optional, default is True + Only used when comparing DataTree objects. Indicates whether or not to + first traverse to the root of the trees before checking for isomorphism. + If a & b have no parents then this has no effect. See Also -------- @@ -81,23 +143,45 @@ def assert_equal(a, b): assert a.equals(b), formatting.diff_dataset_repr(a, b, "equals") elif isinstance(a, Coordinates): assert a.equals(b), formatting.diff_coords_repr(a, b, "equals") + elif isinstance(a, DataTree): + if from_root: + a = a.root + b = b.root + + assert a.equals(b, from_root=from_root), diff_datatree_repr(a, b, "equals") else: raise TypeError(f"{type(a)} not supported by assertion comparison") +@overload +def assert_identical(a, b): ... + + +@overload +def assert_identical(a: DataTree, b: DataTree, from_root: bool = True): ... + + @ensure_warnings -def assert_identical(a, b): +def assert_identical(a, b, from_root=True): """Like :py:func:`xarray.testing.assert_equal`, but also matches the objects' names and attributes. Raises an AssertionError if two objects are not identical. + For DataTree objects, assert_identical is mapped over all Datasets on each + node, with the DataTrees being identical if both are isomorphic and the + corresponding Datasets at each node are themselves identical. + Parameters ---------- a : xarray.Dataset, xarray.DataArray, xarray.Variable or xarray.Coordinates The first object to compare. b : xarray.Dataset, xarray.DataArray, xarray.Variable or xarray.Coordinates The second object to compare. + from_root : bool, optional, default is True + Only used when comparing DataTree objects. Indicates whether or not to + first traverse to the root of the trees before checking for isomorphism. + If a & b have no parents then this has no effect. See Also -------- @@ -116,6 +200,14 @@ def assert_identical(a, b): assert a.identical(b), formatting.diff_dataset_repr(a, b, "identical") elif isinstance(a, Coordinates): assert a.identical(b), formatting.diff_coords_repr(a, b, "identical") + elif isinstance(a, DataTree): + if from_root: + a = a.root + b = b.root + + assert a.identical(b, from_root=from_root), diff_datatree_repr( + a, b, "identical" + ) else: raise TypeError(f"{type(a)} not supported by assertion comparison") diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 281bd92d05c..26232471aaf 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -99,7 +99,6 @@ def _importorskip( ) has_h5netcdf, requires_h5netcdf = _importorskip("h5netcdf") -has_pynio, requires_pynio = _importorskip("Nio") has_cftime, requires_cftime = _importorskip("cftime") has_dask, requires_dask = _importorskip("dask") with warnings.catch_warnings(): @@ -147,6 +146,7 @@ def _importorskip( requires_pandas_version_two = pytest.mark.skipif( not has_pandas_version_two, reason="requires pandas 2.0.0" ) +has_numpy_array_api, requires_numpy_array_api = _importorskip("numpy", "1.26.0") has_h5netcdf_ros3, requires_h5netcdf_ros3 = _importorskip("h5netcdf", "1.3.0") has_netCDF4_1_6_2_or_above, requires_netCDF4_1_6_2_or_above = _importorskip( diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index be9b3ef0422..bfa26025fd8 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -75,7 +75,6 @@ requires_netCDF4, requires_netCDF4_1_6_2_or_above, requires_pydap, - requires_pynio, requires_scipy, requires_scipy_or_netCDF4, requires_zarr, @@ -3769,7 +3768,7 @@ def test_get_variable_list_empty_driver_kwds(self) -> None: assert "Temperature" in list(actual) -@pytest.fixture(params=["scipy", "netcdf4", "h5netcdf", "pynio", "zarr"]) +@pytest.fixture(params=["scipy", "netcdf4", "h5netcdf", "zarr"]) def readengine(request): return request.param @@ -3818,8 +3817,6 @@ def tmp_store(request, tmp_path): def skip_if_not_engine(engine): if engine == "netcdf4": pytest.importorskip("netCDF4") - elif engine == "pynio": - pytest.importorskip("Nio") else: pytest.importorskip(engine) @@ -3827,25 +3824,22 @@ def skip_if_not_engine(engine): @requires_dask @pytest.mark.filterwarnings("ignore:use make_scale(name) instead") @pytest.mark.xfail(reason="Flaky test. Very open to contributions on fixing this") +@pytest.mark.skipif(ON_WINDOWS, reason="Skipping on Windows") def test_open_mfdataset_manyfiles( readengine, nfiles, parallel, chunks, file_cache_maxsize ): # skip certain combinations skip_if_not_engine(readengine) - if ON_WINDOWS: - pytest.skip("Skipping on Windows") - randdata = np.random.randn(nfiles) original = Dataset({"foo": ("x", randdata)}) # test standard open_mfdataset approach with too many files with create_tmp_files(nfiles) as tmpfiles: - writeengine = readengine if readengine != "pynio" else "netcdf4" # split into multiple sets of temp files for ii in original.x.values: subds = original.isel(x=slice(ii, ii + 1)) - if writeengine != "zarr": - subds.to_netcdf(tmpfiles[ii], engine=writeengine) + if readengine != "zarr": + subds.to_netcdf(tmpfiles[ii], engine=readengine) else: # if writeengine == "zarr": subds.to_zarr(store=tmpfiles[ii]) @@ -4734,39 +4728,6 @@ def test_session(self) -> None: ) -@requires_scipy -@requires_pynio -class TestPyNio(CFEncodedBase, NetCDF3Only): - def test_write_store(self) -> None: - # pynio is read-only for now - pass - - @contextlib.contextmanager - def open(self, path, **kwargs): - with open_dataset(path, engine="pynio", **kwargs) as ds: - yield ds - - def test_kwargs(self) -> None: - kwargs = {"format": "grib"} - path = os.path.join(os.path.dirname(__file__), "data", "example") - with backends.NioDataStore(path, **kwargs) as store: - assert store._manager._kwargs["format"] == "grib" - - def save(self, dataset, path, **kwargs): - return dataset.to_netcdf(path, engine="scipy", **kwargs) - - def test_weakrefs(self) -> None: - example = Dataset({"foo": ("x", np.arange(5.0))}) - expected = example.rename({"foo": "bar", "x": "y"}) - - with create_tmp_file() as tmp_file: - example.to_netcdf(tmp_file, engine="scipy") - on_disk = open_dataset(tmp_file, engine="pynio") - actual = on_disk.rename({"foo": "bar", "x": "y"}) - del on_disk # trigger garbage collection - assert_identical(actual, expected) - - class TestEncodingInvalid: def test_extract_nc4_variable_encoding(self) -> None: var = xr.Variable(("x",), [1, 2, 3], {}, {"foo": "bar"}) diff --git a/xarray/tests/test_backends_datatree.py b/xarray/tests/test_backends_datatree.py index 7bdb2b532d9..4e819eec0b5 100644 --- a/xarray/tests/test_backends_datatree.py +++ b/xarray/tests/test_backends_datatree.py @@ -5,7 +5,7 @@ import pytest from xarray.backends.api import open_datatree -from xarray.datatree_.datatree.testing import assert_equal +from xarray.testing import assert_equal from xarray.tests import ( requires_h5netcdf, requires_netCDF4, diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index a948fafc815..40cf85484da 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -61,7 +61,6 @@ requires_cupy, requires_dask, requires_numexpr, - requires_pandas_version_two, requires_pint, requires_scipy, requires_sparse, @@ -3431,7 +3430,52 @@ def test_expand_dims_kwargs_python36plus(self) -> None: ) assert_identical(other_way_expected, other_way) - @requires_pandas_version_two + @pytest.mark.parametrize("create_index_flag", [True, False]) + def test_expand_dims_create_index_data_variable(self, create_index_flag): + # data variables should not gain an index ever + ds = Dataset({"x": 0}) + + if create_index_flag: + with pytest.warns(UserWarning, match="No index created"): + expanded = ds.expand_dims("x", create_index=create_index_flag) + else: + expanded = ds.expand_dims("x", create_index=create_index_flag) + + # TODO Can't just create the expected dataset directly using constructor because of GH issue 8959 + expected = Dataset({"x": ("x", [0])}).drop_indexes("x").reset_coords("x") + + assert_identical(expanded, expected, check_default_indexes=False) + assert expanded.indexes == {} + + def test_expand_dims_create_index_coordinate_variable(self): + # coordinate variables should gain an index only if create_index is True (the default) + ds = Dataset(coords={"x": 0}) + expanded = ds.expand_dims("x") + expected = Dataset({"x": ("x", [0])}) + assert_identical(expanded, expected) + + expanded_no_index = ds.expand_dims("x", create_index=False) + + # TODO Can't just create the expected dataset directly using constructor because of GH issue 8959 + expected = Dataset(coords={"x": ("x", [0])}).drop_indexes("x") + + assert_identical(expanded_no_index, expected, check_default_indexes=False) + assert expanded_no_index.indexes == {} + + def test_expand_dims_create_index_from_iterable(self): + ds = Dataset(coords={"x": 0}) + expanded = ds.expand_dims(x=[0, 1]) + expected = Dataset({"x": ("x", [0, 1])}) + assert_identical(expanded, expected) + + expanded_no_index = ds.expand_dims(x=[0, 1], create_index=False) + + # TODO Can't just create the expected dataset directly using constructor because of GH issue 8959 + expected = Dataset(coords={"x": ("x", [0, 1])}).drop_indexes("x") + + assert_identical(expanded, expected, check_default_indexes=False) + assert expanded_no_index.indexes == {} + def test_expand_dims_non_nanosecond_conversion(self) -> None: # Regression test for https://github.com/pydata/xarray/issues/7493#issuecomment-1953091000 with pytest.warns(UserWarning, match="non-nanosecond precision"): diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index 3de2fa62dc6..e667c8670c7 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -4,10 +4,9 @@ import pytest import xarray as xr -import xarray.datatree_.datatree.testing as dtt -import xarray.testing as xrt from xarray.core.datatree import DataTree from xarray.core.treenode import NotFoundInTreeError +from xarray.testing import assert_equal, assert_identical from xarray.tests import create_test_data, source_ndarray @@ -17,7 +16,7 @@ def test_empty(self): assert dt.name == "root" assert dt.parent is None assert dt.children == {} - xrt.assert_identical(dt.to_dataset(), xr.Dataset()) + assert_identical(dt.to_dataset(), xr.Dataset()) def test_unnamed(self): dt: DataTree = DataTree() @@ -115,7 +114,7 @@ def test_create_with_data(self): dat = xr.Dataset({"a": 0}) john: DataTree = DataTree(name="john", data=dat) - xrt.assert_identical(john.to_dataset(), dat) + assert_identical(john.to_dataset(), dat) with pytest.raises(TypeError): DataTree(name="mary", parent=john, data="junk") # type: ignore[arg-type] @@ -125,7 +124,7 @@ def test_set_data(self): dat = xr.Dataset({"a": 0}) john.ds = dat # type: ignore[assignment] - xrt.assert_identical(john.to_dataset(), dat) + assert_identical(john.to_dataset(), dat) with pytest.raises(TypeError): john.ds = "junk" # type: ignore[assignment] @@ -185,14 +184,14 @@ def test_getitem_self(self): def test_getitem_single_data_variable(self): data = xr.Dataset({"temp": [0, 50]}) results: DataTree = DataTree(name="results", data=data) - xrt.assert_identical(results["temp"], data["temp"]) + assert_identical(results["temp"], data["temp"]) def test_getitem_single_data_variable_from_node(self): data = xr.Dataset({"temp": [0, 50]}) folder1: DataTree = DataTree(name="folder1") results: DataTree = DataTree(name="results", parent=folder1) DataTree(name="highres", parent=results, data=data) - xrt.assert_identical(folder1["results/highres/temp"], data["temp"]) + assert_identical(folder1["results/highres/temp"], data["temp"]) def test_getitem_nonexistent_node(self): folder1: DataTree = DataTree(name="folder1") @@ -210,7 +209,7 @@ def test_getitem_nonexistent_variable(self): def test_getitem_multiple_data_variables(self): data = xr.Dataset({"temp": [0, 50], "p": [5, 8, 7]}) results: DataTree = DataTree(name="results", data=data) - xrt.assert_identical(results[["temp", "p"]], data[["temp", "p"]]) # type: ignore[index] + assert_identical(results[["temp", "p"]], data[["temp", "p"]]) # type: ignore[index] @pytest.mark.xfail( reason="Indexing needs to return whole tree (GH https://github.com/xarray-contrib/datatree/issues/77)" @@ -218,7 +217,7 @@ def test_getitem_multiple_data_variables(self): def test_getitem_dict_like_selection_access_to_dataset(self): data = xr.Dataset({"temp": [0, 50]}) results: DataTree = DataTree(name="results", data=data) - xrt.assert_identical(results[{"temp": 1}], data[{"temp": 1}]) # type: ignore[index] + assert_identical(results[{"temp": 1}], data[{"temp": 1}]) # type: ignore[index] class TestUpdate: @@ -231,14 +230,14 @@ def test_update(self): print(dt._children) print(dt["a"]) print(expected) - dtt.assert_equal(dt, expected) + assert_equal(dt, expected) def test_update_new_named_dataarray(self): da = xr.DataArray(name="temp", data=[0, 50]) folder1: DataTree = DataTree(name="folder1") folder1.update({"results": da}) expected = da.rename("results") - xrt.assert_equal(folder1["results"], expected) + assert_equal(folder1["results"], expected) def test_update_doesnt_alter_child_name(self): dt: DataTree = DataTree() @@ -256,7 +255,7 @@ def test_update_overwrite(self): print(actual) print(expected) - dtt.assert_equal(actual, expected) + assert_equal(actual, expected) class TestCopy: @@ -267,7 +266,7 @@ def test_copy(self, create_test_datatree): node.attrs["Test"] = [1, 2, 3] for copied in [dt.copy(deep=False), copy(dt)]: - dtt.assert_identical(dt, copied) + assert_identical(dt, copied) for node, copied_node in zip(dt.root.subtree, copied.root.subtree): assert node.encoding == copied_node.encoding @@ -291,7 +290,7 @@ def test_copy_subtree(self): actual = dt["/level1/level2"].copy() expected = DataTree.from_dict({"/level3": xr.Dataset()}, name="level2") - dtt.assert_identical(actual, expected) + assert_identical(actual, expected) def test_deepcopy(self, create_test_datatree): dt = create_test_datatree() @@ -300,7 +299,7 @@ def test_deepcopy(self, create_test_datatree): node.attrs["Test"] = [1, 2, 3] for copied in [dt.copy(deep=True), deepcopy(dt)]: - dtt.assert_identical(dt, copied) + assert_identical(dt, copied) for node, copied_node in zip(dt.root.subtree, copied.root.subtree): assert node.encoding == copied_node.encoding @@ -331,7 +330,7 @@ def test_copy_with_data(self, create_test_datatree): expected = orig.copy() for k, v in new_data.items(): expected[k].data = v - dtt.assert_identical(expected, actual) + assert_identical(expected, actual) # TODO test parents and children? @@ -372,13 +371,13 @@ def test_setitem_new_empty_node(self): john["mary"] = DataTree() mary = john["mary"] assert isinstance(mary, DataTree) - xrt.assert_identical(mary.to_dataset(), xr.Dataset()) + assert_identical(mary.to_dataset(), xr.Dataset()) def test_setitem_overwrite_data_in_node_with_none(self): john: DataTree = DataTree(name="john") mary: DataTree = DataTree(name="mary", parent=john, data=xr.Dataset()) john["mary"] = DataTree() - xrt.assert_identical(mary.to_dataset(), xr.Dataset()) + assert_identical(mary.to_dataset(), xr.Dataset()) john.ds = xr.Dataset() # type: ignore[assignment] with pytest.raises(ValueError, match="has no name"): @@ -389,45 +388,45 @@ def test_setitem_dataset_on_this_node(self): data = xr.Dataset({"temp": [0, 50]}) results: DataTree = DataTree(name="results") results["."] = data - xrt.assert_identical(results.to_dataset(), data) + assert_identical(results.to_dataset(), data) @pytest.mark.xfail(reason="assigning Datasets doesn't yet create new nodes") def test_setitem_dataset_as_new_node(self): data = xr.Dataset({"temp": [0, 50]}) folder1: DataTree = DataTree(name="folder1") folder1["results"] = data - xrt.assert_identical(folder1["results"].to_dataset(), data) + assert_identical(folder1["results"].to_dataset(), data) @pytest.mark.xfail(reason="assigning Datasets doesn't yet create new nodes") def test_setitem_dataset_as_new_node_requiring_intermediate_nodes(self): data = xr.Dataset({"temp": [0, 50]}) folder1: DataTree = DataTree(name="folder1") folder1["results/highres"] = data - xrt.assert_identical(folder1["results/highres"].to_dataset(), data) + assert_identical(folder1["results/highres"].to_dataset(), data) def test_setitem_named_dataarray(self): da = xr.DataArray(name="temp", data=[0, 50]) folder1: DataTree = DataTree(name="folder1") folder1["results"] = da expected = da.rename("results") - xrt.assert_equal(folder1["results"], expected) + assert_equal(folder1["results"], expected) def test_setitem_unnamed_dataarray(self): data = xr.DataArray([0, 50]) folder1: DataTree = DataTree(name="folder1") folder1["results"] = data - xrt.assert_equal(folder1["results"], data) + assert_equal(folder1["results"], data) def test_setitem_variable(self): var = xr.Variable(data=[0, 50], dims="x") folder1: DataTree = DataTree(name="folder1") folder1["results"] = var - xrt.assert_equal(folder1["results"], xr.DataArray(var)) + assert_equal(folder1["results"], xr.DataArray(var)) def test_setitem_coerce_to_dataarray(self): folder1: DataTree = DataTree(name="folder1") folder1["results"] = 0 - xrt.assert_equal(folder1["results"], xr.DataArray(0)) + assert_equal(folder1["results"], xr.DataArray(0)) def test_setitem_add_new_variable_to_empty_node(self): results: DataTree = DataTree(name="results") @@ -449,7 +448,7 @@ def test_setitem_dataarray_replace_existing_node(self): p = xr.DataArray(data=[2, 3]) results["pressure"] = p expected = t.assign(pressure=p) - xrt.assert_identical(results.to_dataset(), expected) + assert_identical(results.to_dataset(), expected) class TestDictionaryInterface: ... @@ -462,16 +461,16 @@ def test_data_in_root(self): assert dt.name is None assert dt.parent is None assert dt.children == {} - xrt.assert_identical(dt.to_dataset(), dat) + assert_identical(dt.to_dataset(), dat) def test_one_layer(self): dat1, dat2 = xr.Dataset({"a": 1}), xr.Dataset({"b": 2}) dt = DataTree.from_dict({"run1": dat1, "run2": dat2}) - xrt.assert_identical(dt.to_dataset(), xr.Dataset()) + assert_identical(dt.to_dataset(), xr.Dataset()) assert dt.name is None - xrt.assert_identical(dt["run1"].to_dataset(), dat1) + assert_identical(dt["run1"].to_dataset(), dat1) assert dt["run1"].children == {} - xrt.assert_identical(dt["run2"].to_dataset(), dat2) + assert_identical(dt["run2"].to_dataset(), dat2) assert dt["run2"].children == {} def test_two_layers(self): @@ -480,13 +479,13 @@ def test_two_layers(self): assert "highres" in dt.children assert "lowres" in dt.children highres_run = dt["highres/run"] - xrt.assert_identical(highres_run.to_dataset(), dat1) + assert_identical(highres_run.to_dataset(), dat1) def test_nones(self): dt = DataTree.from_dict({"d": None, "d/e": None}) assert [node.name for node in dt.subtree] == [None, "d", "e"] assert [node.path for node in dt.subtree] == ["/", "/d", "/d/e"] - xrt.assert_identical(dt["d/e"].to_dataset(), xr.Dataset()) + assert_identical(dt["d/e"].to_dataset(), xr.Dataset()) def test_full(self, simple_datatree): dt = simple_datatree @@ -508,7 +507,7 @@ def test_datatree_values(self): actual = DataTree.from_dict({"a": dat1}) - dtt.assert_identical(actual, expected) + assert_identical(actual, expected) def test_roundtrip(self, simple_datatree): dt = simple_datatree @@ -587,16 +586,16 @@ def test_attribute_access(self, create_test_datatree): # vars / coords for key in ["a", "set0"]: - xrt.assert_equal(dt[key], getattr(dt, key)) + assert_equal(dt[key], getattr(dt, key)) assert key in dir(dt) # dims - xrt.assert_equal(dt["a"]["y"], getattr(dt.a, "y")) + assert_equal(dt["a"]["y"], getattr(dt.a, "y")) assert "y" in dir(dt["a"]) # children for key in ["set1", "set2", "set3"]: - dtt.assert_equal(dt[key], getattr(dt, key)) + assert_equal(dt[key], getattr(dt, key)) assert key in dir(dt) # attrs @@ -649,11 +648,11 @@ def test_assign(self): # kwargs form result = dt.assign(foo=xr.DataArray(0), a=DataTree()) - dtt.assert_equal(result, expected) + assert_equal(result, expected) # dict form result = dt.assign({"foo": xr.DataArray(0), "a": DataTree()}) - dtt.assert_equal(result, expected) + assert_equal(result, expected) class TestPipe: @@ -691,7 +690,7 @@ def f(x, tree, y): class TestSubset: def test_match(self): # TODO is this example going to cause problems with case sensitivity? - dt = DataTree.from_dict( + dt: DataTree = DataTree.from_dict( { "/a/A": None, "/a/B": None, @@ -706,10 +705,10 @@ def test_match(self): "/b/B": None, } ) - dtt.assert_identical(result, expected) + assert_identical(result, expected) def test_filter(self): - simpsons = DataTree.from_dict( + simpsons: DataTree = DataTree.from_dict( d={ "/": xr.Dataset({"age": 83}), "/Herbert": xr.Dataset({"age": 40}), @@ -729,4 +728,99 @@ def test_filter(self): name="Abe", ) elders = simpsons.filter(lambda node: node["age"].item() > 18) - dtt.assert_identical(elders, expected) + assert_identical(elders, expected) + + +class TestDSMethodInheritance: + def test_dataset_method(self): + ds = xr.Dataset({"a": ("x", [1, 2, 3])}) + dt: DataTree = DataTree(data=ds) + DataTree(name="results", parent=dt, data=ds) + + expected: DataTree = DataTree(data=ds.isel(x=1)) + DataTree(name="results", parent=expected, data=ds.isel(x=1)) + + result = dt.isel(x=1) + assert_equal(result, expected) + + def test_reduce_method(self): + ds = xr.Dataset({"a": ("x", [False, True, False])}) + dt: DataTree = DataTree(data=ds) + DataTree(name="results", parent=dt, data=ds) + + expected: DataTree = DataTree(data=ds.any()) + DataTree(name="results", parent=expected, data=ds.any()) + + result = dt.any() + assert_equal(result, expected) + + def test_nan_reduce_method(self): + ds = xr.Dataset({"a": ("x", [1, 2, 3])}) + dt: DataTree = DataTree(data=ds) + DataTree(name="results", parent=dt, data=ds) + + expected: DataTree = DataTree(data=ds.mean()) + DataTree(name="results", parent=expected, data=ds.mean()) + + result = dt.mean() + assert_equal(result, expected) + + def test_cum_method(self): + ds = xr.Dataset({"a": ("x", [1, 2, 3])}) + dt: DataTree = DataTree(data=ds) + DataTree(name="results", parent=dt, data=ds) + + expected: DataTree = DataTree(data=ds.cumsum()) + DataTree(name="results", parent=expected, data=ds.cumsum()) + + result = dt.cumsum() + assert_equal(result, expected) + + +class TestOps: + def test_binary_op_on_int(self): + ds1 = xr.Dataset({"a": [5], "b": [3]}) + ds2 = xr.Dataset({"x": [0.1, 0.2], "y": [10, 20]}) + dt: DataTree = DataTree(data=ds1) + DataTree(name="subnode", data=ds2, parent=dt) + + expected: DataTree = DataTree(data=ds1 * 5) + DataTree(name="subnode", data=ds2 * 5, parent=expected) + + # TODO: Remove ignore when ops.py is migrated? + result: DataTree = dt * 5 # type: ignore[assignment,operator] + assert_equal(result, expected) + + def test_binary_op_on_dataset(self): + ds1 = xr.Dataset({"a": [5], "b": [3]}) + ds2 = xr.Dataset({"x": [0.1, 0.2], "y": [10, 20]}) + dt: DataTree = DataTree(data=ds1) + DataTree(name="subnode", data=ds2, parent=dt) + other_ds = xr.Dataset({"z": ("z", [0.1, 0.2])}) + + expected: DataTree = DataTree(data=ds1 * other_ds) + DataTree(name="subnode", data=ds2 * other_ds, parent=expected) + + result = dt * other_ds + assert_equal(result, expected) + + def test_binary_op_on_datatree(self): + ds1 = xr.Dataset({"a": [5], "b": [3]}) + ds2 = xr.Dataset({"x": [0.1, 0.2], "y": [10, 20]}) + dt: DataTree = DataTree(data=ds1) + DataTree(name="subnode", data=ds2, parent=dt) + + expected: DataTree = DataTree(data=ds1 * ds1) + DataTree(name="subnode", data=ds2 * ds2, parent=expected) + + # TODO: Remove ignore when ops.py is migrated? + result: DataTree = dt * dt # type: ignore[operator] + assert_equal(result, expected) + + +class TestUFuncs: + def test_tree(self, create_test_datatree): + dt = create_test_datatree() + expected = create_test_datatree(modify=lambda ds: np.sin(ds)) + result_tree = np.sin(dt) + assert_equal(result_tree, expected) diff --git a/xarray/tests/test_datatree_mapping.py b/xarray/tests/test_datatree_mapping.py index 16ca726759d..b8b55613c4a 100644 --- a/xarray/tests/test_datatree_mapping.py +++ b/xarray/tests/test_datatree_mapping.py @@ -8,7 +8,7 @@ check_isomorphic, map_over_subtree, ) -from xarray.datatree_.datatree.testing import assert_equal +from xarray.testing import assert_equal empty = xr.Dataset() diff --git a/xarray/tests/test_extensions.py b/xarray/tests/test_extensions.py index 7e1802246c7..7cfffd68620 100644 --- a/xarray/tests/test_extensions.py +++ b/xarray/tests/test_extensions.py @@ -5,9 +5,14 @@ import pytest import xarray as xr + +# TODO: Remove imports in favour of xr.DataTree etc, once part of public API +from xarray.core.datatree import DataTree +from xarray.core.extensions import register_datatree_accessor from xarray.tests import assert_identical +@register_datatree_accessor("example_accessor") @xr.register_dataset_accessor("example_accessor") @xr.register_dataarray_accessor("example_accessor") class ExampleAccessor: @@ -19,6 +24,7 @@ def __init__(self, xarray_obj): class TestAccessor: def test_register(self) -> None: + @register_datatree_accessor("demo") @xr.register_dataset_accessor("demo") @xr.register_dataarray_accessor("demo") class DemoAccessor: @@ -31,6 +37,9 @@ def __init__(self, xarray_obj): def foo(self): return "bar" + dt: DataTree = DataTree() + assert dt.demo.foo == "bar" + ds = xr.Dataset() assert ds.demo.foo == "bar" diff --git a/xarray/tests/test_formatting.py b/xarray/tests/test_formatting.py index 6923d26b79a..256a02d49e2 100644 --- a/xarray/tests/test_formatting.py +++ b/xarray/tests/test_formatting.py @@ -10,6 +10,7 @@ import xarray as xr from xarray.core import formatting +from xarray.core.datatree import DataTree # TODO: Remove when can do xr.DataTree from xarray.tests import requires_cftime, requires_dask, requires_netCDF4 ON_WINDOWS = sys.platform == "win32" @@ -555,6 +556,108 @@ def test_array_scalar_format(self) -> None: format(var, ".2f") assert "Using format_spec is only supported" in str(excinfo.value) + def test_datatree_print_empty_node(self): + dt: DataTree = DataTree(name="root") + printout = dt.__str__() + assert printout == "DataTree('root', parent=None)" + + def test_datatree_print_empty_node_with_attrs(self): + dat = xr.Dataset(attrs={"note": "has attrs"}) + dt: DataTree = DataTree(name="root", data=dat) + printout = dt.__str__() + assert printout == dedent( + """\ + DataTree('root', parent=None) + Dimensions: () + Data variables: + *empty* + Attributes: + note: has attrs""" + ) + + def test_datatree_print_node_with_data(self): + dat = xr.Dataset({"a": [0, 2]}) + dt: DataTree = DataTree(name="root", data=dat) + printout = dt.__str__() + expected = [ + "DataTree('root', parent=None)", + "Dimensions", + "Coordinates", + "a", + "Data variables", + "*empty*", + ] + for expected_line, printed_line in zip(expected, printout.splitlines()): + assert expected_line in printed_line + + def test_datatree_printout_nested_node(self): + dat = xr.Dataset({"a": [0, 2]}) + root: DataTree = DataTree(name="root") + DataTree(name="results", data=dat, parent=root) + printout = root.__str__() + assert printout.splitlines()[2].startswith(" ") + + def test_datatree_repr_of_node_with_data(self): + dat = xr.Dataset({"a": [0, 2]}) + dt: DataTree = DataTree(name="root", data=dat) + assert "Coordinates" in repr(dt) + + def test_diff_datatree_repr_structure(self): + dt_1: DataTree = DataTree.from_dict({"a": None, "a/b": None, "a/c": None}) + dt_2: DataTree = DataTree.from_dict({"d": None, "d/e": None}) + + expected = dedent( + """\ + Left and right DataTree objects are not isomorphic + + Number of children on node '/a' of the left object: 2 + Number of children on node '/d' of the right object: 1""" + ) + actual = formatting.diff_datatree_repr(dt_1, dt_2, "isomorphic") + assert actual == expected + + def test_diff_datatree_repr_node_names(self): + dt_1: DataTree = DataTree.from_dict({"a": None}) + dt_2: DataTree = DataTree.from_dict({"b": None}) + + expected = dedent( + """\ + Left and right DataTree objects are not identical + + Node '/a' in the left object has name 'a' + Node '/b' in the right object has name 'b'""" + ) + actual = formatting.diff_datatree_repr(dt_1, dt_2, "identical") + assert actual == expected + + def test_diff_datatree_repr_node_data(self): + # casting to int64 explicitly ensures that int64s are created on all architectures + ds1 = xr.Dataset({"u": np.int64(0), "v": np.int64(1)}) + ds3 = xr.Dataset({"w": np.int64(5)}) + dt_1: DataTree = DataTree.from_dict({"a": ds1, "a/b": ds3}) + ds2 = xr.Dataset({"u": np.int64(0)}) + ds4 = xr.Dataset({"w": np.int64(6)}) + dt_2: DataTree = DataTree.from_dict({"a": ds2, "a/b": ds4}) + + expected = dedent( + """\ + Left and right DataTree objects are not equal + + + Data in nodes at position '/a' do not match: + + Data variables only on the left object: + v int64 8B 1 + + Data in nodes at position '/a/b' do not match: + + Differing data variables: + L w int64 8B 5 + R w int64 8B 6""" + ) + actual = formatting.diff_datatree_repr(dt_1, dt_2, "equals") + assert actual == expected + def test_inline_variable_array_repr_custom_repr() -> None: class CustomArray: diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index afe4d669628..e9e4eb1364c 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -22,7 +22,6 @@ create_test_data, has_cftime, has_flox, - has_pandas_version_two, requires_dask, requires_flox, requires_scipy, @@ -93,7 +92,7 @@ def test_groupby_sizes_property(dataset) -> None: assert dataset.groupby("x").sizes == dataset.isel(x=1).sizes with pytest.warns(UserWarning, match="The `squeeze` kwarg"): assert dataset.groupby("y").sizes == dataset.isel(y=1).sizes - dataset = dataset.drop("cat") + dataset = dataset.drop_vars("cat") stacked = dataset.stack({"xy": ("x", "y")}) with pytest.warns(UserWarning, match="The `squeeze` kwarg"): assert stacked.groupby("xy").sizes == stacked.isel(xy=0).sizes @@ -2172,7 +2171,6 @@ def test_upsample_interpolate_dask(self, chunked_time: bool) -> None: # done here due to floating point arithmetic assert_allclose(expected, actual, rtol=1e-16) - @pytest.mark.skipif(has_pandas_version_two, reason="requires pandas < 2.0.0") def test_resample_base(self) -> None: times = pd.date_range("2000-01-01T02:03:01", freq="6h", periods=10) array = DataArray(np.arange(10), [("time", times)]) @@ -2204,11 +2202,10 @@ def test_resample_origin(self) -> None: expected = DataArray(array.to_series().resample("24h", origin=origin).mean()) assert_identical(expected, actual) - @pytest.mark.skipif(has_pandas_version_two, reason="requires pandas < 2.0.0") @pytest.mark.parametrize( "loffset", [ - "-12H", + "-12h", datetime.timedelta(hours=-12), pd.Timedelta(hours=-12), pd.DateOffset(hours=-12), diff --git a/xarray/tests/test_plugins.py b/xarray/tests/test_plugins.py index b518c973d3a..8e1eb616cca 100644 --- a/xarray/tests/test_plugins.py +++ b/xarray/tests/test_plugins.py @@ -16,7 +16,6 @@ has_h5netcdf, has_netCDF4, has_pydap, - has_pynio, has_scipy, has_zarr, ) @@ -280,7 +279,6 @@ def test_list_engines() -> None: assert ("netcdf4" in engines) == has_netCDF4 assert ("pydap" in engines) == has_pydap assert ("zarr" in engines) == has_zarr - assert ("pynio" in engines) == has_pynio assert "store" in engines diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index 8a9345e74d4..3167de2e2f0 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -36,12 +36,10 @@ assert_equal, assert_identical, assert_no_warnings, - has_pandas_version_two, raise_if_dask_computes, requires_bottleneck, requires_cupy, requires_dask, - requires_pandas_version_two, requires_pint, requires_sparse, source_ndarray, @@ -2645,7 +2643,6 @@ def test_datetime(self): assert np.ndarray == type(actual) assert np.dtype("datetime64[ns]") == actual.dtype - @requires_pandas_version_two def test_tz_datetime(self) -> None: tz = pytz.timezone("America/New_York") times_ns = pd.date_range("2000", periods=1, tz=tz) @@ -2938,7 +2935,7 @@ def test_from_pint_wrapping_dask(self, Var): @pytest.mark.parametrize( - ("values", "warns_under_pandas_version_two"), + ("values", "warns"), [ (np.datetime64("2000-01-01", "ns"), False), (np.datetime64("2000-01-01", "s"), True), @@ -2957,9 +2954,9 @@ def test_from_pint_wrapping_dask(self, Var): ], ids=lambda x: f"{x}", ) -def test_datetime_conversion_warning(values, warns_under_pandas_version_two) -> None: +def test_datetime_conversion_warning(values, warns) -> None: dims = ["time"] if isinstance(values, (np.ndarray, pd.Index, pd.Series)) else [] - if warns_under_pandas_version_two and has_pandas_version_two: + if warns: with pytest.warns(UserWarning, match="non-nanosecond precision datetime"): var = Variable(dims, values) else: @@ -2979,7 +2976,6 @@ def test_datetime_conversion_warning(values, warns_under_pandas_version_two) -> ) -@requires_pandas_version_two def test_pandas_two_only_datetime_conversion_warnings() -> None: # Note these tests rely on pandas features that are only present in pandas # 2.0.0 and above, and so for now cannot be parametrized. @@ -3014,7 +3010,7 @@ def test_pandas_two_only_datetime_conversion_warnings() -> None: @pytest.mark.parametrize( - ("values", "warns_under_pandas_version_two"), + ("values", "warns"), [ (np.timedelta64(10, "ns"), False), (np.timedelta64(10, "s"), True), @@ -3026,9 +3022,9 @@ def test_pandas_two_only_datetime_conversion_warnings() -> None: ], ids=lambda x: f"{x}", ) -def test_timedelta_conversion_warning(values, warns_under_pandas_version_two) -> None: +def test_timedelta_conversion_warning(values, warns) -> None: dims = ["time"] if isinstance(values, (np.ndarray, pd.Index)) else [] - if warns_under_pandas_version_two and has_pandas_version_two: + if warns: with pytest.warns(UserWarning, match="non-nanosecond precision timedelta"): var = Variable(dims, values) else: @@ -3039,7 +3035,6 @@ def test_timedelta_conversion_warning(values, warns_under_pandas_version_two) -> assert var.dtype == np.dtype("timedelta64[ns]") -@requires_pandas_version_two def test_pandas_two_only_timedelta_conversion_warning() -> None: # Note this test relies on a pandas feature that is only present in pandas # 2.0.0 and above, and so for now cannot be parametrized. @@ -3050,7 +3045,6 @@ def test_pandas_two_only_timedelta_conversion_warning() -> None: assert var.dtype == np.dtype("timedelta64[ns]") -@requires_pandas_version_two @pytest.mark.parametrize( ("index", "dtype"), [