Skip to content

Commit

Permalink
Merge branch 'perf/cache-array-keys' of https://github.com/d-v-b/xarray
Browse files Browse the repository at this point in the history
… into perf/cache-array-keys
  • Loading branch information
d-v-b committed Dec 10, 2024
2 parents 6578caa + bdb20a6 commit 9596bcf
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 15 deletions.
14 changes: 8 additions & 6 deletions xarray/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,7 +666,7 @@ def open_store(
write_empty,
close_store_on_close,
use_zarr_fill_value_as_mask,
cache_members=cache_members
cache_members=cache_members,
)
for group in group_paths
}
Expand Down Expand Up @@ -774,16 +774,18 @@ def _fetch_members(self) -> dict[str, ZarrArray]:
def _update_members(self, data: dict[str, ZarrArray]):
if not self._cache_members:
msg = (
'Updating the members cache is only valid if this object was created '
'with cache_members=True, but this object has `cache_members=False`.'
f'You should update the zarr group directly.'
)
"Updating the members cache is only valid if this object was created "
"with cache_members=True, but this object has `cache_members=False`."
"You should update the zarr group directly."
)
raise ValueError(msg)
else:
self._members = {**self.members, **data}

def array_keys(self) -> tuple[str, ...]:
return tuple(key for (key, node) in self.members.items() if isinstance(node, ZarrArray))
return tuple(
key for (key, node) in self.members.items() if isinstance(node, ZarrArray)
)

def arrays(self) -> tuple[tuple[str, ZarrArray], ...]:
return tuple(
Expand Down
21 changes: 12 additions & 9 deletions xarray/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -2282,7 +2282,10 @@ def create_zarr_target(self):
def create_store(self, cache_members: bool = True):
with self.create_zarr_target() as store_target:
yield backends.ZarrStore.open_group(
store_target, mode="w", cache_members=cache_members, **self.version_kwargs
store_target,
mode="w",
cache_members=cache_members,
**self.version_kwargs,
)

def save(self, dataset, store_target, **kwargs): # type: ignore[override]
Expand Down Expand Up @@ -2598,7 +2601,7 @@ def test_hidden_zarr_keys(self) -> None:

# put it back and try removing from a variable
del zarr_group["var2"].attrs[self.DIMENSION_KEY]

with pytest.raises(KeyError):
with xr.decode_cf(store):
pass
Expand Down Expand Up @@ -3266,10 +3269,10 @@ def test_chunked_cftime_datetime(self) -> None:
def test_cache_members(self) -> None:
"""
Ensure that if `ZarrStore` is created with `cache_members` set to `True`,
a `ZarrStore` only inspects the underlying zarr group once,
a `ZarrStore` only inspects the underlying zarr group once,
and that the results of that inspection are cached.
Otherwise, `ZarrStore.members` should inspect the underlying zarr group each time it is
Otherwise, `ZarrStore.members` should inspect the underlying zarr group each time it is
invoked
"""
with self.create_zarr_target() as store_target:
Expand All @@ -3294,16 +3297,14 @@ def test_cache_members(self) -> None:
# create a new array
new_key = "baz"
zstore_mut.zarr_group.create(name=new_key, shape=(1,), dtype="uint8")

observed_keys_1 = sorted(zstore_stat.array_keys())
assert observed_keys_1 == array_keys

observed_keys_2 = sorted(zstore_mut.array_keys())
assert observed_keys_2 == sorted(array_keys + [new_key])




@requires_zarr
@pytest.mark.skipif(
KVStore is None, reason="zarr-python 2.x or ZARR_V3_EXPERIMENTAL_API is unset."
Expand Down Expand Up @@ -3568,7 +3569,9 @@ def create_zarr_target(self):
@contextlib.contextmanager
def create_store(self, cache_members: bool = True):
with self.create_zarr_target() as store_target:
group = backends.ZarrStore.open_group(store_target, mode="a", cache_members=cache_members)
group = backends.ZarrStore.open_group(
store_target, mode="a", cache_members=cache_members
)
yield group


Expand Down

0 comments on commit 9596bcf

Please sign in to comment.