From 1d4c97446b67b0fc8df0dbaea08a25a2b5687f01 Mon Sep 17 00:00:00 2001 From: Tim Hoffmann <2836374+timhoffm@users.noreply.github.com> Date: Tue, 19 Nov 2024 01:02:50 +0100 Subject: [PATCH] BUG: Copy attrs on pd.merge() This uses the same logic as `pd.concat()`: Copy `attrs` only if all input `attrs` are identical. I've refactored the handling in __finalize__ from special-casing based on th the method name (previously only "concat") to handling "other" parameters that have an `input_objs` attribute. This is a more scalable architecture compared to hard-coding method names in __finalize__. Tests added for `concat()` and `merge()`. Closes #60351. --- pandas/core/generic.py | 4 ++-- pandas/core/reshape/concat.py | 10 +++++++--- pandas/core/reshape/merge.py | 5 ++++- pandas/tests/frame/test_api.py | 26 +++++++++++++++++++++++++- 4 files changed, 38 insertions(+), 7 deletions(-) diff --git a/pandas/core/generic.py b/pandas/core/generic.py index 039bdf9c36ee7..2e08d075fcb45 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -6053,8 +6053,8 @@ def __finalize__(self, other, method: str | None = None, **kwargs) -> Self: assert isinstance(name, str) object.__setattr__(self, name, getattr(other, name, None)) - if method == "concat": - objs = other.objs + elif hasattr(other, "input_objs"): + objs = other.input_objs # propagate attrs only if all concat arguments have the same attrs if all(bool(obj.attrs) for obj in objs): # all concatenate arguments have non-empty attrs diff --git a/pandas/core/reshape/concat.py b/pandas/core/reshape/concat.py index e7cb7069bbc26..508d18f68671d 100644 --- a/pandas/core/reshape/concat.py +++ b/pandas/core/reshape/concat.py @@ -545,7 +545,7 @@ def _get_result( result = sample._constructor_from_mgr(mgr, axes=mgr.axes) result._name = name return result.__finalize__( - types.SimpleNamespace(objs=objs), method="concat" + types.SimpleNamespace(input_objs=objs), method="concat" ) # combine as columns in a frame @@ -566,7 +566,9 @@ def _get_result( ) df = cons(data, index=index, copy=False) df.columns = columns - return df.__finalize__(types.SimpleNamespace(objs=objs), method="concat") + return df.__finalize__( + types.SimpleNamespace(input_objs=objs), method="concat" + ) # combine block managers else: @@ -605,7 +607,9 @@ def _get_result( ) out = sample._constructor_from_mgr(new_data, axes=new_data.axes) - return out.__finalize__(types.SimpleNamespace(objs=objs), method="concat") + return out.__finalize__( + types.SimpleNamespace(input_objs=objs), method="concat" + ) def new_axes( diff --git a/pandas/core/reshape/merge.py b/pandas/core/reshape/merge.py index 6f9bb8cb24f43..5a48f9fa98354 100644 --- a/pandas/core/reshape/merge.py +++ b/pandas/core/reshape/merge.py @@ -16,6 +16,7 @@ cast, final, ) +import types import uuid import warnings @@ -1115,7 +1116,9 @@ def get_result(self) -> DataFrame: self._maybe_restore_index_levels(result) - return result.__finalize__(self, method="merge") + return result.__finalize__( + types.SimpleNamespace(input_objs=[self.left, self.right]), method="merge" + ) @final @cache_readonly diff --git a/pandas/tests/frame/test_api.py b/pandas/tests/frame/test_api.py index 2b0bf1b0576f9..41fa6d5b0f07d 100644 --- a/pandas/tests/frame/test_api.py +++ b/pandas/tests/frame/test_api.py @@ -315,7 +315,7 @@ def test_attrs(self): result = df.rename(columns=str) assert result.attrs == {"version": 1} - def test_attrs_deepcopy(self): + def test_attrs_is_deepcopy(self): df = DataFrame({"A": [2, 3]}) assert df.attrs == {} df.attrs["tags"] = {"spam", "ham"} @@ -324,6 +324,30 @@ def test_attrs_deepcopy(self): assert result.attrs == df.attrs assert result.attrs["tags"] is not df.attrs["tags"] + def test_attrs_concat(self): + # concat propagates attrs if all input attrs are equal + df1 = DataFrame({"A": [2, 3]}) + df1.attrs = {'a': 1, 'b': 2} + df2 = DataFrame({"A": [4, 5]}) + df2.attrs = df1.attrs.copy() + df3 = DataFrame({"A": [6, 7]}) + df3.attrs = df1.attrs.copy() + assert pd.concat([df1, df2, df3]).attrs == df1.attrs + # concat does not propagate attrs if input attrs are different + df2.attrs = {'c': 3} + assert pd.concat([df1, df2, df3]).attrs == {} + + def test_attrs_merge(self): + # merge propagates attrs if all input attrs are equal + df1 = pd.DataFrame({"key": ['a', 'b'], 'val1': [1, 2]}) + df1.attrs = {'a': 1, 'b': 2} + df2 = DataFrame({"key": ['a', 'b'], 'val2': [3, 4]}) + df2.attrs = df1.attrs.copy() + assert pd.merge(df1, df2).attrs == df1.attrs + # merge does not propagate attrs if input attrs are different + df2.attrs = {'c': 3} + assert pd.merge(df1, df2).attrs == {} + @pytest.mark.parametrize("allows_duplicate_labels", [True, False, None]) def test_set_flags( self,