diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index 231834ec1ba..c06bdecebb6 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -601,7 +601,7 @@ class ZarrStore(AbstractWritableDataStore): __slots__ = ( "_append_dim", - "_cache_array_keys", + "_cache_members", "_close_store_on_close", "_consolidate_on_close", "_group", @@ -634,7 +634,7 @@ def open_store( zarr_format=None, use_zarr_fill_value_as_mask=None, write_empty: bool | None = None, - cache_array_keys: bool = False, + cache_members: bool = False, ): ( zarr_group, @@ -666,7 +666,7 @@ def open_store( write_empty, close_store_on_close, use_zarr_fill_value_as_mask, - cache_array_keys=cache_array_keys, + cache_members=cache_members ) for group in group_paths } @@ -748,31 +748,42 @@ def __init__( self._write_empty = write_empty self._close_store_on_close = close_store_on_close self._use_zarr_fill_value_as_mask = use_zarr_fill_value_as_mask + self._cache_members: bool = cache_members + self._members: dict[str, ZarrArray | ZarrGroup] = {} - self._members: tuple[bool, dict[str, ZarrArray | ZarrGroup] | None] - if cache_members: - self._members = (True, None) - else: - self._members = (False, None) + if self._cache_members: + # initialize the cache + self._members = self._fetch_members() @property - def members(self) -> dict[str, ZarrArray | ZarrGroup]: + def members(self) -> dict[str, ZarrArray]: """ Model the arrays and groups contained in self.zarr_group as a dict """ - do_cache, old_members = self._members - if not do_cache or old_members is None: - # we explicitly only care about the arrays, which saves some IO - # in zarr v2 - members = dict(self.zarr_group.arrays()) - if do_cache: - self._members = (do_cache, members) - return members + if not self._cache_members: + return self._fetch_members() else: - return old_members + return self._members + + def _fetch_members(self) -> dict[str, ZarrArray]: + """ + Get the arrays and groups defined in the zarr group modelled by this Store + """ + return dict(self.zarr_group.items()) + + def _update_members(self, data: dict[str, ZarrArray]): + if not self._cache_members: + msg = ( + 'Updating the members cache is only valid if this object was created ' + 'with cache_members=True, but this object has `cache_members=False`.' + f'You should update the zarr group directly.' + ) + raise ValueError(msg) + else: + self._members = {**self.members, **data} def array_keys(self) -> tuple[str, ...]: - return tuple(key for (key, _) in self.arrays()) + return tuple(key for (key, node) in self.members.items() if isinstance(node, ZarrArray)) def arrays(self) -> tuple[tuple[str, ZarrArray], ...]: return tuple( @@ -1047,6 +1058,10 @@ def _open_existing_array(self, *, name) -> ZarrArray: else: zarr_array = self.zarr_group[name] + # update the model of the underlying zarr group + if self._cache_members: + self._update_members({name: zarr_array}) + self._update_members({name: zarr_array}) return zarr_array def _create_new_array( @@ -1075,6 +1090,9 @@ def _create_new_array( **encoding, ) zarr_array = _put_attrs(zarr_array, attrs) + # update the model of the underlying zarr group + if self._cache_members: + self._update_members({name: zarr_array}) return zarr_array def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=None): @@ -1143,7 +1161,8 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No zarr_array.resize(new_shape) zarr_shape = zarr_array.shape - + # update the model of the members of the zarr group + self.members[name] = zarr_array region = tuple(write_region[dim] for dim in dims) # We need to do this for both new and existing variables to ensure we're not diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 6e5d7170553..14756c19bb0 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -2275,10 +2275,10 @@ def create_zarr_target(self): raise NotImplementedError @contextlib.contextmanager - def create_store(self): + def create_store(self, cache_members: bool = True): with self.create_zarr_target() as store_target: yield backends.ZarrStore.open_group( - store_target, mode="w", **self.version_kwargs + store_target, mode="w", cache_members=cache_members, **self.version_kwargs ) def save(self, dataset, store_target, **kwargs): # type: ignore[override] @@ -2572,7 +2572,7 @@ def test_hidden_zarr_keys(self) -> None: skip_if_zarr_format_3("This test is unnecessary; no hidden Zarr keys") expected = create_test_data() - with self.create_store() as store: + with self.create_store(cache_members=False) as store: expected.dump_to_store(store) zarr_group = store.ds @@ -2594,6 +2594,7 @@ def test_hidden_zarr_keys(self) -> None: # put it back and try removing from a variable del zarr_group["var2"].attrs[self.DIMENSION_KEY] + with pytest.raises(KeyError): with xr.decode_cf(store): pass @@ -3258,19 +3259,18 @@ def test_chunked_cftime_datetime(self) -> None: assert original[name].chunks == actual_var.chunks assert original.chunks == actual.chunks - @pytest.mark.parametrize("cache_array_keys", [True, False]) - def test_get_array_keys(self, cache_array_keys: bool) -> None: + def test_cache_members(self) -> None: """ - Ensure that if `ZarrStore` is created with `cache_array_keys` set to `True`, - a `ZarrStore.get_array_keys` only invokes the `array_keys` function on the - `ZarrStore.zarr_group` instance once, and that the results of that call are cached. + Ensure that if `ZarrStore` is created with `cache_members` set to `True`, + a `ZarrStore` only inspects the underlying zarr group once, + and that the results of that inspection are cached. - Otherwise, `ZarrStore.get_array_keys` instance should invoke the `array_keys` - each time it is called. + Otherwise, `ZarrStore.members` should inspect the underlying zarr group each time it is + invoked """ with self.create_zarr_target() as store_target: - zstore = backends.ZarrStore.open_group( - store_target, mode="w", cache_members=cache_array_keys + zstore_mut = backends.ZarrStore.open_group( + store_target, mode="w", cache_members=False ) # ensure that the keys are sorted @@ -3278,20 +3278,26 @@ def test_get_array_keys(self, cache_array_keys: bool) -> None: # create some arrays for ak in array_keys: - zstore.zarr_group.create(name=ak, shape=(1,), dtype="uint8") + zstore_mut.zarr_group.create(name=ak, shape=(1,), dtype="uint8") + + zstore_stat = backends.ZarrStore.open_group( + store_target, mode="r", cache_members=True + ) - observed_keys_0 = sorted(zstore.array_keys()) + observed_keys_0 = sorted(zstore_stat.array_keys()) assert observed_keys_0 == array_keys # create a new array new_key = "baz" - zstore.zarr_group.create(name=new_key, shape=(1,), dtype="uint8") - observed_keys_1 = sorted(zstore.array_keys()) + zstore_mut.zarr_group.create(name=new_key, shape=(1,), dtype="uint8") + + observed_keys_1 = sorted(zstore_stat.array_keys()) + assert observed_keys_1 == array_keys + + observed_keys_2 = sorted(zstore_mut.array_keys()) + assert observed_keys_2 == sorted(array_keys + [new_key]) + - if cache_array_keys: - assert observed_keys_1 == array_keys - else: - assert observed_keys_1 == sorted(array_keys + [new_key]) @requires_zarr @@ -3556,9 +3562,9 @@ def create_zarr_target(self): yield tmp @contextlib.contextmanager - def create_store(self): + def create_store(self, cache_members: bool = True): with self.create_zarr_target() as store_target: - group = backends.ZarrStore.open_group(store_target, mode="a") + group = backends.ZarrStore.open_group(store_target, mode="a", cache_members=cache_members) yield group