Skip to content

Commit

Permalink
Update test_parallelcompat.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Illviljan committed Jul 10, 2024
1 parent 3d48d44 commit e7041e0
Showing 1 changed file with 23 additions and 13 deletions.
36 changes: 23 additions & 13 deletions xarray/tests/test_parallelcompat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,14 @@
import numpy as np
import pytest

from xarray.core.types import T_Chunks, T_DuckArray, T_NormalizedChunks
from xarray.namedarray._typing import _Chunks
from xarray.namedarray._typing import (
_Chunks,
_ChunksLike,
_DType,
_Shape,
chunkedduckarray,
duckarray,
)
from xarray.namedarray.daskmanager import DaskManager
from xarray.namedarray.parallelcompat import (
ChunkManagerEntrypoint,
Expand All @@ -27,7 +33,7 @@ class DummyChunkedArray(np.ndarray):
https://numpy.org/doc/stable/user/basics.subclassing.html#simple-example-adding-an-extra-attribute-to-ndarray
"""

chunks: T_NormalizedChunks
chunks: _Chunks

def __new__(
cls,
Expand Down Expand Up @@ -63,32 +69,36 @@ def __init__(self):
def is_chunked_array(self, data: Any) -> bool:
return isinstance(data, DummyChunkedArray)

def chunks(self, data: DummyChunkedArray) -> T_NormalizedChunks:
def chunks(self, data: chunkedduckarray[Any, Any]) -> _Chunks:
return data.chunks

def normalize_chunks(
self,
chunks: T_Chunks | T_NormalizedChunks,
shape: tuple[int, ...] | None = None,
chunks: _ChunksLike,
shape: _Shape | None = None,
limit: int | None = None,
dtype: np.dtype | None = None,
previous_chunks: T_NormalizedChunks | None = None,
) -> T_NormalizedChunks:
dtype: _DType | None = None,
previous_chunks: _Chunks | None = None,
) -> _Chunks:
from dask.array.core import normalize_chunks

return normalize_chunks(chunks, shape, limit, dtype, previous_chunks)

def from_array(
self, data: T_DuckArray | np.typing.ArrayLike, chunks: _Chunks, **kwargs
) -> DummyChunkedArray:
self, data: duckarray[Any, _DType], chunks: _ChunksLike, **kwargs
) -> chunkedduckarray[Any, _DType]:
from dask import array as da

return da.from_array(data, chunks, **kwargs)

def rechunk(self, data: DummyChunkedArray, chunks, **kwargs) -> DummyChunkedArray:
def rechunk(
self, data: chunkedduckarray[Any, _DType], chunks: _ChunksLike, **kwargs
) -> chunkedduckarray[Any, _DType]:
return data.rechunk(chunks, **kwargs)

def compute(self, *data: DummyChunkedArray, **kwargs) -> tuple[np.ndarray, ...]:
def compute(
self, *data: chunkedduckarray[Any, _DType], **kwargs
) -> tuple[duckarray[Any, _DType], ...]:
from dask.array import compute

return compute(*data, **kwargs)
Expand Down

0 comments on commit e7041e0

Please sign in to comment.