Skip to content

SubClassedDataFrame.groupby().mean() etc. use method of SubClassedDataFrame #51765

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 39 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
df9b39a
a whole bunch of 1d constructors
AdamOrmondroyd Mar 2, 2023
de8c231
got the first three of Lukas' assertions working
AdamOrmondroyd Mar 2, 2023
d057cd0
remove print statements
AdamOrmondroyd Mar 2, 2023
e56b488
Merge branch 'main' into groupby
AdamOrmondroyd Mar 3, 2023
5338d3f
create test for overridden methods
AdamOrmondroyd Mar 3, 2023
5fe7f82
also attack median, std, var, sem, prod, sum, min and max
AdamOrmondroyd Mar 3, 2023
8f32b12
Merge branch 'main' into groupby
AdamOrmondroyd Mar 3, 2023
4aa2b85
add tests for test of methods
AdamOrmondroyd Mar 3, 2023
8d7346d
change 1d constructors to constructors
AdamOrmondroyd Mar 3, 2023
37ae233
tidy up
AdamOrmondroyd Mar 3, 2023
aa57cc2
change to np.all
AdamOrmondroyd Mar 3, 2023
b3df075
remove deliberate test failure
AdamOrmondroyd Mar 3, 2023
adc132a
check that self._obj_1d_constructor is Series
AdamOrmondroyd Mar 3, 2023
31868ff
add entry to docs
AdamOrmondroyd Mar 3, 2023
c0b2ad7
Merge branch 'main' into groupby
AdamOrmondroyd Mar 3, 2023
98b7986
Merge branch 'main' into groupby
AdamOrmondroyd Mar 6, 2023
053c865
Merge branch 'main' into groupby
AdamOrmondroyd Mar 6, 2023
2efa052
check for equality of mean methods
AdamOrmondroyd Mar 6, 2023
1505a1c
repeat for other methods
AdamOrmondroyd Mar 6, 2023
af9ac26
also test Series
AdamOrmondroyd Mar 6, 2023
f4bc548
pass through numeric_only
AdamOrmondroyd Mar 6, 2023
12a9fa8
reinstate type hinting
AdamOrmondroyd Mar 6, 2023
f46eea9
add type() to method comparison
AdamOrmondroyd Mar 6, 2023
185a3c1
test transform
AdamOrmondroyd Mar 7, 2023
bf9bde6
correct _constructor
AdamOrmondroyd Mar 7, 2023
036d662
Merge branch 'main' into groupby
AdamOrmondroyd Mar 7, 2023
f348648
Merge branch 'main' into groupby
AdamOrmondroyd Mar 13, 2023
48ceb0a
Merge branch 'main' into groupby
AdamOrmondroyd Mar 27, 2023
a6be1ea
remove unnecessary(?) if statement
AdamOrmondroyd Mar 27, 2023
f0ed14a
Merge branch 'main' into groupby
AdamOrmondroyd Apr 24, 2023
6631b1e
first pass at decorator
AdamOrmondroyd Apr 25, 2023
94dc186
add decorator to other methods
AdamOrmondroyd Apr 25, 2023
310c339
missed max()
AdamOrmondroyd Apr 25, 2023
27c4ed9
add @wraps
AdamOrmondroyd Apr 25, 2023
9bafd9a
Merge branch 'main' into groupby
AdamOrmondroyd Apr 25, 2023
963b3fe
Merge branch 'main' into groupby
AdamOrmondroyd Apr 26, 2023
8a9f30f
add tests for series example
AdamOrmondroyd Apr 26, 2023
518d42e
Merge branch 'main' into groupby
AdamOrmondroyd May 18, 2023
6320057
Merge branch 'main' into groupby
AdamOrmondroyd May 26, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/source/whatsnew/v2.1.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,10 @@ Plotting
- Bug in :meth:`Series.plot` when invoked with ``color=None`` (:issue:`51953`)
-

Groupby
- Bug in :meth:`GroupBy.mean`, :meth:`GroupBy.median`, :meth:`GroupBy.std`, :meth:`GroupBy.var`, :meth:`GroupBy.sem`, :meth:`GroupBy.prod`, :meth:`GroupBy.min`, :meth:`GroupBy.max` don't use corresponding methods of subclasses of :class:`Series` or :class:`DataFrame` (:issue:`51757`)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this belongs below in the Groupby/resample/rolling section?

-

Groupby/resample/rolling
^^^^^^^^^^^^^^^^^^^^^^^^
- Bug in :meth:`DataFrame.resample` and :meth:`Series.resample` in incorrectly allowing non-fixed ``freq`` when resampling on a :class:`TimedeltaIndex` (:issue:`51896`)
Expand Down
22 changes: 14 additions & 8 deletions pandas/core/groupby/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,9 @@ def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs)
)

# result is a dict whose keys are the elements of result_index
result = Series(result, index=self.grouper.result_index)
result = self._obj_1d_constructor(
result, index=self.grouper.result_index
)
result = self._wrap_aggregated_output(result)
return result

Expand Down Expand Up @@ -703,7 +705,7 @@ def value_counts(
# in a backward compatible way
# GH38672 relates to categorical dtype
ser = self.apply(
Series.value_counts,
self._obj_1d_constructor.value_counts,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is trouble bc in general we can't assume that _constructor is a class

Copy link
Contributor Author

@AdamOrmondroyd AdamOrmondroyd Mar 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What else could it be (practically speaking, I know it's Callable)?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Geopandas has a callable that can dispatch to different classes. @jorisvandenbossche has argued against deprecating allowing this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added a check whether self._obj_1d_constructor is a Series

normalize=normalize,
sort=sort,
ascending=ascending,
Expand All @@ -722,7 +724,9 @@ def value_counts(
llab = lambda lab, inc: lab[inc]
else:
# lab is a Categorical with categories an IntervalIndex
cat_ser = cut(Series(val, copy=False), bins, include_lowest=True)
cat_ser = cut(
self.obj._constructor(val, copy=False), bins, include_lowest=True
)
cat_obj = cast("Categorical", cat_ser._values)
lev = cat_obj.categories
lab = lev.take(
Expand Down Expand Up @@ -1406,9 +1410,9 @@ def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs)
elif relabeling:
# this should be the only (non-raising) case with relabeling
# used reordered index of columns
result = cast(DataFrame, result)
result = cast(self.obj._constructor, result)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these will be wrong if _constructor is not a class

result = result.iloc[:, order]
result = cast(DataFrame, result)
result = cast(self.obj._constructor, result)
# error: Incompatible types in assignment (expression has type
# "Optional[List[str]]", variable has type
# "Union[Union[Union[ExtensionArray, ndarray[Any, Any]],
Expand Down Expand Up @@ -1451,7 +1455,7 @@ def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs)
else:
# GH#32040, GH#35246
# e.g. test_groupby_as_index_select_column_sum_empty_df
result = cast(DataFrame, result)
result = cast(self._constructor, result)
result.columns = self._obj_with_exclusions.columns.copy()

if not self.as_index:
Expand Down Expand Up @@ -1586,7 +1590,7 @@ def _wrap_applied_output_series(
is_transform: bool,
) -> DataFrame | Series:
kwargs = first_not_none._construct_axes_dict()
backup = Series(**kwargs)
backup = self._obj_1d_constructor(**kwargs)
values = [x if (x is not None) else backup for x in values]

all_indexed_same = all_indexes_same(x.index for x in values)
Expand Down Expand Up @@ -1981,7 +1985,9 @@ def _apply_to_column_groupbys(self, func) -> DataFrame:

if not len(results):
# concat would raise
res_df = DataFrame([], columns=columns, index=self.grouper.result_index)
res_df = self.obj._constructor(
[], columns=columns, index=self.grouper.result_index
)
else:
res_df = concat(results, keys=columns, axis=1)

Expand Down
40 changes: 37 additions & 3 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -1450,6 +1450,29 @@ def _aggregate_with_numba(self, func, *args, engine_kwargs=None, **kwargs):
res.index = default_index(len(res))
return res

def _use_subclass_method(func):
"""
Use the corresponding func method in the case of a
subclassed Series or DataFrame.
"""

@wraps(func)
def inner(self, *args, **kwargs):
if not (
getattr(type(self.obj), func.__name__) is getattr(Series, func.__name__)
or getattr(type(self.obj), func.__name__)
is getattr(DataFrame, func.__name__)
):
result = self.agg(
lambda df: getattr(self.obj._constructor(df), func.__name__)(
*args, **kwargs
)
)
return result.__finalize__(self.obj, method="groupby")
return func(self, *args, **kwargs)

return inner

# -----------------------------------------------------------------
# apply/agg/transform

Expand Down Expand Up @@ -1879,6 +1902,7 @@ def hfunc(bvalues: ArrayLike) -> ArrayLike:
return self._reindex_output(result, fill_value=0)

@final
@_use_subclass_method
@Substitution(name="groupby")
@Substitution(see_also=_common_see_also)
def mean(
Expand Down Expand Up @@ -1962,12 +1986,15 @@ def mean(
else:
result = self._cython_agg_general(
"mean",
alt=lambda x: Series(x).mean(numeric_only=numeric_only),
alt=lambda x: self._obj_1d_constructor(x).mean(
numeric_only=numeric_only
),
numeric_only=numeric_only,
)
return result.__finalize__(self.obj, method="groupby")

@final
@_use_subclass_method
def median(self, numeric_only: bool = False):
"""
Compute median of groups, excluding missing values.
Expand All @@ -1990,12 +2017,13 @@ def median(self, numeric_only: bool = False):
"""
result = self._cython_agg_general(
"median",
alt=lambda x: Series(x).median(numeric_only=numeric_only),
alt=lambda x: self._obj_1d_constructor(x).median(numeric_only=numeric_only),
numeric_only=numeric_only,
)
return result.__finalize__(self.obj, method="groupby")

@final
@_use_subclass_method
@Substitution(name="groupby")
@Appender(_common_see_also)
def std(
Expand Down Expand Up @@ -2059,6 +2087,7 @@ def std(
)

@final
@_use_subclass_method
@Substitution(name="groupby")
@Appender(_common_see_also)
def var(
Expand Down Expand Up @@ -2116,7 +2145,7 @@ def var(
else:
return self._cython_agg_general(
"var",
alt=lambda x: Series(x).var(ddof=ddof),
alt=lambda x: self._obj_1d_constructor(x).var(ddof=ddof),
numeric_only=numeric_only,
ddof=ddof,
)
Expand Down Expand Up @@ -2255,6 +2284,7 @@ def _value_counts(
return result.__finalize__(self.obj, method="value_counts")

@final
@_use_subclass_method
def sem(self, ddof: int = 1, numeric_only: bool = False):
"""
Compute standard error of the mean of groups, excluding missing values.
Expand Down Expand Up @@ -2324,6 +2354,7 @@ def size(self) -> DataFrame | Series:
return result

@final
@_use_subclass_method
@doc(_groupby_agg_method_template, fname="sum", no=False, mc=0)
def sum(
self,
Expand Down Expand Up @@ -2354,13 +2385,15 @@ def sum(
return self._reindex_output(result, fill_value=0)

@final
@_use_subclass_method
@doc(_groupby_agg_method_template, fname="prod", no=False, mc=0)
def prod(self, numeric_only: bool = False, min_count: int = 0):
return self._agg_general(
numeric_only=numeric_only, min_count=min_count, alias="prod", npfunc=np.prod
)

@final
@_use_subclass_method
@doc(_groupby_agg_method_template, fname="min", no=False, mc=-1)
def min(
self,
Expand All @@ -2382,6 +2415,7 @@ def min(
)

@final
@_use_subclass_method
@doc(_groupby_agg_method_template, fname="max", no=False, mc=-1)
def max(
self,
Expand Down
112 changes: 112 additions & 0 deletions pandas/tests/groupby/test_groupby_subclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,115 @@ def test_groupby_resample_preserves_subclass(obj):
# Confirm groupby.resample() preserves dataframe type
result = df.groupby("Buyer").resample("5D").sum()
assert isinstance(result, obj)


def test_groupby_overridden_methods():
class UnitSeries(Series):
@property
def _constructor(self):
return UnitSeries

@property
def _constructor_expanddim(self):
return UnitDataFrame

def mean(self, *args, **kwargs):
return 1

def median(self, *args, **kwargs):
return 2

def std(self, *args, **kwargs):
return 3

def var(self, *args, **kwargs):
return 4

def sem(self, *args, **kwargs):
return 5

def prod(self, *args, **kwargs):
return 6

def min(self, *args, **kwargs):
return 7

def max(self, *args, **kwargs):
return 8

class UnitDataFrame(DataFrame):
@property
def _constructor(self):
return UnitDataFrame

@property
def _constructor_expanddim(self):
return UnitSeries

def mean(self, *args, **kwargs):
return 1

def median(self, *args, **kwargs):
return 2

def std(self, *args, **kwargs):
return 3

def var(self, *args, **kwargs):
return 4

def sem(self, *args, **kwargs):
return 5

def prod(self, *args, **kwargs):
return 6

def min(self, *args, **kwargs):
return 7

def max(self, *args, **kwargs):
return 8

columns = ["a", "b"]
data = np.random.rand(4, 2)
udf = UnitDataFrame(data, columns=columns)
udf["group"] = np.ones(4, dtype=int)
udf.loc[2:, "group"] = 2

us = udf[["a", "group"]]

assert np.all(udf.groupby("group").mean() == 1)
assert np.all(udf.groupby("group").median() == 2)
assert np.all(udf.groupby("group").std() == 3)
assert np.all(udf.groupby("group").var() == 4)
assert np.all(udf.groupby("group").sem() == 5)
assert np.all(udf.groupby("group").prod() == 6)
assert np.all(udf.groupby("group").min() == 7)
assert np.all(udf.groupby("group").max() == 8)

assert np.all(us.groupby("group").mean() == 1)
assert np.all(us.groupby("group").median() == 2)
assert np.all(us.groupby("group").std() == 3)
assert np.all(us.groupby("group").var() == 4)
assert np.all(us.groupby("group").sem() == 5)
assert np.all(us.groupby("group").prod() == 6)
assert np.all(us.groupby("group").min() == 7)
assert np.all(us.groupby("group").max() == 8)

assert np.all(udf.groupby("group").transform("mean") == 1)
assert np.all(udf.groupby("group").transform("median") == 2)
assert np.all(udf.groupby("group").transform("std") == 3)
assert np.all(udf.groupby("group").transform("var") == 4)
assert np.all(udf.groupby("group").transform("sem") == 5)
assert np.all(udf.groupby("group").transform("prod") == 6)
assert np.all(udf.groupby("group").transform("min") == 7)
assert np.all(udf.groupby("group").transform("max") == 8)

assert np.all(us.groupby("group").transform("mean") == 1)
assert np.all(us.groupby("group").transform("median") == 2)
assert np.all(us.groupby("group").transform("std") == 3)
assert np.all(us.groupby("group").transform("var") == 4)
assert np.all(us.groupby("group").transform("sem") == 5)
assert np.all(us.groupby("group").transform("prod") == 6)
assert np.all(us.groupby("group").transform("min") == 7)
assert np.all(us.groupby("group").transform("max") == 8)