Skip to content

Commit

Permalink
Use compute instead of load in plot (pydata#9818)
Browse files Browse the repository at this point in the history
* Only compute this array, load computes in place

* Add test

* Update whats-new.rst

* Update whats-new.rst
  • Loading branch information
Illviljan authored Nov 26, 2024
1 parent caf62d3 commit dafcde2
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 2 deletions.
3 changes: 2 additions & 1 deletion doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ Deprecations

Bug fixes
~~~~~~~~~

- Fix unintended load on datasets when calling :py:meth:`DataArray.plot.scatter` (:pull:`9818`).
By `Jimmy Westling <https://github.com/illviljan>`_.

Documentation
~~~~~~~~~~~~~
Expand Down
2 changes: 1 addition & 1 deletion xarray/plot/dataarray_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,7 +946,7 @@ def newplotfunc(
# Remove any nulls, .where(m, drop=True) doesn't work when m is
# a dask array, so load the array to memory.
# It will have to be loaded to memory at some point anyway:
darray = darray.load()
darray = darray.compute()
darray = darray.where(darray.notnull(), drop=True)
else:
size_ = kwargs.pop("_size", linewidth)
Expand Down
19 changes: 19 additions & 0 deletions xarray/tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
assert_no_warnings,
requires_cartopy,
requires_cftime,
requires_dask,
requires_matplotlib,
requires_seaborn,
)
Expand Down Expand Up @@ -3326,6 +3327,24 @@ def test_datarray_scatter(
)


@requires_dask
@requires_matplotlib
@pytest.mark.parametrize(
"plotfunc",
["scatter"],
)
def test_dataarray_not_loading_inplace(plotfunc: str) -> None:
ds = xr.tutorial.scatter_example_dataset()
ds = ds.chunk()

with figure_context():
getattr(ds.A.plot, plotfunc)(x="x")

from dask.array import Array

assert isinstance(ds.A.data, Array)


@requires_matplotlib
def test_assert_valid_xy() -> None:
ds = xr.tutorial.scatter_example_dataset()
Expand Down

0 comments on commit dafcde2

Please sign in to comment.