Skip to content

SubClassedDataFrame.groupby().mean() etc. use method of SubClassedDataFrame #51765

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 39 commits into from
Closed
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
df9b39a
a whole bunch of 1d constructors
AdamOrmondroyd Mar 2, 2023
de8c231
got the first three of Lukas' assertions working
AdamOrmondroyd Mar 2, 2023
d057cd0
remove print statements
AdamOrmondroyd Mar 2, 2023
e56b488
Merge branch 'main' into groupby
AdamOrmondroyd Mar 3, 2023
5338d3f
create test for overridden methods
AdamOrmondroyd Mar 3, 2023
5fe7f82
also attack median, std, var, sem, prod, sum, min and max
AdamOrmondroyd Mar 3, 2023
8f32b12
Merge branch 'main' into groupby
AdamOrmondroyd Mar 3, 2023
4aa2b85
add tests for test of methods
AdamOrmondroyd Mar 3, 2023
8d7346d
change 1d constructors to constructors
AdamOrmondroyd Mar 3, 2023
37ae233
tidy up
AdamOrmondroyd Mar 3, 2023
aa57cc2
change to np.all
AdamOrmondroyd Mar 3, 2023
b3df075
remove deliberate test failure
AdamOrmondroyd Mar 3, 2023
adc132a
check that self._obj_1d_constructor is Series
AdamOrmondroyd Mar 3, 2023
31868ff
add entry to docs
AdamOrmondroyd Mar 3, 2023
c0b2ad7
Merge branch 'main' into groupby
AdamOrmondroyd Mar 3, 2023
98b7986
Merge branch 'main' into groupby
AdamOrmondroyd Mar 6, 2023
053c865
Merge branch 'main' into groupby
AdamOrmondroyd Mar 6, 2023
2efa052
check for equality of mean methods
AdamOrmondroyd Mar 6, 2023
1505a1c
repeat for other methods
AdamOrmondroyd Mar 6, 2023
af9ac26
also test Series
AdamOrmondroyd Mar 6, 2023
f4bc548
pass through numeric_only
AdamOrmondroyd Mar 6, 2023
12a9fa8
reinstate type hinting
AdamOrmondroyd Mar 6, 2023
f46eea9
add type() to method comparison
AdamOrmondroyd Mar 6, 2023
185a3c1
test transform
AdamOrmondroyd Mar 7, 2023
bf9bde6
correct _constructor
AdamOrmondroyd Mar 7, 2023
036d662
Merge branch 'main' into groupby
AdamOrmondroyd Mar 7, 2023
f348648
Merge branch 'main' into groupby
AdamOrmondroyd Mar 13, 2023
48ceb0a
Merge branch 'main' into groupby
AdamOrmondroyd Mar 27, 2023
a6be1ea
remove unnecessary(?) if statement
AdamOrmondroyd Mar 27, 2023
f0ed14a
Merge branch 'main' into groupby
AdamOrmondroyd Apr 24, 2023
6631b1e
first pass at decorator
AdamOrmondroyd Apr 25, 2023
94dc186
add decorator to other methods
AdamOrmondroyd Apr 25, 2023
310c339
missed max()
AdamOrmondroyd Apr 25, 2023
27c4ed9
add @wraps
AdamOrmondroyd Apr 25, 2023
9bafd9a
Merge branch 'main' into groupby
AdamOrmondroyd Apr 25, 2023
963b3fe
Merge branch 'main' into groupby
AdamOrmondroyd Apr 26, 2023
8a9f30f
add tests for series example
AdamOrmondroyd Apr 26, 2023
518d42e
Merge branch 'main' into groupby
AdamOrmondroyd May 18, 2023
6320057
Merge branch 'main' into groupby
AdamOrmondroyd May 26, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/source/whatsnew/v2.1.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,10 @@ Plotting
- Bug in :meth:`Series.plot` when invoked with ``color=None`` (:issue:`51953`)
-

Groupby
- Bug in :meth:`GroupBy.mean`, :meth:`GroupBy.median`, :meth:`GroupBy.std`, :meth:`GroupBy.var`, :meth:`GroupBy.sem`, :meth:`GroupBy.prod`, :meth:`GroupBy.min`, :meth:`GroupBy.max` don't use corresponding methods of subclasses of :class:`Series` or :class:`DataFrame` (:issue:`51757`)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this belongs below in the Groupby/resample/rolling section?

-

Groupby/resample/rolling
^^^^^^^^^^^^^^^^^^^^^^^^
- Bug in :meth:`DataFrameGroupBy.idxmin`, :meth:`SeriesGroupBy.idxmin`, :meth:`DataFrameGroupBy.idxmax`, :meth:`SeriesGroupBy.idxmax` return wrong dtype when used on empty DataFrameGroupBy or SeriesGroupBy (:issue:`51423`)
Expand Down
26 changes: 18 additions & 8 deletions pandas/core/groupby/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,9 @@ def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs)
)

# result is a dict whose keys are the elements of result_index
result = Series(result, index=self.grouper.result_index)
result = self._obj_1d_constructor(
result, index=self.grouper.result_index
)
result = self._wrap_aggregated_output(result)
return result

Expand Down Expand Up @@ -681,14 +683,20 @@ def value_counts(

index_names = self.grouper.names + [self.obj.name]

constructor_1d = (
self._obj_1d_constructor
if isinstance(self._obj_1d_constructor, Series)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think issubclass rather than isinstance?

else Series
)

if is_categorical_dtype(val.dtype) or (
bins is not None and not np.iterable(bins)
):
# scalar bins cannot be done at top level
# in a backward compatible way
# GH38672 relates to categorical dtype
ser = self.apply(
Series.value_counts,
constructor_1d.value_counts,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this also be something like lambda group, **kwargs: group.value_counts(**kwargs)?

(didn't look in detail at how this code is working, so potentially this doesn't make sense at all)

normalize=normalize,
sort=sort,
ascending=ascending,
Expand All @@ -707,7 +715,7 @@ def value_counts(
llab = lambda lab, inc: lab[inc]
else:
# lab is a Categorical with categories an IntervalIndex
cat_ser = cut(Series(val), bins, include_lowest=True)
cat_ser = cut(self.obj._constructor(val), bins, include_lowest=True)
cat_obj = cast("Categorical", cat_ser._values)
lev = cat_obj.categories
lab = lev.take(
Expand Down Expand Up @@ -1308,9 +1316,9 @@ def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs)
elif relabeling:
# this should be the only (non-raising) case with relabeling
# used reordered index of columns
result = cast(DataFrame, result)
result = cast(self.obj._constructor, result)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these will be wrong if _constructor is not a class

result = result.iloc[:, order]
result = cast(DataFrame, result)
result = cast(self.obj._constructor, result)
# error: Incompatible types in assignment (expression has type
# "Optional[List[str]]", variable has type
# "Union[Union[Union[ExtensionArray, ndarray[Any, Any]],
Expand Down Expand Up @@ -1353,7 +1361,7 @@ def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs)
else:
# GH#32040, GH#35246
# e.g. test_groupby_as_index_select_column_sum_empty_df
result = cast(DataFrame, result)
result = cast(self._constructor, result)
result.columns = self._obj_with_exclusions.columns.copy()

if not self.as_index:
Expand Down Expand Up @@ -1481,7 +1489,7 @@ def _wrap_applied_output_series(
is_transform: bool,
) -> DataFrame | Series:
kwargs = first_not_none._construct_axes_dict()
backup = Series(**kwargs)
backup = self._obj_1d_constructor(**kwargs)
values = [x if (x is not None) else backup for x in values]

all_indexed_same = all_indexes_same(x.index for x in values)
Expand Down Expand Up @@ -1876,7 +1884,9 @@ def _apply_to_column_groupbys(self, func) -> DataFrame:

if not len(results):
# concat would raise
res_df = DataFrame([], columns=columns, index=self.grouper.result_index)
res_df = self.obj._constructor(
[], columns=columns, index=self.grouper.result_index
)
else:
res_df = concat(results, keys=columns, axis=1)

Expand Down
100 changes: 97 additions & 3 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -1859,14 +1859,26 @@ def mean(
Name: B, dtype: float64
"""

if not (
type(self.obj).mean is Series.mean or type(self.obj).mean is DataFrame.mean
):
Copy link
Member

@rhshadrach rhshadrach Mar 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be a bad idea, but what about a Boolean flag as a class attribute on Series / DataFrame that subclasses can override if they want to use their methods? Defaults to False.

Perhaps if only to introduce this with a deprecation warning.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does .groupby(...).agg(["sum", "mean"]) work with this?

No:

>>> udf.groupby('group').agg(['sum', 'mean'])
              a                   b
            sum      mean       sum      mean
group
1      0.322530  0.161265  0.769053  0.384527
2      1.714631  0.857315  0.407676  0.203838

Using my example where the mean should always be 1.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However transform() does work:

>>> udf.groupby('group').transform('mean')
   a  b
0  1  1
1  1  1
2  1  1
3  1  1

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for checking. In my opinion that isn't necessarily a blocker for this, but we are expanding the behavior here and introducing new bugs that should get fixed.

If we are to move forward with some form of implementation for this, I think the test coverage here should be expanded to include at least agg, apply, and transform (even if they are currently producing incorrect results).


def f(df, *args, **kwargs):
return self.obj._constructor(df).mean(numeric_only=numeric_only)

result = self.agg(f)
return result.__finalize__(self.obj, method="groupby")

if maybe_use_numba(engine):
from pandas.core._numba.kernels import sliding_mean

return self._numba_agg_general(sliding_mean, engine_kwargs)
else:
result = self._cython_agg_general(
"mean",
alt=lambda x: Series(x).mean(numeric_only=numeric_only),
alt=lambda x: self._obj_1d_constructor(x).mean(
numeric_only=numeric_only
),
numeric_only=numeric_only,
)
return result.__finalize__(self.obj, method="groupby")
Expand All @@ -1892,9 +1904,20 @@ def median(self, numeric_only: bool = False):
Series or DataFrame
Median of values within each group.
"""
if not (
type(self.obj).median is Series.median
or type(self.obj).median is DataFrame.median
):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if this pattern is going to show up a lot, does it merit a decorator?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've had a first pass at making a decorator, not sure how to deal with engine and engine_kwargs


def f(df, *args, **kwargs):
return self.obj._constructor(df).median(numeric_only=numeric_only)

result = self.agg(f)
return result.__finalize__(self.obj, method="groupby")

result = self._cython_agg_general(
"median",
alt=lambda x: Series(x).median(numeric_only=numeric_only),
alt=lambda x: self._obj_1d_constructor(x).median(numeric_only=numeric_only),
numeric_only=numeric_only,
)
return result.__finalize__(self.obj, method="groupby")
Expand Down Expand Up @@ -1950,6 +1973,16 @@ def std(
Series or DataFrame
Standard deviation of values within each group.
"""
if not (
type(self.obj).std is Series.std or type(self.obj).std is DataFrame.std
):

def f(df, *args, **kwargs):
return self.obj._constructor(df).std(numeric_only=numeric_only)

result = self.agg(f)
return result.__finalize__(self.obj, method="groupby")

if maybe_use_numba(engine):
from pandas.core._numba.kernels import sliding_var

Expand Down Expand Up @@ -2013,14 +2046,24 @@ def var(
Series or DataFrame
Variance of values within each group.
"""
if not (
type(self.obj).var is Series.var or type(self.obj).var is DataFrame.var
):

def f(df, *args, **kwargs):
return self.obj._constructor(df).var(numeric_only=numeric_only)

result = self.agg(f)
return result.__finalize__(self.obj, method="groupby")

if maybe_use_numba(engine):
from pandas.core._numba.kernels import sliding_var

return self._numba_agg_general(sliding_var, engine_kwargs, ddof)
else:
return self._cython_agg_general(
"var",
alt=lambda x: Series(x).var(ddof=ddof),
alt=lambda x: self._obj_1d_constructor(x).var(ddof=ddof),
numeric_only=numeric_only,
ddof=ddof,
)
Expand Down Expand Up @@ -2184,6 +2227,17 @@ def sem(self, ddof: int = 1, numeric_only: bool = False):
Series or DataFrame
Standard error of the mean of values within each group.
"""
# TODO: think sem() needs considering more closely
if not (
type(self.obj).sem is Series.sem or type(self.obj).sem is DataFrame.sem
):

def f(df, *args, **kwargs):
return self.obj._constructor(df).sem(numeric_only=numeric_only)

result = self.agg(f)
return result.__finalize__(self.obj, method="groupby")

if numeric_only and self.obj.ndim == 1 and not is_numeric_dtype(self.obj.dtype):
raise TypeError(
f"{type(self).__name__}.sem called with "
Expand Down Expand Up @@ -2236,6 +2290,16 @@ def sum(
engine: str | None = None,
engine_kwargs: dict[str, bool] | None = None,
):
if not (
type(self.obj).sum is Series.sum or type(self.obj).sum is DataFrame.sum
):

def f(df, *args, **kwargs):
return self.obj._constructor(df).sum(numeric_only=numeric_only)

result = self.agg(f)
return result.__finalize__(self.obj, method="groupby")

if maybe_use_numba(engine):
from pandas.core._numba.kernels import sliding_sum

Expand All @@ -2260,6 +2324,16 @@ def sum(
@final
@doc(_groupby_agg_method_template, fname="prod", no=False, mc=0)
def prod(self, numeric_only: bool = False, min_count: int = 0):
if not (
type(self.obj).prod is Series.prod or type(self.obj).prod is DataFrame.prod
):

def f(df, *args, **kwargs):
return self.obj._constructor(df).prod(numeric_only=numeric_only)

result = self.agg(f)
return result.__finalize__(self.obj, method="groupby")

return self._agg_general(
numeric_only=numeric_only, min_count=min_count, alias="prod", npfunc=np.prod
)
Expand All @@ -2273,6 +2347,16 @@ def min(
engine: str | None = None,
engine_kwargs: dict[str, bool] | None = None,
):
if not (
type(self.obj).min is Series.min or type(self.obj).min is DataFrame.min
):

def f(df, *args, **kwargs):
return self.obj._constructor(df).min(numeric_only=numeric_only)

result = self.agg(f)
return result.__finalize__(self.obj, method="groupby")

if maybe_use_numba(engine):
from pandas.core._numba.kernels import sliding_min_max

Expand All @@ -2294,6 +2378,16 @@ def max(
engine: str | None = None,
engine_kwargs: dict[str, bool] | None = None,
):
if not (
type(self.obj).max is Series.max or type(self.obj).max is DataFrame.max
):

def f(df, *args, **kwargs):
return self.obj._constructor(df).max(numeric_only=numeric_only)

result = self.agg(f)
return result.__finalize__(self.obj, method="groupby")

if maybe_use_numba(engine):
from pandas.core._numba.kernels import sliding_min_max

Expand Down
92 changes: 92 additions & 0 deletions pandas/tests/groupby/test_groupby_subclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,95 @@ def test_groupby_resample_preserves_subclass(obj):
# Confirm groupby.resample() preserves dataframe type
result = df.groupby("Buyer").resample("5D").sum()
assert isinstance(result, obj)


def test_groupby_overridden_methods():
class UnitSeries(Series):
@property
def _constructor(self):
return UnitSeries

@property
def _constructor_expanddim(self):
return UnitDataFrame

def mean(self, *args, **kwargs):
return 1

def median(self, *args, **kwargs):
return 2

def std(self, *args, **kwargs):
return 3

def var(self, *args, **kwargs):
return 4

def sem(self, *args, **kwargs):
return 5

def prod(self, *args, **kwargs):
return 6

def min(self, *args, **kwargs):
return 7

def max(self, *args, **kwargs):
return 8

class UnitDataFrame(DataFrame):
@property
def _constructor(self):
return UnitDataFrame

@property
def _constructor_expanddim(self):
return UnitSeries

def mean(self, *args, **kwargs):
return 1

def median(self, *args, **kwargs):
return 2

def std(self, *args, **kwargs):
return 3

def var(self, *args, **kwargs):
return 4

def sem(self, *args, **kwargs):
return 5

def prod(self, *args, **kwargs):
return 6

def min(self, *args, **kwargs):
return 7

def max(self, *args, **kwargs):
return 8

columns = ["a", "b"]
data = np.random.rand(4, 2)
udf = UnitDataFrame(data, columns=columns)
udf["group"] = np.ones(4, dtype=int)
udf.loc[2:, "group"] = 2

assert np.all(udf.groupby("group").mean() == 1)
assert np.all(udf.groupby("group").median() == 2)
assert np.all(udf.groupby("group").std() == 3)
assert np.all(udf.groupby("group").var() == 4)
assert np.all(udf.groupby("group").sem() == 5)
assert np.all(udf.groupby("group").prod() == 6)
assert np.all(udf.groupby("group").min() == 7)
assert np.all(udf.groupby("group").max() == 8)

assert np.all(udf.groupby("group").transform("mean") == 1)
assert np.all(udf.groupby("group").transform("median") == 2)
assert np.all(udf.groupby("group").transform("std") == 3)
assert np.all(udf.groupby("group").transform("var") == 4)
assert np.all(udf.groupby("group").transform("sem") == 5)
assert np.all(udf.groupby("group").transform("prod") == 6)
assert np.all(udf.groupby("group").transform("min") == 7)
assert np.all(udf.groupby("group").transform("max") == 8)