From dafcde2d603b1ce9e80e9b02cdfd238e34f21cfa Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 26 Nov 2024 07:45:05 +0100 Subject: [PATCH] Use compute instead of load in plot (#9818) * Only compute this array, load computes in place * Add test * Update whats-new.rst * Update whats-new.rst --- doc/whats-new.rst | 3 ++- xarray/plot/dataarray_plot.py | 2 +- xarray/tests/test_plot.py | 19 +++++++++++++++++++ 3 files changed, 22 insertions(+), 2 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 906fd0a25b2..4fb23123f4b 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -36,7 +36,8 @@ Deprecations Bug fixes ~~~~~~~~~ - +- Fix unintended load on datasets when calling :py:meth:`DataArray.plot.scatter` (:pull:`9818`). + By `Jimmy Westling `_. Documentation ~~~~~~~~~~~~~ diff --git a/xarray/plot/dataarray_plot.py b/xarray/plot/dataarray_plot.py index c668d78660c..cca9fe4f561 100644 --- a/xarray/plot/dataarray_plot.py +++ b/xarray/plot/dataarray_plot.py @@ -946,7 +946,7 @@ def newplotfunc( # Remove any nulls, .where(m, drop=True) doesn't work when m is # a dask array, so load the array to memory. # It will have to be loaded to memory at some point anyway: - darray = darray.load() + darray = darray.compute() darray = darray.where(darray.notnull(), drop=True) else: size_ = kwargs.pop("_size", linewidth) diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 50254ef4198..1e07459061f 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -33,6 +33,7 @@ assert_no_warnings, requires_cartopy, requires_cftime, + requires_dask, requires_matplotlib, requires_seaborn, ) @@ -3326,6 +3327,24 @@ def test_datarray_scatter( ) +@requires_dask +@requires_matplotlib +@pytest.mark.parametrize( + "plotfunc", + ["scatter"], +) +def test_dataarray_not_loading_inplace(plotfunc: str) -> None: + ds = xr.tutorial.scatter_example_dataset() + ds = ds.chunk() + + with figure_context(): + getattr(ds.A.plot, plotfunc)(x="x") + + from dask.array import Array + + assert isinstance(ds.A.data, Array) + + @requires_matplotlib def test_assert_valid_xy() -> None: ds = xr.tutorial.scatter_example_dataset()