Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update __array__ signatures with copy #9529

Merged
merged 11 commits into from
Sep 25, 2024
2 changes: 1 addition & 1 deletion xarray/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def __complex__(self: Any) -> complex:
return complex(self.values)

def __array__(
self: Any, dtype: DTypeLike | None = None, copy: bool | None = None
self: Any, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None
) -> np.ndarray:
if not copy:
if np.lib.NumpyVersion(np.__version__) >= "2.0.0":
Expand Down
5 changes: 4 additions & 1 deletion xarray/core/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from xarray.core.dataset import calculate_dimensions

if TYPE_CHECKING:
import numpy as np
import pandas as pd

from xarray.core.datatree_io import T_DataTreeNetcdfEngine, T_DataTreeNetcdfTypes
Expand Down Expand Up @@ -737,7 +738,9 @@ def __bool__(self) -> bool:
def __iter__(self) -> Iterator[str]:
return itertools.chain(self._data_variables, self._children) # type: ignore[arg-type]

def __array__(self, dtype=None, copy=None):
def __array__(
self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None
) -> np.ndarray:
raise TypeError(
"cannot directly convert a DataTree into a "
"numpy array. Instead, create an xarray.DataArray "
Expand Down
6 changes: 5 additions & 1 deletion xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,11 @@ def values(self) -> range:
def data(self) -> range:
return range(self.size)

def __array__(self) -> np.ndarray:
def __array__(
self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None
) -> np.ndarray:
if copy is False:
raise NotImplementedError(f"An array copy is necessary, got {copy = }.")
return np.arange(self.size)

@property
Expand Down
43 changes: 27 additions & 16 deletions xarray/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import numpy as np
import pandas as pd
from packaging.version import Version

from xarray.core import duck_array_ops
from xarray.core.nputils import NumpyVIndexAdapter
Expand Down Expand Up @@ -505,9 +506,14 @@ class ExplicitlyIndexed:

__slots__ = ()

def __array__(self, dtype: np.typing.DTypeLike = None) -> np.ndarray:
def __array__(
self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None
) -> np.ndarray:
# Leave casting to an array up to the underlying array type.
return np.asarray(self.get_duck_array(), dtype=dtype)
if Version(np.__version__) >= Version("2.0.0"):
return np.asarray(self.get_duck_array(), dtype=dtype, copy=copy)
else:
return np.asarray(self.get_duck_array(), dtype=dtype)

def get_duck_array(self):
return self.array
Expand All @@ -520,11 +526,6 @@ def get_duck_array(self):
key = BasicIndexer((slice(None),) * self.ndim)
return self[key]

def __array__(self, dtype: np.typing.DTypeLike = None) -> np.ndarray:
dcherian marked this conversation as resolved.
Show resolved Hide resolved
# This is necessary because we apply the indexing key in self.get_duck_array()
# Note this is the base class for all lazy indexing classes
return np.asarray(self.get_duck_array(), dtype=dtype)

def _oindex_get(self, indexer: OuterIndexer):
raise NotImplementedError(
f"{self.__class__.__name__}._oindex_get method should be overridden"
Expand Down Expand Up @@ -570,8 +571,13 @@ def __init__(self, array, indexer_cls: type[ExplicitIndexer] = BasicIndexer):
self.array = as_indexable(array)
self.indexer_cls = indexer_cls

def __array__(self, dtype: np.typing.DTypeLike = None) -> np.ndarray:
return np.asarray(self.get_duck_array(), dtype=dtype)
def __array__(
self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None
) -> np.ndarray:
if Version(np.__version__) >= Version("2.0.0"):
return np.asarray(self.get_duck_array(), dtype=dtype, copy=copy)
else:
return np.asarray(self.get_duck_array(), dtype=dtype)

def get_duck_array(self):
return self.array.get_duck_array()
Expand Down Expand Up @@ -830,9 +836,6 @@ def __init__(self, array):
def _ensure_cached(self):
self.array = as_indexable(self.array.get_duck_array())

def __array__(self, dtype: np.typing.DTypeLike = None) -> np.ndarray:
return np.asarray(self.get_duck_array(), dtype=dtype)

def get_duck_array(self):
self._ensure_cached()
return self.array.get_duck_array()
Expand Down Expand Up @@ -1674,15 +1677,21 @@ def __init__(self, array: pd.Index, dtype: DTypeLike = None):
def dtype(self) -> np.dtype:
return self._dtype

def __array__(self, dtype: DTypeLike = None) -> np.ndarray:
def __array__(
self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None
) -> np.ndarray:
if dtype is None:
dtype = self.dtype
array = self.array
if isinstance(array, pd.PeriodIndex):
with suppress(AttributeError):
# this might not be public API
array = array.astype("object")
return np.asarray(array.values, dtype=dtype)

if Version(np.__version__) >= Version("2.0.0"):
return np.asarray(array.values, dtype=dtype, copy=copy)
else:
return np.asarray(array.values, dtype=dtype)

def get_duck_array(self) -> np.ndarray:
return np.asarray(self)
Expand Down Expand Up @@ -1831,15 +1840,17 @@ def __init__(
super().__init__(array, dtype)
self.level = level

def __array__(self, dtype: DTypeLike = None) -> np.ndarray:
def __array__(
self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None
) -> np.ndarray:
if dtype is None:
dtype = self.dtype
if self.level is not None:
return np.asarray(
self.array.get_level_values(self.level).values, dtype=dtype
)
else:
return super().__array__(dtype)
return super().__array__(dtype, copy=copy)

def _convert_scalar(self, item):
if isinstance(item, tuple) and self.level is not None:
Expand Down
6 changes: 3 additions & 3 deletions xarray/namedarray/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,15 +153,15 @@ def __getitem__(

@overload
def __array__(
self, dtype: None = ..., /, *, copy: None | bool = ...
self, dtype: None = ..., /, *, copy: bool | None = ...
) -> np.ndarray[Any, _DType_co]: ...
@overload
def __array__(
self, dtype: _DType, /, *, copy: None | bool = ...
self, dtype: _DType, /, *, copy: bool | None = ...
) -> np.ndarray[Any, _DType]: ...

def __array__(
self, dtype: _DType | None = ..., /, *, copy: None | bool = ...
self, dtype: _DType | None = ..., /, *, copy: bool | None = ...
) -> np.ndarray[Any, _DType] | np.ndarray[Any, _DType_co]: ...

# TODO: Should return the same subclass but with a new dtype generic.
Expand Down
12 changes: 9 additions & 3 deletions xarray/tests/arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ def __init__(self, array):
def get_duck_array(self):
raise UnexpectedDataAccess("Tried accessing data")

def __array__(self, dtype: np.typing.DTypeLike = None):
def __array__(
self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None
) -> np.ndarray:
raise UnexpectedDataAccess("Tried accessing data")

def __getitem__(self, key):
Expand All @@ -49,7 +51,9 @@ def __init__(self, array: np.ndarray):
def __getitem__(self, key):
return type(self)(self.array[key])

def __array__(self, dtype: np.typing.DTypeLike = None):
def __array__(
self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None
) -> np.ndarray:
raise UnexpectedDataAccess("Tried accessing data")

def __array_namespace__(self):
Expand Down Expand Up @@ -140,7 +144,9 @@ def __repr__(self: Any) -> str:
def get_duck_array(self):
raise UnexpectedDataAccess("Tried accessing data")

def __array__(self, dtype: np.typing.DTypeLike = None):
def __array__(
self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None
) -> np.ndarray:
raise UnexpectedDataAccess("Tried accessing data")

def __getitem__(self, key) -> "ConcatenatableArray":
Expand Down
6 changes: 4 additions & 2 deletions xarray/tests/test_assertions.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,11 @@ def dims(self):
warnings.warn("warning in test", stacklevel=2)
return super().dims

def __array__(self, dtype=None, copy=None):
def __array__(
self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None
) -> np.ndarray:
warnings.warn("warning in test", stacklevel=2)
return super().__array__()
return super().__array__(dtype, copy=copy)

a = WarningVariable("x", [1])
b = WarningVariable("x", [2])
Expand Down
4 changes: 3 additions & 1 deletion xarray/tests/test_formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,7 +942,9 @@ def test_lazy_array_wont_compute() -> None:
from xarray.core.indexing import LazilyIndexedArray

class LazilyIndexedArrayNotComputable(LazilyIndexedArray):
def __array__(self, dtype=None, copy=None):
def __array__(
self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None
) -> np.ndarray:
raise NotImplementedError("Computing this array is not possible.")

arr = LazilyIndexedArrayNotComputable(np.array([1, 2]))
Expand Down
11 changes: 9 additions & 2 deletions xarray/tests/test_namedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import numpy as np
import pytest
from packaging.version import Version

from xarray.core.indexing import ExplicitlyIndexed
from xarray.namedarray._typing import (
Expand Down Expand Up @@ -53,8 +54,14 @@ def shape(self) -> _Shape:
class CustomArray(
CustomArrayBase[_ShapeType_co, _DType_co], Generic[_ShapeType_co, _DType_co]
):
def __array__(self) -> np.ndarray[Any, np.dtype[np.generic]]:
return np.array(self.array)
def __array__(
self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None
) -> np.ndarray[Any, np.dtype[np.generic]]:

if Version(np.__version__) >= Version("2.0.0"):
return np.asarray(self.array, dtype=dtype, copy=copy)
else:
return np.asarray(self.array, dtype=dtype)


class CustomArrayIndexable(
Expand Down
Loading