Skip to content

Commit 1d4c974

Browse files
committed
BUG: Copy attrs on pd.merge()
This uses the same logic as `pd.concat()`: Copy `attrs` only if all input `attrs` are identical. I've refactored the handling in __finalize__ from special-casing based on th the method name (previously only "concat") to handling "other" parameters that have an `input_objs` attribute. This is a more scalable architecture compared to hard-coding method names in __finalize__. Tests added for `concat()` and `merge()`. Closes pandas-dev#60351.
1 parent 6a7685f commit 1d4c974

File tree

4 files changed

+38
-7
lines changed

4 files changed

+38
-7
lines changed

pandas/core/generic.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -6053,8 +6053,8 @@ def __finalize__(self, other, method: str | None = None, **kwargs) -> Self:
60536053
assert isinstance(name, str)
60546054
object.__setattr__(self, name, getattr(other, name, None))
60556055

6056-
if method == "concat":
6057-
objs = other.objs
6056+
elif hasattr(other, "input_objs"):
6057+
objs = other.input_objs
60586058
# propagate attrs only if all concat arguments have the same attrs
60596059
if all(bool(obj.attrs) for obj in objs):
60606060
# all concatenate arguments have non-empty attrs

pandas/core/reshape/concat.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -545,7 +545,7 @@ def _get_result(
545545
result = sample._constructor_from_mgr(mgr, axes=mgr.axes)
546546
result._name = name
547547
return result.__finalize__(
548-
types.SimpleNamespace(objs=objs), method="concat"
548+
types.SimpleNamespace(input_objs=objs), method="concat"
549549
)
550550

551551
# combine as columns in a frame
@@ -566,7 +566,9 @@ def _get_result(
566566
)
567567
df = cons(data, index=index, copy=False)
568568
df.columns = columns
569-
return df.__finalize__(types.SimpleNamespace(objs=objs), method="concat")
569+
return df.__finalize__(
570+
types.SimpleNamespace(input_objs=objs), method="concat"
571+
)
570572

571573
# combine block managers
572574
else:
@@ -605,7 +607,9 @@ def _get_result(
605607
)
606608

607609
out = sample._constructor_from_mgr(new_data, axes=new_data.axes)
608-
return out.__finalize__(types.SimpleNamespace(objs=objs), method="concat")
610+
return out.__finalize__(
611+
types.SimpleNamespace(input_objs=objs), method="concat"
612+
)
609613

610614

611615
def new_axes(

pandas/core/reshape/merge.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
cast,
1717
final,
1818
)
19+
import types
1920
import uuid
2021
import warnings
2122

@@ -1115,7 +1116,9 @@ def get_result(self) -> DataFrame:
11151116

11161117
self._maybe_restore_index_levels(result)
11171118

1118-
return result.__finalize__(self, method="merge")
1119+
return result.__finalize__(
1120+
types.SimpleNamespace(input_objs=[self.left, self.right]), method="merge"
1121+
)
11191122

11201123
@final
11211124
@cache_readonly

pandas/tests/frame/test_api.py

+25-1
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ def test_attrs(self):
315315
result = df.rename(columns=str)
316316
assert result.attrs == {"version": 1}
317317

318-
def test_attrs_deepcopy(self):
318+
def test_attrs_is_deepcopy(self):
319319
df = DataFrame({"A": [2, 3]})
320320
assert df.attrs == {}
321321
df.attrs["tags"] = {"spam", "ham"}
@@ -324,6 +324,30 @@ def test_attrs_deepcopy(self):
324324
assert result.attrs == df.attrs
325325
assert result.attrs["tags"] is not df.attrs["tags"]
326326

327+
def test_attrs_concat(self):
328+
# concat propagates attrs if all input attrs are equal
329+
df1 = DataFrame({"A": [2, 3]})
330+
df1.attrs = {'a': 1, 'b': 2}
331+
df2 = DataFrame({"A": [4, 5]})
332+
df2.attrs = df1.attrs.copy()
333+
df3 = DataFrame({"A": [6, 7]})
334+
df3.attrs = df1.attrs.copy()
335+
assert pd.concat([df1, df2, df3]).attrs == df1.attrs
336+
# concat does not propagate attrs if input attrs are different
337+
df2.attrs = {'c': 3}
338+
assert pd.concat([df1, df2, df3]).attrs == {}
339+
340+
def test_attrs_merge(self):
341+
# merge propagates attrs if all input attrs are equal
342+
df1 = pd.DataFrame({"key": ['a', 'b'], 'val1': [1, 2]})
343+
df1.attrs = {'a': 1, 'b': 2}
344+
df2 = DataFrame({"key": ['a', 'b'], 'val2': [3, 4]})
345+
df2.attrs = df1.attrs.copy()
346+
assert pd.merge(df1, df2).attrs == df1.attrs
347+
# merge does not propagate attrs if input attrs are different
348+
df2.attrs = {'c': 3}
349+
assert pd.merge(df1, df2).attrs == {}
350+
327351
@pytest.mark.parametrize("allows_duplicate_labels", [True, False, None])
328352
def test_set_flags(
329353
self,

0 commit comments

Comments
 (0)