Skip to content

Commit 5357f79

Browse files
BUG: correctly instantiate subclassed DataFrame/Series in groupby apply (#45363)
1 parent be20b2d commit 5357f79

File tree

3 files changed

+28
-13
lines changed

3 files changed

+28
-13
lines changed

pandas/core/groupby/ops.py

+4-12
Original file line numberDiff line numberDiff line change
@@ -753,6 +753,7 @@ def apply(
753753
zipped = zip(group_keys, splitter)
754754

755755
for key, group in zipped:
756+
group = group.__finalize__(data, method="groupby")
756757
object.__setattr__(group, "name", key)
757758

758759
# group might be modified
@@ -1000,6 +1001,7 @@ def _aggregate_series_pure_python(
10001001
splitter = get_splitter(obj, ids, ngroups, axis=0)
10011002

10021003
for i, group in enumerate(splitter):
1004+
group = group.__finalize__(obj, method="groupby")
10031005
res = func(group)
10041006
res = libreduction.extract_result(res)
10051007

@@ -1243,13 +1245,7 @@ def _chop(self, sdata: Series, slice_obj: slice) -> Series:
12431245
# fastpath equivalent to `sdata.iloc[slice_obj]`
12441246
mgr = sdata._mgr.get_slice(slice_obj)
12451247
# __finalize__ not called here, must be applied by caller if applicable
1246-
1247-
# fastpath equivalent to:
1248-
# `return sdata._constructor(mgr, name=sdata.name, fastpath=True)`
1249-
obj = type(sdata)._from_mgr(mgr)
1250-
object.__setattr__(obj, "_flags", sdata._flags)
1251-
object.__setattr__(obj, "_name", sdata._name)
1252-
return obj
1248+
return sdata._constructor(mgr, name=sdata.name, fastpath=True)
12531249

12541250

12551251
class FrameSplitter(DataSplitter):
@@ -1261,11 +1257,7 @@ def _chop(self, sdata: DataFrame, slice_obj: slice) -> DataFrame:
12611257
# return sdata.iloc[:, slice_obj]
12621258
mgr = sdata._mgr.get_slice(slice_obj, axis=1 - self.axis)
12631259
# __finalize__ not called here, must be applied by caller if applicable
1264-
1265-
# fastpath equivalent to `return sdata._constructor(mgr)`
1266-
obj = type(sdata)._from_mgr(mgr)
1267-
object.__setattr__(obj, "_flags", sdata._flags)
1268-
return obj
1260+
return sdata._constructor(mgr)
12691261

12701262

12711263
def get_splitter(

pandas/tests/generic/test_finalize.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -746,6 +746,7 @@ def test_categorical_accessor(method):
746746
"method",
747747
[
748748
operator.methodcaller("sum"),
749+
lambda x: x.apply(lambda y: y),
749750
lambda x: x.agg("sum"),
750751
lambda x: x.agg("mean"),
751752
lambda x: x.agg("median"),
@@ -764,7 +765,6 @@ def test_groupby_finalize(obj, method):
764765
"method",
765766
[
766767
lambda x: x.agg(["sum", "count"]),
767-
lambda x: x.apply(lambda y: y),
768768
lambda x: x.agg("std"),
769769
lambda x: x.agg("var"),
770770
lambda x: x.agg("sem"),

pandas/tests/groupby/test_groupby_subclass.py

+23
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from pandas import (
77
DataFrame,
8+
Index,
89
Series,
910
)
1011
import pandas._testing as tm
@@ -64,6 +65,28 @@ def test_groupby_preserves_metadata():
6465
for _, group_df in custom_df.groupby("c"):
6566
assert group_df.testattr == "hello"
6667

68+
# GH-45314
69+
def func(group):
70+
assert isinstance(group, tm.SubclassedDataFrame)
71+
assert hasattr(group, "testattr")
72+
return group.testattr
73+
74+
result = custom_df.groupby("c").apply(func)
75+
expected = tm.SubclassedSeries(["hello"] * 3, index=Index([7, 8, 9], name="c"))
76+
tm.assert_series_equal(result, expected)
77+
78+
def func2(group):
79+
assert isinstance(group, tm.SubclassedSeries)
80+
assert hasattr(group, "testattr")
81+
return group.testattr
82+
83+
custom_series = tm.SubclassedSeries([1, 2, 3])
84+
custom_series.testattr = "hello"
85+
result = custom_series.groupby(custom_df["c"]).apply(func2)
86+
tm.assert_series_equal(result, expected)
87+
result = custom_series.groupby(custom_df["c"]).agg(func2)
88+
tm.assert_series_equal(result, expected)
89+
6790

6891
@pytest.mark.parametrize("obj", [DataFrame, tm.SubclassedDataFrame])
6992
def test_groupby_resample_preserves_subclass(obj):

0 commit comments

Comments
 (0)