Skip to content

Commit 6697172

Browse files
committed
BUG: Copy attrs on pd.merge()
This uses the same logic as `pd.concat()`: Copy `attrs` only if all input `attrs` are identical. I've refactored the handling in __finalize__ from special-casing based on th the method name (previously only "concat") to handling "other" parameters that have an `input_objs` attribute. This is a more scalable architecture compared to hard-coding method names in __finalize__. Tests added for `concat()` and `merge()`. Closes pandas-dev#60351.
1 parent 6a7685f commit 6697172

File tree

6 files changed

+41
-11
lines changed

6 files changed

+41
-11
lines changed

pandas/core/generic.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -6053,8 +6053,8 @@ def __finalize__(self, other, method: str | None = None, **kwargs) -> Self:
60536053
assert isinstance(name, str)
60546054
object.__setattr__(self, name, getattr(other, name, None))
60556055

6056-
if method == "concat":
6057-
objs = other.objs
6056+
elif hasattr(other, "input_objs"):
6057+
objs = other.input_objs
60586058
# propagate attrs only if all concat arguments have the same attrs
60596059
if all(bool(obj.attrs) for obj in objs):
60606060
# all concatenate arguments have non-empty attrs

pandas/core/reshape/concat.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -545,7 +545,7 @@ def _get_result(
545545
result = sample._constructor_from_mgr(mgr, axes=mgr.axes)
546546
result._name = name
547547
return result.__finalize__(
548-
types.SimpleNamespace(objs=objs), method="concat"
548+
types.SimpleNamespace(input_objs=objs), method="concat"
549549
)
550550

551551
# combine as columns in a frame
@@ -566,7 +566,9 @@ def _get_result(
566566
)
567567
df = cons(data, index=index, copy=False)
568568
df.columns = columns
569-
return df.__finalize__(types.SimpleNamespace(objs=objs), method="concat")
569+
return df.__finalize__(
570+
types.SimpleNamespace(input_objs=objs), method="concat"
571+
)
570572

571573
# combine block managers
572574
else:
@@ -605,7 +607,7 @@ def _get_result(
605607
)
606608

607609
out = sample._constructor_from_mgr(new_data, axes=new_data.axes)
608-
return out.__finalize__(types.SimpleNamespace(objs=objs), method="concat")
610+
return out.__finalize__(types.SimpleNamespace(input_objs=objs), method="concat")
609611

610612

611613
def new_axes(

pandas/core/reshape/merge.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
)
1111
import datetime
1212
from functools import partial
13+
import types
1314
from typing import (
1415
TYPE_CHECKING,
1516
Literal,
@@ -1115,7 +1116,9 @@ def get_result(self) -> DataFrame:
11151116

11161117
self._maybe_restore_index_levels(result)
11171118

1118-
return result.__finalize__(self, method="merge")
1119+
return result.__finalize__(
1120+
types.SimpleNamespace(input_objs=[self.left, self.right]), method="merge"
1121+
)
11191122

11201123
@final
11211124
@cache_readonly

pandas/tests/frame/test_api.py

+25-1
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ def test_attrs(self):
315315
result = df.rename(columns=str)
316316
assert result.attrs == {"version": 1}
317317

318-
def test_attrs_deepcopy(self):
318+
def test_attrs_is_deepcopy(self):
319319
df = DataFrame({"A": [2, 3]})
320320
assert df.attrs == {}
321321
df.attrs["tags"] = {"spam", "ham"}
@@ -324,6 +324,30 @@ def test_attrs_deepcopy(self):
324324
assert result.attrs == df.attrs
325325
assert result.attrs["tags"] is not df.attrs["tags"]
326326

327+
def test_attrs_concat(self):
328+
# concat propagates attrs if all input attrs are equal
329+
df1 = DataFrame({"A": [2, 3]})
330+
df1.attrs = {"a": 1, "b": 2}
331+
df2 = DataFrame({"A": [4, 5]})
332+
df2.attrs = df1.attrs.copy()
333+
df3 = DataFrame({"A": [6, 7]})
334+
df3.attrs = df1.attrs.copy()
335+
assert pd.concat([df1, df2, df3]).attrs == df1.attrs
336+
# concat does not propagate attrs if input attrs are different
337+
df2.attrs = {"c": 3}
338+
assert pd.concat([df1, df2, df3]).attrs == {}
339+
340+
def test_attrs_merge(self):
341+
# merge propagates attrs if all input attrs are equal
342+
df1 = DataFrame({"key": ["a", "b"], "val1": [1, 2]})
343+
df1.attrs = {"a": 1, "b": 2}
344+
df2 = DataFrame({"key": ["a", "b"], "val2": [3, 4]})
345+
df2.attrs = df1.attrs.copy()
346+
assert pd.merge(df1, df2).attrs == df1.attrs
347+
# merge does not propagate attrs if input attrs are different
348+
df2.attrs = {"c": 3}
349+
assert pd.merge(df1, df2).attrs == {}
350+
327351
@pytest.mark.parametrize("allows_duplicate_labels", [True, False, None])
328352
def test_set_flags(
329353
self,

pandas/tests/generic/test_duplicate_labels.py

-3
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,6 @@ def test_concat(self, objs, kwargs):
164164
allows_duplicate_labels=False
165165
),
166166
False,
167-
marks=not_implemented,
168167
),
169168
# false true false
170169
pytest.param(
@@ -173,7 +172,6 @@ def test_concat(self, objs, kwargs):
173172
),
174173
pd.DataFrame({"B": [0, 1]}, index=["a", "d"]),
175174
False,
176-
marks=not_implemented,
177175
),
178176
# true true true
179177
(
@@ -296,7 +294,6 @@ def test_concat_raises(self):
296294
with pytest.raises(pd.errors.DuplicateLabelError, match=msg):
297295
pd.concat(objs, axis=1)
298296

299-
@not_implemented
300297
def test_merge_raises(self):
301298
a = pd.DataFrame({"A": [0, 1, 2]}, index=["a", "b", "c"]).set_flags(
302299
allows_duplicate_labels=False

pandas/tests/generic/test_frame.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,11 @@ def finalize(self, other, method=None, **kwargs):
8585
object.__setattr__(self, name, value)
8686
elif method == "concat":
8787
value = "+".join(
88-
[getattr(o, name) for o in other.objs if getattr(o, name, None)]
88+
[
89+
getattr(o, name)
90+
for o in other.input_objs
91+
if getattr(o, name, None)
92+
]
8993
)
9094
object.__setattr__(self, name, value)
9195
else:

0 commit comments

Comments
 (0)