Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use nested_duck_arrays to detect the innermost layer of duck arrays #9212

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions xarray/core/duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,13 @@
normalize_axis_index,
)

try:
from nested_duck_arrays import first_layer
except ImportError:

def first_layer(x):
return type(x)


dask_available = module_available("dask")

Expand Down Expand Up @@ -266,9 +273,9 @@ def as_shared_dtype(scalars_or_arrays, xp=None):
f" array types {[x.dtype for x in scalars_or_arrays]}"
)

# Avoid calling array_type("cupy") repeatidely in the any check
# Avoid calling array_type("cupy") repeatedly in the any check
array_type_cupy = array_type("cupy")
if any(isinstance(x, array_type_cupy) for x in scalars_or_arrays):
if any(issubclass(first_layer(x), array_type_cupy) for x in scalars_or_arrays):
import cupy as cp

xp = cp
Expand Down
3 changes: 3 additions & 0 deletions xarray/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@ def _importorskip(
category=DeprecationWarning,
)
has_dask_expr, requires_dask_expr = _importorskip("dask_expr")
has_nested_duck_arrays, requires_nested_duck_arrays = _importorskip(
"nested_duck_arrays"
)
has_bottleneck, requires_bottleneck = _importorskip("bottleneck")
has_rasterio, requires_rasterio = _importorskip("rasterio")
has_zarr, requires_zarr = _importorskip("zarr")
Expand Down
20 changes: 20 additions & 0 deletions xarray/tests/test_cupy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest

import xarray as xr
from xarray.tests import requires_dask, requires_nested_duck_arrays

cp = pytest.importorskip("cupy")

Expand Down Expand Up @@ -60,3 +61,22 @@ def test_where() -> None:
output = where(data < 1, 1, data).all()
assert output
assert isinstance(output, cp.ndarray)


@requires_nested_duck_arrays
@requires_dask
def test_where_dask() -> None:
import dask.array as da
import nested_duck_arrays.dask # noqa: F401

from xarray.core.duck_array_ops import where

data = cp.zeros(10)
chunked = da.from_array(data, chunks=(2,))

chunked_output = where(chunked < 1, 1, chunked).all()
assert isinstance(chunked_output, da.Array)

output = chunked_output.compute()
assert output
assert isinstance(output, cp.ndarray)
Loading