Skip to content

Commit

Permalink
Optimize broadcasting (#230)
Browse files Browse the repository at this point in the history
* Optimize broadcasting

xref pydata/xarray#7730

* reorder

* Fix tests

* Another optimization

* fixes

* fix
  • Loading branch information
dcherian authored May 4, 2023
1 parent aa358a5 commit 622ddb2
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 7 deletions.
20 changes: 13 additions & 7 deletions flox/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,8 +257,6 @@ def xarray_reduce(
more_drop.update(idx_other_names)
maybe_drop.update(more_drop)

ds = ds.drop_vars([var for var in maybe_drop if var in ds.variables])

if dim is Ellipsis:
if nby > 1:
raise NotImplementedError("Multiple by are not allowed when dim is Ellipsis.")
Expand All @@ -275,17 +273,23 @@ def xarray_reduce(
# broadcast to make sure grouper dimensions are present in the array.
exclude_dims = tuple(d for d in ds.dims if d not in grouper_dims and d not in dim_tuple)

if any(d not in grouper_dims and d not in obj.dims for d in dim_tuple):
raise ValueError(f"Cannot reduce over absent dimensions {dim}.")

try:
xr.align(ds, *by_da, join="exact", copy=False)
except ValueError as e:
raise ValueError(
"Object being grouped must be exactly aligned with every array in `by`."
) from e

ds_broad = xr.broadcast(ds, *by_da, exclude=exclude_dims)[0]

if any(d not in grouper_dims and d not in obj.dims for d in dim_tuple):
raise ValueError(f"Cannot reduce over absent dimensions {dim}.")
needs_broadcast = any(
not set(grouper_dims).issubset(set(variable.dims)) for variable in ds.data_vars.values()
)
if needs_broadcast:
ds_broad = xr.broadcast(ds, *by_da, exclude=exclude_dims)[0]
else:
ds_broad = ds

dims_not_in_groupers = tuple(d for d in dim_tuple if d not in grouper_dims)
if dims_not_in_groupers == tuple(dim_tuple) and not any(isbins):
Expand All @@ -305,6 +309,8 @@ def xarray_reduce(
else:
return result

ds = ds.drop_vars([var for var in maybe_drop if var in ds.variables])

axis = tuple(range(-len(dim_tuple), 0))

# Set expected_groups and convert to index since we need coords, sizes
Expand Down Expand Up @@ -432,7 +438,7 @@ def wrapper(array, *by, func, skipna, core_dims, **kwargs):

# restore non-dim coord variables without the core dimension
# TODO: shouldn't apply_ufunc handle this?
for var in set(ds_broad.variables) - set(ds_broad._indexes) - set(ds_broad.dims):
for var in set(ds_broad._coord_names) - set(ds_broad._indexes) - set(ds_broad.dims):
if all(d not in ds_broad[var].dims for d in dim_tuple):
actual[var] = ds_broad[var]

Expand Down
2 changes: 2 additions & 0 deletions tests/test_xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,8 @@ def test_multi_index_groupby_sum(engine):
expected = ds.sum("z")
stacked = ds.stack(space=["x", "y"])
actual = xarray_reduce(stacked, "space", dim="z", func="sum", engine=engine)
expected_xarray = stacked.groupby("space").sum("z")
assert_equal(expected_xarray, actual)
assert_equal(expected, actual.unstack("space"))

actual = xarray_reduce(stacked.foo, "space", dim="z", func="sum", engine=engine)
Expand Down

0 comments on commit 622ddb2

Please sign in to comment.