Skip to content

Commit

Permalink
Fix interpolation when non-numeric coords are present. (#9887)
Browse files Browse the repository at this point in the history
* Fix interpolation when non-numeric coords are present.

Closes #8099
Closes #9839

* fix

* Add basic 1d test

---------

Co-authored-by: Illviljan <[email protected]>
  • Loading branch information
dcherian and Illviljan authored Dec 14, 2024
1 parent f05c5ec commit 755581c
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 5 deletions.
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ Bug fixes
By `Bruce Merry <https://github.com/bmerry>`_.
- Fix unintended load on datasets when calling :py:meth:`DataArray.plot.scatter` (:pull:`9818`).
By `Jimmy Westling <https://github.com/illviljan>`_.
- Fix interpolation when non-numeric coordinate variables are present (:issue:`8099`, :issue:`9839`).
By `Deepak Cherian <https://github.com/dcherian>`_.


Documentation
~~~~~~~~~~~~~
Expand Down
11 changes: 6 additions & 5 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4185,7 +4185,7 @@ def _validate_interp_indexer(x, new_x):
}

variables: dict[Hashable, Variable] = {}
reindex: bool = False
reindex_vars: list[Hashable] = []
for name, var in obj._variables.items():
if name in indexers:
continue
Expand All @@ -4207,19 +4207,20 @@ def _validate_interp_indexer(x, new_x):
# booleans and objects and retains the dtype but inside
# this loop there might be some duplicate code that slows it
# down, therefore collect these signals and run it later:
reindex = True
reindex_vars.append(name)
elif all(d not in indexers for d in var.dims):
# For anything else we can only keep variables if they
# are not dependent on any coords that are being
# interpolated along:
variables[name] = var

if reindex:
reindex_indexers = {
if reindex_vars and (
reindex_indexers := {
k: v for k, (_, v) in validated_indexers.items() if v.dims == (k,)
}
):
reindexed = alignment.reindex(
obj,
obj[reindex_vars],
indexers=reindex_indexers,
method=method_non_numeric,
exclude_vars=variables.keys(),
Expand Down
38 changes: 38 additions & 0 deletions xarray/tests/test_interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1055,3 +1055,41 @@ def test_interp1d_complex_out_of_bounds() -> None:
expected = da.interp(time=3.5, kwargs=dict(fill_value=np.nan + np.nan * 1j))
actual = da.interp(time=3.5)
assert_identical(actual, expected)


@requires_scipy
def test_interp_non_numeric_1d() -> None:
ds = xr.Dataset(
{
"numeric": ("time", 1 + np.arange(0, 4, 1)),
"non_numeric": ("time", np.array(["a", "b", "c", "d"])),
},
coords={"time": (np.arange(0, 4, 1))},
)
actual = ds.interp(time=np.linspace(0, 3, 7))

expected = xr.Dataset(
{
"numeric": ("time", 1 + np.linspace(0, 3, 7)),
"non_numeric": ("time", np.array(["a", "b", "b", "c", "c", "d", "d"])),
},
coords={"time": np.linspace(0, 3, 7)},
)
xr.testing.assert_identical(actual, expected)


@requires_scipy
def test_interp_non_numeric_nd() -> None:
# regression test for GH8099, GH9839
ds = xr.Dataset({"x": ("a", np.arange(4))}, coords={"a": (np.arange(4) - 1.5)})
t = xr.DataArray(
np.random.randn(6).reshape((2, 3)) * 0.5,
dims=["r", "s"],
coords={"r": np.arange(2) - 0.5, "s": np.arange(3) - 1},
)
ds["m"] = ds.x > 1

actual = ds.interp(a=t, method="linear")
# with numeric only
expected = ds[["x"]].interp(a=t, method="linear")
assert_identical(actual[["x"]], expected)

0 comments on commit 755581c

Please sign in to comment.