diff --git a/doc/api.rst b/doc/api.rst index a8f8ea7dd1c..4cf8f374d37 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -111,6 +111,7 @@ Dataset contents Dataset.drop_duplicates Dataset.drop_dims Dataset.drop_encoding + Dataset.drop_attrs Dataset.set_coords Dataset.reset_coords Dataset.convert_calendar @@ -306,6 +307,7 @@ DataArray contents DataArray.drop_indexes DataArray.drop_duplicates DataArray.drop_encoding + DataArray.drop_attrs DataArray.reset_coords DataArray.copy DataArray.convert_calendar diff --git a/doc/whats-new.rst b/doc/whats-new.rst index e8369dc2f40..6a8e898c93c 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -26,6 +26,10 @@ New Features By `Martin Raspaud `_. - Extract the source url from fsspec objects (:issue:`9142`, :pull:`8923`). By `Justus Magin `_. +- Add :py:meth:`DataArray.drop_attrs` & :py:meth:`Dataset.drop_attrs` methods, + to return an object without ``attrs``. A ``deep`` parameter controls whether + variables' ``attrs`` are also dropped. + By `Maximilian Roos `_. (:pull:`8288`) Breaking changes ~~~~~~~~~~~~~~~~ diff --git a/properties/__init__.py b/properties/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/pyproject.toml b/pyproject.toml index 2ada0c1c171..4704751b445 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,6 +87,7 @@ exclude_lines = ["pragma: no cover", "if TYPE_CHECKING"] [tool.mypy] enable_error_code = "redundant-self" exclude = [ + 'build', 'xarray/util/generate_.*\.py', 'xarray/datatree_/doc/.*\.py', ] diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index b67f8089eb2..47dc9d13ffc 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -7456,3 +7456,20 @@ def to_dask_dataframe( # this needs to be at the end, or mypy will confuse with `str` # https://mypy.readthedocs.io/en/latest/common_issues.html#dealing-with-conflicting-names str = utils.UncachedAccessor(StringAccessor["DataArray"]) + + def drop_attrs(self, *, deep: bool = True) -> Self: + """ + Removes all attributes from the DataArray. + + Parameters + ---------- + deep : bool, default True + Removes attributes from coordinates. + + Returns + ------- + DataArray + """ + return ( + self._to_temp_dataset().drop_attrs(deep=deep).pipe(self._from_temp_dataset) + ) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 7e0d0151d95..3e57b3630de 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -10681,3 +10681,45 @@ def resample( restore_coord_dims=restore_coord_dims, **indexer_kwargs, ) + + def drop_attrs(self, *, deep: bool = True) -> Self: + """ + Removes all attributes from the Dataset and its variables. + + Parameters + ---------- + deep : bool, default True + Removes attributes from all variables. + + Returns + ------- + Dataset + """ + # Remove attributes from the dataset + self = self._replace(attrs={}) + + if not deep: + return self + + # Remove attributes from each variable in the dataset + for var in self.variables: + # variables don't have a `._replace` method, so we copy and then remove + # attrs. If we added a `._replace` method, we could use that instead. + if var not in self.indexes: + self[var] = self[var].copy() + self[var].attrs = {} + + new_idx_variables = {} + # Not sure this is the most elegant way of doing this, but it works. + # (Should we have a more general "map over all variables, including + # indexes" approach?) + for idx, idx_vars in self.xindexes.group_by_index(): + # copy each coordinate variable of an index and drop their attrs + temp_idx_variables = {k: v.copy() for k, v in idx_vars.items()} + for v in temp_idx_variables.values(): + v.attrs = {} + # re-wrap the index object in new coordinate variables + new_idx_variables.update(idx.create_variables(temp_idx_variables)) + self = self.assign(new_idx_variables) + + return self diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 659c7c168a5..44ef486e5d6 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -2980,6 +2980,11 @@ def test_assign_attrs(self) -> None: assert_identical(new_actual, expected) assert actual.attrs == {"a": 1, "b": 2} + def test_drop_attrs(self) -> None: + # Mostly tested in test_dataset.py, but adding a very small test here + da = DataArray([], attrs=dict(a=1, b=2)) + assert da.drop_attrs().attrs == {} + @pytest.mark.parametrize( "func", [lambda x: x.clip(0, 1), lambda x: np.float64(1.0) * x, np.abs, abs] ) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index f6829861776..fd511af0dfb 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -4450,6 +4450,54 @@ def test_assign_attrs(self) -> None: assert_identical(new_actual, expected) assert actual.attrs == dict(a=1, b=2) + def test_drop_attrs(self) -> None: + # Simple example + ds = Dataset().assign_attrs(a=1, b=2) + original = ds.copy() + expected = Dataset() + result = ds.drop_attrs() + assert_identical(result, expected) + + # Doesn't change original + assert_identical(ds, original) + + # Example with variables and coords with attrs, and a multiindex. (arguably + # should have used a canonical dataset with all the features we're should + # support...) + var = Variable("x", [1, 2, 3], attrs=dict(x=1, y=2)) + idx = IndexVariable("y", [1, 2, 3], attrs=dict(c=1, d=2)) + mx = xr.Coordinates.from_pandas_multiindex( + pd.MultiIndex.from_tuples([(1, 2), (3, 4)], names=["d", "e"]), "z" + ) + ds = Dataset(dict(var1=var), coords=dict(y=idx, z=mx)).assign_attrs(a=1, b=2) + assert ds.attrs != {} + assert ds["var1"].attrs != {} + assert ds["y"].attrs != {} + assert ds.coords["y"].attrs != {} + + original = ds.copy(deep=True) + result = ds.drop_attrs() + + assert result.attrs == {} + assert result["var1"].attrs == {} + assert result["y"].attrs == {} + assert list(result.data_vars) == list(ds.data_vars) + assert list(result.coords) == list(ds.coords) + + # Doesn't change original + assert_identical(ds, original) + # Specifically test that the attrs on the coords are still there. (The index + # can't currently contain `attrs`, so we can't test those.) + assert ds.coords["y"].attrs != {} + + # Test for deep=False + result_shallow = ds.drop_attrs(deep=False) + assert result_shallow.attrs == {} + assert result_shallow["var1"].attrs != {} + assert result_shallow["y"].attrs != {} + assert list(result.data_vars) == list(ds.data_vars) + assert list(result.coords) == list(ds.coords) + def test_assign_multiindex_level(self) -> None: data = create_test_multiindex() with pytest.raises(ValueError, match=r"cannot drop or update.*corrupt.*index "): diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index fa08e9975ab..578e6bcc18e 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -158,9 +158,10 @@ def setup(self) -> Generator: plt.close("all") def pass_in_axis(self, plotmethod, subplot_kw=None) -> None: - fig, axs = plt.subplots(ncols=2, subplot_kw=subplot_kw) - plotmethod(ax=axs[0]) - assert axs[0].has_data() + fig, axs = plt.subplots(ncols=2, subplot_kw=subplot_kw, squeeze=False) + ax = axs[0, 0] + plotmethod(ax=ax) + assert ax.has_data() @pytest.mark.slow def imshow_called(self, plotmethod) -> bool: @@ -240,9 +241,9 @@ def test_1d_x_y_kw(self) -> None: xy: list[list[None | str]] = [[None, None], [None, "z"], ["z", None]] - f, ax = plt.subplots(3, 1) + f, axs = plt.subplots(3, 1, squeeze=False) for aa, (x, y) in enumerate(xy): - da.plot(x=x, y=y, ax=ax.flat[aa]) + da.plot(x=x, y=y, ax=axs.flat[aa]) with pytest.raises(ValueError, match=r"Cannot specify both"): da.plot(x="z", y="z") @@ -1566,7 +1567,9 @@ def test_colorbar_kwargs(self) -> None: assert "MyLabel" in alltxt assert "testvar" not in alltxt # change cbar ax - fig, (ax, cax) = plt.subplots(1, 2) + fig, axs = plt.subplots(1, 2, squeeze=False) + ax = axs[0, 0] + cax = axs[0, 1] self.plotmethod( ax=ax, cbar_ax=cax, add_colorbar=True, cbar_kwargs={"label": "MyBar"} ) @@ -1576,7 +1579,9 @@ def test_colorbar_kwargs(self) -> None: assert "MyBar" in alltxt assert "testvar" not in alltxt # note that there are two ways to achieve this - fig, (ax, cax) = plt.subplots(1, 2) + fig, axs = plt.subplots(1, 2, squeeze=False) + ax = axs[0, 0] + cax = axs[0, 1] self.plotmethod( ax=ax, add_colorbar=True, cbar_kwargs={"label": "MyBar", "cax": cax} ) @@ -3371,16 +3376,16 @@ def test_plot1d_default_rcparams() -> None: # see overlapping markers: fig, ax = plt.subplots(1, 1) ds.plot.scatter(x="A", y="B", marker="o", ax=ax) - np.testing.assert_allclose( - ax.collections[0].get_edgecolor(), mpl.colors.to_rgba_array("w") - ) + actual: np.ndarray = mpl.colors.to_rgba_array("w") + expected: np.ndarray = ax.collections[0].get_edgecolor() # type: ignore[assignment] # mpl error? + np.testing.assert_allclose(actual, expected) # Facetgrids should have the default value as well: fg = ds.plot.scatter(x="A", y="B", col="x", marker="o") ax = fg.axs.ravel()[0] - np.testing.assert_allclose( - ax.collections[0].get_edgecolor(), mpl.colors.to_rgba_array("w") - ) + actual = mpl.colors.to_rgba_array("w") + expected = ax.collections[0].get_edgecolor() # type: ignore[assignment] # mpl error? + np.testing.assert_allclose(actual, expected) # scatter should not emit any warnings when using unfilled markers: with assert_no_warnings(): @@ -3390,9 +3395,9 @@ def test_plot1d_default_rcparams() -> None: # Prioritize edgecolor argument over default plot1d values: fig, ax = plt.subplots(1, 1) ds.plot.scatter(x="A", y="B", marker="o", ax=ax, edgecolor="k") - np.testing.assert_allclose( - ax.collections[0].get_edgecolor(), mpl.colors.to_rgba_array("k") - ) + actual = mpl.colors.to_rgba_array("k") + expected = ax.collections[0].get_edgecolor() # type: ignore[assignment] # mpl error? + np.testing.assert_allclose(actual, expected) @requires_matplotlib