Skip to content

Commit

Permalink
revert to using is_duck_array: pydata#8696
Browse files Browse the repository at this point in the history
  • Loading branch information
andersy005 committed Feb 2, 2024
1 parent 541049f commit 01c3d24
Show file tree
Hide file tree
Showing 11 changed files with 61 additions and 55 deletions.
4 changes: 2 additions & 2 deletions xarray/core/arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from xarray.core.common import ImplementsArrayReduce, ImplementsDatasetReduce
from xarray.core.ops import IncludeNumpySameMethods, IncludeReduceMethods
from xarray.core.options import OPTIONS, _get_keep_attrs
from xarray.namedarray._typing import _arrayfunction_or_api
from xarray.namedarray.utils import is_duck_array


class SupportsArithmetic:
Expand Down Expand Up @@ -45,7 +45,7 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
# See the docstring example for numpy.lib.mixins.NDArrayOperatorsMixin.
out = kwargs.get("out", ())
for x in inputs + out:
if not isinstance(x, _arrayfunction_or_api) and not isinstance(
if not is_duck_array(x) and not isinstance(
x, self._HANDLED_TYPES + (SupportsArithmetic,)
):
return NotImplemented
Expand Down
4 changes: 2 additions & 2 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@
broadcast_variables,
calculate_dimensions,
)
from xarray.namedarray._typing import _arrayfunction_or_api
from xarray.namedarray.daskmanager import DaskManager
from xarray.namedarray.parallelcompat import get_chunked_array_type, guess_chunkmanager
from xarray.namedarray.pycompat import array_type, is_chunked_array
Expand All @@ -124,6 +123,7 @@
either_dict_or_kwargs,
infix_dims,
is_dict_like,
is_duck_array,
is_duck_dask_array,
)
from xarray.plot.accessor import DatasetPlotAccessor
Expand Down Expand Up @@ -2746,7 +2746,7 @@ def _validate_indexers(
elif isinstance(v, Sequence) and len(v) == 0:
yield k, np.empty((0,), dtype="int64")
else:
if not isinstance(v, _arrayfunction_or_api):
if not is_duck_array(v):
v = np.asarray(v)

if v.dtype.kind in "US":
Expand Down
5 changes: 2 additions & 3 deletions xarray/core/duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,9 @@
from xarray.core import dask_array_ops, dtypes, nputils
from xarray.core.options import OPTIONS
from xarray.namedarray import pycompat
from xarray.namedarray._typing import _arrayfunction_or_api
from xarray.namedarray.parallelcompat import get_chunked_array_type, is_chunked_array
from xarray.namedarray.pycompat import array_type
from xarray.namedarray.utils import is_duck_dask_array, module_available
from xarray.namedarray.utils import is_duck_array, is_duck_dask_array, module_available

# remove once numpy 2.0 is the oldest supported version
if module_available("numpy", minversion="2.0.0.dev0"):
Expand Down Expand Up @@ -216,7 +215,7 @@ def astype(data, dtype, **kwargs):


def asarray(data, xp=np):
return data if isinstance(data, _arrayfunction_or_api) else xp.asarray(data)
return data if is_duck_array(data) else xp.asarray(data)


def as_shared_dtype(scalars_or_arrays, xp=np):
Expand Down
8 changes: 3 additions & 5 deletions xarray/core/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from xarray.core.duck_array_ops import array_equiv, astype
from xarray.core.indexing import MemoryCachedArray
from xarray.core.options import OPTIONS, _get_boolean_with_default
from xarray.namedarray._typing import _arrayfunction_or_api
from xarray.namedarray.pycompat import array_type, to_duck_array, to_numpy
from xarray.namedarray.utils import is_duck_array

if TYPE_CHECKING:
from xarray.core.coordinates import AbstractCoordinates
Expand Down Expand Up @@ -630,7 +630,7 @@ def short_data_repr(array):
internal_data = getattr(array, "variable", array)._data
if isinstance(array, np.ndarray):
return short_array_repr(array)
elif isinstance(internal_data, _arrayfunction_or_api):
elif is_duck_array(internal_data):
return limit_lines(repr(array.data), limit=40)
elif getattr(array, "_in_memory", None):
return short_array_repr(array)
Expand Down Expand Up @@ -791,9 +791,7 @@ def extra_items_repr(extra_keys, mapping, ab_side, kwargs):
is_variable = True
except AttributeError:
# compare attribute value
if isinstance(a_mapping[k], _arrayfunction_or_api) or isinstance(
b_mapping[k], _arrayfunction_or_api
):
if is_duck_array(a_mapping[k]) or is_duck_array(b_mapping[k]):
compatible = array_equiv(a_mapping[k], b_mapping[k])
else:
compatible = a_mapping[k] == b_mapping[k]
Expand Down
11 changes: 7 additions & 4 deletions xarray/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,13 @@
is_scalar,
to_0d_array,
)
from xarray.namedarray._typing import _arrayfunction_or_api
from xarray.namedarray.parallelcompat import get_chunked_array_type, is_chunked_array
from xarray.namedarray.pycompat import array_type, integer_types
from xarray.namedarray.utils import either_dict_or_kwargs, is_duck_dask_array
from xarray.namedarray.utils import (
either_dict_or_kwargs,
is_duck_array,
is_duck_dask_array,
)

if TYPE_CHECKING:
from numpy.typing import DTypeLike
Expand Down Expand Up @@ -374,7 +377,7 @@ def __init__(self, key):
k = int(k)
elif isinstance(k, slice):
k = as_integer_slice(k)
elif isinstance(k, _arrayfunction_or_api):
elif is_duck_array(k):
if not np.issubdtype(k.dtype, np.integer):
raise TypeError(
f"invalid indexer array, does not have integer dtype: {k!r}"
Expand Down Expand Up @@ -421,7 +424,7 @@ def __init__(self, key):
"Please pass a numpy array by calling ``.compute``. "
"See https://github.com/dask/dask/issues/8958."
)
elif isinstance(k, _arrayfunction_or_api):
elif is_duck_array(k):
if not np.issubdtype(k.dtype, np.integer):
raise TypeError(
f"invalid indexer array, does not have integer dtype: {k!r}"
Expand Down
8 changes: 2 additions & 6 deletions xarray/core/nputils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from packaging.version import Version

from xarray.namedarray import pycompat
from xarray.namedarray.utils import module_available
from xarray.namedarray.utils import is_duck_array, module_available

# remove once numpy 2.0 is the oldest supported version
if module_available("numpy", minversion="2.0.0.dev0"):
Expand All @@ -27,7 +27,6 @@
from numpy import RankWarning # type: ignore[attr-defined,no-redef,unused-ignore]

from xarray.core.options import OPTIONS
from xarray.namedarray._typing import _arrayfunction_or_api

try:
import bottleneck as bn
Expand Down Expand Up @@ -147,10 +146,7 @@ def _advanced_indexer_subspaces(key):

non_slices = [k for k in key if not isinstance(k, slice)]
broadcasted_shape = np.broadcast_shapes(
*[
item.shape if isinstance(item, _arrayfunction_or_api) else (0,)
for item in non_slices
]
*[item.shape if is_duck_array(item) else (0,) for item in non_slices]
)
ndim = len(broadcasted_shape)
mixed_positions = advanced_index_positions[0] + np.arange(ndim)
Expand Down
10 changes: 5 additions & 5 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@
ensure_us_time_resolution,
maybe_coerce_to_str,
)
from xarray.namedarray._typing import _arrayfunction_or_api
from xarray.namedarray.core import NamedArray, _raise_if_any_duplicate_dimensions
from xarray.namedarray.parallelcompat import get_chunked_array_type
from xarray.namedarray.pycompat import integer_types, is_0d_dask_array, is_chunked_array
from xarray.namedarray.utils import (
either_dict_or_kwargs,
infix_dims,
is_dict_like,
is_duck_array,
is_duck_dask_array,
)

Expand Down Expand Up @@ -407,7 +407,7 @@ def data(self):
Variable.as_numpy
Variable.values
"""
if isinstance(self._data, _arrayfunction_or_api):
if is_duck_array(self._data):
return self._data
elif isinstance(self._data, indexing.ExplicitlyIndexed):
return self._data.get_duck_array()
Expand Down Expand Up @@ -632,7 +632,7 @@ def _validate_indexers(self, key):
for dim, k in zip(self.dims, key):
if not isinstance(k, BASIC_INDEXING_TYPES):
if not isinstance(k, Variable):
if not isinstance(k, _arrayfunction_or_api):
if not is_duck_array(k):
k = np.asarray(k)
if k.ndim > 1:
raise IndexError(
Expand Down Expand Up @@ -677,7 +677,7 @@ def _broadcast_indexes_outer(self, key):
if isinstance(k, Variable):
k = k.data
if not isinstance(k, BASIC_INDEXING_TYPES):
if not isinstance(k, _arrayfunction_or_api):
if not is_duck_array(k):
k = np.asarray(k)
if k.size == 0:
# Slice by empty list; numpy could not infer the dtype
Expand Down Expand Up @@ -936,7 +936,7 @@ def load(self, **kwargs):
self._data = as_compatible_data(loaded_data)
elif isinstance(self._data, ExplicitlyIndexed):
self._data = self._data.get_duck_array()
elif not isinstance(self._data, _arrayfunction_or_api):
elif not is_duck_array(self._data):
self._data = np.asarray(self._data)
return self

Expand Down
8 changes: 8 additions & 0 deletions xarray/namedarray/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,14 @@
)

import numpy as np
from numpy.typing import NDArray

try:
from dask.array.core import Array as DaskArray
from dask.typing import DaskCollection
except ImportError:
DaskArray = NDArray # type: ignore
DaskCollection: Any = NDArray # type: ignore


# Singleton type, as per https://github.com/python/typing/pull/240
Expand Down
14 changes: 7 additions & 7 deletions xarray/namedarray/daskmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from xarray.namedarray.utils import is_duck_dask_array, module_available

if TYPE_CHECKING:
from xarray.core.types import DaskArray, T_Chunks, T_NormalizedChunks
from xarray.namedarray._typing import duckarray
from xarray.core.types import T_Chunks
from xarray.namedarray._typing import DaskArray, _NormalizedChunks, duckarray


dask_available = module_available("dask")
Expand All @@ -32,17 +32,17 @@ def __init__(self) -> None:
def is_chunked_array(self, data: duckarray[Any, Any]) -> bool:
return is_duck_dask_array(data)

def chunks(self, data: DaskArray) -> T_NormalizedChunks:
def chunks(self, data: DaskArray) -> _NormalizedChunks:
return data.chunks

def normalize_chunks(
self,
chunks: T_Chunks | T_NormalizedChunks,
chunks: T_Chunks | _NormalizedChunks,
shape: tuple[int, ...] | None = None,
limit: int | None = None,
dtype: np.dtype | None = None,
previous_chunks: T_NormalizedChunks | None = None,
) -> T_NormalizedChunks:
previous_chunks: _NormalizedChunks | None = None,
) -> _NormalizedChunks:
"""Called by open_dataset"""
from dask.array.core import normalize_chunks

Expand Down Expand Up @@ -220,7 +220,7 @@ def unify_chunks(
self,
*args: Any, # can't type this as mypy assumes args are all same type, but dask unify_chunks args alternate types
**kwargs,
) -> tuple[dict[str, T_NormalizedChunks], list[DaskArray]]:
) -> tuple[dict[str, _NormalizedChunks], list[DaskArray]]:
from dask.array.core import unify_chunks

return unify_chunks(*args, **kwargs)
Expand Down
11 changes: 4 additions & 7 deletions xarray/namedarray/pycompat.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
from packaging.version import Version

from xarray.core.utils import is_scalar
from xarray.namedarray._typing import _arrayfunction_or_api
from xarray.namedarray.utils import is_duck_dask_array
from xarray.namedarray.utils import is_duck_array, is_duck_dask_array

integer_types = (int, np.integer)

Expand Down Expand Up @@ -90,9 +89,7 @@ def mod_version(mod: ModType) -> Version:


def is_chunked_array(x: duckarray[Any, Any]) -> bool:
return is_duck_dask_array(x) or (
isinstance(x, _arrayfunction_or_api) and hasattr(x, "chunks")
)
return is_duck_dask_array(x) or (is_duck_array(x) and hasattr(x, "chunks"))


def is_0d_dask_array(x: duckarray[Any, Any]) -> bool:
Expand Down Expand Up @@ -122,12 +119,12 @@ def to_numpy(data: duckarray[Any, Any]) -> np.ndarray[_ShapeType, _DType]:
return data


def to_duck_array(data) -> duckarray[_ShapeType, _DType]:
def to_duck_array(data: Any) -> duckarray[_ShapeType, _DType]:
from xarray.core.indexing import ExplicitlyIndexed

if isinstance(data, ExplicitlyIndexed):
return data.get_duck_array()
elif isinstance(data, _arrayfunction_or_api):
elif is_duck_array(data):
return data
else:
return np.asarray(data)
33 changes: 19 additions & 14 deletions xarray/namedarray/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,7 @@
import numpy as np
from packaging.version import Version

from xarray.namedarray._typing import (
ErrorOptionsWithWarn,
_arrayfunction_or_api,
_DimsLike,
)
from xarray.namedarray._typing import ErrorOptionsWithWarn, _DimsLike

if TYPE_CHECKING:
if sys.version_info >= (3, 10):
Expand All @@ -23,14 +19,7 @@

from numpy.typing import NDArray

from xarray.namedarray._typing import _Dim, duckarray

try:
from dask.array.core import Array as DaskArray
from dask.typing import DaskCollection
except ImportError:
DaskArray = NDArray # type: ignore
DaskCollection: Any = NDArray # type: ignore
from xarray.namedarray._typing import DaskArray, DaskCollection, _Dim, duckarray


K = TypeVar("K")
Expand Down Expand Up @@ -76,8 +65,24 @@ def is_dask_collection(x: object) -> TypeGuard[DaskCollection]:
return False


def is_duck_array(value: Any) -> TypeGuard[duckarray[Any, Any]]:
# TODO: replace is_duck_array with runtime checks via _arrayfunction_or_api protocol on
# python 3.12 and higher (see https://github.com/pydata/xarray/issues/8696#issuecomment-1924588981)
if isinstance(value, np.ndarray):
return True
return (
hasattr(value, "ndim")
and hasattr(value, "shape")
and hasattr(value, "dtype")
and (
(hasattr(value, "__array_function__") and hasattr(value, "__array_ufunc__"))
or hasattr(value, "__array_namespace__")
)
)


def is_duck_dask_array(x: duckarray[Any, Any]) -> TypeGuard[DaskArray]:
return isinstance(x, _arrayfunction_or_api) and is_dask_collection(x)
return is_duck_array(x) and is_dask_collection(x)


def to_0d_object_array(
Expand Down

0 comments on commit 01c3d24

Please sign in to comment.