Skip to content

Commit

Permalink
Improved duck array wrapping (pydata#9798)
Browse files Browse the repository at this point in the history
* lots more duck array compat, plus tests

* merge sliding_window_view

* namespaces constant

* revert dask allowed

* fix up some tests

* backwards compat sparse mask

* add as_array methods

* to_like_array helper

* only cast non-numpy

* better idxminmax approach

* fix mypy

* naming, add is_array_type

* add public doc and whats new

* update comments

* add support for chunked arrays in as_array_type

* revert array_type methods

* fix up whats new

* comment about bool_

* add jax to complete ci envs

* add pint and sparse to tests

* remove from windows

* mypy, xfail one more sparse

* add dask and a few other methods

* move whats new
  • Loading branch information
slevang authored Nov 26, 2024
1 parent a765ae0 commit caf62d3
Show file tree
Hide file tree
Showing 15 changed files with 730 additions and 103 deletions.
2 changes: 2 additions & 0 deletions ci/requirements/environment-3.13.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,5 @@ dependencies:
- toolz
- typing_extensions
- zarr
- pip:
- jax # no way to get cpu-only jaxlib from conda if gpu is present
2 changes: 2 additions & 0 deletions ci/requirements/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,5 @@ dependencies:
- toolz
- typing_extensions
- zarr
- pip:
- jax # no way to get cpu-only jaxlib from conda if gpu is present
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ v.2024.11.1 (unreleased)

New Features
~~~~~~~~~~~~
- Better support wrapping additional array types (e.g. ``cupy`` or ``jax``) by calling generalized
duck array operations throughout more xarray methods. (:issue:`7848`, :pull:`9798`).
By `Sam Levang <https://github.com/slevang>`_.


Breaking changes
Expand Down
38 changes: 38 additions & 0 deletions xarray/core/array_api_compat.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import numpy as np

from xarray.namedarray.pycompat import array_type


def is_weak_scalar_type(t):
return isinstance(t, bool | int | float | complex | str | bytes)
Expand Down Expand Up @@ -42,3 +44,39 @@ def result_type(*arrays_and_dtypes, xp) -> np.dtype:
return xp.result_type(*arrays_and_dtypes)
else:
return _future_array_api_result_type(*arrays_and_dtypes, xp=xp)


def get_array_namespace(*values):
def _get_single_namespace(x):
if hasattr(x, "__array_namespace__"):
return x.__array_namespace__()
elif isinstance(x, array_type("cupy")):
# cupy is fully compliant from xarray's perspective, but will not expose
# __array_namespace__ until at least v14. Special case it for now
import cupy as cp

return cp
else:
return np

namespaces = {_get_single_namespace(t) for t in values}
non_numpy = namespaces - {np}

if len(non_numpy) > 1:
names = [module.__name__ for module in non_numpy]
raise TypeError(f"Mixed array types {names} are not supported.")
elif non_numpy:
[xp] = non_numpy
else:
xp = np

return xp


def to_like_array(array, like):
# Mostly for cupy compatibility, because cupy binary ops require all cupy arrays
xp = get_array_namespace(like)
if xp is not np:
return xp.asarray(array)
# avoid casting things like pint quantities to numpy arrays
return array
4 changes: 2 additions & 2 deletions xarray/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ def clip(
keep_attrs = _get_keep_attrs(default=True)

return apply_ufunc(
np.clip, self, min, max, keep_attrs=keep_attrs, dask="allowed"
duck_array_ops.clip, self, min, max, keep_attrs=keep_attrs, dask="allowed"
)

def get_index(self, key: Hashable) -> pd.Index:
Expand Down Expand Up @@ -1760,7 +1760,7 @@ def _full_like_variable(
**from_array_kwargs,
)
else:
data = np.full_like(other.data, fill_value, dtype=dtype)
data = duck_array_ops.full_like(other.data, fill_value, dtype=dtype)

return Variable(dims=other.dims, data=data, attrs=other.attrs)

Expand Down
16 changes: 9 additions & 7 deletions xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from xarray.core import dtypes, duck_array_ops, utils
from xarray.core.alignment import align, deep_align
from xarray.core.array_api_compat import to_like_array
from xarray.core.common import zeros_like
from xarray.core.duck_array_ops import datetime_to_numeric
from xarray.core.formatting import limit_lines
Expand Down Expand Up @@ -1702,7 +1703,7 @@ def cross(
)

c = apply_ufunc(
np.cross,
duck_array_ops.cross,
a,
b,
input_core_dims=[[dim], [dim]],
Expand Down Expand Up @@ -2170,13 +2171,14 @@ def _calc_idxminmax(
chunks = dict(zip(array.dims, array.chunks, strict=True))
dask_coord = chunkmanager.from_array(array[dim].data, chunks=chunks[dim])
data = dask_coord[duck_array_ops.ravel(indx.data)]
res = indx.copy(data=duck_array_ops.reshape(data, indx.shape))
# we need to attach back the dim name
res.name = dim
else:
res = array[dim][(indx,)]
# The dim is gone but we need to remove the corresponding coordinate.
del res.coords[dim]
arr_coord = to_like_array(array[dim].data, array.data)
data = arr_coord[duck_array_ops.ravel(indx.data)]

# rebuild like the argmin/max output, and rename as the dim name
data = duck_array_ops.reshape(data, indx.shape)
res = indx.copy(data=data)
res.name = dim

if skipna or (skipna is None and array.dtype.kind in na_dtypes):
# Put the NaN values back in after removing them
Expand Down
10 changes: 6 additions & 4 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
align,
)
from xarray.core.arithmetic import DatasetArithmetic
from xarray.core.array_api_compat import to_like_array
from xarray.core.common import (
DataWithCoords,
_contains_datetime_like_objects,
Expand Down Expand Up @@ -127,7 +128,7 @@
calculate_dimensions,
)
from xarray.namedarray.parallelcompat import get_chunked_array_type, guess_chunkmanager
from xarray.namedarray.pycompat import array_type, is_chunked_array
from xarray.namedarray.pycompat import array_type, is_chunked_array, to_numpy
from xarray.plot.accessor import DatasetPlotAccessor
from xarray.util.deprecation_helpers import _deprecate_positional_args, deprecate_dims

Expand Down Expand Up @@ -6620,7 +6621,7 @@ def dropna(
array = self._variables[k]
if dim in array.dims:
dims = [d for d in array.dims if d != dim]
count += np.asarray(array.count(dims))
count += to_numpy(array.count(dims).data)
size += math.prod([self.sizes[d] for d in dims])

if thresh is not None:
Expand Down Expand Up @@ -8734,16 +8735,17 @@ def _integrate_one(self, coord, datetime_unit=None, cumulative=False):
coord_names.add(k)
else:
if k in self.data_vars and dim in v.dims:
coord_data = to_like_array(coord_var.data, like=v.data)
if _contains_datetime_like_objects(v):
v = datetime_to_numeric(v, datetime_unit=datetime_unit)
if cumulative:
integ = duck_array_ops.cumulative_trapezoid(
v.data, coord_var.data, axis=v.get_axis_num(dim)
v.data, coord_data, axis=v.get_axis_num(dim)
)
v_dims = v.dims
else:
integ = duck_array_ops.trapz(
v.data, coord_var.data, axis=v.get_axis_num(dim)
v.data, coord_data, axis=v.get_axis_num(dim)
)
v_dims = list(v.dims)
v_dims.remove(dim)
Expand Down
Loading

0 comments on commit caf62d3

Please sign in to comment.