Skip to content

Commit

Permalink
Merge branch 'main' into namedarray_chunkmanager
Browse files Browse the repository at this point in the history
  • Loading branch information
max-sixty authored Jul 13, 2024
2 parents 2f1a6ca + d8b7644 commit d5d3e35
Show file tree
Hide file tree
Showing 9 changed files with 140 additions and 16 deletions.
2 changes: 2 additions & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ Dataset contents
Dataset.drop_duplicates
Dataset.drop_dims
Dataset.drop_encoding
Dataset.drop_attrs
Dataset.set_coords
Dataset.reset_coords
Dataset.convert_calendar
Expand Down Expand Up @@ -306,6 +307,7 @@ DataArray contents
DataArray.drop_indexes
DataArray.drop_duplicates
DataArray.drop_encoding
DataArray.drop_attrs
DataArray.reset_coords
DataArray.copy
DataArray.convert_calendar
Expand Down
4 changes: 4 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ New Features
By `Martin Raspaud <https://github.com/mraspaud>`_.
- Extract the source url from fsspec objects (:issue:`9142`, :pull:`8923`).
By `Justus Magin <https://github.com/keewis>`_.
- Add :py:meth:`DataArray.drop_attrs` & :py:meth:`Dataset.drop_attrs` methods,
to return an object without ``attrs``. A ``deep`` parameter controls whether
variables' ``attrs`` are also dropped.
By `Maximilian Roos <https://github.com/max-sixty>`_. (:pull:`8288`)

Breaking changes
~~~~~~~~~~~~~~~~
Expand Down
Empty file added properties/__init__.py
Empty file.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ exclude_lines = ["pragma: no cover", "if TYPE_CHECKING"]
[tool.mypy]
enable_error_code = "redundant-self"
exclude = [
'build',
'xarray/util/generate_.*\.py',
'xarray/datatree_/doc/.*\.py',
]
Expand Down
17 changes: 17 additions & 0 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -7456,3 +7456,20 @@ def to_dask_dataframe(
# this needs to be at the end, or mypy will confuse with `str`
# https://mypy.readthedocs.io/en/latest/common_issues.html#dealing-with-conflicting-names
str = utils.UncachedAccessor(StringAccessor["DataArray"])

def drop_attrs(self, *, deep: bool = True) -> Self:
"""
Removes all attributes from the DataArray.
Parameters
----------
deep : bool, default True
Removes attributes from coordinates.
Returns
-------
DataArray
"""
return (
self._to_temp_dataset().drop_attrs(deep=deep).pipe(self._from_temp_dataset)
)
42 changes: 42 additions & 0 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10681,3 +10681,45 @@ def resample(
restore_coord_dims=restore_coord_dims,
**indexer_kwargs,
)

def drop_attrs(self, *, deep: bool = True) -> Self:
"""
Removes all attributes from the Dataset and its variables.
Parameters
----------
deep : bool, default True
Removes attributes from all variables.
Returns
-------
Dataset
"""
# Remove attributes from the dataset
self = self._replace(attrs={})

if not deep:
return self

# Remove attributes from each variable in the dataset
for var in self.variables:
# variables don't have a `._replace` method, so we copy and then remove
# attrs. If we added a `._replace` method, we could use that instead.
if var not in self.indexes:
self[var] = self[var].copy()
self[var].attrs = {}

new_idx_variables = {}
# Not sure this is the most elegant way of doing this, but it works.
# (Should we have a more general "map over all variables, including
# indexes" approach?)
for idx, idx_vars in self.xindexes.group_by_index():
# copy each coordinate variable of an index and drop their attrs
temp_idx_variables = {k: v.copy() for k, v in idx_vars.items()}
for v in temp_idx_variables.values():
v.attrs = {}
# re-wrap the index object in new coordinate variables
new_idx_variables.update(idx.create_variables(temp_idx_variables))
self = self.assign(new_idx_variables)

return self
5 changes: 5 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2980,6 +2980,11 @@ def test_assign_attrs(self) -> None:
assert_identical(new_actual, expected)
assert actual.attrs == {"a": 1, "b": 2}

def test_drop_attrs(self) -> None:
# Mostly tested in test_dataset.py, but adding a very small test here
da = DataArray([], attrs=dict(a=1, b=2))
assert da.drop_attrs().attrs == {}

@pytest.mark.parametrize(
"func", [lambda x: x.clip(0, 1), lambda x: np.float64(1.0) * x, np.abs, abs]
)
Expand Down
48 changes: 48 additions & 0 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4450,6 +4450,54 @@ def test_assign_attrs(self) -> None:
assert_identical(new_actual, expected)
assert actual.attrs == dict(a=1, b=2)

def test_drop_attrs(self) -> None:
# Simple example
ds = Dataset().assign_attrs(a=1, b=2)
original = ds.copy()
expected = Dataset()
result = ds.drop_attrs()
assert_identical(result, expected)

# Doesn't change original
assert_identical(ds, original)

# Example with variables and coords with attrs, and a multiindex. (arguably
# should have used a canonical dataset with all the features we're should
# support...)
var = Variable("x", [1, 2, 3], attrs=dict(x=1, y=2))
idx = IndexVariable("y", [1, 2, 3], attrs=dict(c=1, d=2))
mx = xr.Coordinates.from_pandas_multiindex(
pd.MultiIndex.from_tuples([(1, 2), (3, 4)], names=["d", "e"]), "z"
)
ds = Dataset(dict(var1=var), coords=dict(y=idx, z=mx)).assign_attrs(a=1, b=2)
assert ds.attrs != {}
assert ds["var1"].attrs != {}
assert ds["y"].attrs != {}
assert ds.coords["y"].attrs != {}

original = ds.copy(deep=True)
result = ds.drop_attrs()

assert result.attrs == {}
assert result["var1"].attrs == {}
assert result["y"].attrs == {}
assert list(result.data_vars) == list(ds.data_vars)
assert list(result.coords) == list(ds.coords)

# Doesn't change original
assert_identical(ds, original)
# Specifically test that the attrs on the coords are still there. (The index
# can't currently contain `attrs`, so we can't test those.)
assert ds.coords["y"].attrs != {}

# Test for deep=False
result_shallow = ds.drop_attrs(deep=False)
assert result_shallow.attrs == {}
assert result_shallow["var1"].attrs != {}
assert result_shallow["y"].attrs != {}
assert list(result.data_vars) == list(ds.data_vars)
assert list(result.coords) == list(ds.coords)

def test_assign_multiindex_level(self) -> None:
data = create_test_multiindex()
with pytest.raises(ValueError, match=r"cannot drop or update.*corrupt.*index "):
Expand Down
37 changes: 21 additions & 16 deletions xarray/tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,10 @@ def setup(self) -> Generator:
plt.close("all")

def pass_in_axis(self, plotmethod, subplot_kw=None) -> None:
fig, axs = plt.subplots(ncols=2, subplot_kw=subplot_kw)
plotmethod(ax=axs[0])
assert axs[0].has_data()
fig, axs = plt.subplots(ncols=2, subplot_kw=subplot_kw, squeeze=False)
ax = axs[0, 0]
plotmethod(ax=ax)
assert ax.has_data()

@pytest.mark.slow
def imshow_called(self, plotmethod) -> bool:
Expand Down Expand Up @@ -240,9 +241,9 @@ def test_1d_x_y_kw(self) -> None:

xy: list[list[None | str]] = [[None, None], [None, "z"], ["z", None]]

f, ax = plt.subplots(3, 1)
f, axs = plt.subplots(3, 1, squeeze=False)
for aa, (x, y) in enumerate(xy):
da.plot(x=x, y=y, ax=ax.flat[aa])
da.plot(x=x, y=y, ax=axs.flat[aa])

with pytest.raises(ValueError, match=r"Cannot specify both"):
da.plot(x="z", y="z")
Expand Down Expand Up @@ -1566,7 +1567,9 @@ def test_colorbar_kwargs(self) -> None:
assert "MyLabel" in alltxt
assert "testvar" not in alltxt
# change cbar ax
fig, (ax, cax) = plt.subplots(1, 2)
fig, axs = plt.subplots(1, 2, squeeze=False)
ax = axs[0, 0]
cax = axs[0, 1]
self.plotmethod(
ax=ax, cbar_ax=cax, add_colorbar=True, cbar_kwargs={"label": "MyBar"}
)
Expand All @@ -1576,7 +1579,9 @@ def test_colorbar_kwargs(self) -> None:
assert "MyBar" in alltxt
assert "testvar" not in alltxt
# note that there are two ways to achieve this
fig, (ax, cax) = plt.subplots(1, 2)
fig, axs = plt.subplots(1, 2, squeeze=False)
ax = axs[0, 0]
cax = axs[0, 1]
self.plotmethod(
ax=ax, add_colorbar=True, cbar_kwargs={"label": "MyBar", "cax": cax}
)
Expand Down Expand Up @@ -3371,16 +3376,16 @@ def test_plot1d_default_rcparams() -> None:
# see overlapping markers:
fig, ax = plt.subplots(1, 1)
ds.plot.scatter(x="A", y="B", marker="o", ax=ax)
np.testing.assert_allclose(
ax.collections[0].get_edgecolor(), mpl.colors.to_rgba_array("w")
)
actual: np.ndarray = mpl.colors.to_rgba_array("w")
expected: np.ndarray = ax.collections[0].get_edgecolor() # type: ignore[assignment] # mpl error?
np.testing.assert_allclose(actual, expected)

# Facetgrids should have the default value as well:
fg = ds.plot.scatter(x="A", y="B", col="x", marker="o")
ax = fg.axs.ravel()[0]
np.testing.assert_allclose(
ax.collections[0].get_edgecolor(), mpl.colors.to_rgba_array("w")
)
actual = mpl.colors.to_rgba_array("w")
expected = ax.collections[0].get_edgecolor() # type: ignore[assignment] # mpl error?
np.testing.assert_allclose(actual, expected)

# scatter should not emit any warnings when using unfilled markers:
with assert_no_warnings():
Expand All @@ -3390,9 +3395,9 @@ def test_plot1d_default_rcparams() -> None:
# Prioritize edgecolor argument over default plot1d values:
fig, ax = plt.subplots(1, 1)
ds.plot.scatter(x="A", y="B", marker="o", ax=ax, edgecolor="k")
np.testing.assert_allclose(
ax.collections[0].get_edgecolor(), mpl.colors.to_rgba_array("k")
)
actual = mpl.colors.to_rgba_array("k")
expected = ax.collections[0].get_edgecolor() # type: ignore[assignment] # mpl error?
np.testing.assert_allclose(actual, expected)


@requires_matplotlib
Expand Down

0 comments on commit d5d3e35

Please sign in to comment.