Skip to content

Commit eeb264c

Browse files
pilkibunjreback
pilkibun
authored andcommitted
Groupby transform cleanups (#27467)
1 parent 5b1a870 commit eeb264c

File tree

9 files changed

+234
-18
lines changed

9 files changed

+234
-18
lines changed

doc/source/whatsnew/v1.0.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ Backwards incompatible API changes
3939

4040
.. _whatsnew_1000.api.other:
4141

42-
-
42+
- :class:`pandas.core.groupby.GroupBy.transform` now raises on invalid operation names (:issue:`27489`).
4343
-
4444

4545
Other API changes

pandas/core/base.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import builtins
55
from collections import OrderedDict
66
import textwrap
7+
from typing import Optional
78
import warnings
89

910
import numpy as np
@@ -566,7 +567,7 @@ def is_any_frame():
566567
else:
567568
result = None
568569

569-
f = self._is_cython_func(arg)
570+
f = self._get_cython_func(arg)
570571
if f and not args and not kwargs:
571572
return getattr(self, f)(), None
572573

@@ -653,7 +654,7 @@ def _shallow_copy(self, obj=None, obj_type=None, **kwargs):
653654
kwargs[attr] = getattr(self, attr)
654655
return obj_type(obj, **kwargs)
655656

656-
def _is_cython_func(self, arg):
657+
def _get_cython_func(self, arg: str) -> Optional[str]:
657658
"""
658659
if we define an internal function for this argument, return it
659660
"""

pandas/core/groupby/base.py

+98-1
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,103 @@ def _gotitem(self, key, ndim, subset=None):
9898

9999
dataframe_apply_whitelist = common_apply_whitelist | frozenset(["dtypes", "corrwith"])
100100

101-
cython_transforms = frozenset(["cumprod", "cumsum", "shift", "cummin", "cummax"])
101+
# cythonized transformations or canned "agg+broadcast", which do not
102+
# require postprocessing of the result by transform.
103+
cythonized_kernels = frozenset(["cumprod", "cumsum", "shift", "cummin", "cummax"])
102104

103105
cython_cast_blacklist = frozenset(["rank", "count", "size", "idxmin", "idxmax"])
106+
107+
# List of aggregation/reduction functions.
108+
# These map each group to a single numeric value
109+
reduction_kernels = frozenset(
110+
[
111+
"all",
112+
"any",
113+
"count",
114+
"first",
115+
"idxmax",
116+
"idxmin",
117+
"last",
118+
"mad",
119+
"max",
120+
"mean",
121+
"median",
122+
"min",
123+
"ngroup",
124+
"nth",
125+
"nunique",
126+
"prod",
127+
# as long as `quantile`'s signature accepts only
128+
# a single quantile value, it's a reduction.
129+
# GH#27526 might change that.
130+
"quantile",
131+
"sem",
132+
"size",
133+
"skew",
134+
"std",
135+
"sum",
136+
"var",
137+
]
138+
)
139+
140+
# List of transformation functions.
141+
# a transformation is a function that, for each group,
142+
# produces a result that has the same shape as the group.
143+
transformation_kernels = frozenset(
144+
[
145+
"backfill",
146+
"bfill",
147+
"corrwith",
148+
"cumcount",
149+
"cummax",
150+
"cummin",
151+
"cumprod",
152+
"cumsum",
153+
"diff",
154+
"ffill",
155+
"fillna",
156+
"pad",
157+
"pct_change",
158+
"rank",
159+
"shift",
160+
"tshift",
161+
]
162+
)
163+
164+
# these are all the public methods on Grouper which don't belong
165+
# in either of the above lists
166+
groupby_other_methods = frozenset(
167+
[
168+
"agg",
169+
"aggregate",
170+
"apply",
171+
"boxplot",
172+
# corr and cov return ngroups*ncolumns rows, so they
173+
# are neither a transformation nor a reduction
174+
"corr",
175+
"cov",
176+
"describe",
177+
"dtypes",
178+
"expanding",
179+
"filter",
180+
"get_group",
181+
"groups",
182+
"head",
183+
"hist",
184+
"indices",
185+
"ndim",
186+
"ngroups",
187+
"ohlc",
188+
"pipe",
189+
"plot",
190+
"resample",
191+
"rolling",
192+
"tail",
193+
"take",
194+
"transform",
195+
]
196+
)
197+
# Valid values of `name` for `groupby.transform(name)`
198+
# NOTE: do NOT edit this directly. New additions should be inserted
199+
# into the appropriate list above.
200+
transform_kernel_whitelist = reduction_kernels | transformation_kernels

pandas/core/groupby/generic.py

+21-11
Original file line numberDiff line numberDiff line change
@@ -573,13 +573,19 @@ def _transform_general(self, func, *args, **kwargs):
573573
def transform(self, func, *args, **kwargs):
574574

575575
# optimized transforms
576-
func = self._is_cython_func(func) or func
576+
func = self._get_cython_func(func) or func
577+
577578
if isinstance(func, str):
578-
if func in base.cython_transforms:
579-
# cythonized transform
579+
if not (func in base.transform_kernel_whitelist):
580+
msg = "'{func}' is not a valid function name for transform(name)"
581+
raise ValueError(msg.format(func=func))
582+
if func in base.cythonized_kernels:
583+
# cythonized transformation or canned "reduction+broadcast"
580584
return getattr(self, func)(*args, **kwargs)
581585
else:
582-
# cythonized aggregation and merge
586+
# If func is a reduction, we need to broadcast the
587+
# result to the whole group. Compute func result
588+
# and deal with possible broadcasting below.
583589
result = getattr(self, func)(*args, **kwargs)
584590
else:
585591
return self._transform_general(func, *args, **kwargs)
@@ -590,7 +596,7 @@ def transform(self, func, *args, **kwargs):
590596

591597
obj = self._obj_with_exclusions
592598

593-
# nuiscance columns
599+
# nuisance columns
594600
if not result.columns.equals(obj.columns):
595601
return self._transform_general(func, *args, **kwargs)
596602

@@ -853,7 +859,7 @@ def aggregate(self, func_or_funcs=None, *args, **kwargs):
853859
if relabeling:
854860
ret.columns = columns
855861
else:
856-
cyfunc = self._is_cython_func(func_or_funcs)
862+
cyfunc = self._get_cython_func(func_or_funcs)
857863
if cyfunc and not args and not kwargs:
858864
return getattr(self, cyfunc)()
859865

@@ -1005,15 +1011,19 @@ def _aggregate_named(self, func, *args, **kwargs):
10051011
@Substitution(klass="Series", selected="A.")
10061012
@Appender(_transform_template)
10071013
def transform(self, func, *args, **kwargs):
1008-
func = self._is_cython_func(func) or func
1014+
func = self._get_cython_func(func) or func
10091015

1010-
# if string function
10111016
if isinstance(func, str):
1012-
if func in base.cython_transforms:
1013-
# cythonized transform
1017+
if not (func in base.transform_kernel_whitelist):
1018+
msg = "'{func}' is not a valid function name for transform(name)"
1019+
raise ValueError(msg.format(func=func))
1020+
if func in base.cythonized_kernels:
1021+
# cythonized transform or canned "agg+broadcast"
10141022
return getattr(self, func)(*args, **kwargs)
10151023
else:
1016-
# cythonized aggregation and merge
1024+
# If func is a reduction, we need to broadcast the
1025+
# result to the whole group. Compute func result
1026+
# and deal with possible broadcasting below.
10171027
return self._transform_fast(
10181028
lambda: getattr(self, func)(*args, **kwargs), func
10191029
)

pandas/core/groupby/groupby.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ class providing the base-class of operations.
261261
262262
* f must return a value that either has the same shape as the input
263263
subframe or can be broadcast to the shape of the input subframe.
264-
For example, f returns a scalar it will be broadcast to have the
264+
For example, if `f` returns a scalar it will be broadcast to have the
265265
same shape as the input subframe.
266266
* if this is a DataFrame, f must support application column-by-column
267267
in the subframe. If f also supports application to the entire subframe,

pandas/core/resample.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1046,7 +1046,7 @@ def _downsample(self, how, **kwargs):
10461046
**kwargs : kw args passed to how function
10471047
"""
10481048
self._set_binner()
1049-
how = self._is_cython_func(how) or how
1049+
how = self._get_cython_func(how) or how
10501050
ax = self.ax
10511051
obj = self._selected_obj
10521052

@@ -1194,7 +1194,7 @@ def _downsample(self, how, **kwargs):
11941194
if self.kind == "timestamp":
11951195
return super()._downsample(how, **kwargs)
11961196

1197-
how = self._is_cython_func(how) or how
1197+
how = self._get_cython_func(how) or how
11981198
ax = self.ax
11991199

12001200
if is_subperiod(ax.freq, self.freq):

pandas/tests/groupby/conftest.py

+8
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import pytest
33

44
from pandas import DataFrame, MultiIndex
5+
from pandas.core.groupby.base import reduction_kernels
56
from pandas.util import testing as tm
67

78

@@ -102,3 +103,10 @@ def three_group():
102103
"F": np.random.randn(11),
103104
}
104105
)
106+
107+
108+
@pytest.fixture(params=sorted(reduction_kernels))
109+
def reduction_func(request):
110+
"""yields the string names of all groupby reduction functions, one at a time.
111+
"""
112+
return request.param

pandas/tests/groupby/test_transform.py

+49
Original file line numberDiff line numberDiff line change
@@ -1003,6 +1003,55 @@ def test_ffill_not_in_axis(func, key, val):
10031003
assert_frame_equal(result, expected)
10041004

10051005

1006+
def test_transform_invalid_name_raises():
1007+
# GH#27486
1008+
df = DataFrame(dict(a=[0, 1, 1, 2]))
1009+
g = df.groupby(["a", "b", "b", "c"])
1010+
with pytest.raises(ValueError, match="not a valid function name"):
1011+
g.transform("some_arbitrary_name")
1012+
1013+
# method exists on the object, but is not a valid transformation/agg
1014+
assert hasattr(g, "aggregate") # make sure the method exists
1015+
with pytest.raises(ValueError, match="not a valid function name"):
1016+
g.transform("aggregate")
1017+
1018+
# Test SeriesGroupBy
1019+
g = df["a"].groupby(["a", "b", "b", "c"])
1020+
with pytest.raises(ValueError, match="not a valid function name"):
1021+
g.transform("some_arbitrary_name")
1022+
1023+
1024+
@pytest.mark.parametrize(
1025+
"obj",
1026+
[
1027+
DataFrame(
1028+
dict(a=[0, 0, 0, 1, 1, 1], b=range(6)), index=["A", "B", "C", "D", "E", "F"]
1029+
),
1030+
Series([0, 0, 0, 1, 1, 1], index=["A", "B", "C", "D", "E", "F"]),
1031+
],
1032+
)
1033+
def test_transform_agg_by_name(reduction_func, obj):
1034+
func = reduction_func
1035+
g = obj.groupby(np.repeat([0, 1], 3))
1036+
1037+
if func == "ngroup": # GH#27468
1038+
pytest.xfail("TODO: g.transform('ngroup') doesn't work")
1039+
if func == "size": # GH#27469
1040+
pytest.xfail("TODO: g.transform('size') doesn't work")
1041+
1042+
args = {"nth": [0], "quantile": [0.5]}.get(func, [])
1043+
1044+
result = g.transform(func, *args)
1045+
1046+
# this is the *definition* of a transformation
1047+
tm.assert_index_equal(result.index, obj.index)
1048+
if hasattr(obj, "columns"):
1049+
tm.assert_index_equal(result.columns, obj.columns)
1050+
1051+
# verify that values were broadcasted across each group
1052+
assert len(set(DataFrame(result).iloc[-3:, -1])) == 1
1053+
1054+
10061055
def test_transform_lambda_with_datetimetz():
10071056
# GH 27496
10081057
df = DataFrame(

pandas/tests/groupby/test_whitelist.py

+51
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,11 @@
99
import pytest
1010

1111
from pandas import DataFrame, Index, MultiIndex, Series, date_range
12+
from pandas.core.groupby.base import (
13+
groupby_other_methods,
14+
reduction_kernels,
15+
transformation_kernels,
16+
)
1217
from pandas.util import testing as tm
1318

1419
AGG_FUNCTIONS = [
@@ -376,3 +381,49 @@ def test_groupby_selection_with_methods(df):
376381
tm.assert_frame_equal(
377382
g.filter(lambda x: len(x) == 3), g_exp.filter(lambda x: len(x) == 3)
378383
)
384+
385+
386+
def test_all_methods_categorized(mframe):
387+
grp = mframe.groupby(mframe.iloc[:, 0])
388+
names = {_ for _ in dir(grp) if not _.startswith("_")} - set(mframe.columns)
389+
new_names = set(names)
390+
new_names -= reduction_kernels
391+
new_names -= transformation_kernels
392+
new_names -= groupby_other_methods
393+
394+
assert not (reduction_kernels & transformation_kernels)
395+
assert not (reduction_kernels & groupby_other_methods)
396+
assert not (transformation_kernels & groupby_other_methods)
397+
398+
# new public method?
399+
if new_names:
400+
msg = """
401+
There are uncatgeorized methods defined on the Grouper class:
402+
{names}.
403+
404+
Was a new method recently added?
405+
406+
Every public method On Grouper must appear in exactly one the
407+
following three lists defined in pandas.core.groupby.base:
408+
- `reduction_kernels`
409+
- `transformation_kernels`
410+
- `groupby_other_methods`
411+
see the comments in pandas/core/groupby/base.py for guidance on
412+
how to fix this test.
413+
"""
414+
raise AssertionError(msg.format(names=names))
415+
416+
# removed a public method?
417+
all_categorized = reduction_kernels | transformation_kernels | groupby_other_methods
418+
print(names)
419+
print(all_categorized)
420+
if not (names == all_categorized):
421+
msg = """
422+
Some methods which are supposed to be on the Grouper class
423+
are missing:
424+
{names}.
425+
426+
They're still defined in one of the lists that live in pandas/core/groupby/base.py.
427+
If you removed a method, you should update them
428+
"""
429+
raise AssertionError(msg.format(names=all_categorized - names))

0 commit comments

Comments
 (0)