diff --git a/xarray/plot/dataset_plot.py b/xarray/plot/dataset_plot.py index b70304789a1..96b59f6174e 100644 --- a/xarray/plot/dataset_plot.py +++ b/xarray/plot/dataset_plot.py @@ -722,7 +722,7 @@ def _temp_dataarray(ds: Dataset, y: Hashable, locals_: dict[str, Any]) -> DataAr from xarray.core.dataarray import DataArray coords = dict(ds[y].coords) - dims = dict.fromkeys(ds[y].dims) + dims = set(ds[y].dims) # Add extra coords to the DataArray from valid kwargs, if using all # kwargs there is a risk that we add unnecessary dataarrays as @@ -735,11 +735,10 @@ def _temp_dataarray(ds: Dataset, y: Hashable, locals_: dict[str, Any]) -> DataAr darray = ds.get(key) if darray is not None: coords[key] = darray - for d in darray.dims: - dims[d] = None + dims.update(darray.dims) # Trim dataset from unneccessary dims: - ds_trimmed = ds.drop_dims(ds.dims - dims.keys()) + ds_trimmed = ds.drop_dims(ds.sizes.keys() - dims) # TODO: Use ds.dims in the future # The dataarray has to include all the dims. Broadcast to that shape # and add the additional coords: