Skip to content

Commit

Permalink
Switch to Variable
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Dec 18, 2024
1 parent 633b361 commit 18c96ea
Showing 1 changed file with 12 additions and 11 deletions.
23 changes: 12 additions & 11 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,25 +191,26 @@ def _vindex_like(
else:
da = da.chunk("auto")

array = da._variable._data
var = da._variable
array = var._data

import dask.array
from dask.array.core import slices_from_chunks
from dask.graph_manipulation import clone

from xarray.core.dask_array_compat import reshape_blockwise

array = clone(array) # FIXME: add to dask
# array = clone(array) # FIXME: add to dask

# dimensions for indexed result
out_dims = tuple(
itertools.chain(*(indexer.dims if this == dim else (this,) for this in da.dims))
itertools.chain(
*(indexer.dims if this == dim else (this,) for this in var.dims)
)
)
out_chunks = tuple(
da.chunksizes.get(dim, like_da.chunksizes[dim]) for dim in out_dims
)
out_shape = tuple(da.sizes.get(dim, like_da.sizes[dim]) for dim in out_dims)
out_shape = tuple(var.sizes.get(dim, like_da.sizes[dim]) for dim in out_dims)
idxr = indexer._variable._data

# shuffle indices that can be reshaped blockwise to desired shape
Expand All @@ -221,17 +222,17 @@ def _vindex_like(
flat_indices = [
idxr[slicer].ravel().tolist() for slicer in slices_from_chunks(core_dim_chunks)
]
shuffled = dask.array.shuffle(
array, flat_indices, axis=da.get_axis_num(dim), chunks="auto"
)
shuffled = var._shuffle(flat_indices, dim=dim, chunks="auto")
# shuffle with `chunks="auto"` could change chunks, so we recalculate out_chunks
new_chunksizes = dict(zip(da.dims, shuffled.chunks, strict=True))
new_chunksizes = dict(zip(var.dims, shuffled.chunks, strict=True))
out_chunks = tuple(
new_chunksizes.get(dim, like_da.chunksizes[dim]) for dim in out_dims
)
if shuffled.shape != out_shape:
shuffled = reshape_blockwise(shuffled, shape=out_shape, chunks=out_chunks)
return Variable(dims=out_dims, data=shuffled, attrs=da.attrs)
out_data = reshape_blockwise(shuffled._data, shape=out_shape, chunks=out_chunks)
else:
out_data = shuffled._data
return Variable(out_dims, out_data, var.attrs)


class _DummyGroup(Generic[T_Xarray]):
Expand Down

0 comments on commit 18c96ea

Please sign in to comment.