From af722f01c91ade023f0c828b886fcb1b0c3d3355 Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Fri, 21 Jun 2024 15:34:30 -0700 Subject: [PATCH 01/33] Improve `to_zarr` docs (#9139) Promoted `region='auto'`, refine the langague slightly --- xarray/core/dataset.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 0923c5d4822..0b8be674675 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -2458,24 +2458,26 @@ def to_zarr( If set, the dimension along which the data will be appended. All other dimensions on overridden variables must remain the same size. region : dict or "auto", optional - Optional mapping from dimension names to integer slices along - dataset dimensions to indicate the region of existing zarr array(s) - in which to write this dataset's data. For example, - ``{'x': slice(0, 1000), 'y': slice(10000, 11000)}`` would indicate - that values should be written to the region ``0:1000`` along ``x`` - and ``10000:11000`` along ``y``. - - Can also specify ``"auto"``, in which case the existing store will be - opened and the region inferred by matching the new data's coordinates. - ``"auto"`` can be used as a single string, which will automatically infer - the region for all dimensions, or as dictionary values for specific - dimensions mixed together with explicit slices for other dimensions. + Optional mapping from dimension names to either a) ``"auto"``, or b) integer + slices, indicating the region of existing zarr array(s) in which to write + this dataset's data. + + If ``"auto"`` is provided the existing store will be opened and the region + inferred by matching indexes. ``"auto"`` can be used as a single string, + which will automatically infer the region for all dimensions, or as + dictionary values for specific dimensions mixed together with explicit + slices for other dimensions. + + Alternatively integer slices can be provided; for example, ``{'x': slice(0, + 1000), 'y': slice(10000, 11000)}`` would indicate that values should be + written to the region ``0:1000`` along ``x`` and ``10000:11000`` along + ``y``. Two restrictions apply to the use of ``region``: - If ``region`` is set, _all_ variables in a dataset must have at least one dimension in common with the region. Other variables - should be written in a separate call to ``to_zarr()``. + should be written in a separate single call to ``to_zarr()``. - Dimensions cannot be included in both ``region`` and ``append_dim`` at the same time. To create empty arrays to fill in with ``region``, use a separate call to ``to_zarr()`` with From 2645d7f6d95abffb04393c7ef1692125ee4ba869 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Fri, 21 Jun 2024 16:35:13 -0600 Subject: [PATCH 02/33] groupby: remove some internal use of IndexVariable (#9123) * Remove internal use of IndexVariable * cleanup * cleanup more * cleanup --- xarray/core/groupby.py | 63 +++++++++++++++++++++++++++-------------- xarray/core/groupers.py | 37 +++++++++++++++++------- 2 files changed, 67 insertions(+), 33 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index eceb4e62199..42e7f01a526 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -19,8 +19,10 @@ from xarray.core.arithmetic import DataArrayGroupbyArithmetic, DatasetGroupbyArithmetic from xarray.core.common import ImplementsArrayReduce, ImplementsDatasetReduce from xarray.core.concat import concat +from xarray.core.coordinates import Coordinates from xarray.core.formatting import format_array_flat from xarray.core.indexes import ( + PandasIndex, create_default_index_implicit, filter_indexes_from_coords, ) @@ -246,7 +248,7 @@ def to_array(self) -> DataArray: return self.to_dataarray() -T_Group = Union["T_DataArray", "IndexVariable", _DummyGroup] +T_Group = Union["T_DataArray", _DummyGroup] def _ensure_1d(group: T_Group, obj: T_DataWithCoords) -> tuple[ @@ -256,7 +258,7 @@ def _ensure_1d(group: T_Group, obj: T_DataWithCoords) -> tuple[ list[Hashable], ]: # 1D cases: do nothing - if isinstance(group, (IndexVariable, _DummyGroup)) or group.ndim == 1: + if isinstance(group, _DummyGroup) or group.ndim == 1: return group, obj, None, [] from xarray.core.dataarray import DataArray @@ -271,9 +273,7 @@ def _ensure_1d(group: T_Group, obj: T_DataWithCoords) -> tuple[ newobj = obj.stack({stacked_dim: orig_dims}) return newgroup, newobj, stacked_dim, inserted_dims - raise TypeError( - f"group must be DataArray, IndexVariable or _DummyGroup, got {type(group)!r}." - ) + raise TypeError(f"group must be DataArray or _DummyGroup, got {type(group)!r}.") @dataclass @@ -299,7 +299,7 @@ class ResolvedGrouper(Generic[T_DataWithCoords]): codes: DataArray = field(init=False) full_index: pd.Index = field(init=False) group_indices: T_GroupIndices = field(init=False) - unique_coord: IndexVariable | _DummyGroup = field(init=False) + unique_coord: Variable | _DummyGroup = field(init=False) # _ensure_1d: group1d: T_Group = field(init=False) @@ -315,7 +315,7 @@ def __post_init__(self) -> None: # might be used multiple times. self.grouper = copy.deepcopy(self.grouper) - self.group: T_Group = _resolve_group(self.obj, self.group) + self.group = _resolve_group(self.obj, self.group) ( self.group1d, @@ -328,14 +328,18 @@ def __post_init__(self) -> None: @property def name(self) -> Hashable: + """Name for the grouped coordinate after reduction.""" # the name has to come from unique_coord because we need `_bins` suffix for BinGrouper - return self.unique_coord.name + (name,) = self.unique_coord.dims + return name @property def size(self) -> int: + """Number of groups.""" return len(self) def __len__(self) -> int: + """Number of groups.""" return len(self.full_index) @property @@ -358,8 +362,8 @@ def factorize(self) -> None: ] if encoded.unique_coord is None: unique_values = self.full_index[np.unique(encoded.codes)] - self.unique_coord = IndexVariable( - self.codes.name, unique_values, attrs=self.group.attrs + self.unique_coord = Variable( + dims=self.codes.name, data=unique_values, attrs=self.group.attrs ) else: self.unique_coord = encoded.unique_coord @@ -378,7 +382,9 @@ def _validate_groupby_squeeze(squeeze: bool | None) -> None: ) -def _resolve_group(obj: T_DataWithCoords, group: T_Group | Hashable) -> T_Group: +def _resolve_group( + obj: T_DataWithCoords, group: T_Group | Hashable | IndexVariable +) -> T_Group: from xarray.core.dataarray import DataArray error_msg = ( @@ -620,6 +626,8 @@ def _iter_grouped(self, warn_squeeze=True) -> Iterator[T_Xarray]: yield self._obj.isel({self._group_dim: indices}) def _infer_concat_args(self, applied_example): + from xarray.core.groupers import BinGrouper + (grouper,) = self.groupers if self._group_dim in applied_example.dims: coord = grouper.group1d @@ -628,7 +636,10 @@ def _infer_concat_args(self, applied_example): coord = grouper.unique_coord positions = None (dim,) = coord.dims - if isinstance(coord, _DummyGroup): + if isinstance(grouper.group, _DummyGroup) and not isinstance( + grouper.grouper, BinGrouper + ): + # When binning we actually do set the index coord = None coord = getattr(coord, "variable", coord) return coord, dim, positions @@ -641,6 +652,7 @@ def _binary_op(self, other, f, reflexive=False): (grouper,) = self.groupers obj = self._original_obj + name = grouper.name group = grouper.group codes = self._codes dims = group.dims @@ -649,9 +661,11 @@ def _binary_op(self, other, f, reflexive=False): group = coord = group.to_dataarray() else: coord = grouper.unique_coord - if not isinstance(coord, DataArray): - coord = DataArray(grouper.unique_coord) - name = grouper.name + if isinstance(coord, Variable): + assert coord.ndim == 1 + (coord_dim,) = coord.dims + # TODO: explicitly create Index here + coord = DataArray(coord, coords={coord_dim: coord.data}) if not isinstance(other, (Dataset, DataArray)): raise TypeError( @@ -766,6 +780,7 @@ def _flox_reduce( obj = self._original_obj (grouper,) = self.groupers + name = grouper.name isbin = isinstance(grouper.grouper, BinGrouper) if keep_attrs is None: @@ -797,14 +812,14 @@ def _flox_reduce( # weird backcompat # reducing along a unique indexed dimension with squeeze=True # should raise an error - if (dim is None or dim == grouper.name) and grouper.name in obj.xindexes: - index = obj.indexes[grouper.name] + if (dim is None or dim == name) and name in obj.xindexes: + index = obj.indexes[name] if index.is_unique and self._squeeze: - raise ValueError(f"cannot reduce over dimensions {grouper.name!r}") + raise ValueError(f"cannot reduce over dimensions {name!r}") unindexed_dims: tuple[Hashable, ...] = tuple() if isinstance(grouper.group, _DummyGroup) and not isbin: - unindexed_dims = (grouper.name,) + unindexed_dims = (name,) parsed_dim: tuple[Hashable, ...] if isinstance(dim, str): @@ -848,15 +863,19 @@ def _flox_reduce( # in the grouped variable group_dims = grouper.group.dims if set(group_dims).issubset(set(parsed_dim)): - result[grouper.name] = output_index + result = result.assign_coords( + Coordinates( + coords={name: (name, np.array(output_index))}, + indexes={name: PandasIndex(output_index, dim=name)}, + ) + ) result = result.drop_vars(unindexed_dims) # broadcast and restore non-numeric data variables (backcompat) for name, var in non_numeric.items(): if all(d not in var.dims for d in parsed_dim): result[name] = var.variable.set_dims( - (grouper.name,) + var.dims, - (result.sizes[grouper.name],) + var.shape, + (name,) + var.dims, (result.sizes[name],) + var.shape ) if not isinstance(result, Dataset): diff --git a/xarray/core/groupers.py b/xarray/core/groupers.py index e33cd3ad99f..075afd9f62f 100644 --- a/xarray/core/groupers.py +++ b/xarray/core/groupers.py @@ -23,7 +23,7 @@ from xarray.core.resample_cftime import CFTimeGrouper from xarray.core.types import DatetimeLike, SideOptions, T_GroupIndices from xarray.core.utils import emit_user_level_warning -from xarray.core.variable import IndexVariable +from xarray.core.variable import Variable __all__ = [ "EncodedGroups", @@ -55,7 +55,17 @@ class EncodedGroups: codes: DataArray full_index: pd.Index group_indices: T_GroupIndices | None = field(default=None) - unique_coord: IndexVariable | _DummyGroup | None = field(default=None) + unique_coord: Variable | _DummyGroup | None = field(default=None) + + def __post_init__(self): + assert isinstance(self.codes, DataArray) + if self.codes.name is None: + raise ValueError("Please set a name on the array you are grouping by.") + assert isinstance(self.full_index, pd.Index) + assert ( + isinstance(self.unique_coord, (Variable, _DummyGroup)) + or self.unique_coord is None + ) class Grouper(ABC): @@ -134,10 +144,10 @@ def _factorize_unique(self) -> EncodedGroups: "Failed to group data. Are you grouping by a variable that is all NaN?" ) codes = self.group.copy(data=codes_) - unique_coord = IndexVariable( - self.group.name, unique_values, attrs=self.group.attrs + unique_coord = Variable( + dims=codes.name, data=unique_values, attrs=self.group.attrs ) - full_index = unique_coord + full_index = pd.Index(unique_values) return EncodedGroups( codes=codes, full_index=full_index, unique_coord=unique_coord @@ -152,12 +162,13 @@ def _factorize_dummy(self) -> EncodedGroups: size_range = np.arange(size) if isinstance(self.group, _DummyGroup): codes = self.group.to_dataarray().copy(data=size_range) + unique_coord = self.group + full_index = pd.RangeIndex(self.group.size) else: codes = self.group.copy(data=size_range) - unique_coord = self.group - full_index = IndexVariable( - self.group.name, unique_coord.values, self.group.attrs - ) + unique_coord = self.group.variable.to_base_variable() + full_index = pd.Index(unique_coord.data) + return EncodedGroups( codes=codes, group_indices=group_indices, @@ -201,7 +212,9 @@ def factorize(self, group) -> EncodedGroups: codes = DataArray( binned_codes, getattr(group, "coords", None), name=new_dim_name ) - unique_coord = IndexVariable(new_dim_name, pd.Index(unique_values), group.attrs) + unique_coord = Variable( + dims=new_dim_name, data=unique_values, attrs=group.attrs + ) return EncodedGroups( codes=codes, full_index=full_index, unique_coord=unique_coord ) @@ -318,7 +331,9 @@ def factorize(self, group) -> EncodedGroups: ] group_indices += [slice(sbins[-1], None)] - unique_coord = IndexVariable(group.name, first_items.index, group.attrs) + unique_coord = Variable( + dims=group.name, data=first_items.index, attrs=group.attrs + ) codes = group.copy(data=codes_) return EncodedGroups( From deb2082ab6e648b7e87cd26a74d084262bd1cfdf Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Sat, 22 Jun 2024 11:03:37 -0700 Subject: [PATCH 03/33] Improve zarr chunks docs (#9140) * Improve zarr chunks docs Makes them more structure, consistent. I think removes a mistake re the default chunks arg in `open_zarr` (it's not `None`, it's `auto`). Adds a comment re performance with `chunks=None`, closing https://github.com/pydata/xarray/issues/9111 --- doc/whats-new.rst | 2 ++ xarray/backends/api.py | 43 +++++++++++++++++++++++++---------------- xarray/backends/zarr.py | 18 +++++++++++------ 3 files changed, 40 insertions(+), 23 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index e7a48458ae2..51a2c98fb9c 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -40,6 +40,8 @@ Bug fixes Documentation ~~~~~~~~~~~~~ +- Improvements to Zarr & chunking docs (:pull:`9139`, :pull:`9140`, :pull:`9132`) + By `Maximilian Roos `_ Internal Changes ~~~~~~~~~~~~~~~~ diff --git a/xarray/backends/api.py b/xarray/backends/api.py index ea3639db5c4..7054c62126e 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -425,15 +425,19 @@ def open_dataset( is chosen based on available dependencies, with a preference for "netcdf4". A custom backend class (a subclass of ``BackendEntrypoint``) can also be used. - chunks : int, dict, 'auto' or None, optional - If chunks is provided, it is used to load the new dataset into dask - arrays. ``chunks=-1`` loads the dataset with dask using a single - chunk for all arrays. ``chunks={}`` loads the dataset with dask using - engine preferred chunks if exposed by the backend, otherwise with - a single chunk for all arrays. In order to reproduce the default behavior - of ``xr.open_zarr(...)`` use ``xr.open_dataset(..., engine='zarr', chunks={})``. - ``chunks='auto'`` will use dask ``auto`` chunking taking into account the - engine preferred chunks. See dask chunking for more details. + chunks : int, dict, 'auto' or None, default: None + If provided, used to load the data into dask arrays. + + - ``chunks="auto"`` will use dask ``auto`` chunking taking into account the + engine preferred chunks. + - ``chunks=None`` skips using dask, which is generally faster for + small arrays. + - ``chunks=-1`` loads the data with dask using a single chunk for all arrays. + - ``chunks={}`` loads the data with dask using the engine's preferred chunk + size, generally identical to the format's chunk size. If not available, a + single chunk for all arrays. + + See dask chunking for more details. cache : bool, optional If True, cache data loaded from the underlying datastore in memory as NumPy arrays when accessed to avoid reading from the underlying data- @@ -631,14 +635,19 @@ def open_dataarray( Engine to use when reading files. If not provided, the default engine is chosen based on available dependencies, with a preference for "netcdf4". - chunks : int, dict, 'auto' or None, optional - If chunks is provided, it is used to load the new dataset into dask - arrays. ``chunks=-1`` loads the dataset with dask using a single - chunk for all arrays. `chunks={}`` loads the dataset with dask using - engine preferred chunks if exposed by the backend, otherwise with - a single chunk for all arrays. - ``chunks='auto'`` will use dask ``auto`` chunking taking into account the - engine preferred chunks. See dask chunking for more details. + chunks : int, dict, 'auto' or None, default: None + If provided, used to load the data into dask arrays. + + - ``chunks='auto'`` will use dask ``auto`` chunking taking into account the + engine preferred chunks. + - ``chunks=None`` skips using dask, which is generally faster for + small arrays. + - ``chunks=-1`` loads the data with dask using a single chunk for all arrays. + - ``chunks={}`` loads the data with dask using engine preferred chunks if + exposed by the backend, otherwise with a single chunk for all arrays. + + See dask chunking for more details. + cache : bool, optional If True, cache data loaded from the underlying datastore in memory as NumPy arrays when accessed to avoid reading from the underlying data- diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index 5f6aa0f119c..9796fcbf9e2 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -973,12 +973,18 @@ def open_zarr( Array synchronizer provided to zarr group : str, optional Group path. (a.k.a. `path` in zarr terminology.) - chunks : int or dict or tuple or {None, 'auto'}, optional - Chunk sizes along each dimension, e.g., ``5`` or - ``{'x': 5, 'y': 5}``. If `chunks='auto'`, dask chunks are created - based on the variable's zarr chunks. If `chunks=None`, zarr array - data will lazily convert to numpy arrays upon access. This accepts - all the chunk specifications as Dask does. + chunks : int, dict, 'auto' or None, default: 'auto' + If provided, used to load the data into dask arrays. + + - ``chunks='auto'`` will use dask ``auto`` chunking taking into account the + engine preferred chunks. + - ``chunks=None`` skips using dask, which is generally faster for + small arrays. + - ``chunks=-1`` loads the data with dask using a single chunk for all arrays. + - ``chunks={}`` loads the data with dask using engine preferred chunks if + exposed by the backend, otherwise with a single chunk for all arrays. + + See dask chunking for more details. overwrite_encoded_chunks : bool, optional Whether to drop the zarr chunks encoded for each variable when a dataset is loaded with specified chunk sizes (default: False) From fe4fb061499f77681dd330cffb116c24388fe3d9 Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Sun, 23 Jun 2024 20:39:25 -0700 Subject: [PATCH 04/33] Include numbagg in type checks (#9159) * Include numbagg in type checks --- pyproject.toml | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index db64d7a18c5..2081f7f87bc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -118,7 +118,6 @@ module = [ "matplotlib.*", "mpl_toolkits.*", "nc_time_axis.*", - "numbagg.*", "netCDF4.*", "netcdftime.*", "opt_einsum.*", @@ -329,8 +328,7 @@ filterwarnings = [ "default:the `pandas.MultiIndex` object:FutureWarning:xarray.tests.test_variable", "default:Using a non-tuple sequence for multidimensional indexing is deprecated:FutureWarning", "default:Duplicate dimension names present:UserWarning:xarray.namedarray.core", - "default:::xarray.tests.test_strategies", - # TODO: remove once we know how to deal with a changed signature in protocols + "default:::xarray.tests.test_strategies", # TODO: remove once we know how to deal with a changed signature in protocols "ignore:__array__ implementation doesn't accept a copy keyword, so passing copy=False failed.", ] From c8ff731aa83b5b555b1c75bf72120e9f1ca043d9 Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Sun, 23 Jun 2024 20:41:41 -0700 Subject: [PATCH 05/33] Remove mypy exclusions for a couple more libraries (#9160) * Remove mypy exclusions for a couple more libraries Also (unrelated) allow mypy passing without `array_api_strict` installed, which isn't in our dev dependencies... --- pyproject.toml | 2 -- xarray/tests/test_dtypes.py | 6 +++--- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2081f7f87bc..1815fa6dd5d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -110,7 +110,6 @@ module = [ "cloudpickle.*", "cubed.*", "cupy.*", - "dask.types.*", "fsspec.*", "h5netcdf.*", "h5py.*", @@ -126,7 +125,6 @@ module = [ "pooch.*", "pyarrow.*", "pydap.*", - "pytest.*", "scipy.*", "seaborn.*", "setuptools", diff --git a/xarray/tests/test_dtypes.py b/xarray/tests/test_dtypes.py index e817bfdb330..498ba2ce59f 100644 --- a/xarray/tests/test_dtypes.py +++ b/xarray/tests/test_dtypes.py @@ -11,9 +11,9 @@ except ImportError: class DummyArrayAPINamespace: - bool = None - int32 = None - float64 = None + bool = None # type: ignore[unused-ignore,var-annotated] + int32 = None # type: ignore[unused-ignore,var-annotated] + float64 = None # type: ignore[unused-ignore,var-annotated] array_api_strict = DummyArrayAPINamespace From 872c1c576dc4bc1724e1c526ddc45cb420394ce3 Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Sun, 23 Jun 2024 21:48:59 -0700 Subject: [PATCH 06/33] Add test for #9155 (#9161) * Add test for #9155 I can't get this to fail locally, so adding a test to assess what's going on. Alos excludes matplotlib from type exclusions * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- pyproject.toml | 1 - xarray/tests/test_plot.py | 10 ++++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 1815fa6dd5d..2ada0c1c171 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -114,7 +114,6 @@ module = [ "h5netcdf.*", "h5py.*", "iris.*", - "matplotlib.*", "mpl_toolkits.*", "nc_time_axis.*", "netCDF4.*", diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index a44b621a981..b302ad3af93 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -3406,3 +3406,13 @@ def test_plot1d_filtered_nulls() -> None: actual = pc.get_offsets().shape[0] assert expected == actual + + +@requires_matplotlib +def test_9155() -> None: + # A test for types from issue #9155 + + with figure_context(): + data = xr.DataArray([1, 2, 3], dims=["x"]) + fig, ax = plt.subplots(ncols=1, nrows=1) + data.plot(ax=ax) From 56209bd9a3192e4f1e82c21e5ffcf4c3bacaaae3 Mon Sep 17 00:00:00 2001 From: Jessica Scheick Date: Mon, 24 Jun 2024 11:31:30 -0400 Subject: [PATCH 07/33] Docs: Add page with figure for navigating help resources (#9147) * add config to build mermaid diagrams in docs --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Deepak Cherian --- ci/requirements/doc.yml | 1 + doc/conf.py | 5 +++ doc/help-diagram.rst | 75 +++++++++++++++++++++++++++++++++++++++++ doc/index.rst | 4 ++- doc/whats-new.rst | 3 ++ 5 files changed, 87 insertions(+), 1 deletion(-) create mode 100644 doc/help-diagram.rst diff --git a/ci/requirements/doc.yml b/ci/requirements/doc.yml index 066d085ec53..39c2d4d6e88 100644 --- a/ci/requirements/doc.yml +++ b/ci/requirements/doc.yml @@ -42,5 +42,6 @@ dependencies: - sphinxext-rediraffe - zarr>=2.10 - pip: + - sphinxcontrib-mermaid # relative to this file. Needs to be editable to be accepted. - -e ../.. diff --git a/doc/conf.py b/doc/conf.py index 80b24445f71..91bcdf8b8f8 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -59,6 +59,7 @@ ) nbsphinx_allow_errors = False +nbsphinx_requirejs_path = "" # -- General configuration ------------------------------------------------ @@ -68,7 +69,9 @@ # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. + extensions = [ + "sphinxcontrib.mermaid", "sphinx.ext.autodoc", "sphinx.ext.autosummary", "sphinx.ext.intersphinx", @@ -175,6 +178,8 @@ "pd.NaT": "~pandas.NaT", } +# mermaid config +mermaid_version = "10.9.1" # Add any paths that contain templates here, relative to this directory. templates_path = ["_templates", sphinx_autosummary_accessors.templates_path] diff --git a/doc/help-diagram.rst b/doc/help-diagram.rst new file mode 100644 index 00000000000..a42a2f0936a --- /dev/null +++ b/doc/help-diagram.rst @@ -0,0 +1,75 @@ +Getting Help +============ + +Navigating the wealth of resources available for Xarray can be overwhelming. +We've created this flow chart to help guide you towards the best way to get help, depending on what you're working towards. +The links to each resource are provided below the diagram. +Regardless of how you interact with us, we're always thrilled to hear from you! + +.. mermaid:: + :alt: Flowchart illustrating the different ways to access help using or contributing to Xarray. + + flowchart TD + intro[Welcome to Xarray! How can we help?]:::quesNodefmt + usage(["fa:fa-chalkboard-user Xarray Tutorials + fab:fa-readme Xarray Docs + fab:fa-google Google/fab:fa-stack-overflow Stack Exchange + fa:fa-robot Ask AI/a Language Learning Model (LLM)"]):::ansNodefmt + API([fab:fa-readme Xarray Docs + fab:fa-readme extension's docs]):::ansNodefmt + help([fab:fa-github Xarray Discussions + fab:fa-discord Xarray Discord + fa:fa-users Xarray Office Hours + fa:fa-globe Pangeo Discourse]):::ansNodefmt + bug([Report and Propose here: + fab:fa-github Xarray Issues]):::ansNodefmt + contrib([fa:fa-book-open Xarray Contributor's Guide]):::ansNodefmt + pr(["fab:fa-github Pull Request (PR)"]):::ansNodefmt + dev([fab:fa-github Comment on your PR + fa:fa-users Developer's Meeting]):::ansNodefmt + report[Thanks for letting us know!]:::quesNodefmt + merged[fa:fa-hands-clapping Your PR was merged. + Thanks for contributing to Xarray!]:::quesNodefmt + + + intro -->|How do I use Xarray?| usage + usage -->|"with extensions (like Dask)"| API + + usage -->|I'd like some more help| help + intro -->|I found a bug| bug + intro -->|I'd like to make a small change| contrib + subgraph bugcontrib[Bugs and Contributions] + bug + contrib + bug -->|I just wanted to tell you| report + bug<-->|I'd like to fix the bug!| contrib + pr -->|my PR was approved| merged + end + + + intro -->|I wish Xarray could...| bug + + + pr <-->|my PR is quiet| dev + contrib -->pr + + classDef quesNodefmt fill:#9DEEF4,stroke:#206C89 + + classDef ansNodefmt fill:#FFAA05,stroke:#E37F17 + + classDef boxfmt fill:#FFF5ED,stroke:#E37F17 + class bugcontrib boxfmt + + linkStyle default font-size:20pt,color:#206C89 + + +- `Xarray Tutorials `__ +- `Xarray Docs `__ +- `Google/Stack Exchange `__ +- `Xarray Discussions `__ +- `Xarray Discord `__ +- `Xarray Office Hours `__ +- `Pangeo Discourse `__ +- `Xarray Issues `__ +- `Xarray Contributors Guide `__ +- `Developer's Meeting `__ diff --git a/doc/index.rst b/doc/index.rst index 138e9d91601..4a5fe4ee080 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -14,7 +14,8 @@ efficient, and fun! `Releases `__ | `Stack Overflow `__ | `Mailing List `__ | -`Blog `__ +`Blog `__ | +`Tutorials `__ .. grid:: 1 1 2 2 @@ -65,6 +66,7 @@ efficient, and fun! Tutorials & Videos API Reference How do I ... + Getting Help Ecosystem .. toctree:: diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 51a2c98fb9c..c3383a5648a 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -40,9 +40,12 @@ Bug fixes Documentation ~~~~~~~~~~~~~ +- Adds a flow-chart diagram to help users navigate help resources (`Discussion #8990 `_). + By `Jessica Scheick `_. - Improvements to Zarr & chunking docs (:pull:`9139`, :pull:`9140`, :pull:`9132`) By `Maximilian Roos `_ + Internal Changes ~~~~~~~~~~~~~~~~ From b5180749d351f8b85fd39677bf137caaa90288a7 Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Tue, 25 Jun 2024 15:18:53 +0200 Subject: [PATCH 08/33] switch to unit `"D"` (#9170) --- xarray/tests/test_missing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/tests/test_missing.py b/xarray/tests/test_missing.py index 3adcc132b61..da9513a7c71 100644 --- a/xarray/tests/test_missing.py +++ b/xarray/tests/test_missing.py @@ -84,7 +84,7 @@ def make_interpolate_example_data(shape, frac_nan, seed=12345, non_uniform=False if non_uniform: # construct a datetime index that has irregular spacing - deltas = pd.to_timedelta(rs.normal(size=shape[0], scale=10), unit="d") + deltas = pd.to_timedelta(rs.normal(size=shape[0], scale=10), unit="D") coords = {"time": (pd.Timestamp("2000-01-01") + deltas).sort_values()} else: coords = {"time": pd.date_range("2000-01-01", freq="D", periods=shape[0])} From 07b175633eba30dbfcd6eb0cf514ef1b1da9cf64 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Wed, 26 Jun 2024 11:05:23 -0700 Subject: [PATCH 09/33] Slightly improve DataTree repr (#9064) * Improve DataTree repr * Adjust DataTree repr to include full path * More tweaks * Use "Group:" in repr instead of "DataTree:" * Fix errors in new repr tests * Fix repr on windows --- xarray/core/datatree.py | 11 ++++--- xarray/core/datatree_render.py | 11 ++++--- xarray/core/formatting.py | 15 +++------ xarray/core/iterators.py | 19 +++++------ xarray/tests/test_datatree.py | 57 +++++++++++++++++++++++++++++++++ xarray/tests/test_formatting.py | 18 ++++++----- 6 files changed, 94 insertions(+), 37 deletions(-) diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index 4e4d30885a3..c923ca2eb87 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -1314,11 +1314,12 @@ def match(self, pattern: str) -> DataTree: ... } ... ) >>> dt.match("*/B") - DataTree('None', parent=None) - ├── DataTree('a') - │ └── DataTree('B') - └── DataTree('b') - └── DataTree('B') + + Group: / + ├── Group: /a + │ └── Group: /a/B + └── Group: /b + └── Group: /b/B """ matching_nodes = { node.path: node.ds diff --git a/xarray/core/datatree_render.py b/xarray/core/datatree_render.py index d069071495e..f10f2540952 100644 --- a/xarray/core/datatree_render.py +++ b/xarray/core/datatree_render.py @@ -57,11 +57,12 @@ def __init__(self): >>> s0a = DataTree(name="sub0A", parent=s0) >>> s1 = DataTree(name="sub1", parent=root) >>> print(RenderDataTree(root)) - DataTree('root', parent=None) - ├── DataTree('sub0') - │ ├── DataTree('sub0B') - │ └── DataTree('sub0A') - └── DataTree('sub1') + + Group: / + ├── Group: /sub0 + │ ├── Group: /sub0/sub0B + │ └── Group: /sub0/sub0A + └── Group: /sub1 """ super().__init__("\u2502 ", "\u251c\u2500\u2500 ", "\u2514\u2500\u2500 ") diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py index ad65a44d7d5..c15df34b5b1 100644 --- a/xarray/core/formatting.py +++ b/xarray/core/formatting.py @@ -1023,20 +1023,21 @@ def diff_datatree_repr(a: DataTree, b: DataTree, compat): def _single_node_repr(node: DataTree) -> str: """Information about this node, not including its relationships to other nodes.""" - node_info = f"DataTree('{node.name}')" - if node.has_data or node.has_attrs: ds_info = "\n" + repr(node.ds) else: ds_info = "" - return node_info + ds_info + return f"Group: {node.path}{ds_info}" def datatree_repr(dt: DataTree): """A printable representation of the structure of this entire tree.""" renderer = RenderDataTree(dt) - lines = [] + name_info = "" if dt.name is None else f" {dt.name!r}" + header = f"" + + lines = [header] for pre, fill, node in renderer: node_repr = _single_node_repr(node) @@ -1051,12 +1052,6 @@ def datatree_repr(dt: DataTree): else: lines.append(f"{fill}{' ' * len(renderer.style.vertical)}{line}") - # Tack on info about whether or not root node has a parent at the start - first_line = lines[0] - parent = f'"{dt.parent.name}"' if dt.parent is not None else "None" - first_line_with_parent = first_line[:-1] + f", parent={parent})" - lines[0] = first_line_with_parent - return "\n".join(lines) diff --git a/xarray/core/iterators.py b/xarray/core/iterators.py index dd5fa7ee97a..ae748b0066c 100644 --- a/xarray/core/iterators.py +++ b/xarray/core/iterators.py @@ -39,15 +39,16 @@ class LevelOrderIter(Iterator): >>> i = DataTree(name="i", parent=g) >>> h = DataTree(name="h", parent=i) >>> print(f) - DataTree('f', parent=None) - ├── DataTree('b') - │ ├── DataTree('a') - │ └── DataTree('d') - │ ├── DataTree('c') - │ └── DataTree('e') - └── DataTree('g') - └── DataTree('i') - └── DataTree('h') + + Group: / + ├── Group: /b + │ ├── Group: /b/a + │ └── Group: /b/d + │ ├── Group: /b/d/c + │ └── Group: /b/d/e + └── Group: /g + └── Group: /g/i + └── Group: /g/i/h >>> [node.name for node in LevelOrderIter(f)] ['f', 'b', 'g', 'a', 'd', 'i', 'c', 'e', 'h'] >>> [node.name for node in LevelOrderIter(f, maxlevel=3)] diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index 58fec20d4c6..b0dc2accd3e 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -623,6 +623,63 @@ def test_operation_with_attrs_but_no_data(self): dt.sel(dim_0=0) +class TestRepr: + def test_repr(self): + dt: DataTree = DataTree.from_dict( + { + "/": xr.Dataset( + {"e": (("x",), [1.0, 2.0])}, + coords={"x": [2.0, 3.0]}, + ), + "/b": xr.Dataset({"f": (("y",), [3.0])}), + "/b/c": xr.Dataset(), + "/b/d": xr.Dataset({"g": 4.0}), + } + ) + + result = repr(dt) + expected = dedent( + """ + + Group: / + │ Dimensions: (x: 2) + │ Coordinates: + │ * x (x) float64 16B 2.0 3.0 + │ Data variables: + │ e (x) float64 16B 1.0 2.0 + └── Group: /b + │ Dimensions: (y: 1) + │ Dimensions without coordinates: y + │ Data variables: + │ f (y) float64 8B 3.0 + ├── Group: /b/c + └── Group: /b/d + Dimensions: () + Data variables: + g float64 8B 4.0 + """ + ).strip() + assert result == expected + + result = repr(dt.b) + expected = dedent( + """ + + Group: /b + │ Dimensions: (y: 1) + │ Dimensions without coordinates: y + │ Data variables: + │ f (y) float64 8B 3.0 + ├── Group: /b/c + └── Group: /b/d + Dimensions: () + Data variables: + g float64 8B 4.0 + """ + ).strip() + assert result == expected + + class TestRestructuring: def test_drop_nodes(self): sue = DataTree.from_dict({"Mary": None, "Kate": None, "Ashley": None}) diff --git a/xarray/tests/test_formatting.py b/xarray/tests/test_formatting.py index b9d5f401a4a..d7a46eeaefc 100644 --- a/xarray/tests/test_formatting.py +++ b/xarray/tests/test_formatting.py @@ -555,16 +555,17 @@ def test_array_scalar_format(self) -> None: def test_datatree_print_empty_node(self): dt: DataTree = DataTree(name="root") - printout = dt.__str__() - assert printout == "DataTree('root', parent=None)" + printout = str(dt) + assert printout == "\nGroup: /" def test_datatree_print_empty_node_with_attrs(self): dat = xr.Dataset(attrs={"note": "has attrs"}) dt: DataTree = DataTree(name="root", data=dat) - printout = dt.__str__() + printout = str(dt) assert printout == dedent( """\ - DataTree('root', parent=None) + + Group: / Dimensions: () Data variables: *empty* @@ -575,9 +576,10 @@ def test_datatree_print_empty_node_with_attrs(self): def test_datatree_print_node_with_data(self): dat = xr.Dataset({"a": [0, 2]}) dt: DataTree = DataTree(name="root", data=dat) - printout = dt.__str__() + printout = str(dt) expected = [ - "DataTree('root', parent=None)", + "", + "Group: /", "Dimensions", "Coordinates", "a", @@ -591,8 +593,8 @@ def test_datatree_printout_nested_node(self): dat = xr.Dataset({"a": [0, 2]}) root: DataTree = DataTree(name="root") DataTree(name="results", data=dat, parent=root) - printout = root.__str__() - assert printout.splitlines()[2].startswith(" ") + printout = str(root) + assert printout.splitlines()[3].startswith(" ") def test_datatree_repr_of_node_with_data(self): dat = xr.Dataset({"a": [0, 2]}) From 19d0fbfcbd3bd74f5846569a78ded68810446c48 Mon Sep 17 00:00:00 2001 From: David Hoese Date: Wed, 26 Jun 2024 13:14:25 -0500 Subject: [PATCH 10/33] Fix example code formatting for CachingFileManager (#9178) --- xarray/backends/file_manager.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/backends/file_manager.py b/xarray/backends/file_manager.py index df901f9a1d9..86d84f532b1 100644 --- a/xarray/backends/file_manager.py +++ b/xarray/backends/file_manager.py @@ -63,7 +63,7 @@ class CachingFileManager(FileManager): FileManager.close(), which ensures that closed files are removed from the cache as well. - Example usage: + Example usage:: manager = FileManager(open, 'example.txt', mode='w') f = manager.acquire() @@ -71,7 +71,7 @@ class CachingFileManager(FileManager): manager.close() # ensures file is closed Note that as long as previous files are still cached, acquiring a file - multiple times from the same FileManager is essentially free: + multiple times from the same FileManager is essentially free:: f1 = manager.acquire() f2 = manager.acquire() From 651bd12749e56b0b2f992c8cae51dae0ece29c65 Mon Sep 17 00:00:00 2001 From: Pontus Lurcock Date: Wed, 26 Jun 2024 20:16:09 +0200 Subject: [PATCH 11/33] Change np.core.defchararray to np.char (#9165) (#9166) * Change np.core.defchararray to np.char.chararray (#9165) Replace a reference to np.core.defchararray with np.char.chararray in xarray.testing.assertions, since the former no longer works on NumPy 2.0.0 and the latter is the "preferred alias" according to NumPy docs. See Issue #9165. * Add test for assert_allclose on dtype S (#9165) * Use np.char.decode, not np.char.chararray.decode ... in assertions._decode_string_data. See #9166. * List #9165 fix in whats-new.rst * cross-like the fixed function * Improve a parameter ID in tests.test_assertions Co-authored-by: Justus Magin * whats-new normalization --------- Co-authored-by: Justus Magin --- doc/whats-new.rst | 2 ++ xarray/testing/assertions.py | 2 +- xarray/tests/test_assertions.py | 5 +++++ 3 files changed, 8 insertions(+), 1 deletion(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index c3383a5648a..97631b4c324 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -35,6 +35,8 @@ Deprecations Bug fixes ~~~~~~~~~ +- Make :py:func:`testing.assert_allclose` work with numpy 2.0 (:issue:`9165`, :pull:`9166`). + By `Pontus Lurcock `_. Documentation diff --git a/xarray/testing/assertions.py b/xarray/testing/assertions.py index 69885868f83..2a4c17e115a 100644 --- a/xarray/testing/assertions.py +++ b/xarray/testing/assertions.py @@ -36,7 +36,7 @@ def wrapper(*args, **kwargs): def _decode_string_data(data): if data.dtype.kind == "S": - return np.core.defchararray.decode(data, "utf-8", "replace") + return np.char.decode(data, "utf-8", "replace") return data diff --git a/xarray/tests/test_assertions.py b/xarray/tests/test_assertions.py index aa0ea46f7db..20b5e163662 100644 --- a/xarray/tests/test_assertions.py +++ b/xarray/tests/test_assertions.py @@ -52,6 +52,11 @@ def test_allclose_regression() -> None: xr.Dataset({"a": ("x", [0, 2]), "b": ("y", [0, 1])}), id="Dataset", ), + pytest.param( + xr.DataArray(np.array("a", dtype="|S1")), + xr.DataArray(np.array("b", dtype="|S1")), + id="DataArray_with_character_dtype", + ), ), ) def test_assert_allclose(obj1, obj2) -> None: From fa41cc0454e6daf47d1417f97a9e72ebb56e3add Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Thu, 27 Jun 2024 12:23:55 +0200 Subject: [PATCH 12/33] temporarily pin `numpy<2` (#9181) --- ci/requirements/doc.yml | 2 +- ci/requirements/environment-windows.yml | 2 +- ci/requirements/environment.yml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/ci/requirements/doc.yml b/ci/requirements/doc.yml index 39c2d4d6e88..116eee7f702 100644 --- a/ci/requirements/doc.yml +++ b/ci/requirements/doc.yml @@ -21,7 +21,7 @@ dependencies: - nbsphinx - netcdf4>=1.5 - numba - - numpy>=1.21 + - numpy>=1.21,<2 - packaging>=21.3 - pandas>=1.4,!=2.1.0 - pooch diff --git a/ci/requirements/environment-windows.yml b/ci/requirements/environment-windows.yml index 3b2e6dc62e6..4cdddc676eb 100644 --- a/ci/requirements/environment-windows.yml +++ b/ci/requirements/environment-windows.yml @@ -23,7 +23,7 @@ dependencies: - netcdf4 - numba - numbagg - - numpy + - numpy<2 - packaging - pandas # - pint>=0.22 diff --git a/ci/requirements/environment.yml b/ci/requirements/environment.yml index 01521e950f4..f1a10bc040b 100644 --- a/ci/requirements/environment.yml +++ b/ci/requirements/environment.yml @@ -26,7 +26,7 @@ dependencies: - numba - numbagg - numexpr - - numpy + - numpy<2 - opt_einsum - packaging - pandas From 48a4f7ac6cf20a8b6d0247c701647c67251ded78 Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Thu, 27 Jun 2024 14:28:48 +0200 Subject: [PATCH 13/33] temporarily remove `pydap` from CI (#9183) (the issue is that with `numpy>=2` `import pydap` succeeds, but `import pydap.lib` raises) --- ci/requirements/all-but-dask.yml | 2 +- ci/requirements/environment-windows.yml | 2 +- ci/requirements/environment.yml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/ci/requirements/all-but-dask.yml b/ci/requirements/all-but-dask.yml index 2f47643cc87..119db282ad9 100644 --- a/ci/requirements/all-but-dask.yml +++ b/ci/requirements/all-but-dask.yml @@ -27,7 +27,7 @@ dependencies: - pandas - pint>=0.22 - pip - - pydap + # - pydap - pytest - pytest-cov - pytest-env diff --git a/ci/requirements/environment-windows.yml b/ci/requirements/environment-windows.yml index 4cdddc676eb..2eedc9b0621 100644 --- a/ci/requirements/environment-windows.yml +++ b/ci/requirements/environment-windows.yml @@ -29,7 +29,7 @@ dependencies: # - pint>=0.22 - pip - pre-commit - - pydap + # - pydap - pytest - pytest-cov - pytest-env diff --git a/ci/requirements/environment.yml b/ci/requirements/environment.yml index f1a10bc040b..317e1fe5f41 100644 --- a/ci/requirements/environment.yml +++ b/ci/requirements/environment.yml @@ -35,7 +35,7 @@ dependencies: - pooch - pre-commit - pyarrow # pandas raises a deprecation warning without this, breaking doctests - - pydap + # - pydap - pytest - pytest-cov - pytest-env From f4183ec043de97273efdfdd4a33df2c3dc08ddff Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Thu, 27 Jun 2024 19:04:16 +0200 Subject: [PATCH 14/33] also pin `numpy` in the all-but-dask CI (#9184) --- ci/requirements/all-but-dask.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ci/requirements/all-but-dask.yml b/ci/requirements/all-but-dask.yml index 119db282ad9..abf6a88690a 100644 --- a/ci/requirements/all-but-dask.yml +++ b/ci/requirements/all-but-dask.yml @@ -22,7 +22,7 @@ dependencies: - netcdf4 - numba - numbagg - - numpy + - numpy<2 - packaging - pandas - pint>=0.22 From 42ed6d30e81dce5b9922ac82f76c5b3cd748b19e Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Fri, 28 Jun 2024 10:18:55 +0200 Subject: [PATCH 15/33] promote floating-point numeric datetimes to 64-bit before decoding (#9182) * promote floating-point dates to 64-bit while decoding * add a test to make sure we don't regress * whats-new entry --- doc/whats-new.rst | 2 ++ xarray/coding/times.py | 2 ++ xarray/tests/test_coding_times.py | 16 ++++++++++++++++ 3 files changed, 20 insertions(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 97631b4c324..c58f73cb1fa 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -37,6 +37,8 @@ Bug fixes ~~~~~~~~~ - Make :py:func:`testing.assert_allclose` work with numpy 2.0 (:issue:`9165`, :pull:`9166`). By `Pontus Lurcock `_. +- Promote floating-point numeric datetimes before decoding (:issue:`9179`, :pull:`9182`). + By `Justus Magin `_. Documentation diff --git a/xarray/coding/times.py b/xarray/coding/times.py index 466e847e003..34d4f9a23ad 100644 --- a/xarray/coding/times.py +++ b/xarray/coding/times.py @@ -278,6 +278,8 @@ def _decode_datetime_with_pandas( # timedelta64 value, and therefore would raise an error in the lines above. if flat_num_dates.dtype.kind in "iu": flat_num_dates = flat_num_dates.astype(np.int64) + elif flat_num_dates.dtype.kind in "f": + flat_num_dates = flat_num_dates.astype(np.float64) # Cast input ordinals to integers of nanoseconds because pd.to_timedelta # works much faster when dealing with integers (GH 1399). diff --git a/xarray/tests/test_coding_times.py b/xarray/tests/test_coding_times.py index 09221d66066..393f8400c46 100644 --- a/xarray/tests/test_coding_times.py +++ b/xarray/tests/test_coding_times.py @@ -1182,6 +1182,22 @@ def test_decode_0size_datetime(use_cftime): np.testing.assert_equal(expected, actual) +def test_decode_float_datetime(): + num_dates = np.array([1867128, 1867134, 1867140], dtype="float32") + units = "hours since 1800-01-01" + calendar = "standard" + + expected = np.array( + ["2013-01-01T00:00:00", "2013-01-01T06:00:00", "2013-01-01T12:00:00"], + dtype="datetime64[ns]", + ) + + actual = decode_cf_datetime( + num_dates, units=units, calendar=calendar, use_cftime=False + ) + np.testing.assert_equal(actual, expected) + + @requires_cftime def test_scalar_unit() -> None: # test that a scalar units (often NaN when using to_netcdf) does not raise an error From caed27437cc695e6fc83475c24c9ae2268806f28 Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Sun, 30 Jun 2024 16:03:46 +0200 Subject: [PATCH 16/33] `"source"` encoding for datasets opened from `fsspec` objects (#8923) * draft for setting `source` from pre-opened `fsspec` file objects * refactor to only import `fsspec` if we're actually going to check Could use `getattr(filename_or_obj, "path", filename_or_obj)` to avoid `isinstance` checks. * replace with a simple `getattr` on `"path"` * add a test * whats-new entry * open the file as a context manager --- doc/whats-new.rst | 2 ++ xarray/backends/api.py | 7 +++++-- xarray/tests/test_backends.py | 15 +++++++++++++++ 3 files changed, 22 insertions(+), 2 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index c58f73cb1fa..0174e16602f 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -24,6 +24,8 @@ New Features ~~~~~~~~~~~~ - Allow chunking for arrays with duplicated dimension names (:issue:`8759`, :pull:`9099`). By `Martin Raspaud `_. +- Extract the source url from fsspec objects (:issue:`9142`, :pull:`8923`). + By `Justus Magin `_. Breaking changes ~~~~~~~~~~~~~~~~ diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 7054c62126e..521bdf65e6a 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -382,8 +382,11 @@ def _dataset_from_backend_dataset( ds.set_close(backend_ds._close) # Ensure source filename always stored in dataset object - if "source" not in ds.encoding and isinstance(filename_or_obj, (str, os.PathLike)): - ds.encoding["source"] = _normalize_path(filename_or_obj) + if "source" not in ds.encoding: + path = getattr(filename_or_obj, "path", filename_or_obj) + + if isinstance(path, (str, os.PathLike)): + ds.encoding["source"] = _normalize_path(path) return ds diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 177700a5404..15485dc178a 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -5151,6 +5151,21 @@ def test_source_encoding_always_present_with_pathlib() -> None: assert ds.encoding["source"] == tmp +@requires_h5netcdf +@requires_fsspec +def test_source_encoding_always_present_with_fsspec() -> None: + import fsspec + + rnddata = np.random.randn(10) + original = Dataset({"foo": ("x", rnddata)}) + with create_tmp_file() as tmp: + original.to_netcdf(tmp) + + fs = fsspec.filesystem("file") + with fs.open(tmp) as f, open_dataset(f) as ds: + assert ds.encoding["source"] == tmp + + def _assert_no_dates_out_of_range_warning(record): undesired_message = "dates out of range" for warning in record: From 3deee7bb535dba9a48ee590c7f5119a7f2d779be Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Sun, 30 Jun 2024 18:46:30 +0200 Subject: [PATCH 17/33] properly diff objects with arrays as attributes on variables (#9169) * move the attr comparison into a common function * check that we can actually diff objects with array attrs * whats-new entry * Add property test * Add more dtypes * Better test * Fix skip * Use simple attrs strategy --------- Co-authored-by: Deepak Cherian Co-authored-by: Deepak Cherian --- doc/whats-new.rst | 2 ++ properties/test_properties.py | 17 +++++++++++++++++ xarray/core/formatting.py | 18 ++++++++++++------ xarray/testing/strategies.py | 6 +++++- xarray/tests/test_formatting.py | 30 ++++++++++++++++++++++++++++++ 5 files changed, 66 insertions(+), 7 deletions(-) create mode 100644 properties/test_properties.py diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 0174e16602f..f3ab5d46e1d 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -39,6 +39,8 @@ Bug fixes ~~~~~~~~~ - Make :py:func:`testing.assert_allclose` work with numpy 2.0 (:issue:`9165`, :pull:`9166`). By `Pontus Lurcock `_. +- Allow diffing objects with array attributes on variables (:issue:`9153`, :pull:`9169`). + By `Justus Magin `_. - Promote floating-point numeric datetimes before decoding (:issue:`9179`, :pull:`9182`). By `Justus Magin `_. diff --git a/properties/test_properties.py b/properties/test_properties.py new file mode 100644 index 00000000000..fc0a1955539 --- /dev/null +++ b/properties/test_properties.py @@ -0,0 +1,17 @@ +import pytest + +pytest.importorskip("hypothesis") + +from hypothesis import given + +import xarray as xr +import xarray.testing.strategies as xrst + + +@given(attrs=xrst.simple_attrs) +def test_assert_identical(attrs): + v = xr.Variable(dims=(), data=0, attrs=attrs) + xr.testing.assert_identical(v, v.copy(deep=True)) + + ds = xr.Dataset(attrs=attrs) + xr.testing.assert_identical(ds, ds.copy(deep=True)) diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py index c15df34b5b1..5c4a3015843 100644 --- a/xarray/core/formatting.py +++ b/xarray/core/formatting.py @@ -765,6 +765,12 @@ def _diff_mapping_repr( a_indexes=None, b_indexes=None, ): + def compare_attr(a, b): + if is_duck_array(a) or is_duck_array(b): + return array_equiv(a, b) + else: + return a == b + def extra_items_repr(extra_keys, mapping, ab_side, kwargs): extra_repr = [ summarizer(k, mapping[k], col_width, **kwargs[k]) for k in extra_keys @@ -801,11 +807,7 @@ def extra_items_repr(extra_keys, mapping, ab_side, kwargs): is_variable = True except AttributeError: # compare attribute value - if is_duck_array(a_mapping[k]) or is_duck_array(b_mapping[k]): - compatible = array_equiv(a_mapping[k], b_mapping[k]) - else: - compatible = a_mapping[k] == b_mapping[k] - + compatible = compare_attr(a_mapping[k], b_mapping[k]) is_variable = False if not compatible: @@ -821,7 +823,11 @@ def extra_items_repr(extra_keys, mapping, ab_side, kwargs): attrs_to_print = set(a_attrs) ^ set(b_attrs) attrs_to_print.update( - {k for k in set(a_attrs) & set(b_attrs) if a_attrs[k] != b_attrs[k]} + { + k + for k in set(a_attrs) & set(b_attrs) + if not compare_attr(a_attrs[k], b_attrs[k]) + } ) for m in (a_mapping, b_mapping): attr_s = "\n".join( diff --git a/xarray/testing/strategies.py b/xarray/testing/strategies.py index 449d0c793cc..085b70e518b 100644 --- a/xarray/testing/strategies.py +++ b/xarray/testing/strategies.py @@ -192,10 +192,14 @@ def dimension_sizes( max_side=2, max_dims=2, ), - dtype=npst.scalar_dtypes(), + dtype=npst.scalar_dtypes() + | npst.byte_string_dtypes() + | npst.unicode_string_dtypes(), ) _attr_values = st.none() | st.booleans() | _readable_strings | _small_arrays +simple_attrs = st.dictionaries(_attr_keys, _attr_values) + def attrs() -> st.SearchStrategy[Mapping[Hashable, Any]]: """ diff --git a/xarray/tests/test_formatting.py b/xarray/tests/test_formatting.py index d7a46eeaefc..6c49ab456f6 100644 --- a/xarray/tests/test_formatting.py +++ b/xarray/tests/test_formatting.py @@ -399,6 +399,36 @@ def test_diff_attrs_repr_with_array(self) -> None: actual = formatting.diff_attrs_repr(attrs_a, attrs_c, "equals") assert expected == actual + def test__diff_mapping_repr_array_attrs_on_variables(self) -> None: + a = { + "a": xr.DataArray( + dims="x", + data=np.array([1], dtype="int16"), + attrs={"b": np.array([1, 2], dtype="int8")}, + ) + } + b = { + "a": xr.DataArray( + dims="x", + data=np.array([1], dtype="int16"), + attrs={"b": np.array([2, 3], dtype="int8")}, + ) + } + actual = formatting.diff_data_vars_repr(a, b, compat="identical", col_width=8) + expected = dedent( + """\ + Differing data variables: + L a (x) int16 2B 1 + Differing variable attributes: + b: [1 2] + R a (x) int16 2B 1 + Differing variable attributes: + b: [2 3] + """.rstrip() + ) + + assert actual == expected + def test_diff_dataset_repr(self) -> None: ds_a = xr.Dataset( data_vars={ From fff82539c7b0f045c35ace332c4f6ecb365a0612 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Sun, 30 Jun 2024 23:13:15 +0200 Subject: [PATCH 18/33] Allow str in static typing of reindex, ffill etc. (#9194) * allow str in reindex * add whats-new --- doc/whats-new.rst | 3 ++- xarray/core/alignment.py | 6 +++--- xarray/core/dataarray.py | 8 ++++---- xarray/core/dataset.py | 8 ++++---- xarray/core/resample.py | 16 ++++++++++------ xarray/tests/test_groupby.py | 6 +++--- 6 files changed, 26 insertions(+), 21 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index f3ab5d46e1d..ac849c7ec19 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -43,7 +43,8 @@ Bug fixes By `Justus Magin `_. - Promote floating-point numeric datetimes before decoding (:issue:`9179`, :pull:`9182`). By `Justus Magin `_. - +- Fiy static typing of tolerance arguments by allowing `str` type (:issue:`8892`, :pull:`9194`). + By `Michael Niklas `_. Documentation ~~~~~~~~~~~~~ diff --git a/xarray/core/alignment.py b/xarray/core/alignment.py index 13e3400d170..44fc7319170 100644 --- a/xarray/core/alignment.py +++ b/xarray/core/alignment.py @@ -137,7 +137,7 @@ def __init__( exclude_dims: str | Iterable[Hashable] = frozenset(), exclude_vars: Iterable[Hashable] = frozenset(), method: str | None = None, - tolerance: int | float | Iterable[int | float] | None = None, + tolerance: float | Iterable[float] | str | None = None, copy: bool = True, fill_value: Any = dtypes.NA, sparse: bool = False, @@ -965,7 +965,7 @@ def reindex( obj: T_Alignable, indexers: Mapping[Any, Any], method: str | None = None, - tolerance: int | float | Iterable[int | float] | None = None, + tolerance: float | Iterable[float] | str | None = None, copy: bool = True, fill_value: Any = dtypes.NA, sparse: bool = False, @@ -1004,7 +1004,7 @@ def reindex_like( obj: T_Alignable, other: Dataset | DataArray, method: str | None = None, - tolerance: int | float | Iterable[int | float] | None = None, + tolerance: float | Iterable[float] | str | None = None, copy: bool = True, fill_value: Any = dtypes.NA, ) -> T_Alignable: diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index d3390d26655..b67f8089eb2 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -1909,7 +1909,7 @@ def reindex_like( other: T_DataArrayOrSet, *, method: ReindexMethodOptions = None, - tolerance: int | float | Iterable[int | float] | None = None, + tolerance: float | Iterable[float] | str | None = None, copy: bool = True, fill_value=dtypes.NA, ) -> Self: @@ -1936,7 +1936,7 @@ def reindex_like( - backfill / bfill: propagate next valid index value backward - nearest: use nearest valid index value - tolerance : optional + tolerance : float | Iterable[float] | str | None, default: None Maximum distance between original and new labels for inexact matches. The values of the index at the matching locations must satisfy the equation ``abs(index[indexer] - target) <= tolerance``. @@ -2096,7 +2096,7 @@ def reindex( indexers: Mapping[Any, Any] | None = None, *, method: ReindexMethodOptions = None, - tolerance: float | Iterable[float] | None = None, + tolerance: float | Iterable[float] | str | None = None, copy: bool = True, fill_value=dtypes.NA, **indexers_kwargs: Any, @@ -2126,7 +2126,7 @@ def reindex( - backfill / bfill: propagate next valid index value backward - nearest: use nearest valid index value - tolerance : float | Iterable[float] | None, default: None + tolerance : float | Iterable[float] | str | None, default: None Maximum distance between original and new labels for inexact matches. The values of the index at the matching locations must satisfy the equation ``abs(index[indexer] - target) <= tolerance``. diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 0b8be674675..50cfc7b0c29 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -3499,7 +3499,7 @@ def reindex_like( self, other: T_Xarray, method: ReindexMethodOptions = None, - tolerance: int | float | Iterable[int | float] | None = None, + tolerance: float | Iterable[float] | str | None = None, copy: bool = True, fill_value: Any = xrdtypes.NA, ) -> Self: @@ -3526,7 +3526,7 @@ def reindex_like( - "backfill" / "bfill": propagate next valid index value backward - "nearest": use nearest valid index value - tolerance : optional + tolerance : float | Iterable[float] | str | None, default: None Maximum distance between original and new labels for inexact matches. The values of the index at the matching locations must satisfy the equation ``abs(index[indexer] - target) <= tolerance``. @@ -3569,7 +3569,7 @@ def reindex( self, indexers: Mapping[Any, Any] | None = None, method: ReindexMethodOptions = None, - tolerance: int | float | Iterable[int | float] | None = None, + tolerance: float | Iterable[float] | str | None = None, copy: bool = True, fill_value: Any = xrdtypes.NA, **indexers_kwargs: Any, @@ -3594,7 +3594,7 @@ def reindex( - "backfill" / "bfill": propagate next valid index value backward - "nearest": use nearest valid index value - tolerance : optional + tolerance : float | Iterable[float] | str | None, default: None Maximum distance between original and new labels for inexact matches. The values of the index at the matching locations must satisfy the equation ``abs(index[indexer] - target) <= tolerance``. diff --git a/xarray/core/resample.py b/xarray/core/resample.py index ceab0a891c9..ec86f2a283f 100644 --- a/xarray/core/resample.py +++ b/xarray/core/resample.py @@ -66,12 +66,12 @@ def _drop_coords(self) -> T_Xarray: obj = obj.drop_vars([k]) return obj - def pad(self, tolerance: float | Iterable[float] | None = None) -> T_Xarray: + def pad(self, tolerance: float | Iterable[float] | str | None = None) -> T_Xarray: """Forward fill new values at up-sampled frequency. Parameters ---------- - tolerance : float | Iterable[float] | None, default: None + tolerance : float | Iterable[float] | str | None, default: None Maximum distance between original and new labels to limit the up-sampling method. Up-sampled data with indices that satisfy the equation @@ -91,12 +91,14 @@ def pad(self, tolerance: float | Iterable[float] | None = None) -> T_Xarray: ffill = pad - def backfill(self, tolerance: float | Iterable[float] | None = None) -> T_Xarray: + def backfill( + self, tolerance: float | Iterable[float] | str | None = None + ) -> T_Xarray: """Backward fill new values at up-sampled frequency. Parameters ---------- - tolerance : float | Iterable[float] | None, default: None + tolerance : float | Iterable[float] | str | None, default: None Maximum distance between original and new labels to limit the up-sampling method. Up-sampled data with indices that satisfy the equation @@ -116,13 +118,15 @@ def backfill(self, tolerance: float | Iterable[float] | None = None) -> T_Xarray bfill = backfill - def nearest(self, tolerance: float | Iterable[float] | None = None) -> T_Xarray: + def nearest( + self, tolerance: float | Iterable[float] | str | None = None + ) -> T_Xarray: """Take new values from nearest original coordinate to up-sampled frequency coordinates. Parameters ---------- - tolerance : float | Iterable[float] | None, default: None + tolerance : float | Iterable[float] | str | None, default: None Maximum distance between original and new labels to limit the up-sampling method. Up-sampled data with indices that satisfy the equation diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 47cda064143..f0a0fd14d9d 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -2037,17 +2037,17 @@ def test_upsample_tolerance(self) -> None: array = DataArray(np.arange(2), [("time", times)]) # Forward fill - actual = array.resample(time="6h").ffill(tolerance="12h") # type: ignore[arg-type] # TODO: tolerance also allows strings, same issue in .reindex. + actual = array.resample(time="6h").ffill(tolerance="12h") expected = DataArray([0.0, 0.0, 0.0, np.nan, 1.0], [("time", times_upsampled)]) assert_identical(expected, actual) # Backward fill - actual = array.resample(time="6h").bfill(tolerance="12h") # type: ignore[arg-type] # TODO: tolerance also allows strings, same issue in .reindex. + actual = array.resample(time="6h").bfill(tolerance="12h") expected = DataArray([0.0, np.nan, 1.0, 1.0, 1.0], [("time", times_upsampled)]) assert_identical(expected, actual) # Nearest - actual = array.resample(time="6h").nearest(tolerance="6h") # type: ignore[arg-type] # TODO: tolerance also allows strings, same issue in .reindex. + actual = array.resample(time="6h").nearest(tolerance="6h") expected = DataArray([0, 0, np.nan, 1, 1], [("time", times_upsampled)]) assert_identical(expected, actual) From 24ab84cb0dbc2706677bab2e3765050f1d4f9646 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dieter=20Werthm=C3=BCller?= Date: Mon, 1 Jul 2024 16:22:59 +0200 Subject: [PATCH 19/33] Fix dark-theme in `html[data-theme=dark]`-tags (#9200) * Fix dark-theme in html tag * Add to release notes --- doc/whats-new.rst | 2 ++ xarray/static/css/style.css | 1 + 2 files changed, 3 insertions(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index ac849c7ec19..685cdf28194 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -45,6 +45,8 @@ Bug fixes By `Justus Magin `_. - Fiy static typing of tolerance arguments by allowing `str` type (:issue:`8892`, :pull:`9194`). By `Michael Niklas `_. +- Dark themes are now properly detected for ``html[data-theme=dark]``-tags (:pull:`9200`). + By `Dieter Werthmüller `_. Documentation ~~~~~~~~~~~~~ diff --git a/xarray/static/css/style.css b/xarray/static/css/style.css index e0a51312b10..dbe61e311c1 100644 --- a/xarray/static/css/style.css +++ b/xarray/static/css/style.css @@ -14,6 +14,7 @@ } html[theme=dark], +html[data-theme=dark], body[data-theme=dark], body.vscode-dark { --xr-font-color0: rgba(255, 255, 255, 1); From 90e44867f7270e7de5e31b8713224039e39d9704 Mon Sep 17 00:00:00 2001 From: Alfonso Ladino Date: Mon, 1 Jul 2024 09:33:15 -0500 Subject: [PATCH 20/33] Add open_datatree benchmark (#9158) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * open_datatree performance improvement on NetCDF files * fixing issue with forward slashes * fixing issue with pytest * open datatree in zarr format improvement * fixing incompatibility in returned object * passing group parameter to opendatatree method and reducing duplicated code * passing group parameter to opendatatree method - NetCDF * Update xarray/backends/netCDF4_.py renaming variables Co-authored-by: Tom Nicholas * renaming variables * renaming variables * renaming group_store variable * removing _open_datatree_netcdf function not used anymore in open_datatree implementations * improving performance of open_datatree method * renaming 'i' variable within list comprehension in open_store method for zarr datatree * using the default generator instead of loading zarr groups in memory * fixing issue with group path to avoid using group[1:] notation. Adding group variable typing hints (str | Iterable[str] | callable) under the open_datatree for h5 files. Finally, separating positional from keyword args * fixing issue with group path to avoid using group[1:] notation and adding group variable typing hints (str | Iterable[str] | callable) under the open_datatree method for netCDF files * fixing issue with group path to avoid using group[1:] notation and adding group variable typing hints (str | Iterable[str] | callable) under the open_datatree method for zarr files * adding 'mode' parameter to open_datatree method * adding 'mode' parameter to H5NetCDFStore.open method * adding new entry related to open_datatree performance improvement * adding new entry related to open_datatree performance improvement * Getting rid of unnecessary parameters for 'open_datatree' method for netCDF4 and Hdf5 backends * passing parent argument into _iter_zarr_groups instead of group[1:] for creating group path * adding benchmark test for opening a deeply nested data tree. This include a new class named 'IONestedDataTree' and another class for benchmarck named 'IOReadDataTreeNetCDF4' * Update doc/whats-new.rst --------- Co-authored-by: Tom Nicholas Co-authored-by: Kai Mühlbauer Co-authored-by: Deepak Cherian --- asv_bench/benchmarks/dataset_io.py | 113 ++++++++++++++++++++++++++++- xarray/backends/zarr.py | 2 +- 2 files changed, 113 insertions(+), 2 deletions(-) diff --git a/asv_bench/benchmarks/dataset_io.py b/asv_bench/benchmarks/dataset_io.py index dcc2de0473b..0956be67dad 100644 --- a/asv_bench/benchmarks/dataset_io.py +++ b/asv_bench/benchmarks/dataset_io.py @@ -7,6 +7,8 @@ import pandas as pd import xarray as xr +from xarray.backends.api import open_datatree +from xarray.core.datatree import DataTree from . import _skip_slow, parameterized, randint, randn, requires_dask @@ -16,7 +18,6 @@ except ImportError: pass - os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE" _ENGINES = tuple(xr.backends.list_engines().keys() - {"store"}) @@ -469,6 +470,116 @@ def create_delayed_write(): return ds.to_netcdf("file.nc", engine="netcdf4", compute=False) +class IONestedDataTree: + """ + A few examples that benchmark reading/writing a heavily nested netCDF datatree with + xarray + """ + + timeout = 300.0 + repeat = 1 + number = 5 + + def make_datatree(self, nchildren=10): + # multiple Dataset + self.ds = xr.Dataset() + self.nt = 1000 + self.nx = 90 + self.ny = 45 + self.nchildren = nchildren + + self.block_chunks = { + "time": self.nt / 4, + "lon": self.nx / 3, + "lat": self.ny / 3, + } + + self.time_chunks = {"time": int(self.nt / 36)} + + times = pd.date_range("1970-01-01", periods=self.nt, freq="D") + lons = xr.DataArray( + np.linspace(0, 360, self.nx), + dims=("lon",), + attrs={"units": "degrees east", "long_name": "longitude"}, + ) + lats = xr.DataArray( + np.linspace(-90, 90, self.ny), + dims=("lat",), + attrs={"units": "degrees north", "long_name": "latitude"}, + ) + self.ds["foo"] = xr.DataArray( + randn((self.nt, self.nx, self.ny), frac_nan=0.2), + coords={"lon": lons, "lat": lats, "time": times}, + dims=("time", "lon", "lat"), + name="foo", + attrs={"units": "foo units", "description": "a description"}, + ) + self.ds["bar"] = xr.DataArray( + randn((self.nt, self.nx, self.ny), frac_nan=0.2), + coords={"lon": lons, "lat": lats, "time": times}, + dims=("time", "lon", "lat"), + name="bar", + attrs={"units": "bar units", "description": "a description"}, + ) + self.ds["baz"] = xr.DataArray( + randn((self.nx, self.ny), frac_nan=0.2).astype(np.float32), + coords={"lon": lons, "lat": lats}, + dims=("lon", "lat"), + name="baz", + attrs={"units": "baz units", "description": "a description"}, + ) + + self.ds.attrs = {"history": "created for xarray benchmarking"} + + self.oinds = { + "time": randint(0, self.nt, 120), + "lon": randint(0, self.nx, 20), + "lat": randint(0, self.ny, 10), + } + self.vinds = { + "time": xr.DataArray(randint(0, self.nt, 120), dims="x"), + "lon": xr.DataArray(randint(0, self.nx, 120), dims="x"), + "lat": slice(3, 20), + } + root = {f"group_{group}": self.ds for group in range(self.nchildren)} + nested_tree1 = { + f"group_{group}/subgroup_1": xr.Dataset() for group in range(self.nchildren) + } + nested_tree2 = { + f"group_{group}/subgroup_2": xr.DataArray(np.arange(1, 10)).to_dataset( + name="a" + ) + for group in range(self.nchildren) + } + nested_tree3 = { + f"group_{group}/subgroup_2/sub-subgroup_1": self.ds + for group in range(self.nchildren) + } + dtree = root | nested_tree1 | nested_tree2 | nested_tree3 + self.dtree = DataTree.from_dict(dtree) + + +class IOReadDataTreeNetCDF4(IONestedDataTree): + def setup(self): + # TODO: Lazily skipped in CI as it is very demanding and slow. + # Improve times and remove errors. + _skip_slow() + + requires_dask() + + self.make_datatree() + self.format = "NETCDF4" + self.filepath = "datatree.nc4.nc" + dtree = self.dtree + dtree.to_netcdf(filepath=self.filepath) + + def time_load_datatree_netcdf4(self): + open_datatree(self.filepath, engine="netcdf4").load() + + def time_open_datatree_netcdf4(self): + open_datatree(self.filepath, engine="netcdf4") + + class IOWriteNetCDFDask: timeout = 60 repeat = 1 diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index 9796fcbf9e2..85a1a6e214c 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -446,7 +446,7 @@ def open_store( stacklevel=stacklevel, zarr_version=zarr_version, ) - group_paths = [str(group / node[1:]) for node in _iter_zarr_groups(zarr_group)] + group_paths = [node for node in _iter_zarr_groups(zarr_group, parent=group)] return { group: cls( zarr_group.get(group), From 6c2d8c3389afe049ccbfd1393e9a81dd5c759f78 Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Mon, 1 Jul 2024 16:47:10 +0200 Subject: [PATCH 21/33] use a `composite` strategy to generate the dataframe with a tz-aware datetime column (#9174) * use a `composite` to generate the dataframe with a tz-aware dt column * remove the `xfail` --- properties/test_pandas_roundtrip.py | 39 +++++++++++++++++------------ 1 file changed, 23 insertions(+), 16 deletions(-) diff --git a/properties/test_pandas_roundtrip.py b/properties/test_pandas_roundtrip.py index 0249aa59d5b..9e0d4640171 100644 --- a/properties/test_pandas_roundtrip.py +++ b/properties/test_pandas_roundtrip.py @@ -9,7 +9,6 @@ import pytest import xarray as xr -from xarray.tests import has_pandas_3 pytest.importorskip("hypothesis") import hypothesis.extra.numpy as npst # isort:skip @@ -25,22 +24,34 @@ numeric_series = numeric_dtypes.flatmap(lambda dt: pdst.series(dtype=dt)) + +@st.composite +def dataframe_strategy(draw): + tz = draw(st.timezones()) + dtype = pd.DatetimeTZDtype(unit="ns", tz=tz) + + datetimes = st.datetimes( + min_value=pd.Timestamp("1677-09-21T00:12:43.145224193"), + max_value=pd.Timestamp("2262-04-11T23:47:16.854775807"), + timezones=st.just(tz), + ) + + df = pdst.data_frames( + [ + pdst.column("datetime_col", elements=datetimes), + pdst.column("other_col", elements=st.integers()), + ], + index=pdst.range_indexes(min_size=1, max_size=10), + ) + return draw(df).astype({"datetime_col": dtype}) + + an_array = npst.arrays( dtype=numeric_dtypes, shape=npst.array_shapes(max_dims=2), # can only convert 1D/2D to pandas ) -datetime_with_tz_strategy = st.datetimes(timezones=st.timezones()) -dataframe_strategy = pdst.data_frames( - [ - pdst.column("datetime_col", elements=datetime_with_tz_strategy), - pdst.column("other_col", elements=st.integers()), - ], - index=pdst.range_indexes(min_size=1, max_size=10), -) - - @st.composite def datasets_1d_vars(draw) -> xr.Dataset: """Generate datasets with only 1D variables @@ -111,11 +122,7 @@ def test_roundtrip_pandas_dataframe(df) -> None: xr.testing.assert_identical(arr, roundtripped.to_xarray()) -@pytest.mark.skipif( - has_pandas_3, - reason="fails to roundtrip on pandas 3 (see https://github.com/pydata/xarray/issues/9098)", -) -@given(df=dataframe_strategy) +@given(df=dataframe_strategy()) def test_roundtrip_pandas_dataframe_datetime(df) -> None: # Need to name the indexes, otherwise Xarray names them 'dim_0', 'dim_1'. df.index.name = "rows" From a86c3ff446c68a90c2a3ea9c961d41635b691b91 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Tue, 2 Jul 2024 17:35:04 -0700 Subject: [PATCH 22/33] Hierarchical coordinates in DataTree (#9063) * Inheritance of data coordinates * Simplify __init__ * Include path name in alignment errors * Fix some mypy errors * mypy fix * simplify DataTree data model * Add to_dataset(local=True) * Fix mypy failure in tests * Fix to_zarr for inherited coords * Fix to_netcdf for heirarchical coords * Add ChainSet * Revise internal data model; remove ChainSet * add another way to construct inherited indexes * Finish refactoring error message * include inherited dimensions in HTML repr, too * Construct ChainMap objects on demand. * slightly better error message with mis-aligned data trees * mypy fix * use float64 instead of float32 for windows * clean-up per review * Add note about inheritance to .ds docs --- xarray/core/datatree.py | 443 +++++++++++++------------ xarray/core/datatree_io.py | 4 +- xarray/core/formatting.py | 23 +- xarray/core/formatting_html.py | 2 +- xarray/core/treenode.py | 16 +- xarray/tests/test_backends_datatree.py | 42 ++- xarray/tests/test_datatree.py | 231 ++++++++++++- xarray/tests/test_utils.py | 6 +- 8 files changed, 527 insertions(+), 240 deletions(-) diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index c923ca2eb87..38f8f8cd495 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -1,7 +1,8 @@ from __future__ import annotations -import copy import itertools +import textwrap +from collections import ChainMap from collections.abc import Hashable, Iterable, Iterator, Mapping, MutableMapping from html import escape from typing import ( @@ -16,6 +17,7 @@ ) from xarray.core import utils +from xarray.core.alignment import align from xarray.core.common import TreeAttrAccessMixin from xarray.core.coordinates import DatasetCoordinates from xarray.core.dataarray import DataArray @@ -31,7 +33,7 @@ MappedDataWithCoords, ) from xarray.core.datatree_render import RenderDataTree -from xarray.core.formatting import datatree_repr +from xarray.core.formatting import datatree_repr, dims_and_coords_repr from xarray.core.formatting_html import ( datatree_repr as datatree_repr_html, ) @@ -79,11 +81,24 @@ T_Path = Union[str, NodePath] +def _collect_data_and_coord_variables( + data: Dataset, +) -> tuple[dict[Hashable, Variable], dict[Hashable, Variable]]: + data_variables = {} + coord_variables = {} + for k, v in data.variables.items(): + if k in data._coord_names: + coord_variables[k] = v + else: + data_variables[k] = v + return data_variables, coord_variables + + def _coerce_to_dataset(data: Dataset | DataArray | None) -> Dataset: if isinstance(data, DataArray): ds = data.to_dataset() elif isinstance(data, Dataset): - ds = data + ds = data.copy(deep=False) elif data is None: ds = Dataset() else: @@ -93,14 +108,57 @@ def _coerce_to_dataset(data: Dataset | DataArray | None) -> Dataset: return ds -def _check_for_name_collisions( - children: Iterable[str], variables: Iterable[Hashable] +def _join_path(root: str, name: str) -> str: + return str(NodePath(root) / name) + + +def _inherited_dataset(ds: Dataset, parent: Dataset) -> Dataset: + return Dataset._construct_direct( + variables=parent._variables | ds._variables, + coord_names=parent._coord_names | ds._coord_names, + dims=parent._dims | ds._dims, + attrs=ds._attrs, + indexes=parent._indexes | ds._indexes, + encoding=ds._encoding, + close=ds._close, + ) + + +def _without_header(text: str) -> str: + return "\n".join(text.split("\n")[1:]) + + +def _indented(text: str) -> str: + return textwrap.indent(text, prefix=" ") + + +def _check_alignment( + path: str, + node_ds: Dataset, + parent_ds: Dataset | None, + children: Mapping[str, DataTree], ) -> None: - colliding_names = set(children).intersection(set(variables)) - if colliding_names: - raise KeyError( - f"Some names would collide between variables and children: {list(colliding_names)}" - ) + if parent_ds is not None: + try: + align(node_ds, parent_ds, join="exact") + except ValueError as e: + node_repr = _indented(_without_header(repr(node_ds))) + parent_repr = _indented(dims_and_coords_repr(parent_ds)) + raise ValueError( + f"group {path!r} is not aligned with its parents:\n" + f"Group:\n{node_repr}\nFrom parents:\n{parent_repr}" + ) from e + + if children: + if parent_ds is not None: + base_ds = _inherited_dataset(node_ds, parent_ds) + else: + base_ds = node_ds + + for child_name, child in children.items(): + child_path = str(NodePath(path) / child_name) + child_ds = child.to_dataset(inherited=False) + _check_alignment(child_path, child_ds, base_ds, child.children) class DatasetView(Dataset): @@ -118,7 +176,7 @@ class DatasetView(Dataset): __slots__ = ( "_attrs", - "_cache", + "_cache", # used by _CachedAccessor "_coord_names", "_dims", "_encoding", @@ -136,21 +194,27 @@ def __init__( raise AttributeError("DatasetView objects are not to be initialized directly") @classmethod - def _from_node( + def _constructor( cls, - wrapping_node: DataTree, + variables: dict[Any, Variable], + coord_names: set[Hashable], + dims: dict[Any, int], + attrs: dict | None, + indexes: dict[Any, Index], + encoding: dict | None, + close: Callable[[], None] | None, ) -> DatasetView: - """Constructor, using dataset attributes from wrapping node""" - + """Private constructor, from Dataset attributes.""" + # We override Dataset._construct_direct below, so we need a new + # constructor for creating DatasetView objects. obj: DatasetView = object.__new__(cls) - obj._variables = wrapping_node._variables - obj._coord_names = wrapping_node._coord_names - obj._dims = wrapping_node._dims - obj._indexes = wrapping_node._indexes - obj._attrs = wrapping_node._attrs - obj._close = wrapping_node._close - obj._encoding = wrapping_node._encoding - + obj._variables = variables + obj._coord_names = coord_names + obj._dims = dims + obj._indexes = indexes + obj._attrs = attrs + obj._close = close + obj._encoding = encoding return obj def __setitem__(self, key, val) -> None: @@ -337,27 +401,27 @@ class DataTree( _name: str | None _parent: DataTree | None _children: dict[str, DataTree] + _cache: dict[str, Any] # used by _CachedAccessor + _data_variables: dict[Hashable, Variable] + _node_coord_variables: dict[Hashable, Variable] + _node_dims: dict[Hashable, int] + _node_indexes: dict[Hashable, Index] _attrs: dict[Hashable, Any] | None - _cache: dict[str, Any] - _coord_names: set[Hashable] - _dims: dict[Hashable, int] _encoding: dict[Hashable, Any] | None _close: Callable[[], None] | None - _indexes: dict[Hashable, Index] - _variables: dict[Hashable, Variable] __slots__ = ( "_name", "_parent", "_children", + "_cache", # used by _CachedAccessor + "_data_variables", + "_node_coord_variables", + "_node_dims", + "_node_indexes", "_attrs", - "_cache", - "_coord_names", - "_dims", "_encoding", "_close", - "_indexes", - "_variables", ) def __init__( @@ -370,14 +434,15 @@ def __init__( """ Create a single node of a DataTree. - The node may optionally contain data in the form of data and coordinate variables, stored in the same way as - data is stored in an xarray.Dataset. + The node may optionally contain data in the form of data and coordinate + variables, stored in the same way as data is stored in an + xarray.Dataset. Parameters ---------- data : Dataset, DataArray, or None, optional - Data to store under the .ds attribute of this node. DataArrays will be promoted to Datasets. - Default is None. + Data to store under the .ds attribute of this node. DataArrays will + be promoted to Datasets. Default is None. parent : DataTree, optional Parent node to this node. Default is None. children : Mapping[str, DataTree], optional @@ -393,30 +458,48 @@ def __init__( -------- DataTree.from_dict """ - - # validate input if children is None: children = {} - ds = _coerce_to_dataset(data) - _check_for_name_collisions(children, ds.variables) super().__init__(name=name) + self._set_node_data(_coerce_to_dataset(data)) + self.parent = parent + self.children = children - # set data attributes - self._replace( - inplace=True, - variables=ds._variables, - coord_names=ds._coord_names, - dims=ds._dims, - indexes=ds._indexes, - attrs=ds._attrs, - encoding=ds._encoding, - ) + def _set_node_data(self, ds: Dataset): + data_vars, coord_vars = _collect_data_and_coord_variables(ds) + self._data_variables = data_vars + self._node_coord_variables = coord_vars + self._node_dims = ds._dims + self._node_indexes = ds._indexes + self._encoding = ds._encoding + self._attrs = ds._attrs self._close = ds._close - # set tree attributes (must happen after variables set to avoid initialization errors) - self.children = children - self.parent = parent + def _pre_attach(self: DataTree, parent: DataTree, name: str) -> None: + super()._pre_attach(parent, name) + if name in parent.ds.variables: + raise KeyError( + f"parent {parent.name} already contains a variable named {name}" + ) + path = str(NodePath(parent.path) / name) + node_ds = self.to_dataset(inherited=False) + parent_ds = parent._to_dataset_view(rebuild_dims=False) + _check_alignment(path, node_ds, parent_ds, self.children) + + @property + def _coord_variables(self) -> ChainMap[Hashable, Variable]: + return ChainMap( + self._node_coord_variables, *(p._node_coord_variables for p in self.parents) + ) + + @property + def _dims(self) -> ChainMap[Hashable, int]: + return ChainMap(self._node_dims, *(p._node_dims for p in self.parents)) + + @property + def _indexes(self) -> ChainMap[Hashable, Index]: + return ChainMap(self._node_indexes, *(p._node_indexes for p in self.parents)) @property def parent(self: DataTree) -> DataTree | None: @@ -429,71 +512,87 @@ def parent(self: DataTree, new_parent: DataTree) -> None: raise ValueError("Cannot set an unnamed node as a child of another node") self._set_parent(new_parent, self.name) + def _to_dataset_view(self, rebuild_dims: bool) -> DatasetView: + variables = dict(self._data_variables) + variables |= self._coord_variables + if rebuild_dims: + dims = calculate_dimensions(variables) + else: + # Note: rebuild_dims=False can create technically invalid Dataset + # objects because it may not contain all dimensions on its direct + # member variables, e.g., consider: + # tree = DataTree.from_dict( + # { + # "/": xr.Dataset({"a": (("x",), [1, 2])}), # x has size 2 + # "/b/c": xr.Dataset({"d": (("x",), [3])}), # x has size1 + # } + # ) + # However, they are fine for internal use cases, for align() or + # building a repr(). + dims = dict(self._dims) + return DatasetView._constructor( + variables=variables, + coord_names=set(self._coord_variables), + dims=dims, + attrs=self._attrs, + indexes=dict(self._indexes), + encoding=self._encoding, + close=None, + ) + @property def ds(self) -> DatasetView: """ An immutable Dataset-like view onto the data in this node. - For a mutable Dataset containing the same data as in this node, use `.to_dataset()` instead. + Includes inherited coordinates and indexes from parent nodes. + + For a mutable Dataset containing the same data as in this node, use + `.to_dataset()` instead. See Also -------- DataTree.to_dataset """ - return DatasetView._from_node(self) + return self._to_dataset_view(rebuild_dims=True) @ds.setter def ds(self, data: Dataset | DataArray | None = None) -> None: - # Known mypy issue for setters with different type to property: - # https://github.com/python/mypy/issues/3004 ds = _coerce_to_dataset(data) + self._replace_node(ds) - _check_for_name_collisions(self.children, ds.variables) - - self._replace( - inplace=True, - variables=ds._variables, - coord_names=ds._coord_names, - dims=ds._dims, - indexes=ds._indexes, - attrs=ds._attrs, - encoding=ds._encoding, - ) - self._close = ds._close - - def _pre_attach(self: DataTree, parent: DataTree) -> None: - """ - Method which superclass calls before setting parent, here used to prevent having two - children with duplicate names (or a data variable with the same name as a child). - """ - super()._pre_attach(parent) - if self.name in list(parent.ds.variables): - raise KeyError( - f"parent {parent.name} already contains a data variable named {self.name}" - ) - - def to_dataset(self) -> Dataset: + def to_dataset(self, inherited: bool = True) -> Dataset: """ Return the data in this node as a new xarray.Dataset object. + Parameters + ---------- + inherited : bool, optional + If False, only include coordinates and indexes defined at the level + of this DataTree node, excluding inherited coordinates. + See Also -------- DataTree.ds """ + coord_vars = self._coord_variables if inherited else self._node_coord_variables + variables = dict(self._data_variables) + variables |= coord_vars + dims = calculate_dimensions(variables) if inherited else dict(self._node_dims) return Dataset._construct_direct( - self._variables, - self._coord_names, - self._dims, - self._attrs, - self._indexes, - self._encoding, + variables, + set(coord_vars), + dims, + None if self._attrs is None else dict(self._attrs), + dict(self._indexes if inherited else self._node_indexes), + None if self._encoding is None else dict(self._encoding), self._close, ) @property - def has_data(self): - """Whether or not there are any data variables in this node.""" - return len(self._variables) > 0 + def has_data(self) -> bool: + """Whether or not there are any variables in this node.""" + return bool(self._data_variables or self._node_coord_variables) @property def has_attrs(self) -> bool: @@ -518,7 +617,7 @@ def variables(self) -> Mapping[Hashable, Variable]: Dataset invariants. It contains all variable objects constituting this DataTree node, including both data variables and coordinates. """ - return Frozen(self._variables) + return Frozen(self._data_variables | self._coord_variables) @property def attrs(self) -> dict[Hashable, Any]: @@ -579,7 +678,7 @@ def _attr_sources(self) -> Iterable[Mapping[Hashable, Any]]: def _item_sources(self) -> Iterable[Mapping[Any, Any]]: """Places to look-up items for key-completion""" yield self.data_vars - yield HybridMappingProxy(keys=self._coord_names, mapping=self.coords) + yield HybridMappingProxy(keys=self._coord_variables, mapping=self.coords) # virtual coordinates yield HybridMappingProxy(keys=self.dims, mapping=self) @@ -621,10 +720,10 @@ def __contains__(self, key: object) -> bool: return key in self.variables or key in self.children def __bool__(self) -> bool: - return bool(self.ds.data_vars) or bool(self.children) + return bool(self._data_variables) or bool(self._children) def __iter__(self) -> Iterator[Hashable]: - return itertools.chain(self.ds.data_vars, self.children) + return itertools.chain(self._data_variables, self._children) def __array__(self, dtype=None, copy=None): raise TypeError( @@ -646,122 +745,32 @@ def _repr_html_(self): return f"
{escape(repr(self))}
" return datatree_repr_html(self) - @classmethod - def _construct_direct( - cls, - variables: dict[Any, Variable], - coord_names: set[Hashable], - dims: dict[Any, int] | None = None, - attrs: dict | None = None, - indexes: dict[Any, Index] | None = None, - encoding: dict | None = None, - name: str | None = None, - parent: DataTree | None = None, - children: dict[str, DataTree] | None = None, - close: Callable[[], None] | None = None, - ) -> DataTree: - """Shortcut around __init__ for internal use when we want to skip costly validation.""" + def _replace_node( + self: DataTree, + data: Dataset | Default = _default, + children: dict[str, DataTree] | Default = _default, + ) -> None: - # data attributes - if dims is None: - dims = calculate_dimensions(variables) - if indexes is None: - indexes = {} - if children is None: - children = dict() + ds = self.to_dataset(inherited=False) if data is _default else data - obj: DataTree = object.__new__(cls) - obj._variables = variables - obj._coord_names = coord_names - obj._dims = dims - obj._indexes = indexes - obj._attrs = attrs - obj._close = close - obj._encoding = encoding - - # tree attributes - obj._name = name - obj._children = children - obj._parent = parent + if children is _default: + children = self._children - return obj + for child_name in children: + if child_name in ds.variables: + raise ValueError(f"node already contains a variable named {child_name}") - def _replace( - self: DataTree, - variables: dict[Hashable, Variable] | None = None, - coord_names: set[Hashable] | None = None, - dims: dict[Any, int] | None = None, - attrs: dict[Hashable, Any] | None | Default = _default, - indexes: dict[Hashable, Index] | None = None, - encoding: dict | None | Default = _default, - name: str | None | Default = _default, - parent: DataTree | None | Default = _default, - children: dict[str, DataTree] | None = None, - inplace: bool = False, - ) -> DataTree: - """ - Fastpath constructor for internal use. + parent_ds = ( + self.parent._to_dataset_view(rebuild_dims=False) + if self.parent is not None + else None + ) + _check_alignment(self.path, ds, parent_ds, children) - Returns an object with optionally replaced attributes. + if data is not _default: + self._set_node_data(ds) - Explicitly passed arguments are *not* copied when placed on the new - datatree. It is up to the caller to ensure that they have the right type - and are not used elsewhere. - """ - # TODO Adding new children inplace using this method will cause bugs. - # You will end up with an inconsistency between the name of the child node and the key the child is stored under. - # Use ._set() instead for now - if inplace: - if variables is not None: - self._variables = variables - if coord_names is not None: - self._coord_names = coord_names - if dims is not None: - self._dims = dims - if attrs is not _default: - self._attrs = attrs - if indexes is not None: - self._indexes = indexes - if encoding is not _default: - self._encoding = encoding - if name is not _default: - self._name = name - if parent is not _default: - self._parent = parent - if children is not None: - self._children = children - obj = self - else: - if variables is None: - variables = self._variables.copy() - if coord_names is None: - coord_names = self._coord_names.copy() - if dims is None: - dims = self._dims.copy() - if attrs is _default: - attrs = copy.copy(self._attrs) - if indexes is None: - indexes = self._indexes.copy() - if encoding is _default: - encoding = copy.copy(self._encoding) - if name is _default: - name = self._name # no need to copy str objects or None - if parent is _default: - parent = copy.copy(self._parent) - if children is _default: - children = copy.copy(self._children) - obj = self._construct_direct( - variables, - coord_names, - dims, - attrs, - indexes, - encoding, - name, - parent, - children, - ) - return obj + self._children = children def copy( self: DataTree, @@ -813,9 +822,8 @@ def _copy_node( deep: bool = False, ) -> DataTree: """Copy just one node of a tree""" - new_node: DataTree = DataTree() - new_node.name = self.name - new_node.ds = self.to_dataset().copy(deep=deep) # type: ignore[assignment] + data = self.ds.copy(deep=deep) + new_node: DataTree = DataTree(data, name=self.name) return new_node def __copy__(self: DataTree) -> DataTree: @@ -963,11 +971,12 @@ def update( raise TypeError(f"Type {type(v)} cannot be assigned to a DataTree") vars_merge_result = dataset_update_method(self.to_dataset(), new_variables) + data = Dataset._construct_direct(**vars_merge_result._asdict()) + # TODO are there any subtleties with preserving order of children like this? merged_children = {**self.children, **new_children} - self._replace( - inplace=True, children=merged_children, **vars_merge_result._asdict() - ) + + self._replace_node(data, children=merged_children) def assign( self, items: Mapping[Any, Any] | None = None, **items_kwargs: Any @@ -1042,10 +1051,12 @@ def drop_nodes( if extra: raise KeyError(f"Cannot drop all nodes - nodes {extra} not present") + result = self.copy() children_to_keep = { - name: child for name, child in self.children.items() if name not in names + name: child for name, child in result.children.items() if name not in names } - return self._replace(children=children_to_keep) + result._replace_node(children=children_to_keep) + return result @classmethod def from_dict( @@ -1137,7 +1148,9 @@ def indexes(self) -> Indexes[pd.Index]: @property def xindexes(self) -> Indexes[Index]: """Mapping of xarray Index objects used for label based indexing.""" - return Indexes(self._indexes, {k: self._variables[k] for k in self._indexes}) + return Indexes( + self._indexes, {k: self._coord_variables[k] for k in self._indexes} + ) @property def coords(self) -> DatasetCoordinates: diff --git a/xarray/core/datatree_io.py b/xarray/core/datatree_io.py index 1473e624d9e..36665a0d153 100644 --- a/xarray/core/datatree_io.py +++ b/xarray/core/datatree_io.py @@ -85,7 +85,7 @@ def _datatree_to_netcdf( unlimited_dims = {} for node in dt.subtree: - ds = node.ds + ds = node.to_dataset(inherited=False) group_path = node.path if ds is None: _create_empty_netcdf_group(filepath, group_path, mode, engine) @@ -151,7 +151,7 @@ def _datatree_to_zarr( ) for node in dt.subtree: - ds = node.ds + ds = node.to_dataset(inherited=False) group_path = node.path if ds is None: _create_empty_zarr_group(store, group_path, mode) diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py index 5c4a3015843..6dca4eba8e8 100644 --- a/xarray/core/formatting.py +++ b/xarray/core/formatting.py @@ -748,6 +748,27 @@ def dataset_repr(ds): return "\n".join(summary) +def dims_and_coords_repr(ds) -> str: + """Partial Dataset repr for use inside DataTree inheritance errors.""" + summary = [] + + col_width = _calculate_col_width(ds.coords) + max_rows = OPTIONS["display_max_rows"] + + dims_start = pretty_print("Dimensions:", col_width) + dims_values = dim_summary_limited(ds, col_width=col_width + 1, max_rows=max_rows) + summary.append(f"{dims_start}({dims_values})") + + if ds.coords: + summary.append(coords_repr(ds.coords, col_width=col_width, max_rows=max_rows)) + + unindexed_dims_str = unindexed_dims_repr(ds.dims, ds.coords, max_rows=max_rows) + if unindexed_dims_str: + summary.append(unindexed_dims_str) + + return "\n".join(summary) + + def diff_dim_summary(a, b): if a.sizes != b.sizes: return f"Differing dimensions:\n ({dim_summary(a)}) != ({dim_summary(b)})" @@ -1030,7 +1051,7 @@ def diff_datatree_repr(a: DataTree, b: DataTree, compat): def _single_node_repr(node: DataTree) -> str: """Information about this node, not including its relationships to other nodes.""" if node.has_data or node.has_attrs: - ds_info = "\n" + repr(node.ds) + ds_info = "\n" + repr(node._to_dataset_view(rebuild_dims=False)) else: ds_info = "" return f"Group: {node.path}{ds_info}" diff --git a/xarray/core/formatting_html.py b/xarray/core/formatting_html.py index 9bf5befbe3f..24b290031eb 100644 --- a/xarray/core/formatting_html.py +++ b/xarray/core/formatting_html.py @@ -386,7 +386,7 @@ def summarize_datatree_children(children: Mapping[str, DataTree]) -> str: def datatree_node_repr(group_title: str, dt: DataTree) -> str: header_components = [f"
{escape(group_title)}
"] - ds = dt.ds + ds = dt._to_dataset_view(rebuild_dims=False) sections = [ children_section(dt.children), diff --git a/xarray/core/treenode.py b/xarray/core/treenode.py index 6f51e1ffa38..77e7ed23a51 100644 --- a/xarray/core/treenode.py +++ b/xarray/core/treenode.py @@ -138,14 +138,14 @@ def _attach(self, parent: Tree | None, child_name: str | None = None) -> None: "To directly set parent, child needs a name, but child is unnamed" ) - self._pre_attach(parent) + self._pre_attach(parent, child_name) parentchildren = parent._children assert not any( child is self for child in parentchildren ), "Tree is corrupt." parentchildren[child_name] = self self._parent = parent - self._post_attach(parent) + self._post_attach(parent, child_name) else: self._parent = None @@ -415,11 +415,11 @@ def _post_detach(self: Tree, parent: Tree) -> None: """Method call after detaching from `parent`.""" pass - def _pre_attach(self: Tree, parent: Tree) -> None: + def _pre_attach(self: Tree, parent: Tree, name: str) -> None: """Method call before attaching to `parent`.""" pass - def _post_attach(self: Tree, parent: Tree) -> None: + def _post_attach(self: Tree, parent: Tree, name: str) -> None: """Method call after attaching to `parent`.""" pass @@ -567,6 +567,9 @@ def same_tree(self, other: Tree) -> bool: return self.root is other.root +AnyNamedNode = TypeVar("AnyNamedNode", bound="NamedNode") + + class NamedNode(TreeNode, Generic[Tree]): """ A TreeNode which knows its own name. @@ -606,10 +609,9 @@ def __repr__(self, level=0): def __str__(self) -> str: return f"NamedNode('{self.name}')" if self.name else "NamedNode()" - def _post_attach(self: NamedNode, parent: NamedNode) -> None: + def _post_attach(self: AnyNamedNode, parent: AnyNamedNode, name: str) -> None: """Ensures child has name attribute corresponding to key under which it has been stored.""" - key = next(k for k, v in parent.children.items() if v is self) - self.name = key + self.name = name @property def path(self) -> str: diff --git a/xarray/tests/test_backends_datatree.py b/xarray/tests/test_backends_datatree.py index 4e819eec0b5..b4c4f481359 100644 --- a/xarray/tests/test_backends_datatree.py +++ b/xarray/tests/test_backends_datatree.py @@ -1,10 +1,12 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast import pytest +import xarray as xr from xarray.backends.api import open_datatree +from xarray.core.datatree import DataTree from xarray.testing import assert_equal from xarray.tests import ( requires_h5netcdf, @@ -13,11 +15,11 @@ ) if TYPE_CHECKING: - from xarray.backends.api import T_NetcdfEngine + from xarray.core.datatree_io import T_DataTreeNetcdfEngine class DatatreeIOBase: - engine: T_NetcdfEngine | None = None + engine: T_DataTreeNetcdfEngine | None = None def test_to_netcdf(self, tmpdir, simple_datatree): filepath = tmpdir / "test.nc" @@ -27,6 +29,21 @@ def test_to_netcdf(self, tmpdir, simple_datatree): roundtrip_dt = open_datatree(filepath, engine=self.engine) assert_equal(original_dt, roundtrip_dt) + def test_to_netcdf_inherited_coords(self, tmpdir): + filepath = tmpdir / "test.nc" + original_dt = DataTree.from_dict( + { + "/": xr.Dataset({"a": (("x",), [1, 2])}, coords={"x": [3, 4]}), + "/sub": xr.Dataset({"b": (("x",), [5, 6])}), + } + ) + original_dt.to_netcdf(filepath, engine=self.engine) + + roundtrip_dt = open_datatree(filepath, engine=self.engine) + assert_equal(original_dt, roundtrip_dt) + subtree = cast(DataTree, roundtrip_dt["/sub"]) + assert "x" not in subtree.to_dataset(inherited=False).coords + def test_netcdf_encoding(self, tmpdir, simple_datatree): filepath = tmpdir / "test.nc" original_dt = simple_datatree @@ -48,12 +65,12 @@ def test_netcdf_encoding(self, tmpdir, simple_datatree): @requires_netCDF4 class TestNetCDF4DatatreeIO(DatatreeIOBase): - engine: T_NetcdfEngine | None = "netcdf4" + engine: T_DataTreeNetcdfEngine | None = "netcdf4" @requires_h5netcdf class TestH5NetCDFDatatreeIO(DatatreeIOBase): - engine: T_NetcdfEngine | None = "h5netcdf" + engine: T_DataTreeNetcdfEngine | None = "h5netcdf" @requires_zarr @@ -119,3 +136,18 @@ def test_to_zarr_default_write_mode(self, tmpdir, simple_datatree): # with default settings, to_zarr should not overwrite an existing dir with pytest.raises(zarr.errors.ContainsGroupError): simple_datatree.to_zarr(tmpdir) + + def test_to_zarr_inherited_coords(self, tmpdir): + original_dt = DataTree.from_dict( + { + "/": xr.Dataset({"a": (("x",), [1, 2])}, coords={"x": [3, 4]}), + "/sub": xr.Dataset({"b": (("x",), [5, 6])}), + } + ) + filepath = tmpdir / "test.zarr" + original_dt.to_zarr(filepath) + + roundtrip_dt = open_datatree(filepath, engine="zarr") + assert_equal(original_dt, roundtrip_dt) + subtree = cast(DataTree, roundtrip_dt["/sub"]) + assert "x" not in subtree.to_dataset(inherited=False).coords diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index b0dc2accd3e..f2b58fa2489 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -1,3 +1,5 @@ +import re +import typing from copy import copy, deepcopy from textwrap import dedent @@ -149,22 +151,37 @@ def test_is_hollow(self): assert not eve.is_hollow +class TestToDataset: + def test_to_dataset(self): + base = xr.Dataset(coords={"a": 1}) + sub = xr.Dataset(coords={"b": 2}) + tree = DataTree.from_dict({"/": base, "/sub": sub}) + subtree = typing.cast(DataTree, tree["sub"]) + + assert_identical(tree.to_dataset(inherited=False), base) + assert_identical(subtree.to_dataset(inherited=False), sub) + + sub_and_base = xr.Dataset(coords={"a": 1, "b": 2}) + assert_identical(tree.to_dataset(inherited=True), base) + assert_identical(subtree.to_dataset(inherited=True), sub_and_base) + + class TestVariablesChildrenNameCollisions: def test_parent_already_has_variable_with_childs_name(self): dt: DataTree = DataTree(data=xr.Dataset({"a": [0], "b": 1})) - with pytest.raises(KeyError, match="already contains a data variable named a"): + with pytest.raises(KeyError, match="already contains a variable named a"): DataTree(name="a", data=None, parent=dt) def test_assign_when_already_child_with_variables_name(self): dt: DataTree = DataTree(data=None) DataTree(name="a", data=None, parent=dt) - with pytest.raises(KeyError, match="names would collide"): + with pytest.raises(ValueError, match="node already contains a variable"): dt.ds = xr.Dataset({"a": 0}) # type: ignore[assignment] dt.ds = xr.Dataset() # type: ignore[assignment] new_ds = dt.to_dataset().assign(a=xr.DataArray(0)) - with pytest.raises(KeyError, match="names would collide"): + with pytest.raises(ValueError, match="node already contains a variable"): dt.ds = new_ds # type: ignore[assignment] @@ -561,7 +578,9 @@ def test_methods(self): def test_arithmetic(self, create_test_datatree): dt = create_test_datatree() - expected = create_test_datatree(modify=lambda ds: 10.0 * ds)["set1"] + expected = create_test_datatree(modify=lambda ds: 10.0 * ds)[ + "set1" + ].to_dataset() result = 10.0 * dt["set1"].ds assert result.identical(expected) @@ -648,13 +667,18 @@ def test_repr(self): │ Data variables: │ e (x) float64 16B 1.0 2.0 └── Group: /b - │ Dimensions: (y: 1) + │ Dimensions: (x: 2, y: 1) + │ Coordinates: + │ * x (x) float64 16B 2.0 3.0 │ Dimensions without coordinates: y │ Data variables: │ f (y) float64 8B 3.0 ├── Group: /b/c └── Group: /b/d - Dimensions: () + Dimensions: (x: 2, y: 1) + Coordinates: + * x (x) float64 16B 2.0 3.0 + Dimensions without coordinates: y Data variables: g float64 8B 4.0 """ @@ -666,13 +690,18 @@ def test_repr(self): """ Group: /b - │ Dimensions: (y: 1) + │ Dimensions: (x: 2, y: 1) + │ Coordinates: + │ * x (x) float64 16B 2.0 3.0 │ Dimensions without coordinates: y │ Data variables: │ f (y) float64 8B 3.0 ├── Group: /b/c └── Group: /b/d - Dimensions: () + Dimensions: (x: 2, y: 1) + Coordinates: + * x (x) float64 16B 2.0 3.0 + Dimensions without coordinates: y Data variables: g float64 8B 4.0 """ @@ -680,6 +709,192 @@ def test_repr(self): assert result == expected +def _exact_match(message: str) -> str: + return re.escape(dedent(message).strip()) + return "^" + re.escape(dedent(message.rstrip())) + "$" + + +class TestInheritance: + def test_inherited_dims(self): + dt = DataTree.from_dict( + { + "/": xr.Dataset({"d": (("x",), [1, 2])}), + "/b": xr.Dataset({"e": (("y",), [3])}), + "/c": xr.Dataset({"f": (("y",), [3, 4, 5])}), + } + ) + assert dt.sizes == {"x": 2} + # nodes should include inherited dimensions + assert dt.b.sizes == {"x": 2, "y": 1} + assert dt.c.sizes == {"x": 2, "y": 3} + # dataset objects created from nodes should not + assert dt.b.ds.sizes == {"y": 1} + assert dt.b.to_dataset(inherited=True).sizes == {"y": 1} + assert dt.b.to_dataset(inherited=False).sizes == {"y": 1} + + def test_inherited_coords_index(self): + dt = DataTree.from_dict( + { + "/": xr.Dataset({"d": (("x",), [1, 2])}, coords={"x": [2, 3]}), + "/b": xr.Dataset({"e": (("y",), [3])}), + } + ) + assert "x" in dt["/b"].indexes + assert "x" in dt["/b"].coords + xr.testing.assert_identical(dt["/x"], dt["/b/x"]) + + def test_inherited_coords_override(self): + dt = DataTree.from_dict( + { + "/": xr.Dataset(coords={"x": 1, "y": 2}), + "/b": xr.Dataset(coords={"x": 4, "z": 3}), + } + ) + assert dt.coords.keys() == {"x", "y"} + root_coords = {"x": 1, "y": 2} + sub_coords = {"x": 4, "y": 2, "z": 3} + xr.testing.assert_equal(dt["/x"], xr.DataArray(1, coords=root_coords)) + xr.testing.assert_equal(dt["/y"], xr.DataArray(2, coords=root_coords)) + assert dt["/b"].coords.keys() == {"x", "y", "z"} + xr.testing.assert_equal(dt["/b/x"], xr.DataArray(4, coords=sub_coords)) + xr.testing.assert_equal(dt["/b/y"], xr.DataArray(2, coords=sub_coords)) + xr.testing.assert_equal(dt["/b/z"], xr.DataArray(3, coords=sub_coords)) + + def test_inconsistent_dims(self): + expected_msg = _exact_match( + """ + group '/b' is not aligned with its parents: + Group: + Dimensions: (x: 1) + Dimensions without coordinates: x + Data variables: + c (x) float64 8B 3.0 + From parents: + Dimensions: (x: 2) + Dimensions without coordinates: x + """ + ) + + with pytest.raises(ValueError, match=expected_msg): + DataTree.from_dict( + { + "/": xr.Dataset({"a": (("x",), [1.0, 2.0])}), + "/b": xr.Dataset({"c": (("x",), [3.0])}), + } + ) + + dt: DataTree = DataTree() + dt["/a"] = xr.DataArray([1.0, 2.0], dims=["x"]) + with pytest.raises(ValueError, match=expected_msg): + dt["/b/c"] = xr.DataArray([3.0], dims=["x"]) + + b: DataTree = DataTree(data=xr.Dataset({"c": (("x",), [3.0])})) + with pytest.raises(ValueError, match=expected_msg): + DataTree( + data=xr.Dataset({"a": (("x",), [1.0, 2.0])}), + children={"b": b}, + ) + + def test_inconsistent_child_indexes(self): + expected_msg = _exact_match( + """ + group '/b' is not aligned with its parents: + Group: + Dimensions: (x: 1) + Coordinates: + * x (x) float64 8B 2.0 + Data variables: + *empty* + From parents: + Dimensions: (x: 1) + Coordinates: + * x (x) float64 8B 1.0 + """ + ) + + with pytest.raises(ValueError, match=expected_msg): + DataTree.from_dict( + { + "/": xr.Dataset(coords={"x": [1.0]}), + "/b": xr.Dataset(coords={"x": [2.0]}), + } + ) + + dt: DataTree = DataTree() + dt.ds = xr.Dataset(coords={"x": [1.0]}) # type: ignore + dt["/b"] = DataTree() + with pytest.raises(ValueError, match=expected_msg): + dt["/b"].ds = xr.Dataset(coords={"x": [2.0]}) + + b: DataTree = DataTree(xr.Dataset(coords={"x": [2.0]})) + with pytest.raises(ValueError, match=expected_msg): + DataTree(data=xr.Dataset(coords={"x": [1.0]}), children={"b": b}) + + def test_inconsistent_grandchild_indexes(self): + expected_msg = _exact_match( + """ + group '/b/c' is not aligned with its parents: + Group: + Dimensions: (x: 1) + Coordinates: + * x (x) float64 8B 2.0 + Data variables: + *empty* + From parents: + Dimensions: (x: 1) + Coordinates: + * x (x) float64 8B 1.0 + """ + ) + + with pytest.raises(ValueError, match=expected_msg): + DataTree.from_dict( + { + "/": xr.Dataset(coords={"x": [1.0]}), + "/b/c": xr.Dataset(coords={"x": [2.0]}), + } + ) + + dt: DataTree = DataTree() + dt.ds = xr.Dataset(coords={"x": [1.0]}) # type: ignore + dt["/b/c"] = DataTree() + with pytest.raises(ValueError, match=expected_msg): + dt["/b/c"].ds = xr.Dataset(coords={"x": [2.0]}) + + c: DataTree = DataTree(xr.Dataset(coords={"x": [2.0]})) + b: DataTree = DataTree(children={"c": c}) + with pytest.raises(ValueError, match=expected_msg): + DataTree(data=xr.Dataset(coords={"x": [1.0]}), children={"b": b}) + + def test_inconsistent_grandchild_dims(self): + expected_msg = _exact_match( + """ + group '/b/c' is not aligned with its parents: + Group: + Dimensions: (x: 1) + Dimensions without coordinates: x + Data variables: + d (x) float64 8B 3.0 + From parents: + Dimensions: (x: 2) + Dimensions without coordinates: x + """ + ) + + with pytest.raises(ValueError, match=expected_msg): + DataTree.from_dict( + { + "/": xr.Dataset({"a": (("x",), [1.0, 2.0])}), + "/b/c": xr.Dataset({"d": (("x",), [3.0])}), + } + ) + + dt: DataTree = DataTree() + dt["/a"] = xr.DataArray([1.0, 2.0], dims=["x"]) + with pytest.raises(ValueError, match=expected_msg): + dt["/b/c/d"] = xr.DataArray([3.0], dims=["x"]) + + class TestRestructuring: def test_drop_nodes(self): sue = DataTree.from_dict({"Mary": None, "Kate": None, "Ashley": None}) diff --git a/xarray/tests/test_utils.py b/xarray/tests/test_utils.py index 50061c774a8..ecec3ca507b 100644 --- a/xarray/tests/test_utils.py +++ b/xarray/tests/test_utils.py @@ -7,7 +7,11 @@ import pytest from xarray.core import duck_array_ops, utils -from xarray.core.utils import either_dict_or_kwargs, infix_dims, iterate_nested +from xarray.core.utils import ( + either_dict_or_kwargs, + infix_dims, + iterate_nested, +) from xarray.tests import assert_array_equal, requires_dask From 52a73711968337f5e623780facfb0d0d6d636633 Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Fri, 5 Jul 2024 03:26:12 +0200 Subject: [PATCH 23/33] avoid converting custom indexes to pandas indexes when formatting coordinate diffs (#9157) * use `.xindexes` to construct the diff * check that the indexes are never converted to pandas indexes * deduplicate the various custom index dummy implementations * whats-new [skip-rtd] * fill in the pr number --------- Co-authored-by: Deepak Cherian --- doc/whats-new.rst | 2 + xarray/core/formatting.py | 4 +- xarray/tests/test_formatting.py | 76 ++++++++++++++++++++++++--------- 3 files changed, 60 insertions(+), 22 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 685cdf28194..753586c32c2 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -37,6 +37,8 @@ Deprecations Bug fixes ~~~~~~~~~ +- Don't convert custom indexes to ``pandas`` indexes when computing a diff (:pull:`9157`) + By `Justus Magin `_. - Make :py:func:`testing.assert_allclose` work with numpy 2.0 (:issue:`9165`, :pull:`9166`). By `Pontus Lurcock `_. - Allow diffing objects with array attributes on variables (:issue:`9153`, :pull:`9169`). diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py index 6dca4eba8e8..6571b288fae 100644 --- a/xarray/core/formatting.py +++ b/xarray/core/formatting.py @@ -898,8 +898,8 @@ def diff_coords_repr(a, b, compat, col_width=None): "Coordinates", summarize_variable, col_width=col_width, - a_indexes=a.indexes, - b_indexes=b.indexes, + a_indexes=a.xindexes, + b_indexes=b.xindexes, ) diff --git a/xarray/tests/test_formatting.py b/xarray/tests/test_formatting.py index 6c49ab456f6..9d0eb81bace 100644 --- a/xarray/tests/test_formatting.py +++ b/xarray/tests/test_formatting.py @@ -10,9 +10,20 @@ import xarray as xr from xarray.core import formatting from xarray.core.datatree import DataTree # TODO: Remove when can do xr.DataTree +from xarray.core.indexes import Index from xarray.tests import requires_cftime, requires_dask, requires_netCDF4 +class CustomIndex(Index): + names: tuple[str, ...] + + def __init__(self, names: tuple[str, ...]): + self.names = names + + def __repr__(self): + return f"CustomIndex(coords={self.names})" + + class TestFormatting: def test_get_indexer_at_least_n_items(self) -> None: cases = [ @@ -219,17 +230,6 @@ def test_attribute_repr(self) -> None: assert "\t" not in tabs def test_index_repr(self) -> None: - from xarray.core.indexes import Index - - class CustomIndex(Index): - names: tuple[str, ...] - - def __init__(self, names: tuple[str, ...]): - self.names = names - - def __repr__(self): - return f"CustomIndex(coords={self.names})" - coord_names = ("x", "y") index = CustomIndex(coord_names) names = ("x",) @@ -258,15 +258,6 @@ def _repr_inline_(self, max_width: int): ), ) def test_index_repr_grouping(self, names) -> None: - from xarray.core.indexes import Index - - class CustomIndex(Index): - def __init__(self, names): - self.names = names - - def __repr__(self): - return f"CustomIndex(coords={self.names})" - index = CustomIndex(names) normal = formatting.summarize_index(names, index, col_width=20) @@ -337,6 +328,51 @@ def test_diff_array_repr(self) -> None: # depending on platform, dtype may not be shown in numpy array repr assert actual == expected.replace(", dtype=int64", "") + da_a = xr.DataArray( + np.array([[1, 2, 3], [4, 5, 6]], dtype="int8"), + dims=("x", "y"), + coords=xr.Coordinates( + { + "x": np.array([True, False], dtype="bool"), + "y": np.array([1, 2, 3], dtype="int16"), + }, + indexes={"y": CustomIndex(("y",))}, + ), + ) + + da_b = xr.DataArray( + np.array([1, 2], dtype="int8"), + dims="x", + coords=xr.Coordinates( + { + "x": np.array([True, False], dtype="bool"), + "label": ("x", np.array([1, 2], dtype="int16")), + }, + indexes={"label": CustomIndex(("label",))}, + ), + ) + + expected = dedent( + """\ + Left and right DataArray objects are not equal + Differing dimensions: + (x: 2, y: 3) != (x: 2) + Differing values: + L + array([[1, 2, 3], + [4, 5, 6]], dtype=int8) + R + array([1, 2], dtype=int8) + Coordinates only on the left object: + * y (y) int16 6B 1 2 3 + Coordinates only on the right object: + * label (x) int16 4B 1 2 + """.rstrip() + ) + + actual = formatting.diff_array_repr(da_a, da_b, "equals") + assert actual == expected + va = xr.Variable( "x", np.array([1, 2, 3], dtype="int64"), {"title": "test Variable"} ) From 971d71d1062fc431ed95e4aa7417751cdcf40c4a Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Sat, 6 Jul 2024 21:19:03 -0700 Subject: [PATCH 24/33] Fix reductions for `np.complex_` dtypes with numbagg (#9210) --- doc/whats-new.rst | 3 +++ xarray/core/nputils.py | 2 +- xarray/tests/test_computation.py | 8 ++++++++ 3 files changed, 12 insertions(+), 1 deletion(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 753586c32c2..8c6b3a099c2 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -49,6 +49,9 @@ Bug fixes By `Michael Niklas `_. - Dark themes are now properly detected for ``html[data-theme=dark]``-tags (:pull:`9200`). By `Dieter Werthmüller `_. +- Reductions no longer fail for ``np.complex_`` dtype arrays when numbagg is + installed. + By `Maximilian Roos `_ Documentation ~~~~~~~~~~~~~ diff --git a/xarray/core/nputils.py b/xarray/core/nputils.py index 6970d37402f..1d30fe9db9e 100644 --- a/xarray/core/nputils.py +++ b/xarray/core/nputils.py @@ -191,7 +191,7 @@ def f(values, axis=None, **kwargs): or kwargs.get("ddof", 0) == 1 ) # TODO: bool? - and values.dtype.kind in "uifc" + and values.dtype.kind in "uif" # and values.dtype.isnative and (dtype is None or np.dtype(dtype) == values.dtype) # numbagg.nanquantile only available after 0.8.0 and with linear method diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index 080447cede0..4b9b95b27bb 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -2598,3 +2598,11 @@ def test_cross(a, b, ae, be, dim: str, axis: int, use_dask: bool) -> None: actual = xr.cross(a, b, dim=dim) xr.testing.assert_duckarray_allclose(expected, actual) + + +@pytest.mark.parametrize("compute_backend", ["numbagg"], indirect=True) +def test_complex_number_reduce(compute_backend): + da = xr.DataArray(np.ones((2,), dtype=np.complex_), dims=["x"]) + # Check that xarray doesn't call into numbagg, which doesn't compile for complex + # numbers at the moment (but will when numba supports dynamic compilation) + da.min() From 04b38a002ca2aea1d85dc367066969e5788f03b5 Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Sun, 7 Jul 2024 14:46:57 -0700 Subject: [PATCH 25/33] Consolidate some numbagg tests (#9211) Some minor repetiton --- xarray/tests/test_missing.py | 51 +++++++++++++++--------------------- 1 file changed, 21 insertions(+), 30 deletions(-) diff --git a/xarray/tests/test_missing.py b/xarray/tests/test_missing.py index da9513a7c71..bd75f633b82 100644 --- a/xarray/tests/test_missing.py +++ b/xarray/tests/test_missing.py @@ -421,46 +421,37 @@ def test_ffill(): assert_equal(actual, expected) -def test_ffill_use_bottleneck_numbagg(): +@pytest.mark.parametrize("compute_backend", [None], indirect=True) +@pytest.mark.parametrize("method", ["ffill", "bfill"]) +def test_b_ffill_use_bottleneck_numbagg(method, compute_backend): + """ + bfill & ffill fail if both bottleneck and numba are disabled + """ da = xr.DataArray(np.array([4, 5, np.nan], dtype=np.float64), dims="x") - with xr.set_options(use_bottleneck=False, use_numbagg=False): - with pytest.raises(RuntimeError): - da.ffill("x") + with pytest.raises(RuntimeError): + getattr(da, method)("x") @requires_dask -def test_ffill_use_bottleneck_dask(): +@pytest.mark.parametrize("compute_backend", [None], indirect=True) +@pytest.mark.parametrize("method", ["ffill", "bfill"]) +def test_b_ffill_use_bottleneck_dask(method, compute_backend): + """ + ffill fails if both bottleneck and numba are disabled, on dask arrays + """ da = xr.DataArray(np.array([4, 5, np.nan], dtype=np.float64), dims="x") - da = da.chunk({"x": 1}) - with xr.set_options(use_bottleneck=False, use_numbagg=False): - with pytest.raises(RuntimeError): - da.ffill("x") + with pytest.raises(RuntimeError): + getattr(da, method)("x") @requires_numbagg @requires_dask -def test_ffill_use_numbagg_dask(): - with xr.set_options(use_bottleneck=False): - da = xr.DataArray(np.array([4, 5, np.nan], dtype=np.float64), dims="x") - da = da.chunk(x=-1) - # Succeeds with a single chunk: - _ = da.ffill("x").compute() - - -def test_bfill_use_bottleneck(): - da = xr.DataArray(np.array([4, 5, np.nan], dtype=np.float64), dims="x") - with xr.set_options(use_bottleneck=False, use_numbagg=False): - with pytest.raises(RuntimeError): - da.bfill("x") - - -@requires_dask -def test_bfill_use_bottleneck_dask(): +@pytest.mark.parametrize("compute_backend", ["numbagg"], indirect=True) +def test_ffill_use_numbagg_dask(compute_backend): da = xr.DataArray(np.array([4, 5, np.nan], dtype=np.float64), dims="x") - da = da.chunk({"x": 1}) - with xr.set_options(use_bottleneck=False, use_numbagg=False): - with pytest.raises(RuntimeError): - da.bfill("x") + da = da.chunk(x=-1) + # Succeeds with a single chunk: + _ = da.ffill("x").compute() @requires_bottleneck From bac01c05a6e299ca61677af47d0bcb841de1787a Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Mon, 8 Jul 2024 02:49:13 -0700 Subject: [PATCH 26/33] Use numpy 2.0-compat `np.complex64` dtype in test (#9217) --- xarray/tests/test_computation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index 4b9b95b27bb..b000de311af 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -2602,7 +2602,7 @@ def test_cross(a, b, ae, be, dim: str, axis: int, use_dask: bool) -> None: @pytest.mark.parametrize("compute_backend", ["numbagg"], indirect=True) def test_complex_number_reduce(compute_backend): - da = xr.DataArray(np.ones((2,), dtype=np.complex_), dims=["x"]) + da = xr.DataArray(np.ones((2,), dtype=np.complex64), dims=["x"]) # Check that xarray doesn't call into numbagg, which doesn't compile for complex # numbers at the moment (but will when numba supports dynamic compilation) da.min() From 179c6706615854fc861335d929bd6c4761485a93 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Mon, 8 Jul 2024 10:12:54 -0700 Subject: [PATCH 27/33] Fix two bugs in DataTree.update() (#9214) * Fix two bugs in DataTree.update() 1. Fix handling of coordinates on a Dataset argument (previously these were silently dropped). 2. Do not copy inherited coordinates down to lower level nodes. * add mypy annotation --- xarray/core/datatree.py | 40 ++++++++++++++++++++--------------- xarray/tests/test_datatree.py | 37 ++++++++++++++++++++++++-------- 2 files changed, 51 insertions(+), 26 deletions(-) diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index 38f8f8cd495..65ff8667cb7 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -61,7 +61,7 @@ import pandas as pd from xarray.core.datatree_io import T_DataTreeNetcdfEngine, T_DataTreeNetcdfTypes - from xarray.core.merge import CoercibleValue + from xarray.core.merge import CoercibleMapping, CoercibleValue from xarray.core.types import ErrorOptions, NetcdfWriteModes, ZarrWriteModes # """ @@ -954,23 +954,29 @@ def update( Just like `dict.update` this is an in-place operation. """ - # TODO separate by type new_children: dict[str, DataTree] = {} - new_variables = {} - for k, v in other.items(): - if isinstance(v, DataTree): - # avoid named node being stored under inconsistent key - new_child: DataTree = v.copy() - # Datatree's name is always a string until we fix that (#8836) - new_child.name = str(k) - new_children[str(k)] = new_child - elif isinstance(v, (DataArray, Variable)): - # TODO this should also accommodate other types that can be coerced into Variables - new_variables[k] = v - else: - raise TypeError(f"Type {type(v)} cannot be assigned to a DataTree") - - vars_merge_result = dataset_update_method(self.to_dataset(), new_variables) + new_variables: CoercibleMapping + + if isinstance(other, Dataset): + new_variables = other + else: + new_variables = {} + for k, v in other.items(): + if isinstance(v, DataTree): + # avoid named node being stored under inconsistent key + new_child: DataTree = v.copy() + # Datatree's name is always a string until we fix that (#8836) + new_child.name = str(k) + new_children[str(k)] = new_child + elif isinstance(v, (DataArray, Variable)): + # TODO this should also accommodate other types that can be coerced into Variables + new_variables[k] = v + else: + raise TypeError(f"Type {type(v)} cannot be assigned to a DataTree") + + vars_merge_result = dataset_update_method( + self.to_dataset(inherited=False), new_variables + ) data = Dataset._construct_direct(**vars_merge_result._asdict()) # TODO are there any subtleties with preserving order of children like this? diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index f2b58fa2489..f7cff17bab5 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -244,11 +244,6 @@ def test_update(self): dt: DataTree = DataTree() dt.update({"foo": xr.DataArray(0), "a": DataTree()}) expected = DataTree.from_dict({"/": xr.Dataset({"foo": 0}), "a": None}) - print(dt) - print(dt.children) - print(dt._children) - print(dt["a"]) - print(expected) assert_equal(dt, expected) def test_update_new_named_dataarray(self): @@ -268,14 +263,38 @@ def test_update_doesnt_alter_child_name(self): def test_update_overwrite(self): actual = DataTree.from_dict({"a": DataTree(xr.Dataset({"x": 1}))}) actual.update({"a": DataTree(xr.Dataset({"x": 2}))}) - expected = DataTree.from_dict({"a": DataTree(xr.Dataset({"x": 2}))}) + assert_equal(actual, expected) - print(actual) - print(expected) - + def test_update_coordinates(self): + expected = DataTree.from_dict({"/": xr.Dataset(coords={"a": 1})}) + actual = DataTree.from_dict({"/": xr.Dataset()}) + actual.update(xr.Dataset(coords={"a": 1})) assert_equal(actual, expected) + def test_update_inherited_coords(self): + expected = DataTree.from_dict( + { + "/": xr.Dataset(coords={"a": 1}), + "/b": xr.Dataset(coords={"c": 1}), + } + ) + actual = DataTree.from_dict( + { + "/": xr.Dataset(coords={"a": 1}), + "/b": xr.Dataset(), + } + ) + actual["/b"].update(xr.Dataset(coords={"c": 1})) + assert_identical(actual, expected) + + # DataTree.identical() currently does not require that non-inherited + # coordinates are defined identically, so we need to check this + # explicitly + actual_node = actual.children["b"].to_dataset(inherited=False) + expected_node = expected.children["b"].to_dataset(inherited=False) + assert_identical(actual_node, expected_node) + class TestCopy: def test_copy(self, create_test_datatree): From 3024655e3689c11908d221913cdb922bcfc69037 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 9 Jul 2024 09:09:18 +0200 Subject: [PATCH 28/33] Only use necessary dims when creating temporary dataarray (#9206) * Only use necessary dims when creating temporary dataarray * Update dataset_plot.py * Can't check only data_vars all corrds are no longer added by default * Update dataset_plot.py * Add tests * Update whats-new.rst * Update dataset_plot.py --- doc/whats-new.rst | 2 ++ xarray/plot/dataset_plot.py | 15 +++++++++----- xarray/tests/test_plot.py | 40 +++++++++++++++++++++++++++++++++++++ 3 files changed, 52 insertions(+), 5 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 8c6b3a099c2..0c401c2348e 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -37,6 +37,8 @@ Deprecations Bug fixes ~~~~~~~~~ +- Fix scatter plot broadcasting unneccesarily. (:issue:`9129`, :pull:`9206`) + By `Jimmy Westling `_. - Don't convert custom indexes to ``pandas`` indexes when computing a diff (:pull:`9157`) By `Justus Magin `_. - Make :py:func:`testing.assert_allclose` work with numpy 2.0 (:issue:`9165`, :pull:`9166`). diff --git a/xarray/plot/dataset_plot.py b/xarray/plot/dataset_plot.py index edc2bf43629..96b59f6174e 100644 --- a/xarray/plot/dataset_plot.py +++ b/xarray/plot/dataset_plot.py @@ -721,8 +721,8 @@ def _temp_dataarray(ds: Dataset, y: Hashable, locals_: dict[str, Any]) -> DataAr """Create a temporary datarray with extra coords.""" from xarray.core.dataarray import DataArray - # Base coords: - coords = dict(ds.coords) + coords = dict(ds[y].coords) + dims = set(ds[y].dims) # Add extra coords to the DataArray from valid kwargs, if using all # kwargs there is a risk that we add unnecessary dataarrays as @@ -732,12 +732,17 @@ def _temp_dataarray(ds: Dataset, y: Hashable, locals_: dict[str, Any]) -> DataAr coord_kwargs = locals_.keys() & valid_coord_kwargs for k in coord_kwargs: key = locals_[k] - if ds.data_vars.get(key) is not None: - coords[key] = ds[key] + darray = ds.get(key) + if darray is not None: + coords[key] = darray + dims.update(darray.dims) + + # Trim dataset from unneccessary dims: + ds_trimmed = ds.drop_dims(ds.sizes.keys() - dims) # TODO: Use ds.dims in the future # The dataarray has to include all the dims. Broadcast to that shape # and add the additional coords: - _y = ds[y].broadcast_like(ds) + _y = ds[y].broadcast_like(ds_trimmed) return DataArray(_y, coords=coords) diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index b302ad3af93..fa08e9975ab 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -3416,3 +3416,43 @@ def test_9155() -> None: data = xr.DataArray([1, 2, 3], dims=["x"]) fig, ax = plt.subplots(ncols=1, nrows=1) data.plot(ax=ax) + + +@requires_matplotlib +def test_temp_dataarray() -> None: + from xarray.plot.dataset_plot import _temp_dataarray + + x = np.arange(1, 4) + y = np.arange(4, 6) + var1 = np.arange(x.size * y.size).reshape((x.size, y.size)) + var2 = np.arange(x.size * y.size).reshape((x.size, y.size)) + ds = xr.Dataset( + { + "var1": (["x", "y"], var1), + "var2": (["x", "y"], 2 * var2), + "var3": (["x"], 3 * x), + }, + coords={ + "x": x, + "y": y, + "model": np.arange(7), + }, + ) + + # No broadcasting: + y_ = "var1" + locals_ = {"x": "var2"} + da = _temp_dataarray(ds, y_, locals_) + assert da.shape == (3, 2) + + # Broadcast from 1 to 2dim: + y_ = "var3" + locals_ = {"x": "var1"} + da = _temp_dataarray(ds, y_, locals_) + assert da.shape == (3, 2) + + # Ignore non-valid coord kwargs: + y_ = "var3" + locals_ = dict(x="x", extend="var2") + da = _temp_dataarray(ds, y_, locals_) + assert da.shape == (3,) From 879b06b06fe2a08dcc1104761b32a31b64b7d74b Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Wed, 10 Jul 2024 09:34:48 +0200 Subject: [PATCH 29/33] Cleanup test_coding_times.py (#9223) * Cleanup test_coding_times * Update test_coding_times.py --- xarray/tests/test_coding_times.py | 93 +++++++++++++++---------------- 1 file changed, 44 insertions(+), 49 deletions(-) diff --git a/xarray/tests/test_coding_times.py b/xarray/tests/test_coding_times.py index 393f8400c46..d568bdc3268 100644 --- a/xarray/tests/test_coding_times.py +++ b/xarray/tests/test_coding_times.py @@ -14,13 +14,15 @@ Dataset, Variable, cftime_range, - coding, conventions, date_range, decode_cf, ) +from xarray.coding.times import _STANDARD_CALENDARS as _STANDARD_CALENDARS_UNSORTED from xarray.coding.times import ( + CFDatetimeCoder, _encode_datetime_with_cftime, + _netcdf_to_numpy_timeunit, _numpy_to_netcdf_timeunit, _should_cftime_be_used, cftime_to_nptime, @@ -28,6 +30,9 @@ decode_cf_timedelta, encode_cf_datetime, encode_cf_timedelta, + format_cftime_datetime, + infer_datetime_units, + infer_timedelta_units, to_timedelta_unboxed, ) from xarray.coding.variables import SerializationWarning @@ -53,11 +58,9 @@ "all_leap", "366_day", } -_ALL_CALENDARS = sorted( - _NON_STANDARD_CALENDARS_SET.union(coding.times._STANDARD_CALENDARS) -) +_STANDARD_CALENDARS = sorted(_STANDARD_CALENDARS_UNSORTED) +_ALL_CALENDARS = sorted(_NON_STANDARD_CALENDARS_SET.union(_STANDARD_CALENDARS)) _NON_STANDARD_CALENDARS = sorted(_NON_STANDARD_CALENDARS_SET) -_STANDARD_CALENDARS = sorted(coding.times._STANDARD_CALENDARS) _CF_DATETIME_NUM_DATES_UNITS = [ (np.arange(10), "days since 2000-01-01"), (np.arange(10).astype("float64"), "days since 2000-01-01"), @@ -130,7 +133,7 @@ def test_cf_datetime(num_dates, units, calendar) -> None: with warnings.catch_warnings(): warnings.filterwarnings("ignore", "Unable to decode time axis") - actual = coding.times.decode_cf_datetime(num_dates, units, calendar) + actual = decode_cf_datetime(num_dates, units, calendar) abs_diff = np.asarray(abs(actual - expected)).ravel() abs_diff = pd.to_timedelta(abs_diff.tolist()).to_numpy() @@ -139,17 +142,15 @@ def test_cf_datetime(num_dates, units, calendar) -> None: # we could do this check with near microsecond accuracy: # https://github.com/Unidata/netcdf4-python/issues/355 assert (abs_diff <= np.timedelta64(1, "s")).all() - encoded, _, _ = coding.times.encode_cf_datetime(actual, units, calendar) + encoded, _, _ = encode_cf_datetime(actual, units, calendar) - assert_array_equal(num_dates, np.around(encoded, 1)) + assert_array_equal(num_dates, np.round(encoded, 1)) if hasattr(num_dates, "ndim") and num_dates.ndim == 1 and "1000" not in units: # verify that wrapping with a pandas.Index works # note that it *does not* currently work to put # non-datetime64 compatible dates into a pandas.Index - encoded, _, _ = coding.times.encode_cf_datetime( - pd.Index(actual), units, calendar - ) - assert_array_equal(num_dates, np.around(encoded, 1)) + encoded, _, _ = encode_cf_datetime(pd.Index(actual), units, calendar) + assert_array_equal(num_dates, np.round(encoded, 1)) @requires_cftime @@ -169,7 +170,7 @@ def test_decode_cf_datetime_overflow() -> None: for i, day in enumerate(days): with warnings.catch_warnings(): warnings.filterwarnings("ignore", "Unable to decode time axis") - result = coding.times.decode_cf_datetime(day, units) + result = decode_cf_datetime(day, units) assert result == expected[i] @@ -178,7 +179,7 @@ def test_decode_cf_datetime_non_standard_units() -> None: # netCDFs from madis.noaa.gov use this format for their time units # they cannot be parsed by cftime, but pd.Timestamp works units = "hours since 1-1-1970" - actual = coding.times.decode_cf_datetime(np.arange(100), units) + actual = decode_cf_datetime(np.arange(100), units) assert_array_equal(actual, expected) @@ -193,7 +194,7 @@ def test_decode_cf_datetime_non_iso_strings() -> None: (np.arange(100), "hours since 2000-01-01 0:00"), ] for num_dates, units in cases: - actual = coding.times.decode_cf_datetime(num_dates, units) + actual = decode_cf_datetime(num_dates, units) abs_diff = abs(actual - expected.values) # once we no longer support versions of netCDF4 older than 1.1.5, # we could do this check with near microsecond accuracy: @@ -212,7 +213,7 @@ def test_decode_standard_calendar_inside_timestamp_range(calendar) -> None: expected = times.values expected_dtype = np.dtype("M8[ns]") - actual = coding.times.decode_cf_datetime(time, units, calendar=calendar) + actual = decode_cf_datetime(time, units, calendar=calendar) assert actual.dtype == expected_dtype abs_diff = abs(actual - expected) # once we no longer support versions of netCDF4 older than 1.1.5, @@ -235,9 +236,7 @@ def test_decode_non_standard_calendar_inside_timestamp_range(calendar) -> None: ) expected_dtype = np.dtype("O") - actual = coding.times.decode_cf_datetime( - non_standard_time, units, calendar=calendar - ) + actual = decode_cf_datetime(non_standard_time, units, calendar=calendar) assert actual.dtype == expected_dtype abs_diff = abs(actual - expected) # once we no longer support versions of netCDF4 older than 1.1.5, @@ -264,7 +263,7 @@ def test_decode_dates_outside_timestamp_range(calendar) -> None: with warnings.catch_warnings(): warnings.filterwarnings("ignore", "Unable to decode time axis") - actual = coding.times.decode_cf_datetime(time, units, calendar=calendar) + actual = decode_cf_datetime(time, units, calendar=calendar) assert all(isinstance(value, expected_date_type) for value in actual) abs_diff = abs(actual - expected) # once we no longer support versions of netCDF4 older than 1.1.5, @@ -282,7 +281,7 @@ def test_decode_standard_calendar_single_element_inside_timestamp_range( for num_time in [735368, [735368], [[735368]]]: with warnings.catch_warnings(): warnings.filterwarnings("ignore", "Unable to decode time axis") - actual = coding.times.decode_cf_datetime(num_time, units, calendar=calendar) + actual = decode_cf_datetime(num_time, units, calendar=calendar) assert actual.dtype == np.dtype("M8[ns]") @@ -295,7 +294,7 @@ def test_decode_non_standard_calendar_single_element_inside_timestamp_range( for num_time in [735368, [735368], [[735368]]]: with warnings.catch_warnings(): warnings.filterwarnings("ignore", "Unable to decode time axis") - actual = coding.times.decode_cf_datetime(num_time, units, calendar=calendar) + actual = decode_cf_datetime(num_time, units, calendar=calendar) assert actual.dtype == np.dtype("O") @@ -309,9 +308,7 @@ def test_decode_single_element_outside_timestamp_range(calendar) -> None: for num_time in [days, [days], [[days]]]: with warnings.catch_warnings(): warnings.filterwarnings("ignore", "Unable to decode time axis") - actual = coding.times.decode_cf_datetime( - num_time, units, calendar=calendar - ) + actual = decode_cf_datetime(num_time, units, calendar=calendar) expected = cftime.num2date( days, units, calendar, only_use_cftime_datetimes=True @@ -338,7 +335,7 @@ def test_decode_standard_calendar_multidim_time_inside_timestamp_range( expected1 = times1.values expected2 = times2.values - actual = coding.times.decode_cf_datetime(mdim_time, units, calendar=calendar) + actual = decode_cf_datetime(mdim_time, units, calendar=calendar) assert actual.dtype == np.dtype("M8[ns]") abs_diff1 = abs(actual[:, 0] - expected1) @@ -379,7 +376,7 @@ def test_decode_nonstandard_calendar_multidim_time_inside_timestamp_range( expected_dtype = np.dtype("O") - actual = coding.times.decode_cf_datetime(mdim_time, units, calendar=calendar) + actual = decode_cf_datetime(mdim_time, units, calendar=calendar) assert actual.dtype == expected_dtype abs_diff1 = abs(actual[:, 0] - expected1) @@ -412,7 +409,7 @@ def test_decode_multidim_time_outside_timestamp_range(calendar) -> None: with warnings.catch_warnings(): warnings.filterwarnings("ignore", "Unable to decode time axis") - actual = coding.times.decode_cf_datetime(mdim_time, units, calendar=calendar) + actual = decode_cf_datetime(mdim_time, units, calendar=calendar) assert actual.dtype == np.dtype("O") @@ -435,7 +432,7 @@ def test_decode_non_standard_calendar_single_element(calendar, num_time) -> None units = "days since 0001-01-01" - actual = coding.times.decode_cf_datetime(num_time, units, calendar=calendar) + actual = decode_cf_datetime(num_time, units, calendar=calendar) expected = np.asarray( cftime.num2date(num_time, units, calendar, only_use_cftime_datetimes=True) @@ -460,9 +457,7 @@ def test_decode_360_day_calendar() -> None: with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") - actual = coding.times.decode_cf_datetime( - num_times, units, calendar=calendar - ) + actual = decode_cf_datetime(num_times, units, calendar=calendar) assert len(w) == 0 assert actual.dtype == np.dtype("O") @@ -476,8 +471,8 @@ def test_decode_abbreviation() -> None: val = np.array([1586628000000.0]) units = "msecs since 1970-01-01T00:00:00Z" - actual = coding.times.decode_cf_datetime(val, units) - expected = coding.times.cftime_to_nptime(cftime.num2date(val, units)) + actual = decode_cf_datetime(val, units) + expected = cftime_to_nptime(cftime.num2date(val, units)) assert_array_equal(actual, expected) @@ -498,7 +493,7 @@ def test_decode_abbreviation() -> None: def test_cf_datetime_nan(num_dates, units, expected_list) -> None: with warnings.catch_warnings(): warnings.filterwarnings("ignore", "All-NaN") - actual = coding.times.decode_cf_datetime(num_dates, units) + actual = decode_cf_datetime(num_dates, units) # use pandas because numpy will deprecate timezone-aware conversions expected = pd.to_datetime(expected_list).to_numpy(dtype="datetime64[ns]") assert_array_equal(expected, actual) @@ -510,7 +505,7 @@ def test_decoded_cf_datetime_array_2d() -> None: variable = Variable( ("x", "y"), np.array([[0, 1], [2, 3]]), {"units": "days since 2000-01-01"} ) - result = coding.times.CFDatetimeCoder().decode(variable) + result = CFDatetimeCoder().decode(variable) assert result.dtype == "datetime64[ns]" expected = pd.date_range("2000-01-01", periods=4).values.reshape(2, 2) assert_array_equal(np.asarray(result), expected) @@ -531,7 +526,7 @@ def test_decoded_cf_datetime_array_2d() -> None: def test_infer_datetime_units(freq, units) -> None: dates = pd.date_range("2000", periods=2, freq=freq) expected = f"{units} since 2000-01-01 00:00:00" - assert expected == coding.times.infer_datetime_units(dates) + assert expected == infer_datetime_units(dates) @pytest.mark.parametrize( @@ -549,7 +544,7 @@ def test_infer_datetime_units(freq, units) -> None: ], ) def test_infer_datetime_units_with_NaT(dates, expected) -> None: - assert expected == coding.times.infer_datetime_units(dates) + assert expected == infer_datetime_units(dates) _CFTIME_DATETIME_UNITS_TESTS = [ @@ -573,7 +568,7 @@ def test_infer_datetime_units_with_NaT(dates, expected) -> None: def test_infer_cftime_datetime_units(calendar, date_args, expected) -> None: date_type = _all_cftime_date_types()[calendar] dates = [date_type(*args) for args in date_args] - assert expected == coding.times.infer_datetime_units(dates) + assert expected == infer_datetime_units(dates) @pytest.mark.filterwarnings("ignore:Timedeltas can't be serialized faithfully") @@ -600,18 +595,18 @@ def test_cf_timedelta(timedeltas, units, numbers) -> None: numbers = np.array(numbers) expected = numbers - actual, _ = coding.times.encode_cf_timedelta(timedeltas, units) + actual, _ = encode_cf_timedelta(timedeltas, units) assert_array_equal(expected, actual) assert expected.dtype == actual.dtype if units is not None: expected = timedeltas - actual = coding.times.decode_cf_timedelta(numbers, units) + actual = decode_cf_timedelta(numbers, units) assert_array_equal(expected, actual) assert expected.dtype == actual.dtype expected = np.timedelta64("NaT", "ns") - actual = coding.times.decode_cf_timedelta(np.array(np.nan), "days") + actual = decode_cf_timedelta(np.array(np.nan), "days") assert_array_equal(expected, actual) @@ -622,7 +617,7 @@ def test_cf_timedelta_2d() -> None: timedeltas = np.atleast_2d(to_timedelta_unboxed(["1D", "2D", "3D"])) expected = timedeltas - actual = coding.times.decode_cf_timedelta(numbers, units) + actual = decode_cf_timedelta(numbers, units) assert_array_equal(expected, actual) assert expected.dtype == actual.dtype @@ -637,7 +632,7 @@ def test_cf_timedelta_2d() -> None: ], ) def test_infer_timedelta_units(deltas, expected) -> None: - assert expected == coding.times.infer_timedelta_units(deltas) + assert expected == infer_timedelta_units(deltas) @requires_cftime @@ -653,7 +648,7 @@ def test_infer_timedelta_units(deltas, expected) -> None: def test_format_cftime_datetime(date_args, expected) -> None: date_types = _all_cftime_date_types() for date_type in date_types.values(): - result = coding.times.format_cftime_datetime(date_type(*date_args)) + result = format_cftime_datetime(date_type(*date_args)) assert result == expected @@ -1008,7 +1003,7 @@ def test_decode_ambiguous_time_warns(calendar) -> None: # we don't decode non-standard calendards with # pandas so expect no warning will be emitted - is_standard_calendar = calendar in coding.times._STANDARD_CALENDARS + is_standard_calendar = calendar in _STANDARD_CALENDARS dates = [1, 2, 3] units = "days since 1-1-1" @@ -1043,9 +1038,9 @@ def test_encode_cf_datetime_defaults_to_correct_dtype( pytest.skip("Nanosecond frequency is not valid for cftime dates.") times = date_range("2000", periods=3, freq=freq) units = f"{encoding_units} since 2000-01-01" - encoded, _units, _ = coding.times.encode_cf_datetime(times, units) + encoded, _units, _ = encode_cf_datetime(times, units) - numpy_timeunit = coding.times._netcdf_to_numpy_timeunit(encoding_units) + numpy_timeunit = _netcdf_to_numpy_timeunit(encoding_units) encoding_units_as_timedelta = np.timedelta64(1, numpy_timeunit) if pd.to_timedelta(1, freq) >= encoding_units_as_timedelta: assert encoded.dtype == np.int64 @@ -1202,7 +1197,7 @@ def test_decode_float_datetime(): def test_scalar_unit() -> None: # test that a scalar units (often NaN when using to_netcdf) does not raise an error variable = Variable(("x", "y"), np.array([[0, 1], [2, 3]]), {"units": np.nan}) - result = coding.times.CFDatetimeCoder().decode(variable) + result = CFDatetimeCoder().decode(variable) assert np.isnan(result.attrs["units"]) From 7ff5d8d1f367898f9e09a83ffe993fac9e90047b Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 11 Jul 2024 03:57:16 +0200 Subject: [PATCH 30/33] Use reshape and ravel from duck_array_ops in coding/times.py (#9225) * Use duck_array_ops.ravel * Use duck_array_ops.reshape --- xarray/coding/times.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/xarray/coding/times.py b/xarray/coding/times.py index 34d4f9a23ad..50a2ba93c09 100644 --- a/xarray/coding/times.py +++ b/xarray/coding/times.py @@ -22,7 +22,7 @@ ) from xarray.core import indexing from xarray.core.common import contains_cftime_datetimes, is_np_datetime_like -from xarray.core.duck_array_ops import asarray +from xarray.core.duck_array_ops import asarray, ravel, reshape from xarray.core.formatting import first_n_items, format_timestamp, last_item from xarray.core.pdcompat import nanosecond_precision_timestamp from xarray.core.utils import emit_user_level_warning @@ -315,7 +315,7 @@ def decode_cf_datetime( cftime.num2date """ num_dates = np.asarray(num_dates) - flat_num_dates = num_dates.ravel() + flat_num_dates = ravel(num_dates) if calendar is None: calendar = "standard" @@ -348,7 +348,7 @@ def decode_cf_datetime( else: dates = _decode_datetime_with_pandas(flat_num_dates, units, calendar) - return dates.reshape(num_dates.shape) + return reshape(dates, num_dates.shape) def to_timedelta_unboxed(value, **kwargs): @@ -369,8 +369,8 @@ def decode_cf_timedelta(num_timedeltas, units: str) -> np.ndarray: """ num_timedeltas = np.asarray(num_timedeltas) units = _netcdf_to_numpy_timeunit(units) - result = to_timedelta_unboxed(num_timedeltas.ravel(), unit=units) - return result.reshape(num_timedeltas.shape) + result = to_timedelta_unboxed(ravel(num_timedeltas), unit=units) + return reshape(result, num_timedeltas.shape) def _unit_timedelta_cftime(units: str) -> timedelta: @@ -428,7 +428,7 @@ def infer_datetime_units(dates) -> str: 'hours', 'minutes' or 'seconds' (the first one that can evenly divide all unique time deltas in `dates`) """ - dates = np.asarray(dates).ravel() + dates = ravel(np.asarray(dates)) if np.asarray(dates).dtype == "datetime64[ns]": dates = to_datetime_unboxed(dates) dates = dates[pd.notnull(dates)] @@ -456,7 +456,7 @@ def infer_timedelta_units(deltas) -> str: {'days', 'hours', 'minutes' 'seconds'} (the first one that can evenly divide all unique time deltas in `deltas`) """ - deltas = to_timedelta_unboxed(np.asarray(deltas).ravel()) + deltas = to_timedelta_unboxed(ravel(np.asarray(deltas))) unique_timedeltas = np.unique(deltas[pd.notnull(deltas)]) return _infer_time_units_from_diff(unique_timedeltas) @@ -643,7 +643,7 @@ def encode_datetime(d): except TypeError: return np.nan if d is None else cftime.date2num(d, units, calendar) - return np.array([encode_datetime(d) for d in dates.ravel()]).reshape(dates.shape) + return reshape(np.array([encode_datetime(d) for d in ravel(dates)]), dates.shape) def cast_to_int_if_safe(num) -> np.ndarray: @@ -753,7 +753,7 @@ def _eagerly_encode_cf_datetime( # Wrap the dates in a DatetimeIndex to do the subtraction to ensure # an OverflowError is raised if the ref_date is too far away from # dates to be encoded (GH 2272). - dates_as_index = pd.DatetimeIndex(dates.ravel()) + dates_as_index = pd.DatetimeIndex(ravel(dates)) time_deltas = dates_as_index - ref_date # retrieve needed units to faithfully encode to int64 @@ -791,7 +791,7 @@ def _eagerly_encode_cf_datetime( floor_division = True num = _division(time_deltas, time_delta, floor_division) - num = num.values.reshape(dates.shape) + num = reshape(num.values, dates.shape) except (OutOfBoundsDatetime, OverflowError, ValueError): num = _encode_datetime_with_cftime(dates, units, calendar) @@ -879,7 +879,7 @@ def _eagerly_encode_cf_timedelta( units = data_units time_delta = _time_units_to_timedelta64(units) - time_deltas = pd.TimedeltaIndex(timedeltas.ravel()) + time_deltas = pd.TimedeltaIndex(ravel(timedeltas)) # retrieve needed units to faithfully encode to int64 needed_units = data_units @@ -911,7 +911,7 @@ def _eagerly_encode_cf_timedelta( floor_division = True num = _division(time_deltas, time_delta, floor_division) - num = num.values.reshape(timedeltas.shape) + num = reshape(num.values, timedeltas.shape) if dtype is not None: num = _cast_to_dtype_if_safe(num, dtype) From eb0fbd7d2692690038e96b32a816a36ea4267a8d Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 11 Jul 2024 03:58:00 +0200 Subject: [PATCH 31/33] Use duckarray assertions in test_coding_times (#9226) --- xarray/tests/test_coding_times.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/xarray/tests/test_coding_times.py b/xarray/tests/test_coding_times.py index d568bdc3268..623e4e9f970 100644 --- a/xarray/tests/test_coding_times.py +++ b/xarray/tests/test_coding_times.py @@ -44,6 +44,8 @@ FirstElementAccessibleArray, arm_xfail, assert_array_equal, + assert_duckarray_allclose, + assert_duckarray_equal, assert_no_warnings, has_cftime, requires_cftime, @@ -144,13 +146,13 @@ def test_cf_datetime(num_dates, units, calendar) -> None: assert (abs_diff <= np.timedelta64(1, "s")).all() encoded, _, _ = encode_cf_datetime(actual, units, calendar) - assert_array_equal(num_dates, np.round(encoded, 1)) + assert_duckarray_allclose(num_dates, encoded) if hasattr(num_dates, "ndim") and num_dates.ndim == 1 and "1000" not in units: # verify that wrapping with a pandas.Index works # note that it *does not* currently work to put # non-datetime64 compatible dates into a pandas.Index encoded, _, _ = encode_cf_datetime(pd.Index(actual), units, calendar) - assert_array_equal(num_dates, np.round(encoded, 1)) + assert_duckarray_allclose(num_dates, encoded) @requires_cftime @@ -893,10 +895,10 @@ def test_time_units_with_timezone_roundtrip(calendar) -> None: ) if calendar in _STANDARD_CALENDARS: - np.testing.assert_array_equal(result_num_dates, expected_num_dates) + assert_duckarray_equal(result_num_dates, expected_num_dates) else: # cftime datetime arithmetic is not quite exact. - np.testing.assert_allclose(result_num_dates, expected_num_dates) + assert_duckarray_allclose(result_num_dates, expected_num_dates) assert result_units == expected_units assert result_calendar == calendar From ff15a08bea27674923afa494b303c6e5cb4d513c Mon Sep 17 00:00:00 2001 From: Mark Harfouche Date: Wed, 10 Jul 2024 22:00:09 -0400 Subject: [PATCH 32/33] Fix time indexing regression in `convert_calendar` (#9192) * MRC -- Selecting with string for cftime See discussion in #9138 This commit and pull request mostly serves as a staging group for a potential fix. Test with: ``` pytest xarray/tests/test_cftimeindex.py::test_cftime_noleap_with_str ``` * effectively remove fastpath * Add docstring * Revert "effectively remove fastpath" This reverts commit 0f1a5a2271e5522b5dd946d7f4f38f591211286e. * Fix by reassigning coordinate * Update what's new entry * Simplify if condition --------- Co-authored-by: Spencer Clark --- doc/whats-new.rst | 6 ++++++ xarray/coding/calendar_ops.py | 12 +++++++++++- xarray/tests/test_calendar_ops.py | 25 ++++++++++++++++++++++++- 3 files changed, 41 insertions(+), 2 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 0c401c2348e..f237b406bd5 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -47,6 +47,12 @@ Bug fixes By `Justus Magin `_. - Promote floating-point numeric datetimes before decoding (:issue:`9179`, :pull:`9182`). By `Justus Magin `_. +- Address regression introduced in :pull:`9002` that prevented objects returned + by py:meth:`DataArray.convert_calendar` to be indexed by a time index in + certain circumstances (:issue:`9138`, :pull:`9192`). By `Mark Harfouche + `_ and `Spencer Clark + `. + - Fiy static typing of tolerance arguments by allowing `str` type (:issue:`8892`, :pull:`9194`). By `Michael Niklas `_. - Dark themes are now properly detected for ``html[data-theme=dark]``-tags (:pull:`9200`). diff --git a/xarray/coding/calendar_ops.py b/xarray/coding/calendar_ops.py index c4fe9e1f4ae..6f492e78bf9 100644 --- a/xarray/coding/calendar_ops.py +++ b/xarray/coding/calendar_ops.py @@ -5,7 +5,10 @@ from xarray.coding.cftime_offsets import date_range_like, get_date_type from xarray.coding.cftimeindex import CFTimeIndex -from xarray.coding.times import _should_cftime_be_used, convert_times +from xarray.coding.times import ( + _should_cftime_be_used, + convert_times, +) from xarray.core.common import _contains_datetime_like_objects, is_np_datetime_like try: @@ -222,6 +225,13 @@ def convert_calendar( # Remove NaN that where put on invalid dates in target calendar out = out.where(out[dim].notnull(), drop=True) + if use_cftime: + # Reassign times to ensure time index of output is a CFTimeIndex + # (previously it was an Index due to the presence of NaN values). + # Note this is not needed in the case that the output time index is + # a DatetimeIndex, since DatetimeIndexes can handle NaN values. + out[dim] = CFTimeIndex(out[dim].data) + if missing is not None: time_target = date_range_like(time, calendar=calendar, use_cftime=use_cftime) out = out.reindex({dim: time_target}, fill_value=missing) diff --git a/xarray/tests/test_calendar_ops.py b/xarray/tests/test_calendar_ops.py index 7d229371808..13e9f7a1030 100644 --- a/xarray/tests/test_calendar_ops.py +++ b/xarray/tests/test_calendar_ops.py @@ -1,9 +1,10 @@ from __future__ import annotations import numpy as np +import pandas as pd import pytest -from xarray import DataArray, infer_freq +from xarray import CFTimeIndex, DataArray, infer_freq from xarray.coding.calendar_ops import convert_calendar, interp_calendar from xarray.coding.cftime_offsets import date_range from xarray.testing import assert_identical @@ -286,3 +287,25 @@ def test_interp_calendar_errors(): ValueError, match="Both 'source.x' and 'target' must contain datetime objects." ): interp_calendar(da1, da2, dim="x") + + +@requires_cftime +@pytest.mark.parametrize( + ("source_calendar", "target_calendar", "expected_index"), + [("standard", "noleap", CFTimeIndex), ("all_leap", "standard", pd.DatetimeIndex)], +) +def test_convert_calendar_produces_time_index( + source_calendar, target_calendar, expected_index +): + # https://github.com/pydata/xarray/issues/9138 + time = date_range("2000-01-01", "2002-01-01", freq="D", calendar=source_calendar) + temperature = np.ones(len(time)) + da = DataArray( + data=temperature, + dims=["time"], + coords=dict( + time=time, + ), + ) + converted = da.convert_calendar(target_calendar) + assert isinstance(converted.indexes["time"], expected_index) From 7087ca49629e07be004d92fd08f916e3359a57e1 Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Thu, 11 Jul 2024 10:54:18 +0200 Subject: [PATCH 33/33] `numpy` 2 compatibility in the `netcdf4` and `h5netcdf` backends (#9136) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * don't remove `netcdf4` from the upstream-dev environment * also stop removing `h5py` and `hdf5` * hard-code the precision (I believe this was missed in #9081) * don't remove `h5py` either * use on-diks _FillValue as standrd expects, use view instead of cast to prevent OverflowError. * whats-new * unpin `numpy` * rework UnsignedCoder * add test * Update xarray/coding/variables.py Co-authored-by: Justus Magin --------- Co-authored-by: Kai Mühlbauer Co-authored-by: Kai Mühlbauer Co-authored-by: Deepak Cherian --- ci/install-upstream-wheels.sh | 5 ++-- ci/requirements/all-but-dask.yml | 2 +- ci/requirements/environment-windows.yml | 2 +- ci/requirements/environment.yml | 2 +- doc/whats-new.rst | 4 ++- xarray/coding/variables.py | 16 +++++++----- xarray/tests/test_backends.py | 33 +++++++++++++++++++++++-- 7 files changed, 49 insertions(+), 15 deletions(-) diff --git a/ci/install-upstream-wheels.sh b/ci/install-upstream-wheels.sh index d728768439a..79fae3c46a9 100755 --- a/ci/install-upstream-wheels.sh +++ b/ci/install-upstream-wheels.sh @@ -13,7 +13,7 @@ $conda remove -y numba numbagg sparse # temporarily remove numexpr $conda remove -y numexpr # temporarily remove backends -$conda remove -y cf_units hdf5 h5py netcdf4 pydap +$conda remove -y cf_units pydap # forcibly remove packages to avoid artifacts $conda remove -y --force \ numpy \ @@ -37,8 +37,7 @@ python -m pip install \ numpy \ scipy \ matplotlib \ - pandas \ - h5py + pandas # for some reason pandas depends on pyarrow already. # Remove once a `pyarrow` version compiled with `numpy>=2.0` is on `conda-forge` python -m pip install \ diff --git a/ci/requirements/all-but-dask.yml b/ci/requirements/all-but-dask.yml index abf6a88690a..119db282ad9 100644 --- a/ci/requirements/all-but-dask.yml +++ b/ci/requirements/all-but-dask.yml @@ -22,7 +22,7 @@ dependencies: - netcdf4 - numba - numbagg - - numpy<2 + - numpy - packaging - pandas - pint>=0.22 diff --git a/ci/requirements/environment-windows.yml b/ci/requirements/environment-windows.yml index 2eedc9b0621..896e390ea3e 100644 --- a/ci/requirements/environment-windows.yml +++ b/ci/requirements/environment-windows.yml @@ -23,7 +23,7 @@ dependencies: - netcdf4 - numba - numbagg - - numpy<2 + - numpy - packaging - pandas # - pint>=0.22 diff --git a/ci/requirements/environment.yml b/ci/requirements/environment.yml index 317e1fe5f41..ef02a3e7f23 100644 --- a/ci/requirements/environment.yml +++ b/ci/requirements/environment.yml @@ -26,7 +26,7 @@ dependencies: - numba - numbagg - numexpr - - numpy<2 + - numpy - opt_einsum - packaging - pandas diff --git a/doc/whats-new.rst b/doc/whats-new.rst index f237b406bd5..e8369dc2f40 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -45,6 +45,8 @@ Bug fixes By `Pontus Lurcock `_. - Allow diffing objects with array attributes on variables (:issue:`9153`, :pull:`9169`). By `Justus Magin `_. +- ``numpy>=2`` compatibility in the ``netcdf4`` backend (:pull:`9136`). + By `Justus Magin `_ and `Kai Mühlbauer `_. - Promote floating-point numeric datetimes before decoding (:issue:`9179`, :pull:`9182`). By `Justus Magin `_. - Address regression introduced in :pull:`9002` that prevented objects returned @@ -67,7 +69,7 @@ Documentation - Adds a flow-chart diagram to help users navigate help resources (`Discussion #8990 `_). By `Jessica Scheick `_. - Improvements to Zarr & chunking docs (:pull:`9139`, :pull:`9140`, :pull:`9132`) - By `Maximilian Roos `_ + By `Maximilian Roos `_. Internal Changes diff --git a/xarray/coding/variables.py b/xarray/coding/variables.py index d31cb6e626a..d19f285d2b9 100644 --- a/xarray/coding/variables.py +++ b/xarray/coding/variables.py @@ -516,10 +516,13 @@ def encode(self, variable: Variable, name: T_Name = None) -> Variable: dims, data, attrs, encoding = unpack_for_encoding(variable) pop_to(encoding, attrs, "_Unsigned") - signed_dtype = np.dtype(f"i{data.dtype.itemsize}") + # we need the on-disk type here + # trying to get it from encoding, resort to an int with the same precision as data.dtype if not available + signed_dtype = np.dtype(encoding.get("dtype", f"i{data.dtype.itemsize}")) if "_FillValue" in attrs: - new_fill = signed_dtype.type(attrs["_FillValue"]) - attrs["_FillValue"] = new_fill + new_fill = np.array(attrs["_FillValue"]) + # use view here to prevent OverflowError + attrs["_FillValue"] = new_fill.view(signed_dtype).item() data = duck_array_ops.astype(duck_array_ops.around(data), signed_dtype) return Variable(dims, data, attrs, encoding, fastpath=True) @@ -535,10 +538,11 @@ def decode(self, variable: Variable, name: T_Name = None) -> Variable: if unsigned == "true": unsigned_dtype = np.dtype(f"u{data.dtype.itemsize}") transform = partial(np.asarray, dtype=unsigned_dtype) - data = lazy_elemwise_func(data, transform, unsigned_dtype) if "_FillValue" in attrs: - new_fill = unsigned_dtype.type(attrs["_FillValue"]) - attrs["_FillValue"] = new_fill + new_fill = np.array(attrs["_FillValue"], dtype=data.dtype) + # use view here to prevent OverflowError + attrs["_FillValue"] = new_fill.view(unsigned_dtype).item() + data = lazy_elemwise_func(data, transform, unsigned_dtype) elif data.dtype.kind == "u": if unsigned == "false": signed_dtype = np.dtype(f"i{data.dtype.itemsize}") diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 15485dc178a..0b90a05262d 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -166,7 +166,7 @@ def create_encoded_masked_and_scaled_data(dtype: np.dtype) -> Dataset: def create_unsigned_masked_scaled_data(dtype: np.dtype) -> Dataset: encoding = { - "_FillValue": 255, + "_FillValue": np.int8(-1), "_Unsigned": "true", "dtype": "i1", "add_offset": dtype.type(10), @@ -925,6 +925,35 @@ def test_roundtrip_mask_and_scale(self, decoded_fn, encoded_fn, dtype) -> None: assert decoded.variables[k].dtype == actual.variables[k].dtype assert_allclose(decoded, actual, decode_bytes=False) + @pytest.mark.parametrize("fillvalue", [np.int8(-1), np.uint8(255)]) + def test_roundtrip_unsigned(self, fillvalue): + # regression/numpy2 test for + encoding = { + "_FillValue": fillvalue, + "_Unsigned": "true", + "dtype": "i1", + } + x = np.array([0, 1, 127, 128, 254, np.nan], dtype=np.float32) + decoded = Dataset({"x": ("t", x, {}, encoding)}) + + attributes = { + "_FillValue": fillvalue, + "_Unsigned": "true", + } + # Create unsigned data corresponding to [0, 1, 127, 128, 255] unsigned + sb = np.asarray([0, 1, 127, -128, -2, -1], dtype="i1") + encoded = Dataset({"x": ("t", sb, attributes)}) + + with self.roundtrip(decoded) as actual: + for k in decoded.variables: + assert decoded.variables[k].dtype == actual.variables[k].dtype + assert_allclose(decoded, actual, decode_bytes=False) + + with self.roundtrip(decoded, open_kwargs=dict(decode_cf=False)) as actual: + for k in encoded.variables: + assert encoded.variables[k].dtype == actual.variables[k].dtype + assert_allclose(encoded, actual, decode_bytes=False) + @staticmethod def _create_cf_dataset(): original = Dataset( @@ -4285,7 +4314,7 @@ def test_roundtrip_coordinates_with_space(self) -> None: def test_roundtrip_numpy_datetime_data(self) -> None: # Override method in DatasetIOBase - remove not applicable # save_kwargs - times = pd.to_datetime(["2000-01-01", "2000-01-02", "NaT"]) + times = pd.to_datetime(["2000-01-01", "2000-01-02", "NaT"], unit="ns") expected = Dataset({"t": ("t", times), "t0": times[0]}) with self.roundtrip(expected) as actual: assert_identical(expected, actual)