Skip to content

Commit

Permalink
Cache the result of DaskManager.normalize_chunks
Browse files Browse the repository at this point in the history
This is only used with the backends codepath, where the inputs
are guaranteed to be tuples. By contrast, `dask.array.normalize_chunks`
accepts dicts as inputs and so, is harder to cache transparently
  • Loading branch information
dcherian committed Dec 16, 2024
1 parent 0945e0e commit 0e34528
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions xarray/namedarray/daskmanager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from collections.abc import Callable, Iterable, Sequence
from functools import lru_cache
from typing import TYPE_CHECKING, Any

import numpy as np
Expand All @@ -19,12 +20,20 @@

try:
from dask.array import Array as DaskArray

except ImportError:
DaskArray = np.ndarray[Any, Any]


dask_available = module_available("dask")

if dask_available:
from dask.array.core import normalize_chunks

normalize_chunks = lru_cache(normalize_chunks)
else:
normalize_chunks = None


class DaskManager(ChunkManagerEntrypoint["DaskArray"]):
array_cls: type[DaskArray]
Expand Down Expand Up @@ -52,8 +61,6 @@ def normalize_chunks(
previous_chunks: _NormalizedChunks | None = None,
) -> Any:
"""Called by open_dataset"""
from dask.array.core import normalize_chunks

return normalize_chunks(
chunks,
shape=shape,
Expand Down

0 comments on commit 0e34528

Please sign in to comment.