Skip to content

SubClassedDataFrame.groupby().mean() etc. use method of SubClassedDataFrame #51765

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 39 commits into from
Closed
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
df9b39a
a whole bunch of 1d constructors
AdamOrmondroyd Mar 2, 2023
de8c231
got the first three of Lukas' assertions working
AdamOrmondroyd Mar 2, 2023
d057cd0
remove print statements
AdamOrmondroyd Mar 2, 2023
e56b488
Merge branch 'main' into groupby
AdamOrmondroyd Mar 3, 2023
5338d3f
create test for overridden methods
AdamOrmondroyd Mar 3, 2023
5fe7f82
also attack median, std, var, sem, prod, sum, min and max
AdamOrmondroyd Mar 3, 2023
8f32b12
Merge branch 'main' into groupby
AdamOrmondroyd Mar 3, 2023
4aa2b85
add tests for test of methods
AdamOrmondroyd Mar 3, 2023
8d7346d
change 1d constructors to constructors
AdamOrmondroyd Mar 3, 2023
37ae233
tidy up
AdamOrmondroyd Mar 3, 2023
aa57cc2
change to np.all
AdamOrmondroyd Mar 3, 2023
b3df075
remove deliberate test failure
AdamOrmondroyd Mar 3, 2023
adc132a
check that self._obj_1d_constructor is Series
AdamOrmondroyd Mar 3, 2023
31868ff
add entry to docs
AdamOrmondroyd Mar 3, 2023
c0b2ad7
Merge branch 'main' into groupby
AdamOrmondroyd Mar 3, 2023
98b7986
Merge branch 'main' into groupby
AdamOrmondroyd Mar 6, 2023
053c865
Merge branch 'main' into groupby
AdamOrmondroyd Mar 6, 2023
2efa052
check for equality of mean methods
AdamOrmondroyd Mar 6, 2023
1505a1c
repeat for other methods
AdamOrmondroyd Mar 6, 2023
af9ac26
also test Series
AdamOrmondroyd Mar 6, 2023
f4bc548
pass through numeric_only
AdamOrmondroyd Mar 6, 2023
12a9fa8
reinstate type hinting
AdamOrmondroyd Mar 6, 2023
f46eea9
add type() to method comparison
AdamOrmondroyd Mar 6, 2023
185a3c1
test transform
AdamOrmondroyd Mar 7, 2023
bf9bde6
correct _constructor
AdamOrmondroyd Mar 7, 2023
036d662
Merge branch 'main' into groupby
AdamOrmondroyd Mar 7, 2023
f348648
Merge branch 'main' into groupby
AdamOrmondroyd Mar 13, 2023
48ceb0a
Merge branch 'main' into groupby
AdamOrmondroyd Mar 27, 2023
a6be1ea
remove unnecessary(?) if statement
AdamOrmondroyd Mar 27, 2023
f0ed14a
Merge branch 'main' into groupby
AdamOrmondroyd Apr 24, 2023
6631b1e
first pass at decorator
AdamOrmondroyd Apr 25, 2023
94dc186
add decorator to other methods
AdamOrmondroyd Apr 25, 2023
310c339
missed max()
AdamOrmondroyd Apr 25, 2023
27c4ed9
add @wraps
AdamOrmondroyd Apr 25, 2023
9bafd9a
Merge branch 'main' into groupby
AdamOrmondroyd Apr 25, 2023
963b3fe
Merge branch 'main' into groupby
AdamOrmondroyd Apr 26, 2023
8a9f30f
add tests for series example
AdamOrmondroyd Apr 26, 2023
518d42e
Merge branch 'main' into groupby
AdamOrmondroyd May 18, 2023
6320057
Merge branch 'main' into groupby
AdamOrmondroyd May 26, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions pandas/core/groupby/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,9 @@ def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs)
)

# result is a dict whose keys are the elements of result_index
result = Series(result, index=self.grouper.result_index)
result = self._obj_1d_constructor(
result, index=self.grouper.result_index
)
result = self._wrap_aggregated_output(result)
return result

Expand Down Expand Up @@ -687,7 +689,7 @@ def value_counts(
# in a backward compatible way
# GH38672 relates to categorical dtype
ser = self.apply(
Series.value_counts,
self._obj_1d_constructor.value_counts,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is trouble bc in general we can't assume that _constructor is a class

Copy link
Contributor Author

@AdamOrmondroyd AdamOrmondroyd Mar 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What else could it be (practically speaking, I know it's Callable)?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Geopandas has a callable that can dispatch to different classes. @jorisvandenbossche has argued against deprecating allowing this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added a check whether self._obj_1d_constructor is a Series

normalize=normalize,
sort=sort,
ascending=ascending,
Expand All @@ -706,7 +708,7 @@ def value_counts(
llab = lambda lab, inc: lab[inc]
else:
# lab is a Categorical with categories an IntervalIndex
cat_ser = cut(Series(val), bins, include_lowest=True)
cat_ser = cut(self.obj._constructor(val), bins, include_lowest=True)
cat_obj = cast("Categorical", cat_ser._values)
lev = cat_obj.categories
lab = lev.take(
Expand Down Expand Up @@ -1289,9 +1291,9 @@ def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs)
elif relabeling:
# this should be the only (non-raising) case with relabeling
# used reordered index of columns
result = cast(DataFrame, result)
result = cast(self.obj._constructor, result)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these will be wrong if _constructor is not a class

result = result.iloc[:, order]
result = cast(DataFrame, result)
result = cast(self.obj._constructor, result)
# error: Incompatible types in assignment (expression has type
# "Optional[List[str]]", variable has type
# "Union[Union[Union[ExtensionArray, ndarray[Any, Any]],
Expand Down Expand Up @@ -1334,7 +1336,7 @@ def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs)
else:
# GH#32040, GH#35246
# e.g. test_groupby_as_index_select_column_sum_empty_df
result = cast(DataFrame, result)
result = cast(self._obj_1d_constructor, result)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_constructor, not _1d_constructor. Also i think this is OK as is

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have also wrestled with whether results should stick to the subclass or just be DataFrames

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does mypy resolve self._constructor to? I'd be surprised if it wasn't Any. You can check by adding reveal_type(self._obj_1d_constructor) on it's own line and then run mypy.

result.columns = self._obj_with_exclusions.columns.copy()

if not self.as_index:
Expand Down Expand Up @@ -1462,7 +1464,7 @@ def _wrap_applied_output_series(
is_transform: bool,
) -> DataFrame | Series:
kwargs = first_not_none._construct_axes_dict()
backup = Series(**kwargs)
backup = self._obj_1d_constructor(**kwargs)
values = [x if (x is not None) else backup for x in values]

all_indexed_same = all_indexes_same(x.index for x in values)
Expand Down Expand Up @@ -1857,7 +1859,9 @@ def _apply_to_column_groupbys(self, func) -> DataFrame:

if not len(results):
# concat would raise
res_df = DataFrame([], columns=columns, index=self.grouper.result_index)
res_df = self.obj._constructor(
[], columns=columns, index=self.grouper.result_index
)
else:
res_df = concat(results, keys=columns, axis=1)

Expand Down
83 changes: 79 additions & 4 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -1629,7 +1629,7 @@ def _cumcount_array(self, ascending: bool = True) -> np.ndarray:

@final
@property
def _obj_1d_constructor(self) -> Callable:
def _obj_1d_constructor(self):
# GH28330 preserve subclassed Series/DataFrames
if isinstance(self.obj, DataFrame):
return self.obj._constructor_sliced
Expand Down Expand Up @@ -1837,14 +1837,24 @@ def mean(
Name: B, dtype: float64
"""

if not (type(self.obj) == Series or type(self.obj) == DataFrame):

def f(df, *args, **kwargs):
return self.obj._constructor(df).mean()

result = self.agg(f)
return result.__finalize__(self.obj, method="groupby")

if maybe_use_numba(engine):
from pandas.core._numba.kernels import sliding_mean

return self._numba_agg_general(sliding_mean, engine_kwargs)
else:
result = self._cython_agg_general(
"mean",
alt=lambda x: Series(x).mean(numeric_only=numeric_only),
alt=lambda x: self._obj_1d_constructor(x).mean(
numeric_only=numeric_only
),
numeric_only=numeric_only,
)
return result.__finalize__(self.obj, method="groupby")
Expand All @@ -1870,9 +1880,17 @@ def median(self, numeric_only: bool = False):
Series or DataFrame
Median of values within each group.
"""
if not (type(self.obj) == Series or type(self.obj) == DataFrame):

def f(df, *args, **kwargs):
return self.obj._constructor(df).median()

result = self.agg(f)
return result.__finalize__(self.obj, method="groupby")

result = self._cython_agg_general(
"median",
alt=lambda x: Series(x).median(numeric_only=numeric_only),
alt=lambda x: self._obj_1d_constructor(x).median(numeric_only=numeric_only),
numeric_only=numeric_only,
)
return result.__finalize__(self.obj, method="groupby")
Expand Down Expand Up @@ -1928,6 +1946,14 @@ def std(
Series or DataFrame
Standard deviation of values within each group.
"""
if not (type(self.obj) == Series or type(self.obj) == DataFrame):

def f(df, *args, **kwargs):
return self.obj._constructor(df).std()

result = self.agg(f)
return result.__finalize__(self.obj, method="groupby")

if maybe_use_numba(engine):
from pandas.core._numba.kernels import sliding_var

Expand Down Expand Up @@ -2011,14 +2037,22 @@ def var(
Series or DataFrame
Variance of values within each group.
"""
if not (type(self.obj) == Series or type(self.obj) == DataFrame):

def f(df, *args, **kwargs):
return self.obj._constructor(df).var()

result = self.agg(f)
return result.__finalize__(self.obj, method="groupby")

if maybe_use_numba(engine):
from pandas.core._numba.kernels import sliding_var

return self._numba_agg_general(sliding_var, engine_kwargs, ddof)
else:
return self._cython_agg_general(
"var",
alt=lambda x: Series(x).var(ddof=ddof),
alt=lambda x: self._obj_1d_constructor(x).var(ddof=ddof),
numeric_only=numeric_only,
ddof=ddof,
)
Expand Down Expand Up @@ -2180,6 +2214,15 @@ def sem(self, ddof: int = 1, numeric_only: bool = False):
Series or DataFrame
Standard error of the mean of values within each group.
"""
# TODO: think sem() needs considering more closely
if not (type(self.obj) == Series or type(self.obj) == DataFrame):

def f(df, *args, **kwargs):
return self.obj._constructor(df).sem()

result = self.agg(f)
return result.__finalize__(self.obj, method="groupby")

if numeric_only and self.obj.ndim == 1 and not is_numeric_dtype(self.obj.dtype):
raise TypeError(
f"{type(self).__name__}.sem called with "
Expand Down Expand Up @@ -2238,6 +2281,14 @@ def sum(
engine: str | None = None,
engine_kwargs: dict[str, bool] | None = None,
):
if not (type(self.obj) == Series or type(self.obj) == DataFrame):

def f(df, *args, **kwargs):
return self.obj._constructor(df).sum()

result = self.agg(f)
return result.__finalize__(self.obj, method="groupby")

if maybe_use_numba(engine):
from pandas.core._numba.kernels import sliding_sum

Expand All @@ -2262,6 +2313,14 @@ def sum(
@final
@doc(_groupby_agg_method_template, fname="prod", no=False, mc=0)
def prod(self, numeric_only: bool = False, min_count: int = 0):
if not (type(self.obj) == Series or type(self.obj) == DataFrame):

def f(df, *args, **kwargs):
return self.obj._constructor(df).prod()

result = self.agg(f)
return result.__finalize__(self.obj, method="groupby")

return self._agg_general(
numeric_only=numeric_only, min_count=min_count, alias="prod", npfunc=np.prod
)
Expand All @@ -2275,6 +2334,14 @@ def min(
engine: str | None = None,
engine_kwargs: dict[str, bool] | None = None,
):
if not (type(self.obj) == Series or type(self.obj) == DataFrame):

def f(df, *args, **kwargs):
return self.obj._constructor(df).min()

result = self.agg(f)
return result.__finalize__(self.obj, method="groupby")

if maybe_use_numba(engine):
from pandas.core._numba.kernels import sliding_min_max

Expand All @@ -2296,6 +2363,14 @@ def max(
engine: str | None = None,
engine_kwargs: dict[str, bool] | None = None,
):
if not (type(self.obj) == Series or type(self.obj) == DataFrame):

def f(df, *args, **kwargs):
return self.obj._constructor(df).max()

result = self.agg(f)
return result.__finalize__(self.obj, method="groupby")

if maybe_use_numba(engine):
from pandas.core._numba.kernels import sliding_min_max

Expand Down
86 changes: 86 additions & 0 deletions pandas/tests/groupby/test_groupby_subclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,89 @@ def test_groupby_resample_preserves_subclass(obj):
# Confirm groupby.resample() preserves dataframe type
result = df.groupby("Buyer").resample("5D").sum()
assert isinstance(result, obj)


def test_groupby_overridden_methods():
class UnitSeries(Series):
@property
def _constructor(self):
return UnitSeries

@property
def _constructor_expanddim(self):
return UnitDataFrame

def mean(self, *args, **kwargs):
return 1

def median(self, *args, **kwargs):
return 2

def std(self, *args, **kwargs):
return 3

def var(self, *args, **kwargs):
return 4

def sem(self, *args, **kwargs):
return 5

def prod(self, *args, **kwargs):
return 6

def min(self, *args, **kwargs):
return 7

def max(self, *args, **kwargs):
return 8

class UnitDataFrame(DataFrame):
@property
def _constructor(self):
return UnitDataFrame

@property
def _constructor_expanddim(self):
return UnitSeries

def mean(self, *args, **kwargs):
print("UnitDataFrame mean")
return 1

def median(self, *args, **kwargs):
return 2

def std(self, *args, **kwargs):
return 3

def var(self, *args, **kwargs):
return 4

def sem(self, *args, **kwargs):
return 5

def prod(self, *args, **kwargs):
return 6

def min(self, *args, **kwargs):
return 7

def max(self, *args, **kwargs):
return 8

params = ["a", "b"]
data = np.random.rand(4, 2)
udf = UnitDataFrame(data, columns=params)
udf["group"] = np.ones(4, dtype=int)
udf.loc[2:, "group"] = 2

assert udf.mean() == 1
assert all(udf.groupby("group").mean() == 1)
assert all(udf.groupby("group").median() == 1)
assert all(udf.groupby("group").std() == 1)
assert all(udf.groupby("group").var() == 1)
assert all(udf.groupby("group").sem() == 1)
assert all(udf.groupby("group").prod() == 1)
assert all(udf.groupby("group").min() == 1)
assert all(udf.groupby("group").max() == 1)
# print(udf.groupby('group').beans()) # AttributeError