Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use shuffle in groupby binary ops. #9896

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 73 additions & 17 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
peek_at,
)
from xarray.core.variable import IndexVariable, Variable
from xarray.namedarray.pycompat import is_chunked_array
from xarray.namedarray.pycompat import is_chunked_array, is_duck_dask_array

if TYPE_CHECKING:
from numpy.typing import ArrayLike
Expand Down Expand Up @@ -172,6 +172,68 @@ def _inverse_permutation_indices(positions, N: int | None = None) -> np.ndarray
return newpositions[newpositions != -1]


def _vindex_like(
da: DataArray, dim: Hashable, indexer: DataArray, like_ds: Dataset | None
) -> Variable:
"""
Apply a vectorized indexer, optionally matching the chunks of a datarray
of the same name in `like_ds`. This is useful for GroupBy binary ops.
This function is intended to be used with Dataset.map.
"""
# we want to use the fact that we know the chunksizes for the output (matches obj)
# so we can't just use Variable's indexing directly
array = da._variable._data
like_da = like_ds.get(da.name)
if not is_duck_dask_array(array):
if like_da is None or not is_duck_dask_array(like_da._variable._data):
# TODO: we should instead check of `shuffle` and `reshape_blockwise`
return da.isel({dim: indexer})
else:
da = da.chunk("auto")

array = da._variable._data

import dask.array
from dask.array.core import slices_from_chunks
from dask.graph_manipulation import clone

from xarray.core.dask_array_compat import reshape_blockwise

array = clone(array) # FIXME: add to dask
# array = clone(array) # FIXME: add to dask

# dimensions for indexed result
out_dims = tuple(
itertools.chain(*(indexer.dims if this == dim else (this,) for this in da.dims))
)
out_chunks = tuple(
da.chunksizes.get(dim, like_da.chunksizes[dim]) for dim in out_dims
)
out_shape = tuple(da.sizes.get(dim, like_da.sizes[dim]) for dim in out_dims)
idxr = indexer._variable._data

# shuffle indices that can be reshaped blockwise to desired shape
core_dim_chunks = tuple(
chunks
for dim, chunks in zip(out_dims, out_chunks, strict=True)
if dim in indexer.dims
)
flat_indices = [
idxr[slicer].ravel().tolist() for slicer in slices_from_chunks(core_dim_chunks)
]
shuffled = dask.array.shuffle(
array, flat_indices, axis=da.get_axis_num(dim), chunks="auto"
Copy link
Contributor Author

@dcherian dcherian Dec 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@phofl here it would be useful to pass chunks={da.get_axis_num(k): v for k, v in da.chunksizes.items() if k != dim}

)
# shuffle with `chunks="auto"` could change chunks, so we recalculate out_chunks
new_chunksizes = dict(zip(da.dims, shuffled.chunks, strict=True))
out_chunks = tuple(
new_chunksizes.get(dim, like_da.chunksizes[dim]) for dim in out_dims
)
if shuffled.shape != out_shape:
shuffled = reshape_blockwise(shuffled, shape=out_shape, chunks=out_chunks)
return Variable(dims=out_dims, data=shuffled, attrs=da.attrs)


class _DummyGroup(Generic[T_Xarray]):
"""Class for keeping track of grouped dimensions without coordinates.

Expand Down Expand Up @@ -899,25 +961,19 @@ def _binary_op(self, other, f, reflexive=False):
group = group.where(~mask, drop=True)
codes = codes.where(~mask, drop=True).astype(int)

# if other is dask-backed, that's a hint that the
# "expanded" dataset is too big to hold in memory.
# this can be the case when `other` was read from disk
# and contains our lazy indexing classes
# We need to check for dask-backed Datasets
# so utils.is_duck_dask_array does not work for this check
if obj.chunks and not other.chunks:
# TODO: What about datasets with some dask vars, and others not?
# This handles dims other than `name``
chunks = {k: v for k, v in obj.chunksizes.items() if k in other.dims}
# a chunk size of 1 seems reasonable since we expect individual elements of
# other to be repeated multiple times across the reduced dimension(s)
chunks[name] = 1
other = other.chunk(chunks)

# codes are defined for coord, so we align `other` with `coord`
# before indexing
other, _ = align(other, coord, join="right", copy=False)
expanded = other.isel({name: codes})

other_as_dataset = (
other._to_temp_dataset() if isinstance(other, DataArray) else other
)
obj_as_dataset = obj._to_temp_dataset() if isinstance(obj, DataArray) else obj
expanded = other_as_dataset.map(
_vindex_like, dim=name, indexer=codes, like_ds=obj_as_dataset
)
if isinstance(other, DataArray):
expanded = other._from_temp_dataset(expanded)

result = g(obj, expanded)

Expand Down
4 changes: 3 additions & 1 deletion xarray/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -2654,12 +2654,14 @@
dims=("y", "x"),
coords={"label": ("x", [2, 2, 1])},
)
# da.groupby("label").min(...)
sub = xr.DataArray(
InaccessibleArray(np.array([1, 2])), dims="label", coords={"label": [1, 2]}
)
chunked = da.chunk(x=1, y=2)
chunked.label.load()
actual = chunked.groupby("label") - sub
with raise_if_dask_computes():
actual = chunked.groupby("label") - sub

Check failure on line 2664 in xarray/tests/test_groupby.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.10 min-all-deps

test_groupby_math_auto_chunk AttributeError: module 'dask.array' has no attribute 'shuffle'
assert actual.chunksizes == {"x": (1, 1, 1), "y": (2, 1)}


Expand Down
Loading