Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Namedarray with shapetype #9260

Draft
wants to merge 23 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
ac6fc13
Update test_namedarray.py
Illviljan Jul 11, 2024
b51f168
Merge branch 'main' into namedarray_typing_np2
Illviljan Jul 11, 2024
184c211
Update _typing.py
Illviljan Jul 11, 2024
eee5cd3
Merge branch 'main' into namedarray_typing_np2
max-sixty Jul 13, 2024
8b649cd
Update test_namedarray.py
Illviljan Jul 16, 2024
bf09e92
Merge branch 'main' into namedarray_typing_np2
Illviljan Jul 16, 2024
7bbab19
Merge branch 'main' into namedarray_typing_np2
Illviljan Jul 18, 2024
bd2b719
Update test_namedarray.py
Illviljan Jul 18, 2024
06b8724
Update _typing.py
Illviljan Jul 19, 2024
25636cb
Merge branch 'main' into namedarray_typing_np2
Illviljan Jul 19, 2024
16b3403
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 19, 2024
0f3afdd
Update _typing.py
Illviljan Jul 19, 2024
ee4cd84
Add shapetypes
Illviljan Jul 19, 2024
c709245
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 19, 2024
23188b4
Fix __array__ missing copy parameter
Illviljan Jul 19, 2024
ca8c5f6
Update core.py
Illviljan Jul 19, 2024
f6d3db3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 19, 2024
65af4d5
Update core.py
Illviljan Jul 19, 2024
a54ff58
Merge branch 'namedarray_shapetype' of https://github.com/Illviljan/x…
Illviljan Jul 19, 2024
2d323e4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 19, 2024
28f9871
Merge branch 'main' into namedarray_shapetype
Illviljan Jul 21, 2024
0494132
Update xarray/namedarray/_typing.py
Illviljan Jul 21, 2024
9add09d
Merge branch 'main' into namedarray_shapetype
Illviljan Aug 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions xarray/namedarray/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ def dtype(self) -> _DType_co: ...
_IntOrUnknown = int
_Shape = tuple[_IntOrUnknown, ...]
_ShapeLike = Union[SupportsIndex, Sequence[SupportsIndex]]
_ShapeType = TypeVar("_ShapeType", bound=Any)
_ShapeType_co = TypeVar("_ShapeType_co", bound=Any, covariant=True)
_ShapeType = TypeVar("_ShapeType", bound=_Shape)
_ShapeType_co = TypeVar("_ShapeType_co", bound=_Shape, covariant=True)

_Axis = int
_Axes = tuple[_Axis, ...]
Expand Down Expand Up @@ -118,7 +118,7 @@ class _array(Protocol[_ShapeType_co, _DType_co]):
"""

@property
def shape(self) -> _Shape: ...
def shape(self) -> _ShapeType_co: ...

@property
def dtype(self) -> _DType_co: ...
Expand Down Expand Up @@ -218,6 +218,7 @@ def __array_namespace__(self) -> ModuleType: ...
_arrayfunction[_ShapeType_co, _DType_co], _arrayapi[_ShapeType_co, _DType_co]
]


Illviljan marked this conversation as resolved.
Show resolved Hide resolved
# Corresponds to np.typing.NDArray:
DuckArray = _arrayfunction[Any, np.dtype[_ScalarType_co]]

Expand Down
111 changes: 105 additions & 6 deletions xarray/namedarray/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,14 +250,14 @@ class NamedArray(NamedArrayAggregations, Generic[_ShapeType_co, _DType_co]):

__slots__ = ("_data", "_dims", "_attrs")

_data: duckarray[Any, _DType_co]
_data: duckarray[_ShapeType_co, _DType_co]
_dims: _Dims
_attrs: dict[Any, Any] | None

def __init__(
self,
dims: _DimsLike,
data: duckarray[Any, _DType_co],
data: duckarray[_ShapeType_co, _DType_co],
attrs: _AttrsLike = None,
):
self._data = data
Expand Down Expand Up @@ -292,7 +292,7 @@ def _new(
def _new(
self,
dims: _DimsLike | Default = _default,
data: duckarray[Any, _DType] | Default = _default,
data: duckarray[_ShapeType, _DType] | Default = _default,
attrs: _AttrsLike | Default = _default,
) -> NamedArray[_ShapeType, _DType] | NamedArray[_ShapeType_co, _DType_co]:
"""
Expand Down Expand Up @@ -447,7 +447,7 @@ def dtype(self) -> _DType_co:
return self._data.dtype

@property
def shape(self) -> _Shape:
def shape(self) -> _ShapeType_co:
"""
Get the shape of the array.

Expand Down Expand Up @@ -850,9 +850,9 @@ def to_numpy(self) -> np.ndarray[Any, Any]:
# TODO an entrypoint so array libraries can choose coercion method?
return to_numpy(self._data)

def as_numpy(self) -> Self:
def as_numpy(self) -> NamedArray[Any, Any]:
"""Coerces wrapped data into a numpy array, returning a Variable."""
return self._replace(data=self.to_numpy())
return self._new(data=self.to_numpy())

def reduce(
self,
Expand Down Expand Up @@ -1163,3 +1163,102 @@ def _raise_if_any_duplicate_dimensions(
raise ValueError(
f"{err_context} cannot handle duplicate dimensions, but dimensions {repeated_dims} appear more than once on this object's dims: {dims}"
)


# # %% function should pass

# data = np.array([1, 2, 3], dtype=np.dtype(np.int64))
# # data: duckarray[Any, np.dtype[np.int64]] = np.array([1, 2, 3], dtype=np.dtype(np.int64))
# reveal_type(data)


# def test(
# data: duckarray[_ShapeType_co, _DType_co]
# ) -> duckarray[_ShapeType_co, _DType_co]:
# return data


# def test2(
# data: _arrayfunction[_ShapeType, _DType]
# ) -> _arrayfunction[_ShapeType, _DType]:
# return data


# b = test(data)
# reveal_type(b)
# c = test2(data)
# reveal_type(c)
# a = NamedArray(("time",), data=data)
# reveal_type(a)


# # %% Class should pass
# from typing import Generic, TypeVar, Protocol, Union

# _ST = TypeVar("_ST", bound=Any, covariant=True)
# _DT = TypeVar("_DT", bound=Any, covariant=True)


# # Valid numpy protocol:
# class ArrayA(Protocol[_ST, _DT]):
# @property
# def dtype(self) -> _DT: ...
# @property
# def shape(self) -> _ST: ...


# class TestArray(Generic[_ST, _DT]):
# __slots__ = ("_data",)

# _data: ArrayA[_ST, _DT]

# def __init__(self, data: ArrayA[_ST, _DT]):
# self._data = data


# ta = TestArray(data)
# reveal_type(ta)


# # %% Class should pass
# # Not valid numpy protocol:
# class ArrayB(Protocol[_ST, _DT]):
# @property
# def dtype(self) -> _DT: ...
# @property
# def shape(self) -> _ST: ...
# def b(self) -> int: ...


# duckiearray = Union[ArrayA[_ST, _DT], ArrayB[_ST, _DT]]


# class TestArray2(Generic[_ST, _DT]):
# __slots__ = ("_data",)

# _data: duckiearray[_ST, _DT]

# def __init__(self, data: duckiearray[_ST, _DT]):
# self._data = data


# ta2 = TestArray2(data)
# reveal_type(ta2)


# # %% Class should pass
# class TestArray3(Generic[_ST, _DT]):
# __slots__ = ("_data",)

# _data: duckarray[_ST, _DT]

# def __init__(self, data: duckarray[_ST, _DT]):
# self._data = data


# ta3 = TestArray3(data)
# reveal_type(ta3)
# # %% Namedarray should pass

# narr = NamedArray(("time",), data)
# reveal_type(narr)
14 changes: 8 additions & 6 deletions xarray/tests/test_namedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,22 @@
_DType,
_IndexKeyLike,
_IntOrUnknown,
_Shape,
_ShapeLike,
_ShapeType,
duckarray,
)


class CustomArrayBase(Generic[_ShapeType_co, _DType_co]):
def __init__(self, array: duckarray[Any, _DType_co]) -> None:
self.array: duckarray[Any, _DType_co] = array
def __init__(self, array: duckarray[_ShapeType_co, _DType_co]) -> None:
self.array: duckarray[_ShapeType_co, _DType_co] = array

@property
def dtype(self) -> _DType_co:
return self.array.dtype

@property
def shape(self) -> _Shape:
def shape(self) -> _ShapeType_co:
return self.array.shape


Expand Down Expand Up @@ -78,9 +78,11 @@ def __array_namespace__(self) -> ModuleType:
return np


def check_duck_array_typevar(a: duckarray[Any, _DType]) -> duckarray[Any, _DType]:
def check_duck_array_typevar(
a: duckarray[_ShapeType, _DType]
) -> duckarray[_ShapeType, _DType]:
# Mypy checks a is valid:
b: duckarray[Any, _DType] = a
b: duckarray[_ShapeType, _DType] = a

# Runtime check if valid:
if isinstance(b, _arrayfunction_or_api):
Expand Down
Loading