diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 8993c136ba6..f23125ee113 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -51,6 +51,13 @@ normalize_axis_index, ) +try: + from nested_duck_arrays import first_layer +except ImportError: + + def first_layer(x): + return type(x) + dask_available = module_available("dask") @@ -266,9 +273,9 @@ def as_shared_dtype(scalars_or_arrays, xp=None): f" array types {[x.dtype for x in scalars_or_arrays]}" ) - # Avoid calling array_type("cupy") repeatidely in the any check + # Avoid calling array_type("cupy") repeatedly in the any check array_type_cupy = array_type("cupy") - if any(isinstance(x, array_type_cupy) for x in scalars_or_arrays): + if any(issubclass(first_layer(x), array_type_cupy) for x in scalars_or_arrays): import cupy as cp xp = cp diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 0caab6e8247..1f3ab79feed 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -113,6 +113,9 @@ def _importorskip( category=DeprecationWarning, ) has_dask_expr, requires_dask_expr = _importorskip("dask_expr") +has_nested_duck_arrays, requires_nested_duck_arrays = _importorskip( + "nested_duck_arrays" +) has_bottleneck, requires_bottleneck = _importorskip("bottleneck") has_rasterio, requires_rasterio = _importorskip("rasterio") has_zarr, requires_zarr = _importorskip("zarr") diff --git a/xarray/tests/test_cupy.py b/xarray/tests/test_cupy.py index 94776902c11..3f9d1cfd98e 100644 --- a/xarray/tests/test_cupy.py +++ b/xarray/tests/test_cupy.py @@ -5,6 +5,7 @@ import pytest import xarray as xr +from xarray.tests import requires_dask, requires_nested_duck_arrays cp = pytest.importorskip("cupy") @@ -60,3 +61,22 @@ def test_where() -> None: output = where(data < 1, 1, data).all() assert output assert isinstance(output, cp.ndarray) + + +@requires_nested_duck_arrays +@requires_dask +def test_where_dask() -> None: + import dask.array as da + import nested_duck_arrays.dask # noqa: F401 + + from xarray.core.duck_array_ops import where + + data = cp.zeros(10) + chunked = da.from_array(data, chunks=(2,)) + + chunked_output = where(chunked < 1, 1, chunked).all() + assert isinstance(chunked_output, da.Array) + + output = chunked_output.compute() + assert output + assert isinstance(output, cp.ndarray)