Skip to content

Commit

Permalink
cache members instead of just array keys
Browse files Browse the repository at this point in the history
  • Loading branch information
d-v-b committed Dec 10, 2024
1 parent 96ddcf4 commit 2d3c13a
Showing 1 changed file with 44 additions and 30 deletions.
74 changes: 44 additions & 30 deletions xarray/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

import numpy as np
import pandas as pd
from zarr import Array as ZarrArray
from zarr import Group as ZarrGroup

from xarray import coding, conventions
from xarray.backends.common import (
Expand Down Expand Up @@ -38,9 +40,6 @@
from xarray.namedarray.utils import module_available

if TYPE_CHECKING:
from zarr import Array as ZarrArray
from zarr import Group as ZarrGroup

from xarray.backends.common import AbstractDataStore
from xarray.core.dataset import Dataset
from xarray.core.datatree import DataTree
Expand Down Expand Up @@ -602,11 +601,11 @@ class ZarrStore(AbstractWritableDataStore):

__slots__ = (
"_append_dim",
"_cache",
"_cache_array_keys",
"_close_store_on_close",
"_consolidate_on_close",
"_group",
"_members",
"_mode",
"_read_only",
"_safe_chunks",
Expand Down Expand Up @@ -690,7 +689,7 @@ def open_group(
zarr_format=None,
use_zarr_fill_value_as_mask=None,
write_empty: bool | None = None,
cache_array_keys: bool = True,
cache_members: bool = True,
):
(
zarr_group,
Expand Down Expand Up @@ -721,7 +720,7 @@ def open_group(
write_empty,
close_store_on_close,
use_zarr_fill_value_as_mask,
cache_array_keys,
cache_members,
)

def __init__(
Expand All @@ -735,7 +734,7 @@ def __init__(
write_empty: bool | None = None,
close_store_on_close: bool = False,
use_zarr_fill_value_as_mask=None,
cache_array_keys: bool = True,
cache_members: bool = True,
):
self.zarr_group = zarr_group
self._read_only = self.zarr_group.read_only
Expand All @@ -750,9 +749,37 @@ def __init__(
self._close_store_on_close = close_store_on_close
self._use_zarr_fill_value_as_mask = use_zarr_fill_value_as_mask

self._cache: dict[str, Any] = {}
if cache_array_keys:
self._cache["array_keys"] = None
self._members: tuple[bool, dict[str, ZarrArray | ZarrGroup] | None]
if cache_members:
self._members = (True, None)
else:
self._members = (False, None)

@property
def members(self) -> dict[str, ZarrArray | ZarrGroup]:
"""
Model the arrays and groups contained in self.zarr_group as a dict
"""
do_cache, old_members = self._members
if not do_cache or old_members is None:
# we explicitly only care about the arrays, which saves some IO
# in zarr v2
members = dict(self.zarr_group.arrays())
if do_cache:
self._members = (do_cache, members)
return members
else:
return old_members

def array_keys(self) -> tuple[str, ...]:
return tuple(key for (key, _) in self.arrays())

def arrays(self) -> tuple[tuple[str, ZarrArray], ...]:
return tuple(
(key, node)
for (key, node) in self.members.items()
if isinstance(node, ZarrArray)
)

@property
def ds(self):
Expand All @@ -761,7 +788,7 @@ def ds(self):

def open_store_variable(self, name, zarr_array=None):
if zarr_array is None:
zarr_array = self.zarr_group[name]
zarr_array = self.members[name]
data = indexing.LazilyIndexedArray(ZarrArrayWrapper(zarr_array))
try_nczarr = self._mode == "r"
dimensions, attributes = _get_zarr_dims_and_attrs(
Expand Down Expand Up @@ -809,20 +836,7 @@ def open_store_variable(self, name, zarr_array=None):
return Variable(dimensions, data, attributes, encoding)

def get_variables(self):
return FrozenDict(
(k, self.open_store_variable(k, v)) for k, v in self.zarr_group.arrays()
)

def get_array_keys(self) -> tuple[str, ...]:
key = "array_keys"
if key not in self._cache:
result = tuple(self.zarr_group.array_keys())
elif self._cache[key] is None:
result = tuple(self.zarr_group.array_keys())
self._cache[key] = result
else:
result = self._cache[key]
return result
return FrozenDict((k, self.open_store_variable(k, v)) for k, v in self.arrays())

def get_attrs(self):
return {
Expand All @@ -834,7 +848,7 @@ def get_attrs(self):
def get_dimensions(self):
try_nczarr = self._mode == "r"
dimensions = {}
for _k, v in self.zarr_group.arrays():
for _k, v in self.arrays():
dim_names, _ = _get_zarr_dims_and_attrs(v, DIMENSION_KEY, try_nczarr)
for d, s in zip(dim_names, v.shape, strict=True):
if d in dimensions and dimensions[d] != s:
Expand Down Expand Up @@ -903,7 +917,7 @@ def store(
existing_keys = {}
existing_variable_names = {}
else:
existing_keys = self.get_array_keys()
existing_keys = self.array_keys()
existing_variable_names = {
vn for vn in variables if _encode_variable_name(vn) in existing_keys
}
Expand Down Expand Up @@ -1081,7 +1095,7 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No
dimensions.
"""

existing_keys = self.get_array_keys()
existing_keys = self.array_keys()
is_zarr_v3_format = _zarr_v3() and self.zarr_group.metadata.zarr_format == 3

for vn, v in variables.items():
Expand Down Expand Up @@ -1273,7 +1287,7 @@ def _validate_and_autodetect_region(self, ds: Dataset) -> Dataset:

def _validate_encoding(self, encoding) -> None:
if encoding and self._mode in ["a", "a-", "r+"]:
existing_var_names = self.get_array_keys()
existing_var_names = self.array_keys()
for var_name in existing_var_names:
if var_name in encoding:
raise ValueError(
Expand Down Expand Up @@ -1545,7 +1559,7 @@ def open_dataset(
zarr_version=zarr_version,
use_zarr_fill_value_as_mask=None,
zarr_format=zarr_format,
cache_array_keys=cache_array_keys,
cache_members=cache_array_keys,
)

store_entrypoint = StoreBackendEntrypoint()
Expand Down

0 comments on commit 2d3c13a

Please sign in to comment.