Skip to content

Commit

Permalink
(fix): handle case when chunks is None
Browse files Browse the repository at this point in the history
  • Loading branch information
ilan-gold committed Jul 9, 2024
1 parent ca6cf66 commit eabaf35
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 7 deletions.
7 changes: 5 additions & 2 deletions src/anndata/_io/specs/lazy_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,11 @@ def make_dask_chunk(block_id: tuple[int, int]):

@_LAZY_REGISTRY.register_read(ZarrArray, IOSpec("array", "0.2.0"))
def read_zarr_array(
elem, _reader, dataset_kwargs: Mapping[str, Any] = MappingProxyType({})
elem: ZarrArray,
_reader: Reader,
dataset_kwargs: Mapping[str, Any] = MappingProxyType({}),
):
chunks: tuple[int, ...] = dataset_kwargs.get("chunks", elem.chunks)
import dask.array as da

return da.from_zarr(elem)
return da.from_zarr(elem, chunks=chunks)
4 changes: 3 additions & 1 deletion src/anndata/_io/specs/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,9 @@ def read_elem_as_dask(
-------
DaskArray
"""
return Reader(_LAZY_REGISTRY).read_elem(elem, dataset_kwargs={"chunks": chunks})
return Reader(_LAZY_REGISTRY).read_elem(
elem, dataset_kwargs={"chunks": chunks} if chunks is not None else {}
)


def write_elem(
Expand Down
26 changes: 22 additions & 4 deletions tests/test_io_elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def store(request, tmp_path) -> H5Group | ZarrGroup:


sparse_formats = ["csr", "csc"]
SIZE = 1000
SIZE = 2500


@pytest.fixture(params=sparse_formats)
Expand Down Expand Up @@ -235,7 +235,15 @@ def test_read_lazy_2d_dask(sparse_format, store):

@pytest.mark.parametrize(
("n_dims", "chunks"),
[(1, (100,)), (1, (400,)), (2, (100, 100)), (2, (400, 400)), (2, (200, 400))],
[
(1, (100,)),
(1, (400,)),
(2, (100, 100)),
(2, (400, 400)),
(2, (200, 400)),
(1, None),
(2, None),
],
)
def test_read_lazy_nd_dask(store, n_dims, chunks):
arr_store = create_dense_store(store, n_dims)
Expand Down Expand Up @@ -269,7 +277,13 @@ def test_read_lazy_h5_cluster(sparse_format, tmp_path):

@pytest.mark.parametrize(
("arr_type", "chunks"),
[("dense", (100, 100)), ("csc", (SIZE, 10)), ("csr", (10, SIZE))],
[
("dense", (100, 100)),
("csc", (SIZE, 10)),
("csr", (10, SIZE)),
("csc", None),
("csr", None),
],
)
def test_read_lazy_h5_chunk_kwargs(arr_type, chunks, tmp_path):
import dask.distributed as dd
Expand All @@ -282,7 +296,11 @@ def test_read_lazy_h5_chunk_kwargs(arr_type, chunks, tmp_path):
else:
arr_store = create_sparse_store(arr_type, store)
X_dask_from_disk = read_elem_as_dask(arr_store["X"], chunks=chunks)
assert X_dask_from_disk.chunksize == chunks
if chunks is not None:
assert X_dask_from_disk.chunksize == chunks
else:
# assert that sparse chunks are set correctly by default
assert X_dask_from_disk.chunksize[bool(arr_type == "csr")] == SIZE
X_from_disk = read_elem(arr_store["X"])
file.close()
with (
Expand Down

0 comments on commit eabaf35

Please sign in to comment.