From 3024655e3689c11908d221913cdb922bcfc69037 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 9 Jul 2024 09:09:18 +0200 Subject: [PATCH] Only use necessary dims when creating temporary dataarray (#9206) * Only use necessary dims when creating temporary dataarray * Update dataset_plot.py * Can't check only data_vars all corrds are no longer added by default * Update dataset_plot.py * Add tests * Update whats-new.rst * Update dataset_plot.py --- doc/whats-new.rst | 2 ++ xarray/plot/dataset_plot.py | 15 +++++++++----- xarray/tests/test_plot.py | 40 +++++++++++++++++++++++++++++++++++++ 3 files changed, 52 insertions(+), 5 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 8c6b3a099c2..0c401c2348e 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -37,6 +37,8 @@ Deprecations Bug fixes ~~~~~~~~~ +- Fix scatter plot broadcasting unneccesarily. (:issue:`9129`, :pull:`9206`) + By `Jimmy Westling `_. - Don't convert custom indexes to ``pandas`` indexes when computing a diff (:pull:`9157`) By `Justus Magin `_. - Make :py:func:`testing.assert_allclose` work with numpy 2.0 (:issue:`9165`, :pull:`9166`). diff --git a/xarray/plot/dataset_plot.py b/xarray/plot/dataset_plot.py index edc2bf43629..96b59f6174e 100644 --- a/xarray/plot/dataset_plot.py +++ b/xarray/plot/dataset_plot.py @@ -721,8 +721,8 @@ def _temp_dataarray(ds: Dataset, y: Hashable, locals_: dict[str, Any]) -> DataAr """Create a temporary datarray with extra coords.""" from xarray.core.dataarray import DataArray - # Base coords: - coords = dict(ds.coords) + coords = dict(ds[y].coords) + dims = set(ds[y].dims) # Add extra coords to the DataArray from valid kwargs, if using all # kwargs there is a risk that we add unnecessary dataarrays as @@ -732,12 +732,17 @@ def _temp_dataarray(ds: Dataset, y: Hashable, locals_: dict[str, Any]) -> DataAr coord_kwargs = locals_.keys() & valid_coord_kwargs for k in coord_kwargs: key = locals_[k] - if ds.data_vars.get(key) is not None: - coords[key] = ds[key] + darray = ds.get(key) + if darray is not None: + coords[key] = darray + dims.update(darray.dims) + + # Trim dataset from unneccessary dims: + ds_trimmed = ds.drop_dims(ds.sizes.keys() - dims) # TODO: Use ds.dims in the future # The dataarray has to include all the dims. Broadcast to that shape # and add the additional coords: - _y = ds[y].broadcast_like(ds) + _y = ds[y].broadcast_like(ds_trimmed) return DataArray(_y, coords=coords) diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index b302ad3af93..fa08e9975ab 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -3416,3 +3416,43 @@ def test_9155() -> None: data = xr.DataArray([1, 2, 3], dims=["x"]) fig, ax = plt.subplots(ncols=1, nrows=1) data.plot(ax=ax) + + +@requires_matplotlib +def test_temp_dataarray() -> None: + from xarray.plot.dataset_plot import _temp_dataarray + + x = np.arange(1, 4) + y = np.arange(4, 6) + var1 = np.arange(x.size * y.size).reshape((x.size, y.size)) + var2 = np.arange(x.size * y.size).reshape((x.size, y.size)) + ds = xr.Dataset( + { + "var1": (["x", "y"], var1), + "var2": (["x", "y"], 2 * var2), + "var3": (["x"], 3 * x), + }, + coords={ + "x": x, + "y": y, + "model": np.arange(7), + }, + ) + + # No broadcasting: + y_ = "var1" + locals_ = {"x": "var2"} + da = _temp_dataarray(ds, y_, locals_) + assert da.shape == (3, 2) + + # Broadcast from 1 to 2dim: + y_ = "var3" + locals_ = {"x": "var1"} + da = _temp_dataarray(ds, y_, locals_) + assert da.shape == (3, 2) + + # Ignore non-valid coord kwargs: + y_ = "var3" + locals_ = dict(x="x", extend="var2") + da = _temp_dataarray(ds, y_, locals_) + assert da.shape == (3,)