Skip to content

Commit

Permalink
Optimize idxmin, idxmax with dask (#9800)
Browse files Browse the repository at this point in the history
* Optimize idxmin, idxmax with dask

Closes #9425

* use map_blocks instead

* small edits

* fix typing

* try again

* Migrate to DaskIndexingAdapter

* cleanup

* fix?

* finish

* fix types

* review comments
  • Loading branch information
dcherian authored Dec 18, 2024
1 parent 8afed74 commit 9fe816e
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 30 deletions.
23 changes: 13 additions & 10 deletions xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,12 @@
from xarray.core.merge import merge_attrs, merge_coordinates_without_align
from xarray.core.options import OPTIONS, _get_keep_attrs
from xarray.core.types import Dims, T_DataArray
from xarray.core.utils import is_dict_like, is_scalar, parse_dims_as_set, result_name
from xarray.core.utils import (
is_dict_like,
is_scalar,
parse_dims_as_set,
result_name,
)
from xarray.core.variable import Variable
from xarray.namedarray.parallelcompat import get_chunked_array_type
from xarray.namedarray.pycompat import is_chunked_array
Expand Down Expand Up @@ -2166,19 +2171,17 @@ def _calc_idxminmax(
indx = func(array, dim=dim, axis=None, keep_attrs=keep_attrs, skipna=skipna)

# Handle chunked arrays (e.g. dask).
coord = array[dim]._variable.to_base_variable()
if is_chunked_array(array.data):
chunkmanager = get_chunked_array_type(array.data)
chunks = dict(zip(array.dims, array.chunks, strict=True))
dask_coord = chunkmanager.from_array(array[dim].data, chunks=chunks[dim])
data = dask_coord[duck_array_ops.ravel(indx.data)]
coord_array = chunkmanager.from_array(
array[dim].data, chunks=((array.sizes[dim],),)
)
coord = coord.copy(data=coord_array)
else:
arr_coord = to_like_array(array[dim].data, array.data)
data = arr_coord[duck_array_ops.ravel(indx.data)]
coord = coord.copy(data=to_like_array(array[dim].data, array.data))

# rebuild like the argmin/max output, and rename as the dim name
data = duck_array_ops.reshape(data, indx.shape)
res = indx.copy(data=data)
res.name = dim
res = indx._replace(coord[(indx.variable,)]).rename(dim)

if skipna or (skipna is None and array.dtype.kind in na_dtypes):
# Put the NaN values back in after removing them
Expand Down
61 changes: 54 additions & 7 deletions xarray/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import enum
import functools
import math
import operator
from collections import Counter, defaultdict
from collections.abc import Callable, Hashable, Iterable, Mapping
Expand Down Expand Up @@ -472,12 +473,6 @@ def __init__(self, key: tuple[slice | np.ndarray[Any, np.dtype[np.generic]], ...
for k in key:
if isinstance(k, slice):
k = as_integer_slice(k)
elif is_duck_dask_array(k):
raise ValueError(
"Vectorized indexing with Dask arrays is not supported. "
"Please pass a numpy array by calling ``.compute``. "
"See https://github.com/dask/dask/issues/8958."
)
elif is_duck_array(k):
if not np.issubdtype(k.dtype, np.integer):
raise TypeError(
Expand Down Expand Up @@ -1508,6 +1503,7 @@ def _oindex_get(self, indexer: OuterIndexer):
return self.array[key]

def _vindex_get(self, indexer: VectorizedIndexer):
_assert_not_chunked_indexer(indexer.tuple)
array = NumpyVIndexAdapter(self.array)
return array[indexer.tuple]

Expand Down Expand Up @@ -1607,6 +1603,28 @@ def transpose(self, order):
return xp.permute_dims(self.array, order)


def _apply_vectorized_indexer_dask_wrapper(indices, coord):
from xarray.core.indexing import (
VectorizedIndexer,
apply_indexer,
as_indexable,
)

return apply_indexer(
as_indexable(coord), VectorizedIndexer((indices.squeeze(axis=-1),))
)


def _assert_not_chunked_indexer(idxr: tuple[Any, ...]) -> None:
if any(is_chunked_array(i) for i in idxr):
raise ValueError(
"Cannot index with a chunked array indexer. "
"Please chunk the array you are indexing first, "
"and drop any indexed dimension coordinate variables. "
"Alternatively, call `.compute()` on any chunked arrays in the indexer."
)


class DaskIndexingAdapter(ExplicitlyIndexedNDArrayMixin):
"""Wrap a dask array to support explicit indexing."""

Expand All @@ -1630,7 +1648,35 @@ def _oindex_get(self, indexer: OuterIndexer):
return value

def _vindex_get(self, indexer: VectorizedIndexer):
return self.array.vindex[indexer.tuple]
try:
return self.array.vindex[indexer.tuple]
except IndexError as e:
# TODO: upstream to dask
has_dask = any(is_duck_dask_array(i) for i in indexer.tuple)
# this only works for "small" 1d coordinate arrays with one chunk
# it is intended for idxmin, idxmax, and allows indexing with
# the nD array output of argmin, argmax
if (
not has_dask
or len(indexer.tuple) > 1
or math.prod(self.array.numblocks) > 1
or self.array.ndim > 1
):
raise e
(idxr,) = indexer.tuple
if idxr.ndim == 0:
return self.array[idxr.data]
else:
import dask.array

return dask.array.map_blocks(
_apply_vectorized_indexer_dask_wrapper,
idxr[..., np.newaxis],
self.array,
chunks=idxr.chunks,
drop_axis=-1,
dtype=self.array.dtype,
)

def __getitem__(self, indexer: ExplicitIndexer):
self._check_and_raise_if_non_basic_indexer(indexer)
Expand Down Expand Up @@ -1770,6 +1816,7 @@ def _vindex_get(
| np.datetime64
| np.timedelta64
):
_assert_not_chunked_indexer(indexer.tuple)
key = self._prepare_key(indexer.tuple)

if getattr(key, "ndim", 0) > 1: # Return np-array if multidimensional
Expand Down
13 changes: 13 additions & 0 deletions xarray/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -1815,3 +1815,16 @@ def test_minimize_graph_size():
# all the other dimensions.
# e.g. previously for 'x', actual == numchunks['y'] * numchunks['z']
assert actual == numchunks[var], (actual, numchunks[var])


def test_idxmin_chunking():
# GH9425
x, y, t = 100, 100, 10
rang = np.arange(t * x * y)
da = xr.DataArray(
rang.reshape(t, x, y), coords={"time": range(t), "x": range(x), "y": range(y)}
)
da = da.chunk(dict(time=-1, x=25, y=25))
actual = da.idxmin("time")
assert actual.chunksizes == {k: da.chunksizes[k] for k in ["x", "y"]}
assert_identical(actual, da.compute().idxmin("time"))
17 changes: 10 additions & 7 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -4949,7 +4949,15 @@ def test_argmax(

assert_identical(result2, expected2)

@pytest.mark.parametrize("use_dask", [True, False])
@pytest.mark.parametrize(
"use_dask",
[
pytest.param(
True, marks=pytest.mark.skipif(not has_dask, reason="no dask")
),
False,
],
)
def test_idxmin(
self,
x: np.ndarray,
Expand All @@ -4958,16 +4966,11 @@ def test_idxmin(
nanindex: int | None,
use_dask: bool,
) -> None:
if use_dask and not has_dask:
pytest.skip("requires dask")
if use_dask and x.dtype.kind == "M":
pytest.xfail("dask operation 'argmin' breaks when dtype is datetime64 (M)")
ar0_raw = xr.DataArray(
x, dims=["x"], coords={"x": np.arange(x.size) * 4}, attrs=self.attrs
)

if use_dask:
ar0 = ar0_raw.chunk({})
ar0 = ar0_raw.chunk()
else:
ar0 = ar0_raw

Expand Down
31 changes: 25 additions & 6 deletions xarray/tests/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -974,7 +974,7 @@ def test_indexing_1d_object_array() -> None:


@requires_dask
def test_indexing_dask_array():
def test_indexing_dask_array() -> None:
import dask.array

da = DataArray(
Expand All @@ -988,34 +988,53 @@ def test_indexing_dask_array():


@requires_dask
def test_indexing_dask_array_scalar():
def test_indexing_dask_array_scalar() -> None:
# GH4276
import dask.array

a = dask.array.from_array(np.linspace(0.0, 1.0))
da = DataArray(a, dims="x")
x_selector = da.argmax(dim=...)
assert not isinstance(x_selector, DataArray)
with raise_if_dask_computes():
actual = da.isel(x_selector)
expected = da.isel(x=-1)
assert_identical(actual, expected)


@requires_dask
def test_vectorized_indexing_dask_array():
def test_vectorized_indexing_dask_array() -> None:
# https://github.com/pydata/xarray/issues/2511#issuecomment-563330352
darr = DataArray(data=[0.2, 0.4, 0.6], coords={"z": range(3)}, dims=("z",))
indexer = DataArray(
data=np.random.randint(0, 3, 8).reshape(4, 2).astype(int),
coords={"y": range(4), "x": range(2)},
dims=("y", "x"),
)
with pytest.raises(ValueError, match="Vectorized indexing with Dask arrays"):
darr[indexer.chunk({"y": 2})]
expected = darr[indexer]

# fails because we can't index pd.Index lazily (yet).
# We could make this succeed by auto-chunking the values
# and constructing a lazy index variable, and not automatically
# create an index for it.
with pytest.raises(ValueError, match="Cannot index with"):
with raise_if_dask_computes():
darr.chunk()[indexer.chunk({"y": 2})]
with pytest.raises(ValueError, match="Cannot index with"):
with raise_if_dask_computes():
actual = darr[indexer.chunk({"y": 2})]

with raise_if_dask_computes():
actual = darr.drop_vars("z").chunk()[indexer.chunk({"y": 2})]
assert_identical(actual, expected.drop_vars("z"))

with raise_if_dask_computes():
actual_variable = darr.variable.chunk()[indexer.variable.chunk({"y": 2})]
assert_identical(actual_variable, expected.variable)


@requires_dask
def test_advanced_indexing_dask_array():
def test_advanced_indexing_dask_array() -> None:
# GH4663
import dask.array as da

Expand Down

0 comments on commit 9fe816e

Please sign in to comment.