From a765ae0e2b76ec8d66cfacf1e9920b6b3c2d69b7 Mon Sep 17 00:00:00 2001 From: Dimitri Papadopoulos Orfanos <3234522+DimitriPapadopoulos@users.noreply.github.com> Date: Mon, 25 Nov 2024 17:55:39 +0100 Subject: [PATCH 01/23] Upgrade ruff to 0.8.0 (#9816) --- .pre-commit-config.yaml | 2 +- asv_bench/benchmarks/dataset_io.py | 2 +- pyproject.toml | 3 +- xarray/__init__.py | 6 +-- xarray/coding/cftime_offsets.py | 3 +- xarray/conventions.py | 7 +-- xarray/core/dataset.py | 8 ++- xarray/plot/utils.py | 7 ++- xarray/testing/assertions.py | 8 +-- xarray/tests/__init__.py | 3 +- xarray/tests/test_backends.py | 3 +- xarray/tests/test_dataset.py | 5 +- xarray/tests/test_formatting.py | 14 +++-- xarray/ufuncs.py | 82 +++++++++++++++--------------- 14 files changed, 70 insertions(+), 83 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e2bffbfefde..d543a36edd3 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -25,7 +25,7 @@ repos: - id: text-unicode-replacement-char - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.7.2 + rev: v0.8.0 hooks: - id: ruff-format - id: ruff diff --git a/asv_bench/benchmarks/dataset_io.py b/asv_bench/benchmarks/dataset_io.py index 3a09288c8dc..f1296a8b44f 100644 --- a/asv_bench/benchmarks/dataset_io.py +++ b/asv_bench/benchmarks/dataset_io.py @@ -305,7 +305,7 @@ def make_ds(self, nfiles=10): ds.attrs = {"history": "created for xarray benchmarking"} self.ds_list.append(ds) - self.filenames_list.append("test_netcdf_%i.nc" % i) + self.filenames_list.append(f"test_netcdf_{i}.nc") class IOWriteMultipleNetCDF3(IOMultipleNetCDF): diff --git a/pyproject.toml b/pyproject.toml index dab280f0eba..3ac1a024195 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,7 +45,7 @@ dev = [ "pytest-env", "pytest-xdist", "pytest-timeout", - "ruff", + "ruff>=0.8.0", "sphinx", "sphinx_autosummary_accessors", "xarray[complete]", @@ -256,7 +256,6 @@ ignore = [ "E501", # line too long - let the formatter worry about that "E731", # do not assign a lambda expression, use a def "UP007", # use X | Y for type annotations - "UP027", # deprecated "C40", # unnecessary generator, comprehension, or literal "PIE790", # unnecessary pass statement "PERF203", # try-except within a loop incurs performance overhead diff --git a/xarray/__init__.py b/xarray/__init__.py index 634f67a61a2..622c927b468 100644 --- a/xarray/__init__.py +++ b/xarray/__init__.py @@ -64,7 +64,7 @@ # A hardcoded __all__ variable is necessary to appease # `mypy --strict` running in projects that import xarray. -__all__ = ( +__all__ = ( # noqa: RUF022 # Sub-packages "groupers", "testing", @@ -117,8 +117,8 @@ "Context", "Coordinates", "DataArray", - "Dataset", "DataTree", + "Dataset", "Index", "IndexSelResult", "IndexVariable", @@ -131,6 +131,6 @@ "SerializationWarning", "TreeIsomorphismError", # Constants - "__version__", "ALL_DIMS", + "__version__", ) diff --git a/xarray/coding/cftime_offsets.py b/xarray/coding/cftime_offsets.py index 9677a406471..a994eb9661f 100644 --- a/xarray/coding/cftime_offsets.py +++ b/xarray/coding/cftime_offsets.py @@ -1451,8 +1451,7 @@ def date_range_like(source, calendar, use_cftime=None): from xarray.core.dataarray import DataArray if not isinstance(source, pd.DatetimeIndex | CFTimeIndex) and ( - isinstance(source, DataArray) - and (source.ndim != 1) + (isinstance(source, DataArray) and (source.ndim != 1)) or not _contains_datetime_like_objects(source.variable) ): raise ValueError( diff --git a/xarray/conventions.py b/xarray/conventions.py index e4e71a481e8..5b57c160850 100644 --- a/xarray/conventions.py +++ b/xarray/conventions.py @@ -726,11 +726,8 @@ def _encode_coordinates( ) # if coordinates set to None, don't write coordinates attribute - if ( - "coordinates" in attrs - and attrs.get("coordinates") is None - or "coordinates" in encoding - and encoding.get("coordinates") is None + if ("coordinates" in attrs and attrs.get("coordinates") is None) or ( + "coordinates" in encoding and encoding.get("coordinates") is None ): # make sure "coordinates" is removed from attrs/encoding attrs.pop("coordinates", None) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index e80ce5fa64a..ee6d272ad66 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -5401,11 +5401,9 @@ def _get_stack_index( and var.dims[0] == dim and ( # stack: must be a single coordinate index - not multi - and not self.xindexes.is_multi(name) + (not multi and not self.xindexes.is_multi(name)) # unstack: must be an index that implements .unstack - or multi - and type(index).unstack is not Index.unstack + or (multi and type(index).unstack is not Index.unstack) ) ): if stack_index is not None and index is not stack_index: @@ -7617,7 +7615,7 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self: if isinstance(idx, pd.MultiIndex): dims = tuple( - name if name is not None else "level_%i" % n # type: ignore[redundant-expr] + name if name is not None else f"level_{n}" # type: ignore[redundant-expr] for n, name in enumerate(idx.names) ) for dim, lev in zip(dims, idx.levels, strict=True): diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index 717adc41ffd..c1b5b29c7bf 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -869,11 +869,11 @@ def _infer_interval_breaks(coord, axis=0, scale=None, check_monotonic=False): if check_monotonic and not _is_monotonic(coord, axis=axis): raise ValueError( "The input coordinate is not sorted in increasing " - "order along axis %d. This can lead to unexpected " + f"order along axis {axis}. This can lead to unexpected " "results. Consider calling the `sortby` method on " "the input DataArray. To plot data with categorical " "axes, consider using the `heatmap` function from " - "the `seaborn` statistical plotting library." % axis + "the `seaborn` statistical plotting library." ) # If logscale, compute the intervals in the logarithmic space @@ -1708,8 +1708,7 @@ def _determine_guide( if ( not add_colorbar and (hueplt_norm.data is not None and hueplt_norm.data_is_numeric is False) - or sizeplt_norm.data is not None - ): + ) or sizeplt_norm.data is not None: add_legend = True else: add_legend = False diff --git a/xarray/testing/assertions.py b/xarray/testing/assertions.py index d2b01677ce4..026fff6ffba 100644 --- a/xarray/testing/assertions.py +++ b/xarray/testing/assertions.py @@ -124,8 +124,8 @@ def assert_equal(a, b, check_dim_order: bool = True): numpy.testing.assert_array_equal """ __tracebackhide__ = True - assert ( - type(a) is type(b) or isinstance(a, Coordinates) and isinstance(b, Coordinates) + assert type(a) is type(b) or ( + isinstance(a, Coordinates) and isinstance(b, Coordinates) ) b = maybe_transpose_dims(a, b, check_dim_order) if isinstance(a, Variable | DataArray): @@ -163,8 +163,8 @@ def assert_identical(a, b): assert_equal, assert_allclose, Dataset.equals, DataArray.equals """ __tracebackhide__ = True - assert ( - type(a) is type(b) or isinstance(a, Coordinates) and isinstance(b, Coordinates) + assert type(a) is type(b) or ( + isinstance(a, Coordinates) and isinstance(b, Coordinates) ) if isinstance(a, Variable): assert a.identical(b), formatting.diff_array_repr(a, b, "identical") diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 442a0e51398..3aafbfcb0c1 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -210,8 +210,7 @@ def __call__(self, dsk, keys, **kwargs): self.total_computes += 1 if self.total_computes > self.max_computes: raise RuntimeError( - "Too many computes. Total: %d > max: %d." - % (self.total_computes, self.max_computes) + f"Too many computes. Total: {self.total_computes} > max: {self.max_computes}." ) return dask.get(dsk, keys, **kwargs) diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 3ca7c08eb2f..8cb26f8482c 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -963,8 +963,7 @@ def test_roundtrip_mask_and_scale(self, decoded_fn, encoded_fn, dtype) -> None: decoded = decoded_fn(dtype) encoded = encoded_fn(dtype) if decoded["x"].encoding["dtype"] == "u1" and not ( - self.engine == "netcdf4" - and self.file_format is None + (self.engine == "netcdf4" and self.file_format is None) or self.file_format == "NETCDF4" ): pytest.skip("uint8 data can't be written to non-NetCDF4 data") diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 67d38aac0fe..d9c5aac57ab 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -388,16 +388,15 @@ def test_unicode_data(self) -> None: byteorder = "<" if sys.byteorder == "little" else ">" expected = dedent( - """\ + f"""\ Size: 12B Dimensions: (foø: 1) Coordinates: - * foø (foø) %cU3 12B %r + * foø (foø) {byteorder}U3 12B {'ba®'!r} Data variables: *empty* Attributes: å: ∑""" - % (byteorder, "ba®") ) actual = str(data) assert expected == actual diff --git a/xarray/tests/test_formatting.py b/xarray/tests/test_formatting.py index c7af13415c0..946d491bd61 100644 --- a/xarray/tests/test_formatting.py +++ b/xarray/tests/test_formatting.py @@ -295,7 +295,7 @@ def test_diff_array_repr(self) -> None: byteorder = "<" if sys.byteorder == "little" else ">" expected = dedent( - """\ + f"""\ Left and right DataArray objects are not identical Differing dimensions: (x: 2, y: 3) != (x: 2) @@ -306,8 +306,8 @@ def test_diff_array_repr(self) -> None: R array([1, 2], dtype=int64) Differing coordinates: - L * x (x) %cU1 8B 'a' 'b' - R * x (x) %cU1 8B 'a' 'c' + L * x (x) {byteorder}U1 8B 'a' 'b' + R * x (x) {byteorder}U1 8B 'a' 'c' Coordinates only on the left object: * y (y) int64 24B 1 2 3 Coordinates only on the right object: @@ -317,7 +317,6 @@ def test_diff_array_repr(self) -> None: R units: kg Attributes only on the left object: description: desc""" - % (byteorder, byteorder) ) actual = formatting.diff_array_repr(da_a, da_b, "identical") @@ -496,15 +495,15 @@ def test_diff_dataset_repr(self) -> None: byteorder = "<" if sys.byteorder == "little" else ">" expected = dedent( - """\ + f"""\ Left and right Dataset objects are not identical Differing dimensions: (x: 2, y: 3) != (x: 2) Differing coordinates: - L * x (x) %cU1 8B 'a' 'b' + L * x (x) {byteorder}U1 8B 'a' 'b' Differing variable attributes: foo: bar - R * x (x) %cU1 8B 'a' 'c' + R * x (x) {byteorder}U1 8B 'a' 'c' Differing variable attributes: source: 0 foo: baz @@ -522,7 +521,6 @@ def test_diff_dataset_repr(self) -> None: R title: newtitle Attributes only on the left object: description: desc""" - % (byteorder, byteorder) ) actual = formatting.diff_dataset_repr(ds_a, ds_b, "identical") diff --git a/xarray/ufuncs.py b/xarray/ufuncs.py index cedece4c68f..e25657216fd 100644 --- a/xarray/ufuncs.py +++ b/xarray/ufuncs.py @@ -247,70 +247,45 @@ def _dedent(doc): "absolute", "acos", "acosh", + "add", + "angle", "arccos", "arccosh", "arcsin", "arcsinh", "arctan", + "arctan2", "arctanh", "asin", "asinh", "atan", + "atan2", "atanh", + "bitwise_and", "bitwise_count", "bitwise_invert", + "bitwise_left_shift", "bitwise_not", + "bitwise_or", + "bitwise_right_shift", + "bitwise_xor", "cbrt", "ceil", "conj", "conjugate", + "copysign", "cos", "cosh", "deg2rad", "degrees", + "divide", + "equal", "exp", "exp2", "expm1", "fabs", - "floor", - "invert", - "isfinite", - "isinf", - "isnan", - "isnat", - "log", - "log10", - "log1p", - "log2", - "logical_not", - "negative", - "positive", - "rad2deg", - "radians", - "reciprocal", - "rint", - "sign", - "signbit", - "sin", - "sinh", - "spacing", - "sqrt", - "square", - "tan", - "tanh", - "trunc", - "add", - "arctan2", - "atan2", - "bitwise_and", - "bitwise_left_shift", - "bitwise_or", - "bitwise_right_shift", - "bitwise_xor", - "copysign", - "divide", - "equal", "float_power", + "floor", "floor_divide", "fmax", "fmin", @@ -320,29 +295,54 @@ def _dedent(doc): "greater_equal", "heaviside", "hypot", + "invert", + "iscomplex", + "isfinite", + "isinf", + "isnan", + "isnat", + "isreal", "lcm", "ldexp", "left_shift", "less", "less_equal", + "log", + "log1p", + "log2", + "log10", "logaddexp", "logaddexp2", "logical_and", + "logical_not", "logical_or", "logical_xor", "maximum", "minimum", "mod", "multiply", + "negative", "nextafter", "not_equal", + "positive", "pow", "power", + "rad2deg", + "radians", + "reciprocal", "remainder", "right_shift", + "rint", + "sign", + "signbit", + "sin", + "sinh", + "spacing", + "sqrt", + "square", "subtract", + "tan", + "tanh", "true_divide", - "angle", - "isreal", - "iscomplex", + "trunc", ] From caf62d3cc37da7af4530fbe87ee189cb9e436765 Mon Sep 17 00:00:00 2001 From: Sam Levang <39069044+slevang@users.noreply.github.com> Date: Tue, 26 Nov 2024 00:31:19 +0000 Subject: [PATCH 02/23] Improved duck array wrapping (#9798) * lots more duck array compat, plus tests * merge sliding_window_view * namespaces constant * revert dask allowed * fix up some tests * backwards compat sparse mask * add as_array methods * to_like_array helper * only cast non-numpy * better idxminmax approach * fix mypy * naming, add is_array_type * add public doc and whats new * update comments * add support for chunked arrays in as_array_type * revert array_type methods * fix up whats new * comment about bool_ * add jax to complete ci envs * add pint and sparse to tests * remove from windows * mypy, xfail one more sparse * add dask and a few other methods * move whats new --- ci/requirements/environment-3.13.yml | 2 + ci/requirements/environment.yml | 2 + doc/whats-new.rst | 3 + xarray/core/array_api_compat.py | 38 ++ xarray/core/common.py | 4 +- xarray/core/computation.py | 16 +- xarray/core/dataset.py | 10 +- xarray/core/duck_array_ops.py | 140 ++++--- xarray/core/nanops.py | 2 +- xarray/core/nputils.py | 9 + xarray/core/rolling.py | 3 + xarray/core/variable.py | 36 +- xarray/tests/test_duck_array_wrapping.py | 510 +++++++++++++++++++++++ xarray/tests/test_strategies.py | 19 +- xarray/tests/test_variable.py | 39 +- 15 files changed, 730 insertions(+), 103 deletions(-) create mode 100644 xarray/tests/test_duck_array_wrapping.py diff --git a/ci/requirements/environment-3.13.yml b/ci/requirements/environment-3.13.yml index dbb446f4454..937cb013711 100644 --- a/ci/requirements/environment-3.13.yml +++ b/ci/requirements/environment-3.13.yml @@ -47,3 +47,5 @@ dependencies: - toolz - typing_extensions - zarr + - pip: + - jax # no way to get cpu-only jaxlib from conda if gpu is present diff --git a/ci/requirements/environment.yml b/ci/requirements/environment.yml index 43938880592..364ae03666f 100644 --- a/ci/requirements/environment.yml +++ b/ci/requirements/environment.yml @@ -49,3 +49,5 @@ dependencies: - toolz - typing_extensions - zarr + - pip: + - jax # no way to get cpu-only jaxlib from conda if gpu is present diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 0da34df2c1a..906fd0a25b2 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -21,6 +21,9 @@ v.2024.11.1 (unreleased) New Features ~~~~~~~~~~~~ +- Better support wrapping additional array types (e.g. ``cupy`` or ``jax``) by calling generalized + duck array operations throughout more xarray methods. (:issue:`7848`, :pull:`9798`). + By `Sam Levang `_. Breaking changes diff --git a/xarray/core/array_api_compat.py b/xarray/core/array_api_compat.py index da072de5b69..e1e5d5c5bdc 100644 --- a/xarray/core/array_api_compat.py +++ b/xarray/core/array_api_compat.py @@ -1,5 +1,7 @@ import numpy as np +from xarray.namedarray.pycompat import array_type + def is_weak_scalar_type(t): return isinstance(t, bool | int | float | complex | str | bytes) @@ -42,3 +44,39 @@ def result_type(*arrays_and_dtypes, xp) -> np.dtype: return xp.result_type(*arrays_and_dtypes) else: return _future_array_api_result_type(*arrays_and_dtypes, xp=xp) + + +def get_array_namespace(*values): + def _get_single_namespace(x): + if hasattr(x, "__array_namespace__"): + return x.__array_namespace__() + elif isinstance(x, array_type("cupy")): + # cupy is fully compliant from xarray's perspective, but will not expose + # __array_namespace__ until at least v14. Special case it for now + import cupy as cp + + return cp + else: + return np + + namespaces = {_get_single_namespace(t) for t in values} + non_numpy = namespaces - {np} + + if len(non_numpy) > 1: + names = [module.__name__ for module in non_numpy] + raise TypeError(f"Mixed array types {names} are not supported.") + elif non_numpy: + [xp] = non_numpy + else: + xp = np + + return xp + + +def to_like_array(array, like): + # Mostly for cupy compatibility, because cupy binary ops require all cupy arrays + xp = get_array_namespace(like) + if xp is not np: + return xp.asarray(array) + # avoid casting things like pint quantities to numpy arrays + return array diff --git a/xarray/core/common.py b/xarray/core/common.py index 6f788f408d0..32135996d3c 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -496,7 +496,7 @@ def clip( keep_attrs = _get_keep_attrs(default=True) return apply_ufunc( - np.clip, self, min, max, keep_attrs=keep_attrs, dask="allowed" + duck_array_ops.clip, self, min, max, keep_attrs=keep_attrs, dask="allowed" ) def get_index(self, key: Hashable) -> pd.Index: @@ -1760,7 +1760,7 @@ def _full_like_variable( **from_array_kwargs, ) else: - data = np.full_like(other.data, fill_value, dtype=dtype) + data = duck_array_ops.full_like(other.data, fill_value, dtype=dtype) return Variable(dims=other.dims, data=data, attrs=other.attrs) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index b15ed7f3f34..6e233425e95 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -24,6 +24,7 @@ from xarray.core import dtypes, duck_array_ops, utils from xarray.core.alignment import align, deep_align +from xarray.core.array_api_compat import to_like_array from xarray.core.common import zeros_like from xarray.core.duck_array_ops import datetime_to_numeric from xarray.core.formatting import limit_lines @@ -1702,7 +1703,7 @@ def cross( ) c = apply_ufunc( - np.cross, + duck_array_ops.cross, a, b, input_core_dims=[[dim], [dim]], @@ -2170,13 +2171,14 @@ def _calc_idxminmax( chunks = dict(zip(array.dims, array.chunks, strict=True)) dask_coord = chunkmanager.from_array(array[dim].data, chunks=chunks[dim]) data = dask_coord[duck_array_ops.ravel(indx.data)] - res = indx.copy(data=duck_array_ops.reshape(data, indx.shape)) - # we need to attach back the dim name - res.name = dim else: - res = array[dim][(indx,)] - # The dim is gone but we need to remove the corresponding coordinate. - del res.coords[dim] + arr_coord = to_like_array(array[dim].data, array.data) + data = arr_coord[duck_array_ops.ravel(indx.data)] + + # rebuild like the argmin/max output, and rename as the dim name + data = duck_array_ops.reshape(data, indx.shape) + res = indx.copy(data=data) + res.name = dim if skipna or (skipna is None and array.dtype.kind in na_dtypes): # Put the NaN values back in after removing them diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index ee6d272ad66..f96b62f701e 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -55,6 +55,7 @@ align, ) from xarray.core.arithmetic import DatasetArithmetic +from xarray.core.array_api_compat import to_like_array from xarray.core.common import ( DataWithCoords, _contains_datetime_like_objects, @@ -127,7 +128,7 @@ calculate_dimensions, ) from xarray.namedarray.parallelcompat import get_chunked_array_type, guess_chunkmanager -from xarray.namedarray.pycompat import array_type, is_chunked_array +from xarray.namedarray.pycompat import array_type, is_chunked_array, to_numpy from xarray.plot.accessor import DatasetPlotAccessor from xarray.util.deprecation_helpers import _deprecate_positional_args, deprecate_dims @@ -6620,7 +6621,7 @@ def dropna( array = self._variables[k] if dim in array.dims: dims = [d for d in array.dims if d != dim] - count += np.asarray(array.count(dims)) + count += to_numpy(array.count(dims).data) size += math.prod([self.sizes[d] for d in dims]) if thresh is not None: @@ -8734,16 +8735,17 @@ def _integrate_one(self, coord, datetime_unit=None, cumulative=False): coord_names.add(k) else: if k in self.data_vars and dim in v.dims: + coord_data = to_like_array(coord_var.data, like=v.data) if _contains_datetime_like_objects(v): v = datetime_to_numeric(v, datetime_unit=datetime_unit) if cumulative: integ = duck_array_ops.cumulative_trapezoid( - v.data, coord_var.data, axis=v.get_axis_num(dim) + v.data, coord_data, axis=v.get_axis_num(dim) ) v_dims = v.dims else: integ = duck_array_ops.trapz( - v.data, coord_var.data, axis=v.get_axis_num(dim) + v.data, coord_data, axis=v.get_axis_num(dim) ) v_dims = list(v.dims) v_dims.remove(dim) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 0b915166279..7e7333fd8ea 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -18,21 +18,16 @@ import pandas as pd from numpy import all as array_all # noqa: F401 from numpy import any as array_any # noqa: F401 -from numpy import concatenate as _concatenate from numpy import ( # noqa: F401 - full_like, - gradient, isclose, - isin, isnat, take, - tensordot, - transpose, unravel_index, ) from pandas.api.types import is_extension_array_dtype from xarray.core import dask_array_compat, dask_array_ops, dtypes, nputils +from xarray.core.array_api_compat import get_array_namespace from xarray.core.options import OPTIONS from xarray.core.utils import is_duck_array, is_duck_dask_array, module_available from xarray.namedarray.parallelcompat import get_chunked_array_type @@ -52,28 +47,6 @@ dask_available = module_available("dask") -def get_array_namespace(*values): - def _get_array_namespace(x): - if hasattr(x, "__array_namespace__"): - return x.__array_namespace__() - else: - return np - - namespaces = {_get_array_namespace(t) for t in values} - non_numpy = namespaces - {np} - - if len(non_numpy) > 1: - raise TypeError( - "cannot deal with more than one type supporting the array API at the same time" - ) - elif non_numpy: - [xp] = non_numpy - else: - xp = np - - return xp - - def einsum(*args, **kwargs): from xarray.core.options import OPTIONS @@ -82,7 +55,23 @@ def einsum(*args, **kwargs): return opt_einsum.contract(*args, **kwargs) else: - return np.einsum(*args, **kwargs) + xp = get_array_namespace(*args) + return xp.einsum(*args, **kwargs) + + +def tensordot(*args, **kwargs): + xp = get_array_namespace(*args) + return xp.tensordot(*args, **kwargs) + + +def cross(*args, **kwargs): + xp = get_array_namespace(*args) + return xp.cross(*args, **kwargs) + + +def gradient(f, *varargs, axis=None, edge_order=1): + xp = get_array_namespace(f) + return xp.gradient(f, *varargs, axis=axis, edge_order=edge_order) def _dask_or_eager_func( @@ -131,15 +120,20 @@ def fail_on_dask_array_input(values, msg=None, func_name=None): "masked_invalid", eager_module=np.ma, dask_module="dask.array.ma" ) -# sliding_window_view will not dispatch arbitrary kwargs (automatic_rechunk), -# so we need to hand-code this. -sliding_window_view = _dask_or_eager_func( - "sliding_window_view", - eager_module=np.lib.stride_tricks, - dask_module=dask_array_compat, - dask_only_kwargs=("automatic_rechunk",), - numpy_only_kwargs=("subok", "writeable"), -) + +def sliding_window_view(array, window_shape, axis=None, **kwargs): + # TODO: some libraries (e.g. jax) don't have this, implement an alternative? + xp = get_array_namespace(array) + # sliding_window_view will not dispatch arbitrary kwargs (automatic_rechunk), + # so we need to hand-code this. + func = _dask_or_eager_func( + "sliding_window_view", + eager_module=xp.lib.stride_tricks, + dask_module=dask_array_compat, + dask_only_kwargs=("automatic_rechunk",), + numpy_only_kwargs=("subok", "writeable"), + ) + return func(array, window_shape, axis=axis, **kwargs) def round(array): @@ -172,7 +166,9 @@ def isnull(data): ) ): # these types cannot represent missing values - return full_like(data, dtype=bool, fill_value=False) + # bool_ is for backwards compat with numpy<2, and cupy + dtype = xp.bool_ if hasattr(xp, "bool_") else xp.bool + return full_like(data, dtype=dtype, fill_value=False) else: # at this point, array should have dtype=object if isinstance(data, np.ndarray) or is_extension_array_dtype(data): @@ -213,11 +209,23 @@ def cumulative_trapezoid(y, x, axis): # Pad so that 'axis' has same length in result as it did in y pads = [(1, 0) if i == axis else (0, 0) for i in range(y.ndim)] - integrand = np.pad(integrand, pads, mode="constant", constant_values=0.0) + + xp = get_array_namespace(y, x) + integrand = xp.pad(integrand, pads, mode="constant", constant_values=0.0) return cumsum(integrand, axis=axis, skipna=False) +def full_like(a, fill_value, **kwargs): + xp = get_array_namespace(a) + return xp.full_like(a, fill_value, **kwargs) + + +def empty_like(a, **kwargs): + xp = get_array_namespace(a) + return xp.empty_like(a, **kwargs) + + def astype(data, dtype, **kwargs): if hasattr(data, "__array_namespace__"): xp = get_array_namespace(data) @@ -348,7 +356,8 @@ def array_notnull_equiv(arr1, arr2): def count(data, axis=None): """Count the number of non-NA in this array along the given axis or axes""" - return np.sum(np.logical_not(isnull(data)), axis=axis) + xp = get_array_namespace(data) + return xp.sum(xp.logical_not(isnull(data)), axis=axis) def sum_where(data, axis=None, dtype=None, where=None): @@ -363,7 +372,7 @@ def sum_where(data, axis=None, dtype=None, where=None): def where(condition, x, y): """Three argument where() with better dtype promotion rules.""" - xp = get_array_namespace(condition) + xp = get_array_namespace(condition, x, y) return xp.where(condition, *as_shared_dtype([x, y], xp=xp)) @@ -380,15 +389,25 @@ def fillna(data, other): return where(notnull(data), data, other) +def logical_not(data): + xp = get_array_namespace(data) + return xp.logical_not(data) + + +def clip(data, min=None, max=None): + xp = get_array_namespace(data) + return xp.clip(data, min, max) + + def concatenate(arrays, axis=0): """concatenate() with better dtype promotion rules.""" - # TODO: remove the additional check once `numpy` adds `concat` to its array namespace - if hasattr(arrays[0], "__array_namespace__") and not isinstance( - arrays[0], np.ndarray - ): - xp = get_array_namespace(arrays[0]) + # TODO: `concat` is the xp compliant name, but fallback to concatenate for + # older numpy and for cupy + xp = get_array_namespace(*arrays) + if hasattr(xp, "concat"): return xp.concat(as_shared_dtype(arrays, xp=xp), axis=axis) - return _concatenate(as_shared_dtype(arrays), axis=axis) + else: + return xp.concatenate(as_shared_dtype(arrays, xp=xp), axis=axis) def stack(arrays, axis=0): @@ -406,6 +425,26 @@ def ravel(array): return reshape(array, (-1,)) +def transpose(array, axes=None): + xp = get_array_namespace(array) + return xp.transpose(array, axes) + + +def moveaxis(array, source, destination): + xp = get_array_namespace(array) + return xp.moveaxis(array, source, destination) + + +def pad(array, pad_width, **kwargs): + xp = get_array_namespace(array) + return xp.pad(array, pad_width, **kwargs) + + +def quantile(array, q, axis=None, **kwargs): + xp = get_array_namespace(array) + return xp.quantile(array, q, axis=axis, **kwargs) + + @contextlib.contextmanager def _ignore_warnings_if(condition): if condition: @@ -747,6 +786,11 @@ def last(values, axis, skipna=None): return take(values, -1, axis=axis) +def isin(element, test_elements, **kwargs): + xp = get_array_namespace(element, test_elements) + return xp.isin(element, test_elements, **kwargs) + + def least_squares(lhs, rhs, rcond=None, skipna=False): """Return the coefficients and residuals of a least-squares fit.""" if is_duck_dask_array(rhs): diff --git a/xarray/core/nanops.py b/xarray/core/nanops.py index 7fbb63068c0..4894cf02be2 100644 --- a/xarray/core/nanops.py +++ b/xarray/core/nanops.py @@ -128,7 +128,7 @@ def nanmean(a, axis=None, dtype=None, out=None): "ignore", r"Mean of empty slice", category=RuntimeWarning ) - return np.nanmean(a, axis=axis, dtype=dtype) + return nputils.nanmean(a, axis=axis, dtype=dtype) def nanmedian(a, axis=None, out=None): diff --git a/xarray/core/nputils.py b/xarray/core/nputils.py index bf5dfa1bc32..3211ab296e6 100644 --- a/xarray/core/nputils.py +++ b/xarray/core/nputils.py @@ -7,6 +7,7 @@ import pandas as pd from packaging.version import Version +from xarray.core.array_api_compat import get_array_namespace from xarray.core.utils import is_duck_array, module_available from xarray.namedarray import pycompat @@ -179,6 +180,11 @@ def f(values, axis=None, **kwargs): dtype = kwargs.get("dtype") bn_func = getattr(bn, name, None) + xp = get_array_namespace(values) + if xp is not np: + func = getattr(xp, name, None) + if func is not None: + return func(values, axis=axis, **kwargs) if ( module_available("numbagg") and OPTIONS["use_numbagg"] @@ -229,6 +235,9 @@ def f(values, axis=None, **kwargs): # bottleneck does not take care dtype, min_count kwargs.pop("dtype", None) result = bn_func(values, axis=axis, **kwargs) + # bottleneck returns python scalars for reduction over all axes + if isinstance(result, float): + result = np.float64(result) else: result = getattr(npmodule, name)(values, axis=axis, **kwargs) diff --git a/xarray/core/rolling.py b/xarray/core/rolling.py index cb16c3723ca..fde87841d32 100644 --- a/xarray/core/rolling.py +++ b/xarray/core/rolling.py @@ -708,6 +708,7 @@ def _array_reduce( ) del kwargs["dim"] + xp = duck_array_ops.get_array_namespace(self.obj.data) if ( OPTIONS["use_numbagg"] and module_available("numbagg") @@ -722,6 +723,7 @@ def _array_reduce( # TODO: we could also allow this, probably as part of a refactoring of this # module, so we can use the machinery in `self.reduce`. and self.ndim == 1 + and xp is np ): import numbagg @@ -744,6 +746,7 @@ def _array_reduce( or module_available("dask", "2024.11.0") ) and self.ndim == 1 + and xp is np ): return self._bottleneck_reduce( bottleneck_move_func, keep_attrs=keep_attrs, **kwargs diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 9f660d0878a..07113d66b5b 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -19,6 +19,7 @@ import xarray as xr # only for Dataset and DataArray from xarray.core import common, dtypes, duck_array_ops, indexing, nputils, ops, utils from xarray.core.arithmetic import VariableArithmetic +from xarray.core.array_api_compat import to_like_array from xarray.core.common import AbstractArray from xarray.core.extension_array import PandasExtensionArray from xarray.core.indexing import ( @@ -828,7 +829,7 @@ def __getitem__(self, key) -> Self: data = indexing.apply_indexer(indexable, indexer) if new_order: - data = np.moveaxis(data, range(len(new_order)), new_order) + data = duck_array_ops.moveaxis(data, range(len(new_order)), new_order) return self._finalize_indexing_result(dims, data) def _finalize_indexing_result(self, dims, data) -> Self: @@ -866,12 +867,15 @@ def _getitem_with_mask(self, key, fill_value=dtypes.NA): # we need to invert the mask in order to pass data first. This helps # pint to choose the correct unit # TODO: revert after https://github.com/hgrecco/pint/issues/1019 is fixed - data = duck_array_ops.where(np.logical_not(mask), data, fill_value) + mask = to_like_array(mask, data) + data = duck_array_ops.where( + duck_array_ops.logical_not(mask), data, fill_value + ) else: # array cannot be indexed along dimensions of size 0, so just # build the mask directly instead. mask = indexing.create_mask(indexer, self.shape) - data = np.broadcast_to(fill_value, getattr(mask, "shape", ())) + data = duck_array_ops.broadcast_to(fill_value, getattr(mask, "shape", ())) if new_order: data = duck_array_ops.moveaxis(data, range(len(new_order)), new_order) @@ -902,7 +906,7 @@ def __setitem__(self, key, value): if new_order: value = duck_array_ops.asarray(value) value = value[(len(dims) - value.ndim) * (np.newaxis,) + (Ellipsis,)] - value = np.moveaxis(value, new_order, range(len(new_order))) + value = duck_array_ops.moveaxis(value, new_order, range(len(new_order))) indexable = as_indexable(self._data) indexing.set_with_indexer(indexable, index_tuple, value) @@ -1122,7 +1126,7 @@ def _shift_one_dim(self, dim, count, fill_value=dtypes.NA): dim_pad = (width, 0) if count >= 0 else (0, width) pads = [(0, 0) if d != dim else dim_pad for d in self.dims] - data = np.pad( + data = duck_array_ops.pad( duck_array_ops.astype(trimmed_data, dtype), pads, mode="constant", @@ -1268,7 +1272,7 @@ def pad( if reflect_type is not None: pad_option_kwargs["reflect_type"] = reflect_type - array = np.pad( + array = duck_array_ops.pad( duck_array_ops.astype(self.data, dtype, copy=False), pad_width_by_index, mode=mode, @@ -1557,14 +1561,16 @@ def _unstack_once( if is_missing_values: dtype, fill_value = dtypes.maybe_promote(self.dtype) - create_template = partial(np.full_like, fill_value=fill_value) + create_template = partial( + duck_array_ops.full_like, fill_value=fill_value + ) else: dtype = self.dtype fill_value = dtypes.get_fill_value(dtype) - create_template = np.empty_like + create_template = duck_array_ops.empty_like else: dtype = self.dtype - create_template = partial(np.full_like, fill_value=fill_value) + create_template = partial(duck_array_ops.full_like, fill_value=fill_value) if sparse: # unstacking a dense multitindexed array to a sparse array @@ -1654,7 +1660,8 @@ def clip(self, min=None, max=None): """ from xarray.core.computation import apply_ufunc - return apply_ufunc(np.clip, self, min, max, dask="allowed") + xp = duck_array_ops.get_array_namespace(self.data) + return apply_ufunc(xp.clip, self, min, max, dask="allowed") def reduce( # type: ignore[override] self, @@ -1947,7 +1954,7 @@ def quantile( if skipna or (skipna is None and self.dtype.kind in "cfO"): _quantile_func = nputils.nanquantile else: - _quantile_func = np.quantile + _quantile_func = duck_array_ops.quantile if keep_attrs is None: keep_attrs = _get_keep_attrs(default=False) @@ -1961,11 +1968,14 @@ def quantile( if utils.is_scalar(dim): dim = [dim] + xp = duck_array_ops.get_array_namespace(self.data) + def _wrapper(npa, **kwargs): # move quantile axis to end. required for apply_ufunc - return np.moveaxis(_quantile_func(npa, **kwargs), 0, -1) + return xp.moveaxis(_quantile_func(npa, **kwargs), 0, -1) - axis = np.arange(-1, -1 * len(dim) - 1, -1) + # jax requires hashable + axis = tuple(range(-1, -1 * len(dim) - 1, -1)) kwargs = {"q": q, "axis": axis, "method": method} diff --git a/xarray/tests/test_duck_array_wrapping.py b/xarray/tests/test_duck_array_wrapping.py new file mode 100644 index 00000000000..59928dce370 --- /dev/null +++ b/xarray/tests/test_duck_array_wrapping.py @@ -0,0 +1,510 @@ +import numpy as np +import pandas as pd +import pytest + +import xarray as xr + +# Don't run cupy in CI because it requires a GPU +NAMESPACE_ARRAYS = { + "cupy": { + "attrs": { + "array": "ndarray", + "constructor": "asarray", + }, + "xfails": {"quantile": "no nanquantile"}, + }, + "dask.array": { + "attrs": { + "array": "Array", + "constructor": "from_array", + }, + "xfails": { + "argsort": "no argsort", + "conjugate": "conj but no conjugate", + "searchsorted": "dask.array.searchsorted but no Array.searchsorted", + }, + }, + "jax.numpy": { + "attrs": { + "array": "ndarray", + "constructor": "asarray", + }, + "xfails": { + "rolling_construct": "no sliding_window_view", + "rolling_reduce": "no sliding_window_view", + "cumulative_construct": "no sliding_window_view", + "cumulative_reduce": "no sliding_window_view", + }, + }, + "pint": { + "attrs": { + "array": "Quantity", + "constructor": "Quantity", + }, + "xfails": { + "all": "returns a bool", + "any": "returns a bool", + "argmax": "returns an int", + "argmin": "returns an int", + "argsort": "returns an int", + "count": "returns an int", + "dot": "no tensordot", + "full_like": "should work, see: https://github.com/hgrecco/pint/pull/1669", + "idxmax": "returns the coordinate", + "idxmin": "returns the coordinate", + "isin": "returns a bool", + "isnull": "returns a bool", + "notnull": "returns a bool", + "rolling_reduce": "no dispatch for numbagg/bottleneck", + "cumulative_reduce": "no dispatch for numbagg/bottleneck", + "searchsorted": "returns an int", + "weighted": "no tensordot", + }, + }, + "sparse": { + "attrs": { + "array": "COO", + "constructor": "COO", + }, + "xfails": { + "cov": "dense output", + "corr": "no nanstd", + "cross": "no cross", + "count": "dense output", + "dot": "fails on some platforms/versions", + "isin": "no isin", + "rolling_construct": "no sliding_window_view", + "rolling_reduce": "no sliding_window_view", + "cumulative_construct": "no sliding_window_view", + "cumulative_reduce": "no sliding_window_view", + "coarsen_construct": "pad constant_values must be fill_value", + "coarsen_reduce": "pad constant_values must be fill_value", + "weighted": "fill_value error", + "coarsen": "pad constant_values must be fill_value", + "quantile": "no non skipping version", + "differentiate": "no gradient", + "argmax": "no nan skipping version", + "argmin": "no nan skipping version", + "idxmax": "no nan skipping version", + "idxmin": "no nan skipping version", + "median": "no nan skipping version", + "std": "no nan skipping version", + "var": "no nan skipping version", + "cumsum": "no cumsum", + "cumprod": "no cumprod", + "argsort": "no argsort", + "conjugate": "no conjugate", + "searchsorted": "no searchsorted", + "shift": "pad constant_values must be fill_value", + "pad": "pad constant_values must be fill_value", + }, + }, +} + + +class _BaseTest: + def setup_for_test(self, request, namespace): + self.namespace = namespace + self.xp = pytest.importorskip(namespace) + self.Array = getattr(self.xp, NAMESPACE_ARRAYS[namespace]["attrs"]["array"]) + self.constructor = getattr( + self.xp, NAMESPACE_ARRAYS[namespace]["attrs"]["constructor"] + ) + xarray_method = request.node.name.split("test_")[1].split("[")[0] + if xarray_method in NAMESPACE_ARRAYS[namespace]["xfails"]: + reason = NAMESPACE_ARRAYS[namespace]["xfails"][xarray_method] + pytest.xfail(f"xfail for {self.namespace}: {reason}") + + def get_test_dataarray(self): + data = np.asarray([[1, 2, 3, np.nan, 5]]) + x = np.arange(5) + data = self.constructor(data) + return xr.DataArray( + data, + dims=["y", "x"], + coords={"y": [1], "x": x}, + name="foo", + ) + + +@pytest.mark.parametrize("namespace", NAMESPACE_ARRAYS) +class TestTopLevelMethods(_BaseTest): + @pytest.fixture(autouse=True) + def setUp(self, request, namespace): + self.setup_for_test(request, namespace) + self.x1 = self.get_test_dataarray() + self.x2 = self.get_test_dataarray().assign_coords(x=np.arange(2, 7)) + + def test_apply_ufunc(self): + func = lambda x: x + 1 + result = xr.apply_ufunc(func, self.x1, dask="parallelized") + assert isinstance(result.data, self.Array) + + def test_align(self): + result = xr.align(self.x1, self.x2) + assert isinstance(result[0].data, self.Array) + assert isinstance(result[1].data, self.Array) + + def test_broadcast(self): + result = xr.broadcast(self.x1, self.x2) + assert isinstance(result[0].data, self.Array) + assert isinstance(result[1].data, self.Array) + + def test_concat(self): + result = xr.concat([self.x1, self.x2], dim="x") + assert isinstance(result.data, self.Array) + + def test_merge(self): + result = xr.merge([self.x1, self.x2], compat="override") + assert isinstance(result.foo.data, self.Array) + + def test_where(self): + x1, x2 = xr.align(self.x1, self.x2, join="inner") + result = xr.where(x1 > 2, x1, x2) + assert isinstance(result.data, self.Array) + + def test_full_like(self): + result = xr.full_like(self.x1, 0) + assert isinstance(result.data, self.Array) + + def test_cov(self): + result = xr.cov(self.x1, self.x2) + assert isinstance(result.data, self.Array) + + def test_corr(self): + result = xr.corr(self.x1, self.x2) + assert isinstance(result.data, self.Array) + + def test_cross(self): + x1, x2 = xr.align(self.x1.squeeze(), self.x2.squeeze(), join="inner") + result = xr.cross(x1, x2, dim="x") + assert isinstance(result.data, self.Array) + + def test_dot(self): + result = xr.dot(self.x1, self.x2) + assert isinstance(result.data, self.Array) + + def test_map_blocks(self): + result = xr.map_blocks(lambda x: x + 1, self.x1) + assert isinstance(result.data, self.Array) + + +@pytest.mark.parametrize("namespace", NAMESPACE_ARRAYS) +class TestDataArrayMethods(_BaseTest): + @pytest.fixture(autouse=True) + def setUp(self, request, namespace): + self.setup_for_test(request, namespace) + self.x = self.get_test_dataarray() + + def test_loc(self): + result = self.x.loc[{"x": slice(1, 3)}] + assert isinstance(result.data, self.Array) + + def test_isel(self): + result = self.x.isel(x=slice(1, 3)) + assert isinstance(result.data, self.Array) + + def test_sel(self): + result = self.x.sel(x=slice(1, 3)) + assert isinstance(result.data, self.Array) + + def test_squeeze(self): + result = self.x.squeeze("y") + assert isinstance(result.data, self.Array) + + @pytest.mark.xfail(reason="interp uses numpy and scipy") + def test_interp(self): + # TODO: some cases could be made to work + result = self.x.interp(x=2.5) + assert isinstance(result.data, self.Array) + + def test_isnull(self): + result = self.x.isnull() + assert isinstance(result.data, self.Array) + + def test_notnull(self): + result = self.x.notnull() + assert isinstance(result.data, self.Array) + + def test_count(self): + result = self.x.count() + assert isinstance(result.data, self.Array) + + def test_dropna(self): + result = self.x.dropna(dim="x") + assert isinstance(result.data, self.Array) + + def test_fillna(self): + result = self.x.fillna(0) + assert isinstance(result.data, self.Array) + + @pytest.mark.xfail(reason="ffill uses bottleneck or numbagg") + def test_ffill(self): + result = self.x.ffill() + assert isinstance(result.data, self.Array) + + @pytest.mark.xfail(reason="bfill uses bottleneck or numbagg") + def test_bfill(self): + result = self.x.bfill() + assert isinstance(result.data, self.Array) + + @pytest.mark.xfail(reason="interpolate_na uses numpy and scipy") + def test_interpolate_na(self): + result = self.x.interpolate_na() + assert isinstance(result.data, self.Array) + + def test_where(self): + result = self.x.where(self.x > 2) + assert isinstance(result.data, self.Array) + + def test_isin(self): + test_elements = self.constructor(np.asarray([1])) + result = self.x.isin(test_elements) + assert isinstance(result.data, self.Array) + + def test_groupby(self): + result = self.x.groupby("x").mean() + assert isinstance(result.data, self.Array) + + def test_groupby_bins(self): + result = self.x.groupby_bins("x", bins=[0, 2, 4, 6]).mean() + assert isinstance(result.data, self.Array) + + def test_rolling_iter(self): + result = self.x.rolling(x=3) + elem = next(iter(result))[1] + assert isinstance(elem.data, self.Array) + + def test_rolling_construct(self): + result = self.x.rolling(x=3).construct(x="window") + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_rolling_reduce(self, skipna): + result = self.x.rolling(x=3).mean(skipna=skipna) + assert isinstance(result.data, self.Array) + + @pytest.mark.xfail(reason="rolling_exp uses numbagg") + def test_rolling_exp_reduce(self): + result = self.x.rolling_exp(x=3).mean() + assert isinstance(result.data, self.Array) + + def test_cumulative_iter(self): + result = self.x.cumulative("x") + elem = next(iter(result))[1] + assert isinstance(elem.data, self.Array) + + def test_cumulative_construct(self): + result = self.x.cumulative("x").construct(x="window") + assert isinstance(result.data, self.Array) + + def test_cumulative_reduce(self): + result = self.x.cumulative("x").sum() + assert isinstance(result.data, self.Array) + + def test_weighted(self): + result = self.x.weighted(self.x.fillna(0)).mean() + assert isinstance(result.data, self.Array) + + def test_coarsen_construct(self): + result = self.x.coarsen(x=2, boundary="pad").construct(x=["a", "b"]) + assert isinstance(result.data, self.Array) + + def test_coarsen_reduce(self): + result = self.x.coarsen(x=2, boundary="pad").mean() + assert isinstance(result.data, self.Array) + + def test_resample(self): + time_coord = pd.date_range("2000-01-01", periods=5) + result = self.x.assign_coords(x=time_coord).resample(x="D").mean() + assert isinstance(result.data, self.Array) + + def test_diff(self): + result = self.x.diff("x") + assert isinstance(result.data, self.Array) + + def test_dot(self): + result = self.x.dot(self.x) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_quantile(self, skipna): + result = self.x.quantile(0.5, skipna=skipna) + assert isinstance(result.data, self.Array) + + def test_differentiate(self): + # edge_order is not implemented in jax, and only supports passing None + edge_order = None if self.namespace == "jax.numpy" else 1 + result = self.x.differentiate("x", edge_order=edge_order) + assert isinstance(result.data, self.Array) + + def test_integrate(self): + result = self.x.integrate("x") + assert isinstance(result.data, self.Array) + + @pytest.mark.xfail(reason="polyfit uses numpy linalg") + def test_polyfit(self): + # TODO: this could work, there are just a lot of different linalg calls + result = self.x.polyfit("x", 1) + assert isinstance(result.polyfit_coefficients.data, self.Array) + + def test_map_blocks(self): + result = self.x.map_blocks(lambda x: x + 1) + assert isinstance(result.data, self.Array) + + def test_all(self): + result = self.x.all(dim="x") + assert isinstance(result.data, self.Array) + + def test_any(self): + result = self.x.any(dim="x") + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_argmax(self, skipna): + result = self.x.argmax(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_argmin(self, skipna): + result = self.x.argmin(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_idxmax(self, skipna): + result = self.x.idxmax(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_idxmin(self, skipna): + result = self.x.idxmin(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_max(self, skipna): + result = self.x.max(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_min(self, skipna): + result = self.x.min(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_mean(self, skipna): + result = self.x.mean(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_median(self, skipna): + result = self.x.median(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_prod(self, skipna): + result = self.x.prod(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_sum(self, skipna): + result = self.x.sum(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_std(self, skipna): + result = self.x.std(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_var(self, skipna): + result = self.x.var(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_cumsum(self, skipna): + result = self.x.cumsum(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_cumprod(self, skipna): + result = self.x.cumprod(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + def test_argsort(self): + result = self.x.argsort() + assert isinstance(result.data, self.Array) + + def test_astype(self): + result = self.x.astype(int) + assert isinstance(result.data, self.Array) + + def test_clip(self): + result = self.x.clip(min=2.0, max=4.0) + assert isinstance(result.data, self.Array) + + def test_conj(self): + result = self.x.conj() + assert isinstance(result.data, self.Array) + + def test_conjugate(self): + result = self.x.conjugate() + assert isinstance(result.data, self.Array) + + def test_imag(self): + result = self.x.imag + assert isinstance(result.data, self.Array) + + def test_searchsorted(self): + v = self.constructor(np.asarray([3])) + result = self.x.squeeze().searchsorted(v) + assert isinstance(result, self.Array) + + def test_round(self): + result = self.x.round() + assert isinstance(result.data, self.Array) + + def test_real(self): + result = self.x.real + assert isinstance(result.data, self.Array) + + def test_T(self): + result = self.x.T + assert isinstance(result.data, self.Array) + + @pytest.mark.xfail(reason="rank uses bottleneck") + def test_rank(self): + # TODO: scipy has rankdata, as does jax, so this can work + result = self.x.rank() + assert isinstance(result.data, self.Array) + + def test_transpose(self): + result = self.x.transpose() + assert isinstance(result.data, self.Array) + + def test_stack(self): + result = self.x.stack(z=("x", "y")) + assert isinstance(result.data, self.Array) + + def test_unstack(self): + result = self.x.stack(z=("x", "y")).unstack("z") + assert isinstance(result.data, self.Array) + + def test_shift(self): + result = self.x.shift(x=1) + assert isinstance(result.data, self.Array) + + def test_roll(self): + result = self.x.roll(x=1) + assert isinstance(result.data, self.Array) + + def test_pad(self): + result = self.x.pad(x=1) + assert isinstance(result.data, self.Array) + + def test_sortby(self): + result = self.x.sortby("x") + assert isinstance(result.data, self.Array) + + def test_broadcast_like(self): + result = self.x.broadcast_like(self.x) + assert isinstance(result.data, self.Array) diff --git a/xarray/tests/test_strategies.py b/xarray/tests/test_strategies.py index 798f5f732d1..48819333ca2 100644 --- a/xarray/tests/test_strategies.py +++ b/xarray/tests/test_strategies.py @@ -13,6 +13,7 @@ from hypothesis import given from hypothesis.extra.array_api import make_strategies_namespace +from xarray.core.options import set_options from xarray.core.variable import Variable from xarray.testing.strategies import ( attrs, @@ -267,14 +268,14 @@ def test_mean(self, data, var): Test that given a Variable of at least one dimension, the mean of the Variable is always equal to the mean of the underlying array. """ + with set_options(use_numbagg=False): + # specify arbitrary reduction along at least one dimension + reduction_dims = data.draw(unique_subset_of(var.dims, min_size=1)) - # specify arbitrary reduction along at least one dimension - reduction_dims = data.draw(unique_subset_of(var.dims, min_size=1)) + # create expected result (using nanmean because arrays with Nans will be generated) + reduction_axes = tuple(var.get_axis_num(dim) for dim in reduction_dims) + expected = np.nanmean(var.data, axis=reduction_axes) - # create expected result (using nanmean because arrays with Nans will be generated) - reduction_axes = tuple(var.get_axis_num(dim) for dim in reduction_dims) - expected = np.nanmean(var.data, axis=reduction_axes) - - # assert property is always satisfied - result = var.mean(dim=reduction_dims).data - npt.assert_equal(expected, result) + # assert property is always satisfied + result = var.mean(dim=reduction_dims).data + npt.assert_equal(expected, result) diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index 9c6f50037d3..1461489e731 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -1978,26 +1978,27 @@ def test_reduce_funcs(self): def test_reduce_keepdims(self): v = Variable(["x", "y"], self.d) - assert_identical( - v.mean(keepdims=True), Variable(v.dims, np.mean(self.d, keepdims=True)) - ) - assert_identical( - v.mean(dim="x", keepdims=True), - Variable(v.dims, np.mean(self.d, axis=0, keepdims=True)), - ) - assert_identical( - v.mean(dim="y", keepdims=True), - Variable(v.dims, np.mean(self.d, axis=1, keepdims=True)), - ) - assert_identical( - v.mean(dim=["y", "x"], keepdims=True), - Variable(v.dims, np.mean(self.d, axis=(1, 0), keepdims=True)), - ) + with set_options(use_numbagg=False): + assert_identical( + v.mean(keepdims=True), Variable(v.dims, np.mean(self.d, keepdims=True)) + ) + assert_identical( + v.mean(dim="x", keepdims=True), + Variable(v.dims, np.mean(self.d, axis=0, keepdims=True)), + ) + assert_identical( + v.mean(dim="y", keepdims=True), + Variable(v.dims, np.mean(self.d, axis=1, keepdims=True)), + ) + assert_identical( + v.mean(dim=["y", "x"], keepdims=True), + Variable(v.dims, np.mean(self.d, axis=(1, 0), keepdims=True)), + ) - v = Variable([], 1.0) - assert_identical( - v.mean(keepdims=True), Variable([], np.mean(v.data, keepdims=True)) - ) + v = Variable([], 1.0) + assert_identical( + v.mean(keepdims=True), Variable([], np.mean(v.data, keepdims=True)) + ) @requires_dask def test_reduce_keepdims_dask(self): From dafcde2d603b1ce9e80e9b02cdfd238e34f21cfa Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 26 Nov 2024 07:45:05 +0100 Subject: [PATCH 03/23] Use compute instead of load in plot (#9818) * Only compute this array, load computes in place * Add test * Update whats-new.rst * Update whats-new.rst --- doc/whats-new.rst | 3 ++- xarray/plot/dataarray_plot.py | 2 +- xarray/tests/test_plot.py | 19 +++++++++++++++++++ 3 files changed, 22 insertions(+), 2 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 906fd0a25b2..4fb23123f4b 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -36,7 +36,8 @@ Deprecations Bug fixes ~~~~~~~~~ - +- Fix unintended load on datasets when calling :py:meth:`DataArray.plot.scatter` (:pull:`9818`). + By `Jimmy Westling `_. Documentation ~~~~~~~~~~~~~ diff --git a/xarray/plot/dataarray_plot.py b/xarray/plot/dataarray_plot.py index c668d78660c..cca9fe4f561 100644 --- a/xarray/plot/dataarray_plot.py +++ b/xarray/plot/dataarray_plot.py @@ -946,7 +946,7 @@ def newplotfunc( # Remove any nulls, .where(m, drop=True) doesn't work when m is # a dask array, so load the array to memory. # It will have to be loaded to memory at some point anyway: - darray = darray.load() + darray = darray.compute() darray = darray.where(darray.notnull(), drop=True) else: size_ = kwargs.pop("_size", linewidth) diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 50254ef4198..1e07459061f 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -33,6 +33,7 @@ assert_no_warnings, requires_cartopy, requires_cftime, + requires_dask, requires_matplotlib, requires_seaborn, ) @@ -3326,6 +3327,24 @@ def test_datarray_scatter( ) +@requires_dask +@requires_matplotlib +@pytest.mark.parametrize( + "plotfunc", + ["scatter"], +) +def test_dataarray_not_loading_inplace(plotfunc: str) -> None: + ds = xr.tutorial.scatter_example_dataset() + ds = ds.chunk() + + with figure_context(): + getattr(ds.A.plot, plotfunc)(x="x") + + from dask.array import Array + + assert isinstance(ds.A.data, Array) + + @requires_matplotlib def test_assert_valid_xy() -> None: ds = xr.tutorial.scatter_example_dataset() From 1317337b6c831023e337c5ca6229778d85d166bb Mon Sep 17 00:00:00 2001 From: Janukan Sivajeyan <28988453+JanukanS@users.noreply.github.com> Date: Tue, 26 Nov 2024 11:58:45 -0700 Subject: [PATCH 04/23] Described default centre argument behaviour in rolling functions (#9819) * Described default centre argument behaviour in rolling functions. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- xarray/core/dataarray.py | 3 ++- xarray/core/dataset.py | 3 ++- xarray/core/rolling.py | 6 ++++-- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index eae11c0c491..f989990bbd4 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -7081,7 +7081,8 @@ def rolling( (otherwise result is NA). The default, None, is equivalent to setting min_periods equal to the size of the window. center : bool or Mapping to int, default: False - Set the labels at the center of the window. + Set the labels at the center of the window. The default, False, + sets the labels at the right edge of the window. **window_kwargs : optional The keyword arguments form of ``dim``. One of dim or window_kwargs must be provided. diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index f96b62f701e..b9f932196ad 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -10728,7 +10728,8 @@ def rolling( (otherwise result is NA). The default, None, is equivalent to setting min_periods equal to the size of the window. center : bool or Mapping to int, default: False - Set the labels at the center of the window. + Set the labels at the center of the window. The default, False, + sets the labels at the right edge of the window. **window_kwargs : optional The keyword arguments form of ``dim``. One of dim or window_kwargs must be provided. diff --git a/xarray/core/rolling.py b/xarray/core/rolling.py index fde87841d32..6186f4dacfe 100644 --- a/xarray/core/rolling.py +++ b/xarray/core/rolling.py @@ -282,7 +282,8 @@ def __init__( (otherwise result is NA). The default, None, is equivalent to setting min_periods equal to the size of the window. center : bool, default: False - Set the labels at the center of the window. + Set the labels at the center of the window. The default, False, + sets the labels at the right edge of the window. Returns ------- @@ -793,7 +794,8 @@ def __init__( (otherwise result is NA). The default, None, is equivalent to setting min_periods equal to the size of the window. center : bool or mapping of hashable to bool, default: False - Set the labels at the center of the window. + Set the labels at the center of the window. The default, False, + sets the labels at the right edge of the window. Returns ------- From e0ec77ea18fe7a9767e11d7c8bd5c1123c08d723 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 26 Nov 2024 12:06:02 -0700 Subject: [PATCH 05/23] Bump codecov/codecov-action from 5.0.2 to 5.0.7 in the actions group (#9820) Bumps the actions group with 1 update: [codecov/codecov-action](https://github.com/codecov/codecov-action). Updates `codecov/codecov-action` from 5.0.2 to 5.0.7 - [Release notes](https://github.com/codecov/codecov-action/releases) - [Changelog](https://github.com/codecov/codecov-action/blob/main/CHANGELOG.md) - [Commits](https://github.com/codecov/codecov-action/compare/v5.0.2...v5.0.7) --- updated-dependencies: - dependency-name: codecov/codecov-action dependency-type: direct:production update-type: version-update:semver-patch dependency-group: actions ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/ci-additional.yaml | 8 ++++---- .github/workflows/ci.yaml | 2 +- .github/workflows/upstream-dev-ci.yaml | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/ci-additional.yaml b/.github/workflows/ci-additional.yaml index 97e764b0882..91c63528741 100644 --- a/.github/workflows/ci-additional.yaml +++ b/.github/workflows/ci-additional.yaml @@ -123,7 +123,7 @@ jobs: python -m mypy --install-types --non-interactive --cobertura-xml-report mypy_report - name: Upload mypy coverage to Codecov - uses: codecov/codecov-action@v5.0.2 + uses: codecov/codecov-action@v5.0.7 with: file: mypy_report/cobertura.xml flags: mypy @@ -174,7 +174,7 @@ jobs: python -m mypy --install-types --non-interactive --cobertura-xml-report mypy_report - name: Upload mypy coverage to Codecov - uses: codecov/codecov-action@v5.0.2 + uses: codecov/codecov-action@v5.0.7 with: file: mypy_report/cobertura.xml flags: mypy-min @@ -230,7 +230,7 @@ jobs: python -m pyright xarray/ - name: Upload pyright coverage to Codecov - uses: codecov/codecov-action@v5.0.2 + uses: codecov/codecov-action@v5.0.7 with: file: pyright_report/cobertura.xml flags: pyright @@ -286,7 +286,7 @@ jobs: python -m pyright xarray/ - name: Upload pyright coverage to Codecov - uses: codecov/codecov-action@v5.0.2 + uses: codecov/codecov-action@v5.0.7 with: file: pyright_report/cobertura.xml flags: pyright39 diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 4a31a13af7d..b0996acf6fe 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -159,7 +159,7 @@ jobs: path: pytest.xml - name: Upload code coverage to Codecov - uses: codecov/codecov-action@v5.0.2 + uses: codecov/codecov-action@v5.0.7 with: file: ./coverage.xml flags: unittests diff --git a/.github/workflows/upstream-dev-ci.yaml b/.github/workflows/upstream-dev-ci.yaml index 55e72bfa065..30047673187 100644 --- a/.github/workflows/upstream-dev-ci.yaml +++ b/.github/workflows/upstream-dev-ci.yaml @@ -140,7 +140,7 @@ jobs: run: | python -m mypy --install-types --non-interactive --cobertura-xml-report mypy_report - name: Upload mypy coverage to Codecov - uses: codecov/codecov-action@v5.0.2 + uses: codecov/codecov-action@v5.0.7 with: file: mypy_report/cobertura.xml flags: mypy From f3a65d58d6cdf578da9594f0352da88d88f4286c Mon Sep 17 00:00:00 2001 From: Bruce Merry <1963944+bmerry@users.noreply.github.com> Date: Tue, 26 Nov 2024 21:16:03 +0200 Subject: [PATCH 06/23] Fix type annotations for `get_axis_num` (GH 9822) (#9827) * Fix type annotations for `get_axis_num` (GH 9822) Explicitly annotate that a single `str` argument leads to an `int` return, overriding the match against Iterable. * Suppress mypy errors related to get_axis_num * Add PR number to whats-new.rst --- doc/whats-new.rst | 2 ++ xarray/core/common.py | 3 +++ xarray/namedarray/core.py | 3 +++ 3 files changed, 8 insertions(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 4fb23123f4b..8bd57339180 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -36,6 +36,8 @@ Deprecations Bug fixes ~~~~~~~~~ +- Fix type annotations for ``get_axis_num``. (:issue:`9822`, :pull:`9827`). + By `Bruce Merry `_. - Fix unintended load on datasets when calling :py:meth:`DataArray.plot.scatter` (:pull:`9818`). By `Jimmy Westling `_. diff --git a/xarray/core/common.py b/xarray/core/common.py index 32135996d3c..3a70c9ec585 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -213,6 +213,9 @@ def __iter__(self: Any) -> Iterator[Any]: raise TypeError("iteration over a 0-d array") return self._iter() + @overload + def get_axis_num(self, dim: str) -> int: ... # type: ignore [overload-overlap] + @overload def get_axis_num(self, dim: Iterable[Hashable]) -> tuple[int, ...]: ... diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 98d96c73e91..683a1266472 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -669,6 +669,9 @@ def _dask_finalize( data = array_func(results, *args, **kwargs) return type(self)(self._dims, data, attrs=self._attrs) + @overload + def get_axis_num(self, dim: str) -> int: ... # type: ignore [overload-overlap] + @overload def get_axis_num(self, dim: Iterable[Hashable]) -> tuple[int, ...]: ... From 0b97969eba1a7dbc3017848359397e14036cdf53 Mon Sep 17 00:00:00 2001 From: Bruce Merry <1963944+bmerry@users.noreply.github.com> Date: Wed, 27 Nov 2024 11:52:53 +0200 Subject: [PATCH 07/23] Test type annotations for Variable.get_axis_num (#9832) This is a regression test for GH9822. --- xarray/tests/test_variable.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index 1461489e731..7dc5ef0db94 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -1604,7 +1604,7 @@ def test_squeeze(self): with pytest.raises(ValueError, match=r"cannot select a dimension"): v.squeeze("y") - def test_get_axis_num(self): + def test_get_axis_num(self) -> None: v = Variable(["x", "y", "z"], np.random.randn(2, 3, 4)) assert v.get_axis_num("x") == 0 assert v.get_axis_num(["x"]) == (0,) @@ -1612,6 +1612,11 @@ def test_get_axis_num(self): assert v.get_axis_num(["z", "y", "x"]) == (2, 1, 0) with pytest.raises(ValueError, match=r"not found in array dim"): v.get_axis_num("foobar") + # Test the type annotations: mypy will complain if the inferred + # type is wrong + v.get_axis_num("x") + 0 + v.get_axis_num(["x"]) + () + v.get_axis_num(("x", "y")) + () def test_set_dims(self): v = Variable(["x"], [0, 1]) From 7fd572d374df45b863c54e380323d898d060db5a Mon Sep 17 00:00:00 2001 From: Scott Huberty <52462026+scott-huberty@users.noreply.github.com> Date: Wed, 27 Nov 2024 16:08:58 -0800 Subject: [PATCH 08/23] FIX: gracfully handle missing seaborn dependency (#9835) xref: https://github.com/pydata/xarray/pull/9561#discussion_r1861273284 Fixes regression introduced in #9561 this specific seaborn import should be handled gracefully, as it was before ##9561 --- xarray/plot/utils.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index c1b5b29c7bf..c1f0cc11f54 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -143,14 +143,11 @@ def _color_palette(cmap, n_colors): pal = cmap(colors_i) except ValueError: # ValueError happens when mpl doesn't like a colormap, try seaborn - if TYPE_CHECKING: - import seaborn as sns - else: - sns = attempt_import("seaborn") - try: - pal = sns.color_palette(cmap, n_colors=n_colors) - except ValueError: + from seaborn import color_palette + + pal = color_palette(cmap, n_colors=n_colors) + except (ValueError, ImportError): # or maybe we just got a single color as a string cmap = ListedColormap([cmap] * n_colors) pal = cmap(colors_i) @@ -192,7 +189,10 @@ def _determine_cmap_params( cmap_params : dict Use depends on the type of the plotting function """ - import matplotlib as mpl + if TYPE_CHECKING: + import matplotlib as mpl + else: + mpl = attempt_import("matplotlib") if isinstance(levels, Iterable): levels = sorted(levels) From 14a544c93839b1ea4ebf882f76985375be879fb5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kai=20M=C3=BChlbauer?= Date: Mon, 2 Dec 2024 13:06:22 +0100 Subject: [PATCH 09/23] move ensure_dtype_not_object from conventions to backends (#9828) * move ensure_dtype_not_object from conventions to backends * add whats-new.rst entry --- doc/whats-new.rst | 4 +- xarray/backends/common.py | 111 ++++++++++++++++++++++++++- xarray/backends/zarr.py | 2 + xarray/conventions.py | 100 ------------------------ xarray/tests/test_backends.py | 16 ++++ xarray/tests/test_backends_common.py | 15 +++- xarray/tests/test_conventions.py | 19 ----- 7 files changed, 142 insertions(+), 125 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 8bd57339180..9a8154d3791 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -47,8 +47,8 @@ Documentation Internal Changes ~~~~~~~~~~~~~~~~ - - +- Move non-CF related ``ensure_dtype_not_object`` from conventions to backends (:pull:`9828`). + By `Kai Mühlbauer `_. .. _whats-new.2024.11.0: diff --git a/xarray/backends/common.py b/xarray/backends/common.py index 3756de90b60..58a98598a5b 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -4,29 +4,36 @@ import os import time import traceback -from collections.abc import Iterable, Mapping, Sequence +from collections.abc import Hashable, Iterable, Mapping, Sequence from glob import glob -from typing import TYPE_CHECKING, Any, ClassVar, TypeVar, overload +from typing import TYPE_CHECKING, Any, ClassVar, TypeVar, Union, overload import numpy as np +import pandas as pd +from xarray.coding import strings, variables +from xarray.coding.variables import SerializationWarning from xarray.conventions import cf_encoder from xarray.core import indexing -from xarray.core.datatree import DataTree +from xarray.core.datatree import DataTree, Variable from xarray.core.types import ReadBuffer from xarray.core.utils import ( FrozenDict, NdimSizeLenMixin, attempt_import, + emit_user_level_warning, is_remote_uri, ) from xarray.namedarray.parallelcompat import get_chunked_array_type from xarray.namedarray.pycompat import is_chunked_array +from xarray.namedarray.utils import is_duck_dask_array if TYPE_CHECKING: from xarray.core.dataset import Dataset from xarray.core.types import NestedSequence + T_Name = Union[Hashable, None] + # Create a logger object, but don't add any handlers. Leave that to user code. logger = logging.getLogger(__name__) @@ -527,6 +534,101 @@ def set_dimensions(self, variables, unlimited_dims=None): self.set_dimension(dim, length, is_unlimited) +def _infer_dtype(array, name=None): + """Given an object array with no missing values, infer its dtype from all elements.""" + if array.dtype.kind != "O": + raise TypeError("infer_type must be called on a dtype=object array") + + if array.size == 0: + return np.dtype(float) + + native_dtypes = set(np.vectorize(type, otypes=[object])(array.ravel())) + if len(native_dtypes) > 1 and native_dtypes != {bytes, str}: + raise ValueError( + "unable to infer dtype on variable {!r}; object array " + "contains mixed native types: {}".format( + name, ", ".join(x.__name__ for x in native_dtypes) + ) + ) + + element = array[(0,) * array.ndim] + # We use the base types to avoid subclasses of bytes and str (which might + # not play nice with e.g. hdf5 datatypes), such as those from numpy + if isinstance(element, bytes): + return strings.create_vlen_dtype(bytes) + elif isinstance(element, str): + return strings.create_vlen_dtype(str) + + dtype = np.array(element).dtype + if dtype.kind != "O": + return dtype + + raise ValueError( + f"unable to infer dtype on variable {name!r}; xarray " + "cannot serialize arbitrary Python objects" + ) + + +def _copy_with_dtype(data, dtype: np.typing.DTypeLike): + """Create a copy of an array with the given dtype. + + We use this instead of np.array() to ensure that custom object dtypes end + up on the resulting array. + """ + result = np.empty(data.shape, dtype) + result[...] = data + return result + + +def ensure_dtype_not_object(var: Variable, name: T_Name = None) -> Variable: + if var.dtype.kind == "O": + dims, data, attrs, encoding = variables.unpack_for_encoding(var) + + # leave vlen dtypes unchanged + if strings.check_vlen_dtype(data.dtype) is not None: + return var + + if is_duck_dask_array(data): + emit_user_level_warning( + f"variable {name} has data in the form of a dask array with " + "dtype=object, which means it is being loaded into memory " + "to determine a data type that can be safely stored on disk. " + "To avoid this, coerce this variable to a fixed-size dtype " + "with astype() before saving it.", + category=SerializationWarning, + ) + data = data.compute() + + missing = pd.isnull(data) + if missing.any(): + # nb. this will fail for dask.array data + non_missing_values = data[~missing] + inferred_dtype = _infer_dtype(non_missing_values, name) + + # There is no safe bit-pattern for NA in typical binary string + # formats, we so can't set a fill_value. Unfortunately, this means + # we can't distinguish between missing values and empty strings. + fill_value: bytes | str + if strings.is_bytes_dtype(inferred_dtype): + fill_value = b"" + elif strings.is_unicode_dtype(inferred_dtype): + fill_value = "" + else: + # insist on using float for numeric values + if not np.issubdtype(inferred_dtype, np.floating): + inferred_dtype = np.dtype(float) + fill_value = inferred_dtype.type(np.nan) + + data = _copy_with_dtype(data, dtype=inferred_dtype) + data[missing] = fill_value + else: + data = _copy_with_dtype(data, dtype=_infer_dtype(data, name)) + + assert data.dtype.kind != "O" or data.dtype.metadata + var = Variable(dims, data, attrs, encoding, fastpath=True) + return var + + class WritableCFDataStore(AbstractWritableDataStore): __slots__ = () @@ -534,6 +636,9 @@ def encode(self, variables, attributes): # All NetCDF files get CF encoded by default, without this attempting # to write times, for example, would fail. variables, attributes = cf_encoder(variables, attributes) + variables = { + k: ensure_dtype_not_object(v, name=k) for k, v in variables.items() + } variables = {k: self.encode_variable(v) for k, v in variables.items()} attributes = {k: self.encode_attribute(v) for k, v in attributes.items()} return variables, attributes diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index 1acc0a502e6..fda99b131d8 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -19,6 +19,7 @@ _encode_variable_name, _normalize_path, datatree_from_dict_with_io_cleanup, + ensure_dtype_not_object, ) from xarray.backends.store import StoreBackendEntrypoint from xarray.core import indexing @@ -507,6 +508,7 @@ def encode_zarr_variable(var, needs_copy=True, name=None): """ var = conventions.encode_cf_variable(var, name=name) + var = ensure_dtype_not_object(var, name=name) # zarr allows unicode, but not variable-length strings, so it's both # simpler and more compact to always encode as UTF-8 explicitly. diff --git a/xarray/conventions.py b/xarray/conventions.py index 5b57c160850..57407a15f51 100644 --- a/xarray/conventions.py +++ b/xarray/conventions.py @@ -6,7 +6,6 @@ from typing import TYPE_CHECKING, Any, Literal, TypeVar, Union import numpy as np -import pandas as pd from xarray.coding import strings, times, variables from xarray.coding.variables import SerializationWarning, pop_to @@ -50,41 +49,6 @@ T_DatasetOrAbstractstore = Union[Dataset, AbstractDataStore] -def _infer_dtype(array, name=None): - """Given an object array with no missing values, infer its dtype from all elements.""" - if array.dtype.kind != "O": - raise TypeError("infer_type must be called on a dtype=object array") - - if array.size == 0: - return np.dtype(float) - - native_dtypes = set(np.vectorize(type, otypes=[object])(array.ravel())) - if len(native_dtypes) > 1 and native_dtypes != {bytes, str}: - raise ValueError( - "unable to infer dtype on variable {!r}; object array " - "contains mixed native types: {}".format( - name, ", ".join(x.__name__ for x in native_dtypes) - ) - ) - - element = array[(0,) * array.ndim] - # We use the base types to avoid subclasses of bytes and str (which might - # not play nice with e.g. hdf5 datatypes), such as those from numpy - if isinstance(element, bytes): - return strings.create_vlen_dtype(bytes) - elif isinstance(element, str): - return strings.create_vlen_dtype(str) - - dtype = np.array(element).dtype - if dtype.kind != "O": - return dtype - - raise ValueError( - f"unable to infer dtype on variable {name!r}; xarray " - "cannot serialize arbitrary Python objects" - ) - - def ensure_not_multiindex(var: Variable, name: T_Name = None) -> None: # only the pandas multi-index dimension coordinate cannot be serialized (tuple values) if isinstance(var._data, indexing.PandasMultiIndexingAdapter): @@ -99,67 +63,6 @@ def ensure_not_multiindex(var: Variable, name: T_Name = None) -> None: ) -def _copy_with_dtype(data, dtype: np.typing.DTypeLike): - """Create a copy of an array with the given dtype. - - We use this instead of np.array() to ensure that custom object dtypes end - up on the resulting array. - """ - result = np.empty(data.shape, dtype) - result[...] = data - return result - - -def ensure_dtype_not_object(var: Variable, name: T_Name = None) -> Variable: - # TODO: move this from conventions to backends? (it's not CF related) - if var.dtype.kind == "O": - dims, data, attrs, encoding = variables.unpack_for_encoding(var) - - # leave vlen dtypes unchanged - if strings.check_vlen_dtype(data.dtype) is not None: - return var - - if is_duck_dask_array(data): - emit_user_level_warning( - f"variable {name} has data in the form of a dask array with " - "dtype=object, which means it is being loaded into memory " - "to determine a data type that can be safely stored on disk. " - "To avoid this, coerce this variable to a fixed-size dtype " - "with astype() before saving it.", - category=SerializationWarning, - ) - data = data.compute() - - missing = pd.isnull(data) - if missing.any(): - # nb. this will fail for dask.array data - non_missing_values = data[~missing] - inferred_dtype = _infer_dtype(non_missing_values, name) - - # There is no safe bit-pattern for NA in typical binary string - # formats, we so can't set a fill_value. Unfortunately, this means - # we can't distinguish between missing values and empty strings. - fill_value: bytes | str - if strings.is_bytes_dtype(inferred_dtype): - fill_value = b"" - elif strings.is_unicode_dtype(inferred_dtype): - fill_value = "" - else: - # insist on using float for numeric values - if not np.issubdtype(inferred_dtype, np.floating): - inferred_dtype = np.dtype(float) - fill_value = inferred_dtype.type(np.nan) - - data = _copy_with_dtype(data, dtype=inferred_dtype) - data[missing] = fill_value - else: - data = _copy_with_dtype(data, dtype=_infer_dtype(data, name)) - - assert data.dtype.kind != "O" or data.dtype.metadata - var = Variable(dims, data, attrs, encoding, fastpath=True) - return var - - def encode_cf_variable( var: Variable, needs_copy: bool = True, name: T_Name = None ) -> Variable: @@ -196,9 +99,6 @@ def encode_cf_variable( ]: var = coder.encode(var, name=name) - # TODO(kmuehlbauer): check if ensure_dtype_not_object can be moved to backends: - var = ensure_dtype_not_object(var, name=name) - for attr_name in CF_RELATED_DATA: pop_to(var.encoding, var.attrs, attr_name) return var diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 8cb26f8482c..7ea9239fb80 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -1400,6 +1400,22 @@ def test_multiindex_not_implemented(self) -> None: with self.roundtrip(ds_reset) as actual: assert_identical(actual, ds_reset) + @requires_dask + def test_string_object_warning(self) -> None: + original = Dataset( + { + "x": ( + [ + "y", + ], + np.array(["foo", "bar"], dtype=object), + ) + } + ).chunk() + with pytest.warns(SerializationWarning, match="dask array with dtype=object"): + with self.roundtrip(original) as actual: + assert_identical(original, actual) + class NetCDFBase(CFEncodedBase): """Tests for all netCDF3 and netCDF4 backends.""" diff --git a/xarray/tests/test_backends_common.py b/xarray/tests/test_backends_common.py index c7dba36ea58..dc89ecefbfe 100644 --- a/xarray/tests/test_backends_common.py +++ b/xarray/tests/test_backends_common.py @@ -1,8 +1,9 @@ from __future__ import annotations +import numpy as np import pytest -from xarray.backends.common import robust_getitem +from xarray.backends.common import _infer_dtype, robust_getitem class DummyFailure(Exception): @@ -30,3 +31,15 @@ def test_robust_getitem() -> None: array = DummyArray(failures=3) with pytest.raises(DummyFailure): robust_getitem(array, ..., catch=DummyFailure, initial_delay=1, max_retries=2) + + +@pytest.mark.parametrize( + "data", + [ + np.array([["ab", "cdef", b"X"], [1, 2, "c"]], dtype=object), + np.array([["x", 1], ["y", 2]], dtype="object"), + ], +) +def test_infer_dtype_error_on_mixed_types(data): + with pytest.raises(ValueError, match="unable to infer dtype on variable"): + _infer_dtype(data, "test") diff --git a/xarray/tests/test_conventions.py b/xarray/tests/test_conventions.py index 39950b4f9b8..495d760c534 100644 --- a/xarray/tests/test_conventions.py +++ b/xarray/tests/test_conventions.py @@ -249,13 +249,6 @@ def test_emit_coordinates_attribute_in_encoding(self) -> None: assert enc["b"].attrs.get("coordinates") == "t" assert "coordinates" not in enc["b"].encoding - @requires_dask - def test_string_object_warning(self) -> None: - original = Variable(("x",), np.array(["foo", "bar"], dtype=object)).chunk() - with pytest.warns(SerializationWarning, match="dask array with dtype=object"): - encoded = conventions.encode_cf_variable(original) - assert_identical(original, encoded) - @requires_cftime class TestDecodeCF: @@ -593,18 +586,6 @@ def test_encoding_kwarg_fixed_width_string(self) -> None: pass -@pytest.mark.parametrize( - "data", - [ - np.array([["ab", "cdef", b"X"], [1, 2, "c"]], dtype=object), - np.array([["x", 1], ["y", 2]], dtype="object"), - ], -) -def test_infer_dtype_error_on_mixed_types(data): - with pytest.raises(ValueError, match="unable to infer dtype on variable"): - conventions._infer_dtype(data, "test") - - class TestDecodeCFVariableWithArrayUnits: def test_decode_cf_variable_with_array_units(self) -> None: v = Variable(["t"], [1, 2, 3], {"units": np.array(["foobar"], dtype=object)}) From cc62e82b1fea8839ffab6ba7403aeab7f00f4e0e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 2 Dec 2024 11:36:25 -0800 Subject: [PATCH 10/23] Update pre-commit hooks (#9849) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/astral-sh/ruff-pre-commit: v0.8.0 → v0.8.1](https://github.com/astral-sh/ruff-pre-commit/compare/v0.8.0...v0.8.1) - [github.com/rbubley/mirrors-prettier: v3.3.3 → v3.4.1](https://github.com/rbubley/mirrors-prettier/compare/v3.3.3...v3.4.1) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d543a36edd3..7e22cf8ab37 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -25,7 +25,7 @@ repos: - id: text-unicode-replacement-char - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.8.0 + rev: v0.8.1 hooks: - id: ruff-format - id: ruff @@ -37,7 +37,7 @@ repos: exclude: "generate_aggregations.py" additional_dependencies: ["black==24.8.0"] - repo: https://github.com/rbubley/mirrors-prettier - rev: v3.3.3 + rev: v3.4.1 hooks: - id: prettier args: [--cache-location=.prettier_cache/cache] From 05f24f75d370555842d8597d6f6d8a769c156067 Mon Sep 17 00:00:00 2001 From: Janukan Sivajeyan <28988453+JanukanS@users.noreply.github.com> Date: Tue, 3 Dec 2024 03:18:21 -0700 Subject: [PATCH 11/23] Fixed function links in dataarray.py and dataset.py (#9850) --- xarray/core/dataarray.py | 33 +++++++++++++++++---------------- xarray/core/dataset.py | 33 +++++++++++++++++---------------- 2 files changed, 34 insertions(+), 32 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index f989990bbd4..5150998aebb 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -1504,8 +1504,8 @@ def isel( See Also -------- - Dataset.isel - DataArray.sel + :func:`Dataset.isel ` + :func:`DataArray.sel ` :doc:`xarray-tutorial:intermediate/indexing/indexing` Tutorial material on indexing with Xarray objects @@ -1642,8 +1642,8 @@ def sel( See Also -------- - Dataset.sel - DataArray.isel + :func:`Dataset.sel ` + :func:`DataArray.isel ` :doc:`xarray-tutorial:intermediate/indexing/indexing` Tutorial material on indexing with Xarray objects @@ -5612,8 +5612,9 @@ def map_blocks( See Also -------- - dask.array.map_blocks, xarray.apply_ufunc, xarray.Dataset.map_blocks - xarray.DataArray.map_blocks + :func:`dask.array.map_blocks ` + :func:`xarray.apply_ufunc ` + :func:`xarray.Dataset.map_blocks ` :doc:`xarray-tutorial:advanced/map_blocks/map_blocks` Advanced Tutorial on map_blocks with dask @@ -6905,13 +6906,13 @@ def groupby( :doc:`xarray-tutorial:fundamentals/03.2_groupby_with_xarray` Tutorial on :py:func:`~xarray.DataArray.Groupby` demonstrating reductions, transformation and comparison with :py:func:`~xarray.DataArray.resample` - DataArray.groupby_bins - Dataset.groupby - core.groupby.DataArrayGroupBy - DataArray.coarsen - pandas.DataFrame.groupby - Dataset.resample - DataArray.resample + :external:py:meth:`pandas.DataFrame.groupby ` + :func:`DataArray.groupby_bins ` + :func:`Dataset.groupby ` + :func:`core.groupby.DataArrayGroupBy ` + :func:`DataArray.coarsen ` + :func:`Dataset.resample ` + :func:`DataArray.resample ` """ from xarray.core.groupby import ( DataArrayGroupBy, @@ -7048,7 +7049,7 @@ def weighted(self, weights: DataArray) -> DataArrayWeighted: See Also -------- - Dataset.weighted + :func:`Dataset.weighted ` :ref:`comput.weighted` User guide on weighted array reduction using :py:func:`~xarray.DataArray.weighted` @@ -7332,8 +7333,8 @@ def coarsen( See Also -------- - core.rolling.DataArrayCoarsen - Dataset.coarsen + :class:`core.rolling.DataArrayCoarsen ` + :func:`Dataset.coarsen ` :ref:`reshape.coarsen` User guide describing :py:func:`~xarray.DataArray.coarsen` diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index b9f932196ad..2c1f5cfd4ac 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -3082,8 +3082,8 @@ def isel( See Also -------- - Dataset.sel - DataArray.isel + :func:`Dataset.sel ` + :func:`DataArray.isel ` :doc:`xarray-tutorial:intermediate/indexing/indexing` Tutorial material on indexing with Xarray objects @@ -3236,8 +3236,8 @@ def sel( See Also -------- - Dataset.isel - DataArray.sel + :func:`Dataset.isel ` + :func:`DataArray.sel ` :doc:`xarray-tutorial:intermediate/indexing/indexing` Tutorial material on indexing with Xarray objects @@ -9020,8 +9020,9 @@ def map_blocks( See Also -------- - dask.array.map_blocks, xarray.apply_ufunc, xarray.Dataset.map_blocks - xarray.DataArray.map_blocks + :func:`dask.array.map_blocks ` + :func:`xarray.apply_ufunc ` + :func:`xarray.DataArray.map_blocks ` :doc:`xarray-tutorial:advanced/map_blocks/map_blocks` Advanced Tutorial on map_blocks with dask @@ -10551,13 +10552,13 @@ def groupby( :doc:`xarray-tutorial:fundamentals/03.2_groupby_with_xarray` Tutorial on :py:func:`~xarray.Dataset.Groupby` demonstrating reductions, transformation and comparison with :py:func:`~xarray.Dataset.resample`. - Dataset.groupby_bins - DataArray.groupby - core.groupby.DatasetGroupBy - pandas.DataFrame.groupby - Dataset.coarsen - Dataset.resample - DataArray.resample + :external:py:meth:`pandas.DataFrame.groupby ` + :func:`Dataset.groupby_bins ` + :func:`DataArray.groupby ` + :class:`core.groupby.DatasetGroupBy` + :func:`Dataset.coarsen ` + :func:`Dataset.resample ` + :func:`DataArray.resample ` """ from xarray.core.groupby import ( DatasetGroupBy, @@ -10695,7 +10696,7 @@ def weighted(self, weights: DataArray) -> DatasetWeighted: See Also -------- - DataArray.weighted + :func:`DataArray.weighted ` :ref:`comput.weighted` User guide on weighted array reduction using :py:func:`~xarray.Dataset.weighted` @@ -10824,8 +10825,8 @@ def coarsen( See Also -------- - core.rolling.DatasetCoarsen - DataArray.coarsen + :class:`core.rolling.DatasetCoarsen` + :func:`DataArray.coarsen ` :ref:`reshape.coarsen` User guide describing :py:func:`~xarray.Dataset.coarsen` From 7445012d7255c071d39762ad4f723fd2511ea692 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 3 Dec 2024 08:41:40 -0700 Subject: [PATCH 12/23] Fix seed for random test data. (#9844) * Fix seed for random test data. Also switch to using default_rng instead of RandomState. * Fixes * one more fix. * more fixes * last one? * one more --- asv_bench/benchmarks/__init__.py | 6 +-- asv_bench/benchmarks/reindexing.py | 2 +- asv_bench/benchmarks/unstacking.py | 2 +- doc/user-guide/computation.rst | 3 +- doc/user-guide/dask.rst | 2 +- doc/user-guide/pandas.rst | 2 +- xarray/tests/__init__.py | 8 ++-- xarray/tests/test_backends.py | 9 +++-- xarray/tests/test_computation.py | 6 +-- xarray/tests/test_dask.py | 4 +- xarray/tests/test_dataarray.py | 8 ++-- xarray/tests/test_dataset.py | 56 ++++++++++++++-------------- xarray/tests/test_duck_array_ops.py | 8 ++-- xarray/tests/test_formatting_html.py | 2 +- xarray/tests/test_groupby.py | 2 +- xarray/tests/test_missing.py | 2 +- 16 files changed, 63 insertions(+), 59 deletions(-) diff --git a/asv_bench/benchmarks/__init__.py b/asv_bench/benchmarks/__init__.py index 697fcb58494..4a9613ce026 100644 --- a/asv_bench/benchmarks/__init__.py +++ b/asv_bench/benchmarks/__init__.py @@ -30,13 +30,13 @@ def requires_sparse(): def randn(shape, frac_nan=None, chunks=None, seed=0): - rng = np.random.RandomState(seed) + rng = np.random.default_rng(seed) if chunks is None: x = rng.standard_normal(shape) else: import dask.array as da - rng = da.random.RandomState(seed) + rng = da.random.default_rng(seed) x = rng.standard_normal(shape, chunks=chunks) if frac_nan is not None: @@ -47,7 +47,7 @@ def randn(shape, frac_nan=None, chunks=None, seed=0): def randint(low, high=None, size=None, frac_minus=None, seed=0): - rng = np.random.RandomState(seed) + rng = np.random.default_rng(seed) x = rng.randint(low, high, size) if frac_minus is not None: inds = rng.choice(range(x.size), int(x.size * frac_minus)) diff --git a/asv_bench/benchmarks/reindexing.py b/asv_bench/benchmarks/reindexing.py index 9d0767fc3b3..61e6b2213f3 100644 --- a/asv_bench/benchmarks/reindexing.py +++ b/asv_bench/benchmarks/reindexing.py @@ -11,7 +11,7 @@ class Reindex: def setup(self): - data = np.random.RandomState(0).randn(ntime, nx, ny) + data = np.random.default_rng(0).random((ntime, nx, ny)) self.ds = xr.Dataset( {"temperature": (("time", "x", "y"), data)}, coords={"time": np.arange(ntime), "x": np.arange(nx), "y": np.arange(ny)}, diff --git a/asv_bench/benchmarks/unstacking.py b/asv_bench/benchmarks/unstacking.py index dc8bc3307c3..b3af5eac19c 100644 --- a/asv_bench/benchmarks/unstacking.py +++ b/asv_bench/benchmarks/unstacking.py @@ -8,7 +8,7 @@ class Unstacking: def setup(self): - data = np.random.RandomState(0).randn(250, 500) + data = np.random.default_rng(0).random((250, 500)) self.da_full = xr.DataArray(data, dims=list("ab")).stack(flat_dim=[...]) self.da_missing = self.da_full[:-1] self.df_missing = self.da_missing.to_pandas() diff --git a/doc/user-guide/computation.rst b/doc/user-guide/computation.rst index ff12902cf56..5d7002484c2 100644 --- a/doc/user-guide/computation.rst +++ b/doc/user-guide/computation.rst @@ -30,7 +30,8 @@ numpy) over all array values: .. ipython:: python arr = xr.DataArray( - np.random.RandomState(0).randn(2, 3), [("x", ["a", "b"]), ("y", [10, 20, 30])] + np.random.default_rng(0).random((2, 3)), + [("x", ["a", "b"]), ("y", [10, 20, 30])], ) arr - 3 abs(arr) diff --git a/doc/user-guide/dask.rst b/doc/user-guide/dask.rst index 3ad84133d0b..cadb7962f1c 100644 --- a/doc/user-guide/dask.rst +++ b/doc/user-guide/dask.rst @@ -292,7 +292,7 @@ work as a streaming operation, when run on arrays loaded from disk: .. ipython:: :verbatim: - In [56]: rs = np.random.RandomState(0) + In [56]: rs = np.random.default_rng(0) In [57]: array1 = xr.DataArray(rs.randn(1000, 100000), dims=["place", "time"]) # 800MB diff --git a/doc/user-guide/pandas.rst b/doc/user-guide/pandas.rst index 5fe5e15fa63..3d4a6b42bc8 100644 --- a/doc/user-guide/pandas.rst +++ b/doc/user-guide/pandas.rst @@ -202,7 +202,7 @@ Let's take a look: .. ipython:: python - data = np.random.RandomState(0).rand(2, 3, 4) + data = np.random.default_rng(0).rand(2, 3, 4) items = list("ab") major_axis = list("mno") minor_axis = pd.date_range(start="2000", periods=4, name="date") diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 3aafbfcb0c1..1f2eedcd8f0 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -298,12 +298,12 @@ def assert_allclose(a, b, check_default_indexes=True, **kwargs): def create_test_data( - seed: int | None = None, + seed: int = 12345, add_attrs: bool = True, dim_sizes: tuple[int, int, int] = _DEFAULT_TEST_DIM_SIZES, use_extension_array: bool = False, ) -> Dataset: - rs = np.random.RandomState(seed) + rs = np.random.default_rng(seed) _vars = { "var1": ["dim1", "dim2"], "var2": ["dim1", "dim2"], @@ -329,7 +329,7 @@ def create_test_data( "dim1", pd.Categorical( rs.choice( - list(string.ascii_lowercase[: rs.randint(1, 5)]), + list(string.ascii_lowercase[: rs.integers(1, 5)]), size=dim_sizes[0], ) ), @@ -337,7 +337,7 @@ def create_test_data( if dim_sizes == _DEFAULT_TEST_DIM_SIZES: numbers_values = np.array([0, 1, 2, 0, 0, 1, 1, 2, 2, 3], dtype="int64") else: - numbers_values = rs.randint(0, 3, _dims["dim3"], dtype="int64") + numbers_values = rs.integers(0, 3, _dims["dim3"], dtype="int64") obj.coords["numbers"] = ("dim3", numbers_values) obj.encoding = {"foo": "bar"} assert_writeable(obj) diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 7ea9239fb80..ff254225321 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -2868,8 +2868,11 @@ def test_append_with_new_variable(self) -> None: # check append mode for new variable with self.create_zarr_target() as store_target: - xr.concat([ds, ds_to_append], dim="time").to_zarr( - store_target, mode="w", **self.version_kwargs + combined = xr.concat([ds, ds_to_append], dim="time") + combined.to_zarr(store_target, mode="w", **self.version_kwargs) + assert_identical( + combined, + xr.open_dataset(store_target, engine="zarr", **self.version_kwargs), ) ds_with_new_var.to_zarr(store_target, mode="a", **self.version_kwargs) combined = xr.concat([ds, ds_to_append], dim="time") @@ -6494,7 +6497,7 @@ def test_zarr_safe_chunk_region(tmp_path): arr.isel(a=slice(5, -1)).chunk(a=5).to_zarr(store, region="auto", mode="r+") # Test if the code is detecting the last chunk correctly - data = np.random.RandomState(0).randn(2920, 25, 53) + data = np.random.default_rng(0).random((2920, 25, 53)) ds = xr.Dataset({"temperature": (("time", "lat", "lon"), data)}) chunks = {"time": 1000, "lat": 25, "lon": 53} ds.chunk(chunks).to_zarr(store, compute=False, mode="w") diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index 4610aa62f64..fd9f6ef41ea 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -1293,9 +1293,9 @@ def covariance(x, y): (x - x.mean(axis=-1, keepdims=True)) * (y - y.mean(axis=-1, keepdims=True)) ).mean(axis=-1) - rs = np.random.RandomState(42) - array1 = da.from_array(rs.randn(4, 4), chunks=(2, 4)) - array2 = da.from_array(rs.randn(4, 4), chunks=(2, 4)) + rs = np.random.default_rng(42) + array1 = da.from_array(rs.random((4, 4)), chunks=(2, 4)) + array2 = da.from_array(rs.random((4, 4)), chunks=(2, 4)) data_array_1 = xr.DataArray(array1, dims=("x", "z")) data_array_2 = xr.DataArray(array2, dims=("y", "z")) diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 54ae80a1d9d..e3e12599926 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -38,7 +38,7 @@ def test_raise_if_dask_computes(): - data = da.from_array(np.random.RandomState(0).randn(4, 6), chunks=(2, 2)) + data = da.from_array(np.random.default_rng(0).random((4, 6)), chunks=(2, 2)) with pytest.raises(RuntimeError, match=r"Too many computes"): with raise_if_dask_computes(): data.compute() @@ -77,7 +77,7 @@ def assertLazyAndAllClose(self, expected, actual): @pytest.fixture(autouse=True) def setUp(self): - self.values = np.random.RandomState(0).randn(4, 6) + self.values = np.random.default_rng(0).random((4, 6)) self.data = da.from_array(self.values, chunks=(2, 2)) self.eager_var = Variable(("x", "y"), self.values) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index c8b438948de..ea5186e59d0 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -2284,7 +2284,7 @@ class NdArraySubclass(np.ndarray): assert isinstance(converted_subok.data, NdArraySubclass) def test_is_null(self) -> None: - x = np.random.RandomState(42).randn(5, 6) + x = np.random.default_rng(42).random((5, 6)) x[x < 0] = np.nan original = DataArray(x, [-np.arange(5), np.arange(6)], ["x", "y"]) expected = DataArray(pd.isnull(x), [-np.arange(5), np.arange(6)], ["x", "y"]) @@ -3528,7 +3528,7 @@ def test_from_multiindex_series_sparse(self) -> None: idx = pd.MultiIndex.from_product([np.arange(3), np.arange(5)], names=["a", "b"]) series: pd.Series = pd.Series( - np.random.RandomState(0).random(len(idx)), index=idx + np.random.default_rng(0).random(len(idx)), index=idx ).sample(n=5, random_state=3) dense = DataArray.from_series(series, sparse=False) @@ -3703,8 +3703,8 @@ def test_to_dict_with_numpy_attrs(self) -> None: assert expected_attrs == actual["attrs"] def test_to_masked_array(self) -> None: - rs = np.random.RandomState(44) - x = rs.random_sample(size=(10, 20)) + rs = np.random.default_rng(44) + x = rs.random(size=(10, 20)) x_masked = np.ma.masked_where(x < 0.5, x) da = DataArray(x_masked) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index d9c5aac57ab..d92b26fcee5 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -99,7 +99,7 @@ def create_append_test_data(seed=None) -> tuple[Dataset, Dataset, Dataset]: - rs = np.random.RandomState(seed) + rs = np.random.default_rng(seed) lat = [2, 1, 0] lon = [0, 1, 2] @@ -126,7 +126,7 @@ def create_append_test_data(seed=None) -> tuple[Dataset, Dataset, Dataset]: ds = xr.Dataset( data_vars={ "da": xr.DataArray( - rs.rand(3, 3, nt1), + rs.random((3, 3, nt1)), coords=[lat, lon, time1], dims=["lat", "lon", "time"], ), @@ -141,7 +141,7 @@ def create_append_test_data(seed=None) -> tuple[Dataset, Dataset, Dataset]: ds_to_append = xr.Dataset( data_vars={ "da": xr.DataArray( - rs.rand(3, 3, nt2), + rs.random((3, 3, nt2)), coords=[lat, lon, time2], dims=["lat", "lon", "time"], ), @@ -156,7 +156,7 @@ def create_append_test_data(seed=None) -> tuple[Dataset, Dataset, Dataset]: ds_with_new_var = xr.Dataset( data_vars={ "new_var": xr.DataArray( - rs.rand(3, 3, nt1 + nt2), + rs.random((3, 3, nt1 + nt2)), coords=[lat, lon, time1.append(time2)], dims=["lat", "lon", "time"], ) @@ -293,14 +293,14 @@ def test_repr(self) -> None: numbers (dim3) int64 80B 0 1 2 0 0 1 1 2 2 3 Dimensions without coordinates: dim1 Data variables: - var1 (dim1, dim2) float64 576B -1.086 0.9973 0.283 ... 0.4684 -0.8312 - var2 (dim1, dim2) float64 576B 1.162 -1.097 -2.123 ... 1.267 0.3328 - var3 (dim3, dim1) float64 640B 0.5565 -0.2121 0.4563 ... -0.2452 -0.3616 + var1 (dim1, dim2) float64 576B -0.9891 -0.3678 1.288 ... -0.2116 0.364 + var2 (dim1, dim2) float64 576B 0.953 1.52 1.704 ... 0.1347 -0.6423 + var3 (dim3, dim1) float64 640B 0.4107 0.9941 0.1665 ... 0.716 1.555 Attributes: foo: bar""".format(data["dim3"].dtype) ) actual = "\n".join(x.rstrip() for x in repr(data).split("\n")) - print(actual) + assert expected == actual with set_options(display_width=100): @@ -7161,13 +7161,13 @@ def test_raise_no_warning_assert_close(ds) -> None: @pytest.mark.parametrize("dask", [True, False]) @pytest.mark.parametrize("edge_order", [1, 2]) def test_differentiate(dask, edge_order) -> None: - rs = np.random.RandomState(42) + rs = np.random.default_rng(42) coord = [0.2, 0.35, 0.4, 0.6, 0.7, 0.75, 0.76, 0.8] da = xr.DataArray( - rs.randn(8, 6), + rs.random((8, 6)), dims=["x", "y"], - coords={"x": coord, "z": 3, "x2d": (("x", "y"), rs.randn(8, 6))}, + coords={"x": coord, "z": 3, "x2d": (("x", "y"), rs.random((8, 6)))}, ) if dask and has_dask: da = da.chunk({"x": 4}) @@ -7210,7 +7210,7 @@ def test_differentiate(dask, edge_order) -> None: @pytest.mark.filterwarnings("ignore:Converting non-nanosecond") @pytest.mark.parametrize("dask", [True, False]) def test_differentiate_datetime(dask) -> None: - rs = np.random.RandomState(42) + rs = np.random.default_rng(42) coord = np.array( [ "2004-07-13", @@ -7226,9 +7226,9 @@ def test_differentiate_datetime(dask) -> None: ) da = xr.DataArray( - rs.randn(8, 6), + rs.random((8, 6)), dims=["x", "y"], - coords={"x": coord, "z": 3, "x2d": (("x", "y"), rs.randn(8, 6))}, + coords={"x": coord, "z": 3, "x2d": (("x", "y"), rs.random((8, 6)))}, ) if dask and has_dask: da = da.chunk({"x": 4}) @@ -7260,12 +7260,12 @@ def test_differentiate_datetime(dask) -> None: @requires_cftime @pytest.mark.parametrize("dask", [True, False]) def test_differentiate_cftime(dask) -> None: - rs = np.random.RandomState(42) + rs = np.random.default_rng(42) coord = xr.cftime_range("2000", periods=8, freq="2ME") da = xr.DataArray( - rs.randn(8, 6), - coords={"time": coord, "z": 3, "t2d": (("time", "y"), rs.randn(8, 6))}, + rs.random((8, 6)), + coords={"time": coord, "z": 3, "t2d": (("time", "y"), rs.random((8, 6)))}, dims=["time", "y"], ) @@ -7289,17 +7289,17 @@ def test_differentiate_cftime(dask) -> None: @pytest.mark.parametrize("dask", [True, False]) def test_integrate(dask) -> None: - rs = np.random.RandomState(42) + rs = np.random.default_rng(42) coord = [0.2, 0.35, 0.4, 0.6, 0.7, 0.75, 0.76, 0.8] da = xr.DataArray( - rs.randn(8, 6), + rs.random((8, 6)), dims=["x", "y"], coords={ "x": coord, - "x2": (("x",), rs.randn(8)), + "x2": (("x",), rs.random(8)), "z": 3, - "x2d": (("x", "y"), rs.randn(8, 6)), + "x2d": (("x", "y"), rs.random((8, 6))), }, ) if dask and has_dask: @@ -7343,17 +7343,17 @@ def test_integrate(dask) -> None: @requires_scipy @pytest.mark.parametrize("dask", [True, False]) def test_cumulative_integrate(dask) -> None: - rs = np.random.RandomState(43) + rs = np.random.default_rng(43) coord = [0.2, 0.35, 0.4, 0.6, 0.7, 0.75, 0.76, 0.8] da = xr.DataArray( - rs.randn(8, 6), + rs.random((8, 6)), dims=["x", "y"], coords={ "x": coord, - "x2": (("x",), rs.randn(8)), + "x2": (("x",), rs.random(8)), "z": 3, - "x2d": (("x", "y"), rs.randn(8, 6)), + "x2d": (("x", "y"), rs.random((8, 6))), }, ) if dask and has_dask: @@ -7406,7 +7406,7 @@ def test_cumulative_integrate(dask) -> None: @pytest.mark.parametrize("dask", [True, False]) @pytest.mark.parametrize("which_datetime", ["np", "cftime"]) def test_trapezoid_datetime(dask, which_datetime) -> None: - rs = np.random.RandomState(42) + rs = np.random.default_rng(42) coord: ArrayLike if which_datetime == "np": coord = np.array( @@ -7428,8 +7428,8 @@ def test_trapezoid_datetime(dask, which_datetime) -> None: coord = xr.cftime_range("2000", periods=8, freq="2D") da = xr.DataArray( - rs.randn(8, 6), - coords={"time": coord, "z": 3, "t2d": (("time", "y"), rs.randn(8, 6))}, + rs.random((8, 6)), + coords={"time": coord, "z": 3, "t2d": (("time", "y"), rs.random((8, 6)))}, dims=["time", "y"], ) diff --git a/xarray/tests/test_duck_array_ops.py b/xarray/tests/test_duck_array_ops.py index a2f5631ce1b..e1306964757 100644 --- a/xarray/tests/test_duck_array_ops.py +++ b/xarray/tests/test_duck_array_ops.py @@ -341,16 +341,16 @@ def test_types(self, val1, val2, val3, null): def construct_dataarray(dim_num, dtype, contains_nan, dask): # dimnum <= 3 - rng = np.random.RandomState(0) + rng = np.random.default_rng(0) shapes = [16, 8, 4][:dim_num] dims = ("x", "y", "z")[:dim_num] if np.issubdtype(dtype, np.floating): - array = rng.randn(*shapes).astype(dtype) + array = rng.random(shapes).astype(dtype) elif np.issubdtype(dtype, np.integer): - array = rng.randint(0, 10, size=shapes).astype(dtype) + array = rng.integers(0, 10, size=shapes).astype(dtype) elif np.issubdtype(dtype, np.bool_): - array = rng.randint(0, 1, size=shapes).astype(dtype) + array = rng.integers(0, 1, size=shapes).astype(dtype) elif dtype is str: array = rng.choice(["a", "b", "c", "d"], size=shapes) else: diff --git a/xarray/tests/test_formatting_html.py b/xarray/tests/test_formatting_html.py index b518e7e95d9..7c9cdbeaaf5 100644 --- a/xarray/tests/test_formatting_html.py +++ b/xarray/tests/test_formatting_html.py @@ -11,7 +11,7 @@ @pytest.fixture def dataarray() -> xr.DataArray: - return xr.DataArray(np.random.RandomState(0).randn(4, 6)) + return xr.DataArray(np.random.default_rng(0).random((4, 6))) @pytest.fixture diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 2d5d6c4c16c..e4383dd58a9 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -2879,7 +2879,7 @@ def test_multiple_groupers(use_flox: bool, shuffle: bool) -> None: ) b = xr.DataArray( - np.random.RandomState(0).randn(2, 3, 4), + np.random.default_rng(0).random((2, 3, 4)), coords={"xy": (("x", "y"), [["a", "b", "c"], ["b", "c", "c"]], {"foo": "bar"})}, dims=["x", "y", "z"], ) diff --git a/xarray/tests/test_missing.py b/xarray/tests/test_missing.py index 58d8a9dcf5d..eb21cca0861 100644 --- a/xarray/tests/test_missing.py +++ b/xarray/tests/test_missing.py @@ -62,7 +62,7 @@ def ds(): def make_interpolate_example_data(shape, frac_nan, seed=12345, non_uniform=False): - rs = np.random.RandomState(seed) + rs = np.random.default_rng(seed) vals = rs.normal(size=shape) if frac_nan == 1: vals[:] = np.nan From f81018567ed30ba489c3944a20c1064b89a606a5 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 3 Dec 2024 08:41:53 -0700 Subject: [PATCH 13/23] Set `zarr_format` for `zarr.consolidate_metadata` (#9848) Avoids a few useless gets with Zarr v3 stores. --- xarray/backends/zarr.py | 1 + 1 file changed, 1 insertion(+) diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index fda99b131d8..cb3ab375c31 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -972,6 +972,7 @@ def store( if _zarr_v3(): # https://github.com/zarr-developers/zarr-python/pull/2113#issuecomment-2386718323 kwargs["path"] = self.zarr_group.name.lstrip("/") + kwargs["zarr_format"] = self.zarr_group.metadata.zarr_format zarr.consolidate_metadata(self.zarr_group.store, **kwargs) def sync(self): From 66e13a28420edab9841fb70a46dfda72dfde5a50 Mon Sep 17 00:00:00 2001 From: Patrick Peglar Date: Tue, 3 Dec 2024 15:50:15 +0000 Subject: [PATCH 14/23] Reference ncdata in iris conversion docs. (#9847) --- ci/requirements/doc.yml | 1 + doc/getting-started-guide/faq.rst | 6 ++--- doc/user-guide/io.rst | 38 ++++++++++++++++++++++++++++--- 3 files changed, 39 insertions(+), 6 deletions(-) diff --git a/ci/requirements/doc.yml b/ci/requirements/doc.yml index 3bf9640ec39..fc1234b787a 100644 --- a/ci/requirements/doc.yml +++ b/ci/requirements/doc.yml @@ -20,6 +20,7 @@ dependencies: - jupyter_client - matplotlib-base - nbsphinx + - ncdata - netcdf4>=1.5 - numba - numpy>=2 diff --git a/doc/getting-started-guide/faq.rst b/doc/getting-started-guide/faq.rst index b7ffd89b74a..af3b55a4086 100644 --- a/doc/getting-started-guide/faq.rst +++ b/doc/getting-started-guide/faq.rst @@ -173,9 +173,9 @@ integration with Cartopy_. We think the design decisions we have made for xarray (namely, basing it on pandas) make it a faster and more flexible data analysis tool. That said, Iris -has some great domain specific functionality, and xarray includes -methods for converting back and forth between xarray and Iris. See -:py:meth:`~xarray.DataArray.to_iris` for more details. +has some great domain specific functionality, and there are dedicated methods for +converting back and forth between xarray and Iris. See +:ref:`Reading and Writing Iris data ` for more details. What other projects leverage xarray? ------------------------------------ diff --git a/doc/user-guide/io.rst b/doc/user-guide/io.rst index 6f0be112024..8561d37ed40 100644 --- a/doc/user-guide/io.rst +++ b/doc/user-guide/io.rst @@ -13,6 +13,8 @@ format (recommended). import os + import iris + import ncdata.iris_xarray import numpy as np import pandas as pd import xarray as xr @@ -1072,8 +1074,11 @@ Iris The Iris_ tool allows easy reading of common meteorological and climate model formats (including GRIB and UK MetOffice PP files) into ``Cube`` objects which are in many ways very -similar to ``DataArray`` objects, while enforcing a CF-compliant data model. If iris is -installed, xarray can convert a ``DataArray`` into a ``Cube`` using +similar to ``DataArray`` objects, while enforcing a CF-compliant data model. + +DataArray ``to_iris`` and ``from_iris`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +If iris is installed, xarray can convert a ``DataArray`` into a ``Cube`` using :py:meth:`DataArray.to_iris`: .. ipython:: python @@ -1095,9 +1100,36 @@ Conversely, we can create a new ``DataArray`` object from a ``Cube`` using da_cube = xr.DataArray.from_iris(cube) da_cube +Ncdata +~~~~~~ +Ncdata_ provides more sophisticated means of transferring data, including entire +datasets. It uses the file saving and loading functions in both projects to provide a +more "correct" translation between them, but still with very low overhead and not +using actual disk files. -.. _Iris: https://scitools.org.uk/iris +For example: + +.. ipython:: python + :okwarning: + ds = xr.tutorial.open_dataset("air_temperature_gradient") + cubes = ncdata.iris_xarray.cubes_from_xarray(ds) + print(cubes) + print(cubes[1]) + +.. ipython:: python + :okwarning: + + ds = ncdata.iris_xarray.cubes_to_xarray(cubes) + print(ds) + +Ncdata can also adjust file data within load and save operations, to fix data loading +problems or provide exact save formatting without needing to modify files on disk. +See for example : `ncdata usage examples`_ + +.. _Iris: https://scitools.org.uk/iris +.. _Ncdata: https://ncdata.readthedocs.io/en/latest/index.html +.. _ncdata usage examples: https://github.com/pp-mo/ncdata/tree/v0.1.2?tab=readme-ov-file#correct-a-miscoded-attribute-in-iris-input OPeNDAP ------- From 99ee8c6ca54057a9b994d7685f36236f2d5a69d9 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Tue, 3 Dec 2024 20:17:45 +0100 Subject: [PATCH 15/23] Add Pyproject pre-commit hooks (#9840) * add taplo pre-commit hook * move TODO https://github.com/tamasfe/taplo/issues/706 * add validate-pyproject pre-commit hook --- .pre-commit-config.yaml | 10 ++++++++++ pyproject.toml | 21 ++++++++++++++++----- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7e22cf8ab37..8d2b2ee809e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -63,3 +63,13 @@ repos: rev: ebf0b5e44d67f8beaa1cd13a0d0393ea04c6058d hooks: - id: validate-cff + - repo: https://github.com/ComPWA/taplo-pre-commit + rev: v0.9.3 + hooks: + - id: taplo-format + args: ["--option", "array_auto_collapse=false"] + - repo: https://github.com/abravalheri/validate-pyproject + rev: v0.23 + hooks: + - id: validate-pyproject + additional_dependencies: ["validate-pyproject-schema-store[all]"] diff --git a/pyproject.toml b/pyproject.toml index 3ac1a024195..e6c63c4d010 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] authors = [ - {name = "xarray Developers", email = "xarray@googlegroups.com"}, + { name = "xarray Developers", email = "xarray@googlegroups.com" }, ] classifiers = [ "Development Status :: 5 - Production/Stable", @@ -16,7 +16,7 @@ classifiers = [ ] description = "N-D labeled arrays and datasets in Python" dynamic = ["version"] -license = {text = "Apache-2.0"} +license = { text = "Apache-2.0" } name = "xarray" readme = "README.md" requires-python = ">=3.10" @@ -50,7 +50,16 @@ dev = [ "sphinx_autosummary_accessors", "xarray[complete]", ] -io = ["netCDF4", "h5netcdf", "scipy", 'pydap; python_version<"3.10"', "zarr", "fsspec", "cftime", "pooch"] +io = [ + "netCDF4", + "h5netcdf", + "scipy", + 'pydap; python_version<"3.10"', + "zarr", + "fsspec", + "cftime", + "pooch", +] etc = ["sparse"] parallel = ["dask[complete]"] viz = ["cartopy", "matplotlib", "nc-time-axis", "seaborn"] @@ -249,7 +258,7 @@ extend-select = [ "RUF", ] extend-safe-fixes = [ - "TID252", # absolute imports + "TID252", # absolute imports ] ignore = [ "E402", # module level import not at top of file @@ -327,7 +336,9 @@ filterwarnings = [ "default:the `pandas.MultiIndex` object:FutureWarning:xarray.tests.test_variable", "default:Using a non-tuple sequence for multidimensional indexing is deprecated:FutureWarning", "default:Duplicate dimension names present:UserWarning:xarray.namedarray.core", - "default:::xarray.tests.test_strategies", # TODO: remove once we know how to deal with a changed signature in protocols + + # TODO: remove once we know how to deal with a changed signature in protocols + "default:::xarray.tests.test_strategies", ] log_cli_level = "INFO" From fab900cfe7661da4f9778f9ade181cac91bddb9a Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Thu, 5 Dec 2024 10:23:45 -0700 Subject: [PATCH 16/23] dask tests: Avoid check for non-copies, xfail pandas comparison (#9857) --- xarray/tests/test_dask.py | 1 + xarray/tests/test_variable.py | 25 +++++++++++-------------- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index e3e12599926..068f57ed42d 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -791,6 +791,7 @@ def test_tokenize_duck_dask_array(self): class TestToDaskDataFrame: + @pytest.mark.xfail(reason="https://github.com/dask/dask/issues/11584") def test_to_dask_dataframe(self): # Test conversion of Datasets to dask DataFrames x = np.random.randn(10) diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index 7dc5ef0db94..f4f353eda7d 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -278,34 +278,30 @@ def test_0d_time_data(self): @pytest.mark.filterwarnings("ignore:Converting non-nanosecond") def test_datetime64_conversion(self): times = pd.date_range("2000-01-01", periods=3) - for values, preserve_source in [ - (times, True), - (times.values, True), - (times.values.astype("datetime64[s]"), False), - (times.to_pydatetime(), False), + for values in [ + times, + times.values, + times.values.astype("datetime64[s]"), + times.to_pydatetime(), ]: v = self.cls(["t"], values) assert v.dtype == np.dtype("datetime64[ns]") assert_array_equal(v.values, times.values) assert v.values.dtype == np.dtype("datetime64[ns]") - same_source = source_ndarray(v.values) is source_ndarray(values) - assert preserve_source == same_source @pytest.mark.filterwarnings("ignore:Converting non-nanosecond") def test_timedelta64_conversion(self): times = pd.timedelta_range(start=0, periods=3) - for values, preserve_source in [ - (times, True), - (times.values, True), - (times.values.astype("timedelta64[s]"), False), - (times.to_pytimedelta(), False), + for values in [ + times, + times.values, + times.values.astype("timedelta64[s]"), + times.to_pytimedelta(), ]: v = self.cls(["t"], values) assert v.dtype == np.dtype("timedelta64[ns]") assert_array_equal(v.values, times.values) assert v.values.dtype == np.dtype("timedelta64[ns]") - same_source = source_ndarray(v.values) is source_ndarray(values) - assert preserve_source == same_source def test_object_conversion(self): data = np.arange(5).astype(str).astype(object) @@ -2372,6 +2368,7 @@ def test_dask_rolling(self, dim, window, center): assert actual.shape == expected.shape assert_equal(actual, expected) + @pytest.mark.xfail(reason="https://github.com/dask/dask/issues/11585") def test_multiindex(self): super().test_multiindex() From eac5105470ec7fec767e5897fefdec9320689184 Mon Sep 17 00:00:00 2001 From: Florian Jetter Date: Thu, 5 Dec 2024 22:44:46 +0100 Subject: [PATCH 17/23] Avoid local functions in push (#9856) * Avoid local functions in push * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Deepak Cherian --- xarray/core/dask_array_ops.py | 71 +++++++++++++++++++++-------------- 1 file changed, 43 insertions(+), 28 deletions(-) diff --git a/xarray/core/dask_array_ops.py b/xarray/core/dask_array_ops.py index 8bf9c68b727..2dca38538e1 100644 --- a/xarray/core/dask_array_ops.py +++ b/xarray/core/dask_array_ops.py @@ -1,6 +1,7 @@ from __future__ import annotations import math +from functools import partial from xarray.core import dtypes, nputils @@ -75,6 +76,47 @@ def least_squares(lhs, rhs, rcond=None, skipna=False): return coeffs, residuals +def _fill_with_last_one(a, b): + import numpy as np + + # cumreduction apply the push func over all the blocks first so, + # the only missing part is filling the missing values using the + # last data of the previous chunk + return np.where(np.isnan(b), a, b) + + +def _dtype_push(a, axis, dtype=None): + from xarray.core.duck_array_ops import _push + + # Not sure why the blelloch algorithm force to receive a dtype + return _push(a, axis=axis) + + +def _reset_cumsum(a, axis, dtype=None): + import numpy as np + + cumsum = np.cumsum(a, axis=axis) + reset_points = np.maximum.accumulate(np.where(a == 0, cumsum, 0), axis=axis) + return cumsum - reset_points + + +def _last_reset_cumsum(a, axis, keepdims=None): + import numpy as np + + # Take the last cumulative sum taking into account the reset + # This is useful for blelloch method + return np.take(_reset_cumsum(a, axis=axis), axis=axis, indices=[-1]) + + +def _combine_reset_cumsum(a, b, axis): + import numpy as np + + # It is going to sum the previous result until the first + # non nan value + bitmask = np.cumprod(b != 0, axis=axis) + return np.where(bitmask, b + a, b) + + def push(array, n, axis, method="blelloch"): """ Dask-aware bottleneck.push @@ -91,16 +133,6 @@ def push(array, n, axis, method="blelloch"): # TODO: Replace all this function # once https://github.com/pydata/xarray/issues/9229 being implemented - def _fill_with_last_one(a, b): - # cumreduction apply the push func over all the blocks first so, - # the only missing part is filling the missing values using the - # last data of the previous chunk - return np.where(np.isnan(b), a, b) - - def _dtype_push(a, axis, dtype=None): - # Not sure why the blelloch algorithm force to receive a dtype - return _push(a, axis=axis) - pushed_array = da.reductions.cumreduction( func=_dtype_push, binop=_fill_with_last_one, @@ -113,26 +145,9 @@ def _dtype_push(a, axis, dtype=None): ) if n is not None and 0 < n < array.shape[axis] - 1: - - def _reset_cumsum(a, axis, dtype=None): - cumsum = np.cumsum(a, axis=axis) - reset_points = np.maximum.accumulate(np.where(a == 0, cumsum, 0), axis=axis) - return cumsum - reset_points - - def _last_reset_cumsum(a, axis, keepdims=None): - # Take the last cumulative sum taking into account the reset - # This is useful for blelloch method - return np.take(_reset_cumsum(a, axis=axis), axis=axis, indices=[-1]) - - def _combine_reset_cumsum(a, b): - # It is going to sum the previous result until the first - # non nan value - bitmask = np.cumprod(b != 0, axis=axis) - return np.where(bitmask, b + a, b) - valid_positions = da.reductions.cumreduction( func=_reset_cumsum, - binop=_combine_reset_cumsum, + binop=partial(_combine_reset_cumsum, axis=axis), ident=0, x=da.isnan(array, dtype=int), axis=axis, From f9ed7275ce5cd15e2fc92a88ec352a16ef24fa3c Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Sun, 8 Dec 2024 22:22:49 -0800 Subject: [PATCH 18/23] Add token to codecov (#9865) Without this, we can't upload on `main`, I think. I added a token into our Actions Secrets. ref https://github.com/pydata/xarray/issues/9860 --- .github/workflows/ci.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index b0996acf6fe..eb7d2c858af 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -160,6 +160,8 @@ jobs: - name: Upload code coverage to Codecov uses: codecov/codecov-action@v5.0.7 + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} with: file: ./coverage.xml flags: unittests From 2d628b6d31e782c48da11a84ce03d572b1dd9e11 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 9 Dec 2024 02:35:25 -0800 Subject: [PATCH 19/23] Bump codecov/codecov-action from 5.0.7 to 5.1.1 in the actions group (#9866) --- .github/workflows/ci-additional.yaml | 8 ++++---- .github/workflows/ci.yaml | 2 +- .github/workflows/upstream-dev-ci.yaml | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/ci-additional.yaml b/.github/workflows/ci-additional.yaml index 91c63528741..84114056312 100644 --- a/.github/workflows/ci-additional.yaml +++ b/.github/workflows/ci-additional.yaml @@ -123,7 +123,7 @@ jobs: python -m mypy --install-types --non-interactive --cobertura-xml-report mypy_report - name: Upload mypy coverage to Codecov - uses: codecov/codecov-action@v5.0.7 + uses: codecov/codecov-action@v5.1.1 with: file: mypy_report/cobertura.xml flags: mypy @@ -174,7 +174,7 @@ jobs: python -m mypy --install-types --non-interactive --cobertura-xml-report mypy_report - name: Upload mypy coverage to Codecov - uses: codecov/codecov-action@v5.0.7 + uses: codecov/codecov-action@v5.1.1 with: file: mypy_report/cobertura.xml flags: mypy-min @@ -230,7 +230,7 @@ jobs: python -m pyright xarray/ - name: Upload pyright coverage to Codecov - uses: codecov/codecov-action@v5.0.7 + uses: codecov/codecov-action@v5.1.1 with: file: pyright_report/cobertura.xml flags: pyright @@ -286,7 +286,7 @@ jobs: python -m pyright xarray/ - name: Upload pyright coverage to Codecov - uses: codecov/codecov-action@v5.0.7 + uses: codecov/codecov-action@v5.1.1 with: file: pyright_report/cobertura.xml flags: pyright39 diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index eb7d2c858af..ad710e36247 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -159,7 +159,7 @@ jobs: path: pytest.xml - name: Upload code coverage to Codecov - uses: codecov/codecov-action@v5.0.7 + uses: codecov/codecov-action@v5.1.1 env: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} with: diff --git a/.github/workflows/upstream-dev-ci.yaml b/.github/workflows/upstream-dev-ci.yaml index 30047673187..6a8b8d777c4 100644 --- a/.github/workflows/upstream-dev-ci.yaml +++ b/.github/workflows/upstream-dev-ci.yaml @@ -140,7 +140,7 @@ jobs: run: | python -m mypy --install-types --non-interactive --cobertura-xml-report mypy_report - name: Upload mypy coverage to Codecov - uses: codecov/codecov-action@v5.0.7 + uses: codecov/codecov-action@v5.1.1 with: file: mypy_report/cobertura.xml flags: mypy From 96e0ff7d70c605a1505ff89a2d62b5c4138b0305 Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Mon, 9 Dec 2024 15:29:03 -0800 Subject: [PATCH 20/23] Remove deprecated behavior for non-dim positional args (#9864) * Remove deprecated behavior with for non-dim positional args * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- doc/whats-new.rst | 11 ++++++++--- xarray/core/dataarray.py | 15 --------------- xarray/core/dataset.py | 17 +++++------------ xarray/core/groupby.py | 2 -- xarray/core/weighted.py | 7 ------- 5 files changed, 13 insertions(+), 39 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 9a8154d3791..6a08246182c 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -14,9 +14,9 @@ What's New np.random.seed(123456) -.. _whats-new.2024.11.1: +.. _whats-new.2024.12.0: -v.2024.11.1 (unreleased) +v.2024.12.0 (unreleased) ------------------------ New Features @@ -28,7 +28,12 @@ New Features Breaking changes ~~~~~~~~~~~~~~~~ - +- Methods including ``dropna``, ``rank``, ``idxmax``, ``idxmin`` require + non-dimension arguments to be passed as keyword arguments. The previous + behavior, which allowed ``.idxmax('foo', 'all')`` was too easily confused with + ``'all'`` being a dimension. The updated equivalent is ``.idxmax('foo', + how='all')``. The previous behavior was deprecated in v2023.10.0. + By `Maximilian Roos `_. Deprecations ~~~~~~~~~~~~ diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 5150998aebb..d287564cfe5 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -1026,7 +1026,6 @@ def reset_coords( drop: Literal[True], ) -> Self: ... - @_deprecate_positional_args("v2023.10.0") def reset_coords( self, names: Dims = None, @@ -1364,7 +1363,6 @@ def chunksizes(self) -> Mapping[Any, tuple[int, ...]]: all_variables = [self.variable] + [c.variable for c in self.coords.values()] return get_chunksizes(all_variables) - @_deprecate_positional_args("v2023.10.0") def chunk( self, chunks: T_ChunksFreq = {}, # noqa: B006 # {} even though it's technically unsafe, is being used intentionally here (#4667) @@ -1835,7 +1833,6 @@ def thin( ds = self._to_temp_dataset().thin(indexers, **indexers_kwargs) return self._from_temp_dataset(ds) - @_deprecate_positional_args("v2023.10.0") def broadcast_like( self, other: T_DataArrayOrSet, @@ -1948,7 +1945,6 @@ def _reindex_callback( return da - @_deprecate_positional_args("v2023.10.0") def reindex_like( self, other: T_DataArrayOrSet, @@ -2135,7 +2131,6 @@ def reindex_like( fill_value=fill_value, ) - @_deprecate_positional_args("v2023.10.0") def reindex( self, indexers: Mapping[Any, Any] | None = None, @@ -2960,7 +2955,6 @@ def stack( ) return self._from_temp_dataset(ds) - @_deprecate_positional_args("v2023.10.0") def unstack( self, dim: Dims = None, @@ -3385,7 +3379,6 @@ def drop_isel( dataset = dataset.drop_isel(indexers=indexers, **indexers_kwargs) return self._from_temp_dataset(dataset) - @_deprecate_positional_args("v2023.10.0") def dropna( self, dim: Hashable, @@ -4889,7 +4882,6 @@ def _title_for_slice(self, truncate: int = 50) -> str: return title - @_deprecate_positional_args("v2023.10.0") def diff( self, dim: Hashable, @@ -5198,7 +5190,6 @@ def sortby( ds = self._to_temp_dataset().sortby(variables, ascending=ascending) return self._from_temp_dataset(ds) - @_deprecate_positional_args("v2023.10.0") def quantile( self, q: ArrayLike, @@ -5318,7 +5309,6 @@ def quantile( ) return self._from_temp_dataset(ds) - @_deprecate_positional_args("v2023.10.0") def rank( self, dim: Hashable, @@ -5897,7 +5887,6 @@ def pad( ) return self._from_temp_dataset(ds) - @_deprecate_positional_args("v2023.10.0") def idxmin( self, dim: Hashable | None = None, @@ -5995,7 +5984,6 @@ def idxmin( keep_attrs=keep_attrs, ) - @_deprecate_positional_args("v2023.10.0") def idxmax( self, dim: Hashable = None, @@ -6093,7 +6081,6 @@ def idxmax( keep_attrs=keep_attrs, ) - @_deprecate_positional_args("v2023.10.0") def argmin( self, dim: Dims = None, @@ -6195,7 +6182,6 @@ def argmin( else: return self._replace_maybe_drop_dims(result) - @_deprecate_positional_args("v2023.10.0") def argmax( self, dim: Dims = None, @@ -6544,7 +6530,6 @@ def curvefit( kwargs=kwargs, ) - @_deprecate_positional_args("v2023.10.0") def drop_duplicates( self, dim: Hashable | Iterable[Hashable], diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 2c1f5cfd4ac..ea17a69f827 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -3276,9 +3276,11 @@ def _shuffle(self, dim, *, indices: GroupIndices, chunks: T_Chunks) -> Self: subset = self[[name for name in self._variables if name not in is_chunked]] no_slices: list[list[int]] = [ - list(range(*idx.indices(self.sizes[dim]))) - if isinstance(idx, slice) - else idx + ( + list(range(*idx.indices(self.sizes[dim]))) + if isinstance(idx, slice) + else idx + ) for idx in indices ] no_slices = [idx for idx in no_slices if idx] @@ -5102,7 +5104,6 @@ def set_index( variables, coord_names=coord_names, indexes=indexes_ ) - @_deprecate_positional_args("v2023.10.0") def reset_index( self, dims_or_levels: Hashable | Sequence[Hashable], @@ -5740,7 +5741,6 @@ def _unstack_full_reindex( variables, coord_names=coord_names, indexes=indexes ) - @_deprecate_positional_args("v2023.10.0") def unstack( self, dim: Dims = None, @@ -6502,7 +6502,6 @@ def transpose( ds._variables[name] = var.transpose(*var_dims) return ds - @_deprecate_positional_args("v2023.10.0") def dropna( self, dim: Hashable, @@ -7976,7 +7975,6 @@ def _copy_attrs_from(self, other): if v in self.variables: self.variables[v].attrs = other.variables[v].attrs - @_deprecate_positional_args("v2023.10.0") def diff( self, dim: Hashable, @@ -8324,7 +8322,6 @@ def sortby( indices[key] = order if ascending else order[::-1] return aligned_self.isel(indices) - @_deprecate_positional_args("v2023.10.0") def quantile( self, q: ArrayLike, @@ -8505,7 +8502,6 @@ def quantile( ) return new.assign_coords(quantile=q) - @_deprecate_positional_args("v2023.10.0") def rank( self, dim: Hashable, @@ -9476,7 +9472,6 @@ def pad( attrs = self._attrs if keep_attrs else None return self._replace_with_new_dims(variables, indexes=indexes, attrs=attrs) - @_deprecate_positional_args("v2023.10.0") def idxmin( self, dim: Hashable | None = None, @@ -9575,7 +9570,6 @@ def idxmin( ) ) - @_deprecate_positional_args("v2023.10.0") def idxmax( self, dim: Hashable | None = None, @@ -10258,7 +10252,6 @@ def _wrapper(Y, *args, **kwargs): return result - @_deprecate_positional_args("v2023.10.0") def drop_duplicates( self, dim: Hashable | Iterable[Hashable], diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 9596d19e735..ceae79031f8 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -50,7 +50,6 @@ ) from xarray.core.variable import IndexVariable, Variable from xarray.namedarray.pycompat import is_chunked_array -from xarray.util.deprecation_helpers import _deprecate_positional_args if TYPE_CHECKING: from numpy.typing import ArrayLike @@ -1183,7 +1182,6 @@ def fillna(self, value: Any) -> T_Xarray: """ return ops.fillna(self, value) - @_deprecate_positional_args("v2023.10.0") def quantile( self, q: ArrayLike, diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py index 2c6e7d4282a..269cb49a2c1 100644 --- a/xarray/core/weighted.py +++ b/xarray/core/weighted.py @@ -11,7 +11,6 @@ from xarray.core.computation import apply_ufunc, dot from xarray.core.types import Dims, T_DataArray, T_Xarray from xarray.namedarray.utils import is_duck_dask_array -from xarray.util.deprecation_helpers import _deprecate_positional_args # Weighted quantile methods are a subset of the numpy supported quantile methods. QUANTILE_METHODS = Literal[ @@ -454,7 +453,6 @@ def _weighted_quantile_1d( def _implementation(self, func, dim, **kwargs): raise NotImplementedError("Use `Dataset.weighted` or `DataArray.weighted`") - @_deprecate_positional_args("v2023.10.0") def sum_of_weights( self, dim: Dims = None, @@ -465,7 +463,6 @@ def sum_of_weights( self._sum_of_weights, dim=dim, keep_attrs=keep_attrs ) - @_deprecate_positional_args("v2023.10.0") def sum_of_squares( self, dim: Dims = None, @@ -477,7 +474,6 @@ def sum_of_squares( self._sum_of_squares, dim=dim, skipna=skipna, keep_attrs=keep_attrs ) - @_deprecate_positional_args("v2023.10.0") def sum( self, dim: Dims = None, @@ -489,7 +485,6 @@ def sum( self._weighted_sum, dim=dim, skipna=skipna, keep_attrs=keep_attrs ) - @_deprecate_positional_args("v2023.10.0") def mean( self, dim: Dims = None, @@ -501,7 +496,6 @@ def mean( self._weighted_mean, dim=dim, skipna=skipna, keep_attrs=keep_attrs ) - @_deprecate_positional_args("v2023.10.0") def var( self, dim: Dims = None, @@ -513,7 +507,6 @@ def var( self._weighted_var, dim=dim, skipna=skipna, keep_attrs=keep_attrs ) - @_deprecate_positional_args("v2023.10.0") def std( self, dim: Dims = None, From 49502fcde4db6ea3da1f60ead589580cfdad5c98 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 11 Dec 2024 16:51:30 -0700 Subject: [PATCH 21/23] Fix/silence upstream tests (#9879) --- xarray/tests/test_cftimeindex.py | 1 + xarray/tests/test_formatting.py | 5 +++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/xarray/tests/test_cftimeindex.py b/xarray/tests/test_cftimeindex.py index 116487e2bcf..e34714a344a 100644 --- a/xarray/tests/test_cftimeindex.py +++ b/xarray/tests/test_cftimeindex.py @@ -1188,6 +1188,7 @@ def test_to_datetimeindex_feb_29(calendar): index.to_datetimeindex() +@pytest.mark.xfail(reason="fails on pandas main branch") @requires_cftime def test_multiindex(): index = xr.cftime_range("2001-01-01", periods=100, calendar="360_day") diff --git a/xarray/tests/test_formatting.py b/xarray/tests/test_formatting.py index 946d491bd61..9b658fa0d66 100644 --- a/xarray/tests/test_formatting.py +++ b/xarray/tests/test_formatting.py @@ -1020,9 +1020,10 @@ def test_display_nbytes() -> None: assert actual == expected actual = repr(xds["foo"]) - expected = """ + array_repr = repr(xds.foo.data).replace("\n ", "") + expected = f""" Size: 2kB -array([ 0, 1, 2, ..., 1197, 1198, 1199], dtype=int16) +{array_repr} Coordinates: * foo (foo) int16 2kB 0 1 2 3 4 5 6 ... 1194 1195 1196 1197 1198 1199 """.strip() From b7e6036555e54ffdaa685ba3cc6a94cd9664c4b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kai=20M=C3=BChlbauer?= Date: Fri, 13 Dec 2024 16:06:45 +0100 Subject: [PATCH 22/23] finalize deprecation of "closed"-parameter (#9882) * finalize deprecation of "closed" to "inclusive" in date_range and cftime_range * add whats-new.rst entry * fix tests * fix test * remove stale function --- doc/whats-new.rst | 4 +- xarray/coding/cftime_offsets.py | 73 ++--------------------------- xarray/tests/test_cftime_offsets.py | 48 ------------------- 3 files changed, 8 insertions(+), 117 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 6a08246182c..08e6218ca14 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -37,7 +37,9 @@ Breaking changes Deprecations ~~~~~~~~~~~~ - +- Finalize deprecation of ``closed`` parameters of :py:func:`cftime_range` and + :py:func:`date_range` (:pull:`9882`). + By `Kai Mühlbauer `_. Bug fixes ~~~~~~~~~ diff --git a/xarray/coding/cftime_offsets.py b/xarray/coding/cftime_offsets.py index a994eb9661f..89c06e56ea7 100644 --- a/xarray/coding/cftime_offsets.py +++ b/xarray/coding/cftime_offsets.py @@ -62,15 +62,13 @@ ) from xarray.core.common import _contains_datetime_like_objects, is_np_datetime_like from xarray.core.pdcompat import ( - NoDefault, count_not_none, nanosecond_precision_timestamp, - no_default, ) from xarray.core.utils import attempt_import, emit_user_level_warning if TYPE_CHECKING: - from xarray.core.types import InclusiveOptions, Self, SideOptions, TypeAlias + from xarray.core.types import InclusiveOptions, Self, TypeAlias DayOption: TypeAlias = Literal["start", "end"] @@ -943,42 +941,6 @@ def _generate_range(start, end, periods, offset): current = next_date -def _translate_closed_to_inclusive(closed): - """Follows code added in pandas #43504.""" - emit_user_level_warning( - "Following pandas, the `closed` parameter is deprecated in " - "favor of the `inclusive` parameter, and will be removed in " - "a future version of xarray.", - FutureWarning, - ) - if closed is None: - inclusive = "both" - elif closed in ("left", "right"): - inclusive = closed - else: - raise ValueError( - f"Argument `closed` must be either 'left', 'right', or None. " - f"Got {closed!r}." - ) - return inclusive - - -def _infer_inclusive( - closed: NoDefault | SideOptions, inclusive: InclusiveOptions | None -) -> InclusiveOptions: - """Follows code added in pandas #43504.""" - if closed is not no_default and inclusive is not None: - raise ValueError( - "Following pandas, deprecated argument `closed` cannot be " - "passed if argument `inclusive` is not None." - ) - if closed is not no_default: - return _translate_closed_to_inclusive(closed) - if inclusive is None: - return "both" - return inclusive - - def cftime_range( start=None, end=None, @@ -986,8 +948,7 @@ def cftime_range( freq=None, normalize=False, name=None, - closed: NoDefault | SideOptions = no_default, - inclusive: None | InclusiveOptions = None, + inclusive: InclusiveOptions = "both", calendar="standard", ) -> CFTimeIndex: """Return a fixed frequency CFTimeIndex. @@ -1006,16 +967,7 @@ def cftime_range( Normalize start/end dates to midnight before generating date range. name : str, default: None Name of the resulting index - closed : {None, "left", "right"}, default: "NO_DEFAULT" - Make the interval closed with respect to the given frequency to the - "left", "right", or both sides (None). - - .. deprecated:: 2023.02.0 - Following pandas, the ``closed`` parameter is deprecated in favor - of the ``inclusive`` parameter, and will be removed in a future - version of xarray. - - inclusive : {None, "both", "neither", "left", "right"}, default None + inclusive : {"both", "neither", "left", "right"}, default "both" Include boundaries; whether to set each bound as closed or open. .. versionadded:: 2023.02.0 @@ -1193,8 +1145,6 @@ def cftime_range( offset = to_offset(freq) dates = np.array(list(_generate_range(start, end, periods, offset))) - inclusive = _infer_inclusive(closed, inclusive) - if inclusive == "neither": left_closed = False right_closed = False @@ -1229,8 +1179,7 @@ def date_range( tz=None, normalize=False, name=None, - closed: NoDefault | SideOptions = no_default, - inclusive: None | InclusiveOptions = None, + inclusive: InclusiveOptions = "both", calendar="standard", use_cftime=None, ): @@ -1257,20 +1206,10 @@ def date_range( Normalize start/end dates to midnight before generating date range. name : str, default: None Name of the resulting index - closed : {None, "left", "right"}, default: "NO_DEFAULT" - Make the interval closed with respect to the given frequency to the - "left", "right", or both sides (None). - - .. deprecated:: 2023.02.0 - Following pandas, the `closed` parameter is deprecated in favor - of the `inclusive` parameter, and will be removed in a future - version of xarray. - - inclusive : {None, "both", "neither", "left", "right"}, default: None + inclusive : {"both", "neither", "left", "right"}, default: "both" Include boundaries; whether to set each bound as closed or open. .. versionadded:: 2023.02.0 - calendar : str, default: "standard" Calendar type for the datetimes. use_cftime : boolean, optional @@ -1294,8 +1233,6 @@ def date_range( if tz is not None: use_cftime = False - inclusive = _infer_inclusive(closed, inclusive) - if _is_standard_calendar(calendar) and use_cftime is not True: try: return pd.date_range( diff --git a/xarray/tests/test_cftime_offsets.py b/xarray/tests/test_cftime_offsets.py index f6f97108c1d..1ab6c611aac 100644 --- a/xarray/tests/test_cftime_offsets.py +++ b/xarray/tests/test_cftime_offsets.py @@ -1057,15 +1057,6 @@ def test_rollback(calendar, offset, initial_date_args, partial_expected_date_arg False, [(1, 1, 2), (1, 1, 3)], ), - ( - "0001-01-01", - "0001-01-04", - None, - "D", - None, - False, - [(1, 1, 1), (1, 1, 2), (1, 1, 3), (1, 1, 4)], - ), ( "0001-01-01", "0001-01-04", @@ -1294,13 +1285,6 @@ def test_invalid_cftime_range_inputs( cftime_range(start, end, periods, freq, inclusive=inclusive) # type: ignore[arg-type] -def test_invalid_cftime_arg() -> None: - with pytest.warns( - FutureWarning, match="Following pandas, the `closed` parameter is deprecated" - ): - cftime_range("2000", "2001", None, "YE", closed="left") - - _CALENDAR_SPECIFIC_MONTH_END_TESTS = [ ("noleap", [(2, 28), (4, 30), (6, 30), (8, 31), (10, 31), (12, 31)]), ("all_leap", [(2, 29), (4, 30), (6, 30), (8, 31), (10, 31), (12, 31)]), @@ -1534,15 +1518,6 @@ def as_timedelta_not_implemented_error(): tick.as_timedelta() -@pytest.mark.parametrize("function", [cftime_range, date_range]) -def test_cftime_or_date_range_closed_and_inclusive_error(function: Callable) -> None: - if function == cftime_range and not has_cftime: - pytest.skip("requires cftime") - - with pytest.raises(ValueError, match="Following pandas, deprecated"): - function("2000", periods=3, closed=None, inclusive="right") - - @pytest.mark.parametrize("function", [cftime_range, date_range]) def test_cftime_or_date_range_invalid_inclusive_value(function: Callable) -> None: if function == cftime_range and not has_cftime: @@ -1552,29 +1527,6 @@ def test_cftime_or_date_range_invalid_inclusive_value(function: Callable) -> Non function("2000", periods=3, inclusive="foo") -@pytest.mark.parametrize( - "function", - [ - pytest.param(cftime_range, id="cftime", marks=requires_cftime), - pytest.param(date_range, id="date"), - ], -) -@pytest.mark.parametrize( - ("closed", "inclusive"), [(None, "both"), ("left", "left"), ("right", "right")] -) -def test_cftime_or_date_range_closed( - function: Callable, - closed: Literal["left", "right", None], - inclusive: Literal["left", "right", "both"], -) -> None: - with pytest.warns(FutureWarning, match="Following pandas"): - result_closed = function("2000-01-01", "2000-01-04", freq="D", closed=closed) - result_inclusive = function( - "2000-01-01", "2000-01-04", freq="D", inclusive=inclusive - ) - np.testing.assert_equal(result_closed.values, result_inclusive.values) - - @pytest.mark.parametrize("function", [cftime_range, date_range]) def test_cftime_or_date_range_inclusive_None(function) -> None: if function == cftime_range and not has_cftime: From f05c5ec799f144b8cc3b9355a702814be7285d8f Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Fri, 13 Dec 2024 08:59:54 -0700 Subject: [PATCH 23/23] Fix upstream Zarr compatibility (#9884) Closes #9880 --- xarray/backends/zarr.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index cb3ab375c31..d7f056a209a 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -1135,9 +1135,7 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No else: encoded_attrs[DIMENSION_KEY] = dims - encoding["exists_ok" if _zarr_v3() else "overwrite"] = ( - True if self._mode == "w" else False - ) + encoding["overwrite"] = True if self._mode == "w" else False zarr_array = self._create_new_array( name=name,