From 0e345286158ba568ee242e4842218fcb9a029697 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 3 Dec 2024 16:11:45 -0700 Subject: [PATCH] Cache the result of `DaskManager.normalize_chunks` This is only used with the backends codepath, where the inputs are guaranteed to be tuples. By contrast, `dask.array.normalize_chunks` accepts dicts as inputs and so, is harder to cache transparently --- xarray/namedarray/daskmanager.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/xarray/namedarray/daskmanager.py b/xarray/namedarray/daskmanager.py index 6485ba375f5..26bcc3e43e3 100644 --- a/xarray/namedarray/daskmanager.py +++ b/xarray/namedarray/daskmanager.py @@ -1,6 +1,7 @@ from __future__ import annotations from collections.abc import Callable, Iterable, Sequence +from functools import lru_cache from typing import TYPE_CHECKING, Any import numpy as np @@ -19,12 +20,20 @@ try: from dask.array import Array as DaskArray + except ImportError: DaskArray = np.ndarray[Any, Any] dask_available = module_available("dask") +if dask_available: + from dask.array.core import normalize_chunks + + normalize_chunks = lru_cache(normalize_chunks) +else: + normalize_chunks = None + class DaskManager(ChunkManagerEntrypoint["DaskArray"]): array_cls: type[DaskArray] @@ -52,8 +61,6 @@ def normalize_chunks( previous_chunks: _NormalizedChunks | None = None, ) -> Any: """Called by open_dataset""" - from dask.array.core import normalize_chunks - return normalize_chunks( chunks, shape=shape,