Skip to content

Commit

Permalink
refactor members cache
Browse files Browse the repository at this point in the history
  • Loading branch information
d-v-b committed Dec 10, 2024
1 parent 1eaf3ea commit 9d147b9
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 42 deletions.
59 changes: 39 additions & 20 deletions xarray/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,7 @@ class ZarrStore(AbstractWritableDataStore):

__slots__ = (
"_append_dim",
"_cache_array_keys",
"_cache_members",
"_close_store_on_close",
"_consolidate_on_close",
"_group",
Expand Down Expand Up @@ -634,7 +634,7 @@ def open_store(
zarr_format=None,
use_zarr_fill_value_as_mask=None,
write_empty: bool | None = None,
cache_array_keys: bool = False,
cache_members: bool = False,
):
(
zarr_group,
Expand Down Expand Up @@ -666,7 +666,7 @@ def open_store(
write_empty,
close_store_on_close,
use_zarr_fill_value_as_mask,
cache_array_keys=cache_array_keys,
cache_members=cache_members
)
for group in group_paths
}
Expand Down Expand Up @@ -748,31 +748,42 @@ def __init__(
self._write_empty = write_empty
self._close_store_on_close = close_store_on_close
self._use_zarr_fill_value_as_mask = use_zarr_fill_value_as_mask
self._cache_members: bool = cache_members
self._members: dict[str, ZarrArray | ZarrGroup] = {}

self._members: tuple[bool, dict[str, ZarrArray | ZarrGroup] | None]
if cache_members:
self._members = (True, None)
else:
self._members = (False, None)
if self._cache_members:
# initialize the cache
self._members = self._fetch_members()

@property
def members(self) -> dict[str, ZarrArray | ZarrGroup]:
def members(self) -> dict[str, ZarrArray]:
"""
Model the arrays and groups contained in self.zarr_group as a dict
"""
do_cache, old_members = self._members
if not do_cache or old_members is None:
# we explicitly only care about the arrays, which saves some IO
# in zarr v2
members = dict(self.zarr_group.arrays())
if do_cache:
self._members = (do_cache, members)
return members
if not self._cache_members:
return self._fetch_members()
else:
return old_members
return self._members

def _fetch_members(self) -> dict[str, ZarrArray]:
"""
Get the arrays and groups defined in the zarr group modelled by this Store
"""
return dict(self.zarr_group.items())

def _update_members(self, data: dict[str, ZarrArray]):
if not self._cache_members:
msg = (
'Updating the members cache is only valid if this object was created '
'with cache_members=True, but this object has `cache_members=False`.'
f'You should update the zarr group directly.'
)
raise ValueError(msg)
else:
self._members = {**self.members, **data}

def array_keys(self) -> tuple[str, ...]:
return tuple(key for (key, _) in self.arrays())
return tuple(key for (key, node) in self.members.items() if isinstance(node, ZarrArray))

def arrays(self) -> tuple[tuple[str, ZarrArray], ...]:
return tuple(
Expand Down Expand Up @@ -1047,6 +1058,10 @@ def _open_existing_array(self, *, name) -> ZarrArray:
else:
zarr_array = self.zarr_group[name]

# update the model of the underlying zarr group
if self._cache_members:
self._update_members({name: zarr_array})
self._update_members({name: zarr_array})
return zarr_array

def _create_new_array(
Expand Down Expand Up @@ -1075,6 +1090,9 @@ def _create_new_array(
**encoding,
)
zarr_array = _put_attrs(zarr_array, attrs)
# update the model of the underlying zarr group
if self._cache_members:
self._update_members({name: zarr_array})
return zarr_array

def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=None):
Expand Down Expand Up @@ -1143,7 +1161,8 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No
zarr_array.resize(new_shape)

zarr_shape = zarr_array.shape

# update the model of the members of the zarr group
self.members[name] = zarr_array
region = tuple(write_region[dim] for dim in dims)

# We need to do this for both new and existing variables to ensure we're not
Expand Down
50 changes: 28 additions & 22 deletions xarray/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -2275,10 +2275,10 @@ def create_zarr_target(self):
raise NotImplementedError

@contextlib.contextmanager
def create_store(self):
def create_store(self, cache_members: bool = True):
with self.create_zarr_target() as store_target:
yield backends.ZarrStore.open_group(
store_target, mode="w", **self.version_kwargs
store_target, mode="w", cache_members=cache_members, **self.version_kwargs
)

def save(self, dataset, store_target, **kwargs): # type: ignore[override]
Expand Down Expand Up @@ -2572,7 +2572,7 @@ def test_hidden_zarr_keys(self) -> None:
skip_if_zarr_format_3("This test is unnecessary; no hidden Zarr keys")

expected = create_test_data()
with self.create_store() as store:
with self.create_store(cache_members=False) as store:
expected.dump_to_store(store)
zarr_group = store.ds

Expand All @@ -2594,6 +2594,7 @@ def test_hidden_zarr_keys(self) -> None:

# put it back and try removing from a variable
del zarr_group["var2"].attrs[self.DIMENSION_KEY]

with pytest.raises(KeyError):
with xr.decode_cf(store):
pass
Expand Down Expand Up @@ -3258,40 +3259,45 @@ def test_chunked_cftime_datetime(self) -> None:
assert original[name].chunks == actual_var.chunks
assert original.chunks == actual.chunks

@pytest.mark.parametrize("cache_array_keys", [True, False])
def test_get_array_keys(self, cache_array_keys: bool) -> None:
def test_cache_members(self) -> None:
"""
Ensure that if `ZarrStore` is created with `cache_array_keys` set to `True`,
a `ZarrStore.get_array_keys` only invokes the `array_keys` function on the
`ZarrStore.zarr_group` instance once, and that the results of that call are cached.
Ensure that if `ZarrStore` is created with `cache_members` set to `True`,
a `ZarrStore` only inspects the underlying zarr group once,
and that the results of that inspection are cached.
Otherwise, `ZarrStore.get_array_keys` instance should invoke the `array_keys`
each time it is called.
Otherwise, `ZarrStore.members` should inspect the underlying zarr group each time it is
invoked
"""
with self.create_zarr_target() as store_target:
zstore = backends.ZarrStore.open_group(
store_target, mode="w", cache_members=cache_array_keys
zstore_mut = backends.ZarrStore.open_group(
store_target, mode="w", cache_members=False
)

# ensure that the keys are sorted
array_keys = sorted(("foo", "bar"))

# create some arrays
for ak in array_keys:
zstore.zarr_group.create(name=ak, shape=(1,), dtype="uint8")
zstore_mut.zarr_group.create(name=ak, shape=(1,), dtype="uint8")

zstore_stat = backends.ZarrStore.open_group(
store_target, mode="r", cache_members=True
)

observed_keys_0 = sorted(zstore.array_keys())
observed_keys_0 = sorted(zstore_stat.array_keys())
assert observed_keys_0 == array_keys

# create a new array
new_key = "baz"
zstore.zarr_group.create(name=new_key, shape=(1,), dtype="uint8")
observed_keys_1 = sorted(zstore.array_keys())
zstore_mut.zarr_group.create(name=new_key, shape=(1,), dtype="uint8")

observed_keys_1 = sorted(zstore_stat.array_keys())
assert observed_keys_1 == array_keys

observed_keys_2 = sorted(zstore_mut.array_keys())
assert observed_keys_2 == sorted(array_keys + [new_key])


if cache_array_keys:
assert observed_keys_1 == array_keys
else:
assert observed_keys_1 == sorted(array_keys + [new_key])


@requires_zarr
Expand Down Expand Up @@ -3556,9 +3562,9 @@ def create_zarr_target(self):
yield tmp

@contextlib.contextmanager
def create_store(self):
def create_store(self, cache_members: bool = True):
with self.create_zarr_target() as store_target:
group = backends.ZarrStore.open_group(store_target, mode="a")
group = backends.ZarrStore.open_group(store_target, mode="a", cache_members=cache_members)
yield group


Expand Down

0 comments on commit 9d147b9

Please sign in to comment.