From 2d3c13a82d31952b32cf71c7120ec490af28816d Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Tue, 10 Dec 2024 10:24:23 +0100 Subject: [PATCH] cache members instead of just array keys --- xarray/backends/zarr.py | 74 ++++++++++++++++++++++++----------------- 1 file changed, 44 insertions(+), 30 deletions(-) diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index 03e96985251..3124cf62769 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -9,6 +9,8 @@ import numpy as np import pandas as pd +from zarr import Array as ZarrArray +from zarr import Group as ZarrGroup from xarray import coding, conventions from xarray.backends.common import ( @@ -38,9 +40,6 @@ from xarray.namedarray.utils import module_available if TYPE_CHECKING: - from zarr import Array as ZarrArray - from zarr import Group as ZarrGroup - from xarray.backends.common import AbstractDataStore from xarray.core.dataset import Dataset from xarray.core.datatree import DataTree @@ -602,11 +601,11 @@ class ZarrStore(AbstractWritableDataStore): __slots__ = ( "_append_dim", - "_cache", "_cache_array_keys", "_close_store_on_close", "_consolidate_on_close", "_group", + "_members", "_mode", "_read_only", "_safe_chunks", @@ -690,7 +689,7 @@ def open_group( zarr_format=None, use_zarr_fill_value_as_mask=None, write_empty: bool | None = None, - cache_array_keys: bool = True, + cache_members: bool = True, ): ( zarr_group, @@ -721,7 +720,7 @@ def open_group( write_empty, close_store_on_close, use_zarr_fill_value_as_mask, - cache_array_keys, + cache_members, ) def __init__( @@ -735,7 +734,7 @@ def __init__( write_empty: bool | None = None, close_store_on_close: bool = False, use_zarr_fill_value_as_mask=None, - cache_array_keys: bool = True, + cache_members: bool = True, ): self.zarr_group = zarr_group self._read_only = self.zarr_group.read_only @@ -750,9 +749,37 @@ def __init__( self._close_store_on_close = close_store_on_close self._use_zarr_fill_value_as_mask = use_zarr_fill_value_as_mask - self._cache: dict[str, Any] = {} - if cache_array_keys: - self._cache["array_keys"] = None + self._members: tuple[bool, dict[str, ZarrArray | ZarrGroup] | None] + if cache_members: + self._members = (True, None) + else: + self._members = (False, None) + + @property + def members(self) -> dict[str, ZarrArray | ZarrGroup]: + """ + Model the arrays and groups contained in self.zarr_group as a dict + """ + do_cache, old_members = self._members + if not do_cache or old_members is None: + # we explicitly only care about the arrays, which saves some IO + # in zarr v2 + members = dict(self.zarr_group.arrays()) + if do_cache: + self._members = (do_cache, members) + return members + else: + return old_members + + def array_keys(self) -> tuple[str, ...]: + return tuple(key for (key, _) in self.arrays()) + + def arrays(self) -> tuple[tuple[str, ZarrArray], ...]: + return tuple( + (key, node) + for (key, node) in self.members.items() + if isinstance(node, ZarrArray) + ) @property def ds(self): @@ -761,7 +788,7 @@ def ds(self): def open_store_variable(self, name, zarr_array=None): if zarr_array is None: - zarr_array = self.zarr_group[name] + zarr_array = self.members[name] data = indexing.LazilyIndexedArray(ZarrArrayWrapper(zarr_array)) try_nczarr = self._mode == "r" dimensions, attributes = _get_zarr_dims_and_attrs( @@ -809,20 +836,7 @@ def open_store_variable(self, name, zarr_array=None): return Variable(dimensions, data, attributes, encoding) def get_variables(self): - return FrozenDict( - (k, self.open_store_variable(k, v)) for k, v in self.zarr_group.arrays() - ) - - def get_array_keys(self) -> tuple[str, ...]: - key = "array_keys" - if key not in self._cache: - result = tuple(self.zarr_group.array_keys()) - elif self._cache[key] is None: - result = tuple(self.zarr_group.array_keys()) - self._cache[key] = result - else: - result = self._cache[key] - return result + return FrozenDict((k, self.open_store_variable(k, v)) for k, v in self.arrays()) def get_attrs(self): return { @@ -834,7 +848,7 @@ def get_attrs(self): def get_dimensions(self): try_nczarr = self._mode == "r" dimensions = {} - for _k, v in self.zarr_group.arrays(): + for _k, v in self.arrays(): dim_names, _ = _get_zarr_dims_and_attrs(v, DIMENSION_KEY, try_nczarr) for d, s in zip(dim_names, v.shape, strict=True): if d in dimensions and dimensions[d] != s: @@ -903,7 +917,7 @@ def store( existing_keys = {} existing_variable_names = {} else: - existing_keys = self.get_array_keys() + existing_keys = self.array_keys() existing_variable_names = { vn for vn in variables if _encode_variable_name(vn) in existing_keys } @@ -1081,7 +1095,7 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No dimensions. """ - existing_keys = self.get_array_keys() + existing_keys = self.array_keys() is_zarr_v3_format = _zarr_v3() and self.zarr_group.metadata.zarr_format == 3 for vn, v in variables.items(): @@ -1273,7 +1287,7 @@ def _validate_and_autodetect_region(self, ds: Dataset) -> Dataset: def _validate_encoding(self, encoding) -> None: if encoding and self._mode in ["a", "a-", "r+"]: - existing_var_names = self.get_array_keys() + existing_var_names = self.array_keys() for var_name in existing_var_names: if var_name in encoding: raise ValueError( @@ -1545,7 +1559,7 @@ def open_dataset( zarr_version=zarr_version, use_zarr_fill_value_as_mask=None, zarr_format=zarr_format, - cache_array_keys=cache_array_keys, + cache_members=cache_array_keys, ) store_entrypoint = StoreBackendEntrypoint()