Skip to content

Commit

Permalink
Rewrite interp to use apply_ufunc (#9881)
Browse files Browse the repository at this point in the history
* Don't eagerly compute dask arrays in localize

* Clean up test

* Clean up Variable handling

* Silence test warning

* Use apply_ufunc instead

* Add test for #4463

Closes #4463

* complete tests

* Add comments

* Clear up broadcasting

* typing

* try a different warning filter

* one more fix

* types + more duck_array_ops

* fixes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Apply suggestions from code review

Co-authored-by: Michael Niklas  <[email protected]>

* Apply suggestions from code review

Co-authored-by: Illviljan <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Apply suggestions from code review

Co-authored-by: Illviljan <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* Revert "Apply suggestions from code review"

This reverts commit 1b9845d.

---------

Co-authored-by: Illviljan <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Michael Niklas <[email protected]>
  • Loading branch information
4 people authored Dec 19, 2024
1 parent a90fff9 commit 29fe679
Show file tree
Hide file tree
Showing 3 changed files with 269 additions and 204 deletions.
55 changes: 18 additions & 37 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2921,19 +2921,11 @@ def _validate_interp_indexers(
"""Variant of _validate_indexers to be used for interpolation"""
for k, v in self._validate_indexers(indexers):
if isinstance(v, Variable):
if v.ndim == 1:
yield k, v.to_index_variable()
else:
yield k, v
elif isinstance(v, int):
yield k, v
elif is_scalar(v):
yield k, Variable((), v, attrs=self.coords[k].attrs)
elif isinstance(v, np.ndarray):
if v.ndim == 0:
yield k, Variable((), v, attrs=self.coords[k].attrs)
elif v.ndim == 1:
yield k, IndexVariable((k,), v, attrs=self.coords[k].attrs)
else:
raise AssertionError() # Already tested by _validate_indexers
yield k, Variable(dims=(k,), data=v, attrs=self.coords[k].attrs)
else:
raise TypeError(type(v))

Expand Down Expand Up @@ -4127,18 +4119,6 @@ def interp(

coords = either_dict_or_kwargs(coords, coords_kwargs, "interp")
indexers = dict(self._validate_interp_indexers(coords))

if coords:
# This avoids broadcasting over coordinates that are both in
# the original array AND in the indexing array. It essentially
# forces interpolation along the shared coordinates.
sdims = (
set(self.dims)
.intersection(*[set(nx.dims) for nx in indexers.values()])
.difference(coords.keys())
)
indexers.update({d: self.variables[d] for d in sdims})

obj = self if assume_sorted else self.sortby(list(coords))

def maybe_variable(obj, k):
Expand Down Expand Up @@ -4169,16 +4149,18 @@ def _validate_interp_indexer(x, new_x):
for k, v in indexers.items()
}

# optimization: subset to coordinate range of the target index
if method in ["linear", "nearest"]:
for k, v in validated_indexers.items():
obj, newidx = missing._localize(obj, {k: v})
validated_indexers[k] = newidx[k]

# optimization: create dask coordinate arrays once per Dataset
# rather than once per Variable when dask.array.unify_chunks is called later
# GH4739
if obj.__dask_graph__():
has_chunked_array = bool(
any(is_chunked_array(v._data) for v in obj._variables.values())
)
if has_chunked_array:
# optimization: subset to coordinate range of the target index
if method in ["linear", "nearest"]:
for k, v in validated_indexers.items():
obj, newidx = missing._localize(obj, {k: v})
validated_indexers[k] = newidx[k]
# optimization: create dask coordinate arrays once per Dataset
# rather than once per Variable when dask.array.unify_chunks is called later
# GH4739
dask_indexers = {
k: (index.to_base_variable().chunk(), dest.to_base_variable().chunk())
for k, (index, dest) in validated_indexers.items()
Expand All @@ -4190,10 +4172,9 @@ def _validate_interp_indexer(x, new_x):
if name in indexers:
continue

if is_duck_dask_array(var.data):
use_indexers = dask_indexers
else:
use_indexers = validated_indexers
use_indexers = (
dask_indexers if is_duck_dask_array(var.data) else validated_indexers
)

dtype_kind = var.dtype.kind
if dtype_kind in "uifc":
Expand Down
Loading

0 comments on commit 29fe679

Please sign in to comment.