Skip to content

Commit

Permalink
add .oindex and .vindex to BackendArray (#8885)
Browse files Browse the repository at this point in the history
* add .oindex and .vindex to BackendArray

* Add support for .oindex and .vindex in H5NetCDFArrayWrapper

* Add support for .oindex and .vindex in NetCDF4ArrayWrapper, PydapArrayWrapper, NioArrayWrapper, and ZarrArrayWrapper

* add deprecation warning

* Fix deprecation warning message formatting

* add tests

* Update xarray/core/indexing.py

Co-authored-by: Deepak Cherian <[email protected]>

* Update ZarrArrayWrapper class in xarray/backends/zarr.py

Co-authored-by: Deepak Cherian <[email protected]>

---------

Co-authored-by: Deepak Cherian <[email protected]>
  • Loading branch information
andersy005 and dcherian authored Apr 17, 2024
1 parent b81b451 commit 10c133b
Show file tree
Hide file tree
Showing 9 changed files with 193 additions and 37 deletions.
18 changes: 18 additions & 0 deletions xarray/backends/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,24 @@ def get_duck_array(self, dtype: np.typing.DTypeLike = None):
key = indexing.BasicIndexer((slice(None),) * self.ndim)
return self[key] # type: ignore [index]

def _oindex_get(self, key: indexing.OuterIndexer):
raise NotImplementedError(
f"{self.__class__.__name__}._oindex_get method should be overridden"
)

def _vindex_get(self, key: indexing.VectorizedIndexer):
raise NotImplementedError(
f"{self.__class__.__name__}._vindex_get method should be overridden"
)

@property
def oindex(self) -> indexing.IndexCallable:
return indexing.IndexCallable(self._oindex_get)

@property
def vindex(self) -> indexing.IndexCallable:
return indexing.IndexCallable(self._vindex_get)


class AbstractDataStore:
__slots__ = ()
Expand Down
12 changes: 11 additions & 1 deletion xarray/backends/h5netcdf_.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,17 @@ def get_array(self, needs_lock=True):
ds = self.datastore._acquire(needs_lock)
return ds.variables[self.variable_name]

def __getitem__(self, key):
def _oindex_get(self, key: indexing.OuterIndexer):
return indexing.explicit_indexing_adapter(
key, self.shape, indexing.IndexingSupport.OUTER_1VECTOR, self._getitem
)

def _vindex_get(self, key: indexing.VectorizedIndexer):
return indexing.explicit_indexing_adapter(
key, self.shape, indexing.IndexingSupport.OUTER_1VECTOR, self._getitem
)

def __getitem__(self, key: indexing.BasicIndexer):
return indexing.explicit_indexing_adapter(
key, self.shape, indexing.IndexingSupport.OUTER_1VECTOR, self._getitem
)
Expand Down
12 changes: 11 additions & 1 deletion xarray/backends/netCDF4_.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,17 @@ def get_array(self, needs_lock=True):
variable.set_auto_chartostring(False)
return variable

def __getitem__(self, key):
def _oindex_get(self, key: indexing.OuterIndexer):
return indexing.explicit_indexing_adapter(
key, self.shape, indexing.IndexingSupport.OUTER, self._getitem
)

def _vindex_get(self, key: indexing.VectorizedIndexer):
return indexing.explicit_indexing_adapter(
key, self.shape, indexing.IndexingSupport.OUTER, self._getitem
)

def __getitem__(self, key: indexing.BasicIndexer):
return indexing.explicit_indexing_adapter(
key, self.shape, indexing.IndexingSupport.OUTER, self._getitem
)
Expand Down
12 changes: 11 additions & 1 deletion xarray/backends/pydap_.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,17 @@ def shape(self) -> tuple[int, ...]:
def dtype(self):
return self.array.dtype

def __getitem__(self, key):
def _oindex_get(self, key: indexing.OuterIndexer):
return indexing.explicit_indexing_adapter(
key, self.shape, indexing.IndexingSupport.BASIC, self._getitem
)

def _vindex_get(self, key: indexing.VectorizedIndexer):
return indexing.explicit_indexing_adapter(
key, self.shape, indexing.IndexingSupport.BASIC, self._getitem
)

def __getitem__(self, key: indexing.BasicIndexer):
return indexing.explicit_indexing_adapter(
key, self.shape, indexing.IndexingSupport.BASIC, self._getitem
)
Expand Down
12 changes: 11 additions & 1 deletion xarray/backends/pynio_.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,17 @@ def get_array(self, needs_lock=True):
ds = self.datastore._manager.acquire(needs_lock)
return ds.variables[self.variable_name]

def __getitem__(self, key):
def _oindex_get(self, key: indexing.OuterIndexer):
return indexing.explicit_indexing_adapter(
key, self.shape, indexing.IndexingSupport.BASIC, self._getitem
)

def _vindex_get(self, key: indexing.VectorizedIndexer):
return indexing.explicit_indexing_adapter(
key, self.shape, indexing.IndexingSupport.BASIC, self._getitem
)

def __getitem__(self, key: indexing.BasicIndexer):
return indexing.explicit_indexing_adapter(
key, self.shape, indexing.IndexingSupport.BASIC, self._getitem
)
Expand Down
33 changes: 24 additions & 9 deletions xarray/backends/scipy_.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,7 @@ def get_variable(self, needs_lock=True):
ds = self.datastore._manager.acquire(needs_lock)
return ds.variables[self.variable_name]

def _getitem(self, key):
with self.datastore.lock:
data = self.get_variable(needs_lock=False).data
return data[key]

def __getitem__(self, key):
data = indexing.explicit_indexing_adapter(
key, self.shape, indexing.IndexingSupport.OUTER_1VECTOR, self._getitem
)
def _finalize_result(self, data):
# Copy data if the source file is mmapped. This makes things consistent
# with the netCDF4 library by ensuring we can safely read arrays even
# after closing associated files.
Expand All @@ -88,6 +80,29 @@ def __getitem__(self, key):

return np.array(data, dtype=self.dtype, copy=copy)

def _getitem(self, key):
with self.datastore.lock:
data = self.get_variable(needs_lock=False).data
return data[key]

def _vindex_get(self, key: indexing.VectorizedIndexer):
data = indexing.explicit_indexing_adapter(
key, self.shape, indexing.IndexingSupport.OUTER_1VECTOR, self._getitem
)
return self._finalize_result(data)

def _oindex_get(self, key: indexing.OuterIndexer):
data = indexing.explicit_indexing_adapter(
key, self.shape, indexing.IndexingSupport.OUTER_1VECTOR, self._getitem
)
return self._finalize_result(data)

def __getitem__(self, key):
data = indexing.explicit_indexing_adapter(
key, self.shape, indexing.IndexingSupport.OUTER_1VECTOR, self._getitem
)
return self._finalize_result(data)

def __setitem__(self, key, value):
with self.datastore.lock:
data = self.get_variable(needs_lock=False)
Expand Down
49 changes: 31 additions & 18 deletions xarray/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,25 +84,38 @@ def __init__(self, zarr_array):
def get_array(self):
return self._array

def _oindex(self, key):
return self._array.oindex[key]

def _vindex(self, key):
return self._array.vindex[key]

def _getitem(self, key):
return self._array[key]

def __getitem__(self, key):
array = self._array
if isinstance(key, indexing.BasicIndexer):
method = self._getitem
elif isinstance(key, indexing.VectorizedIndexer):
method = self._vindex
elif isinstance(key, indexing.OuterIndexer):
method = self._oindex
def _oindex_get(self, key: indexing.OuterIndexer):
def raw_indexing_method(key):
return self._array.oindex[key]

return indexing.explicit_indexing_adapter(
key,
self._array.shape,
indexing.IndexingSupport.VECTORIZED,
raw_indexing_method,
)

def _vindex_get(self, key: indexing.VectorizedIndexer):

def raw_indexing_method(key):
return self._array.vindex[key]

return indexing.explicit_indexing_adapter(
key,
self._array.shape,
indexing.IndexingSupport.VECTORIZED,
raw_indexing_method,
)

def __getitem__(self, key: indexing.BasicIndexer):
def raw_indexing_method(key):
return self._array[key]

return indexing.explicit_indexing_adapter(
key, array.shape, indexing.IndexingSupport.VECTORIZED, method
key,
self._array.shape,
indexing.IndexingSupport.VECTORIZED,
raw_indexing_method,
)

# if self.ndim == 0:
Expand Down
36 changes: 30 additions & 6 deletions xarray/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import enum
import functools
import operator
import warnings
from collections import Counter, defaultdict
from collections.abc import Hashable, Iterable, Mapping
from contextlib import suppress
Expand Down Expand Up @@ -564,6 +565,14 @@ def __getitem__(self, key: Any):
return result


BackendArray_fallback_warning_message = (
"The array `{0}` does not support indexing using the .vindex and .oindex properties. "
"The __getitem__ method is being used instead. This fallback behavior will be "
"removed in a future version. Please ensure that the backend array `{1}` implements "
"support for the .vindex and .oindex properties to avoid potential issues."
)


class LazilyIndexedArray(ExplicitlyIndexedNDArrayMixin):
"""Wrap an array to make basic and outer indexing lazy."""

Expand Down Expand Up @@ -615,11 +624,18 @@ def shape(self) -> _Shape:
return tuple(shape)

def get_duck_array(self):
if isinstance(self.array, ExplicitlyIndexedNDArrayMixin):
try:
array = apply_indexer(self.array, self.key)
else:
except NotImplementedError as _:
# If the array is not an ExplicitlyIndexedNDArrayMixin,
# it may wrap a BackendArray so use its __getitem__
# it may wrap a BackendArray subclass that doesn't implement .oindex and .vindex. so use its __getitem__
warnings.warn(
BackendArray_fallback_warning_message.format(
self.array.__class__.__name__, self.array.__class__.__name__
),
category=DeprecationWarning,
stacklevel=2,
)
array = self.array[self.key]

# self.array[self.key] is now a numpy array when
Expand Down Expand Up @@ -691,12 +707,20 @@ def shape(self) -> _Shape:
return np.broadcast(*self.key.tuple).shape

def get_duck_array(self):
if isinstance(self.array, ExplicitlyIndexedNDArrayMixin):
try:
array = apply_indexer(self.array, self.key)
else:
except NotImplementedError as _:
# If the array is not an ExplicitlyIndexedNDArrayMixin,
# it may wrap a BackendArray so use its __getitem__
# it may wrap a BackendArray subclass that doesn't implement .oindex and .vindex. so use its __getitem__
warnings.warn(
BackendArray_fallback_warning_message.format(
self.array.__class__.__name__, self.array.__class__.__name__
),
category=PendingDeprecationWarning,
stacklevel=2,
)
array = self.array[self.key]

# self.array[self.key] is now a numpy array when
# self.array is a BackendArray subclass
# and self.key is BasicIndexer((slice(None, None, None),))
Expand Down
46 changes: 46 additions & 0 deletions xarray/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -5828,3 +5828,49 @@ def test_zarr_region_chunk_partial_offset(tmp_path):
# This write is unsafe, and should raise an error, but does not.
# with pytest.raises(ValueError):
# da.isel(x=slice(5, 25)).chunk(x=(10, 10)).to_zarr(store, region="auto")


def test_backend_array_deprecation_warning(capsys):
class CustomBackendArray(xr.backends.common.BackendArray):
def __init__(self):
array = self.get_array()
self.shape = array.shape
self.dtype = array.dtype

def get_array(self):
return np.arange(10)

def __getitem__(self, key):
return xr.core.indexing.explicit_indexing_adapter(
key, self.shape, xr.core.indexing.IndexingSupport.BASIC, self._getitem
)

def _getitem(self, key):
array = self.get_array()
return array[key]

cba = CustomBackendArray()
indexer = xr.core.indexing.VectorizedIndexer(key=(np.array([0]),))

la = xr.core.indexing.LazilyIndexedArray(cba, indexer)

with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
la.vindex[indexer].get_duck_array()

captured = capsys.readouterr()
assert len(w) == 1
assert issubclass(w[-1].category, PendingDeprecationWarning)
assert (
"The array `CustomBackendArray` does not support indexing using the .vindex and .oindex properties."
in str(w[-1].message)
)
assert "The __getitem__ method is being used instead." in str(w[-1].message)
assert "This fallback behavior will be removed in a future version." in str(
w[-1].message
)
assert (
"Please ensure that the backend array `CustomBackendArray` implements support for the .vindex and .oindex properties to avoid potential issues."
in str(w[-1].message)
)
assert captured.out == ""

0 comments on commit 10c133b

Please sign in to comment.