Skip to content

Groupby transform cleanups #27467

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 47 commits into from Jul 25, 2019
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
071c9fa
Rename _is_cython_func to _get_cython_func
pilkibun Jul 19, 2019
0118cc8
typo
pilkibun Jul 19, 2019
2f29e4e
CLN: whitelist func in transform(func)
pilkibun Jul 19, 2019
45e9f70
comments
pilkibun Jul 19, 2019
0a1a0fb
tests
pilkibun Jul 19, 2019
b845034
black
pilkibun Jul 19, 2019
f45acd5
Deterministic test order for pytest collection
pilkibun Jul 19, 2019
514528b
use format()
pilkibun Jul 19, 2019
d662c9b
delete print func, now parameterized
pilkibun Jul 19, 2019
3c56a62
Make ita fixture
pilkibun Jul 19, 2019
52d4a61
typing
pilkibun Jul 19, 2019
a6a681a
black
pilkibun Jul 19, 2019
8b572ec
Merge remote-tracking branch 'origin/master' into groupby_transform_c…
pilkibun Jul 19, 2019
4e3ba63
Link issue
pilkibun Jul 19, 2019
0ad156b
sort names
pilkibun Jul 21, 2019
b1cb4b0
docstring
pilkibun Jul 21, 2019
d445de4
rename to use 'kernel'
pilkibun Jul 21, 2019
de2c59c
corr doesn't belong on the list
pilkibun Jul 21, 2019
bd73cff
comment
pilkibun Jul 21, 2019
f979ef8
Merge remote-tracking branch 'origin/master' into groupby_transform_c…
pilkibun Jul 21, 2019
7d7eecc
whatsnew
pilkibun Jul 21, 2019
d343686
Fix name in conftest.py
pilkibun Jul 21, 2019
6d0301a
test SeriesGroupby
pilkibun Jul 21, 2019
b6e924e
black
pilkibun Jul 21, 2019
399804c
remove cov
pilkibun Jul 22, 2019
1d158de
bfill is a reduction
pilkibun Jul 22, 2019
04dadd9
bfill is not a reduction I mean
pilkibun Jul 22, 2019
146e402
sort
pilkibun Jul 22, 2019
932d224
comment
pilkibun Jul 22, 2019
897d2fb
comment
pilkibun Jul 22, 2019
6a80828
comment
pilkibun Jul 22, 2019
2555193
remove stale test cases
pilkibun Jul 22, 2019
3d7bc64
remove duplicate test
pilkibun Jul 22, 2019
2b61329
Merge remote-tracking branch 'origin/master' into groupby_transform_c…
pilkibun Jul 22, 2019
053320b
CI
pilkibun Jul 23, 2019
994ab5f
Merge remote-tracking branch 'origin/master' into groupby_transform_c…
pilkibun Jul 23, 2019
6b90418
CI
pilkibun Jul 23, 2019
c932afd
TST: make sure every public method on Grouper is accounted for
pilkibun Jul 24, 2019
5f10720
comment
pilkibun Jul 24, 2019
4b186a4
comment
pilkibun Jul 24, 2019
de0c11c
comment
pilkibun Jul 24, 2019
cba02d6
fix review
pilkibun Jul 24, 2019
b19a46f
message
pilkibun Jul 24, 2019
3511480
Update test
pilkibun Jul 24, 2019
b5d3931
changes
pilkibun Jul 25, 2019
c97e2f2
Merge remote-tracking branch 'origin/master' into groupby_transform_c…
pilkibun Jul 25, 2019
e9343e1
reference issue
pilkibun Jul 25, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pandas/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@ def is_any_frame():
else:
result = None

f = self._is_cython_func(arg)
f = self._get_cython_func(arg)
if f and not args and not kwargs:
return getattr(self, f)(), None

Expand Down Expand Up @@ -652,7 +652,7 @@ def _shallow_copy(self, obj=None, obj_type=None, **kwargs):
kwargs[attr] = getattr(self, attr)
return obj_type(obj, **kwargs)

def _is_cython_func(self, arg):
def _get_cython_func(self, arg):
"""
if we define an internal function for this argument, return it
"""
Expand Down
65 changes: 64 additions & 1 deletion pandas/core/groupby/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,69 @@ def _gotitem(self, key, ndim, subset=None):

dataframe_apply_whitelist = common_apply_whitelist | frozenset(["dtypes", "corrwith"])

cython_transforms = frozenset(["cumprod", "cumsum", "shift", "cummin", "cummax"])
# cythonized transformations or canned "agg+broadcast", which do not
# require postprocessing of the result by transform.
cythonized_kernels = frozenset(["cumprod", "cumsum", "shift", "cummin", "cummax"])

cython_cast_blacklist = frozenset(["rank", "count", "size", "idxmin", "idxmax"])

# List of aggregation/reduction functions.
# These map each series/column to a single value
reduction_functions = frozenset(
[
"sum",
"first",
"mean",
"all",
"any",
"bfill",
"corr",
"count",
"cov",
"idxmax",
"idxmin",
"last",
"mad",
"max",
"median",
"min",
"ngroup",
"nth",
"nunique",
"prod",
"quantile",
"sem",
"size",
"skew",
"std",
"var",
]
)

# List of transformation functions.
# These map each object to a like-indexed result object
transformation_functions = frozenset(
[
"backfill",
"corrwith",
"cumcount",
"cummax",
"cummin",
"cumprod",
"cumsum",
"diff",
"ffill",
"fillna",
"rank",
"pad",
"pct_change",
"shift",
"tshift",
]
)

# Valid values of `name` for `groupby.transform(name)`
transform_recognized_functions = reduction_functions | transformation_functions
transform_recognized_functions -= {
"corr"
} # returns multindex, exclude from transform(name) for now.
32 changes: 21 additions & 11 deletions pandas/core/groupby/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,13 +572,19 @@ def _transform_general(self, func, *args, **kwargs):
def transform(self, func, *args, **kwargs):

# optimized transforms
func = self._is_cython_func(func) or func
func = self._get_cython_func(func) or func

if isinstance(func, str):
if func in base.cython_transforms:
# cythonized transform
if not (func in base.transform_recognized_functions):
msg = "'%s' is not a valid function name for transform(name)"
raise ValueError(msg % func)
if func in base.cythonized_kernels:
# cythonized transformation or canned "reduction+broadcast"
return getattr(self, func)(*args, **kwargs)
else:
# cythonized aggregation and merge
# If func is a reduction, we need to broadcast the
# result to the whole group. Compute func result
# and deal with possible broadcasting below.
result = getattr(self, func)(*args, **kwargs)
else:
return self._transform_general(func, *args, **kwargs)
Expand All @@ -589,7 +595,7 @@ def transform(self, func, *args, **kwargs):

obj = self._obj_with_exclusions

# nuiscance columns
# nuisance columns
if not result.columns.equals(obj.columns):
return self._transform_general(func, *args, **kwargs)

Expand Down Expand Up @@ -852,7 +858,7 @@ def aggregate(self, func_or_funcs=None, *args, **kwargs):
if relabeling:
ret.columns = columns
else:
cyfunc = self._is_cython_func(func_or_funcs)
cyfunc = self._get_cython_func(func_or_funcs)
if cyfunc and not args and not kwargs:
return getattr(self, cyfunc)()

Expand Down Expand Up @@ -1004,15 +1010,19 @@ def _aggregate_named(self, func, *args, **kwargs):
@Substitution(klass="Series", selected="A.")
@Appender(_transform_template)
def transform(self, func, *args, **kwargs):
func = self._is_cython_func(func) or func
func = self._get_cython_func(func) or func

# if string function
if isinstance(func, str):
if func in base.cython_transforms:
# cythonized transform
if not (func in base.transform_recognized_functions):
msg = "'%s' is not a valid function name for transform(name)"
raise ValueError(msg % func)
if func in base.cythonized_kernels:
# cythonized transform or canned "agg+broadcast"
return getattr(self, func)(*args, **kwargs)
else:
# cythonized aggregation and merge
# If func is a reduction, we need to broadcast the
# result to the whole group. Compute func result
# and deal with possible broadcasting below.
return self._transform_fast(
lambda: getattr(self, func)(*args, **kwargs), func
)
Expand Down
4 changes: 2 additions & 2 deletions pandas/core/resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -1047,7 +1047,7 @@ def _downsample(self, how, **kwargs):
**kwargs : kw args passed to how function
"""
self._set_binner()
how = self._is_cython_func(how) or how
how = self._get_cython_func(how) or how
ax = self.ax
obj = self._selected_obj

Expand Down Expand Up @@ -1197,7 +1197,7 @@ def _downsample(self, how, **kwargs):
if self.kind == "timestamp":
return super()._downsample(how, **kwargs)

how = self._is_cython_func(how) or how
how = self._get_cython_func(how) or how
ax = self.ax

if is_subperiod(ax.freq, self.freq):
Expand Down
45 changes: 45 additions & 0 deletions pandas/tests/groupby/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
concat,
date_range,
)
from pandas.core.groupby.base import reduction_functions
from pandas.core.groupby.groupby import DataError
from pandas.util import testing as tm
from pandas.util.testing import assert_frame_equal, assert_series_equal
Expand Down Expand Up @@ -1001,3 +1002,47 @@ def test_ffill_not_in_axis(func, key, val):
expected = df

assert_frame_equal(result, expected)


def test_transform_invalid_name_raises():
df = DataFrame(dict(a=[0, 1, 1, 2]))
g = df.groupby(["a", "b", "b", "c"])
with pytest.raises(ValueError, match="not a valid function name"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you parametrize these instead?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that's needed. Parameterize with one case is not useful, and as a smoke test there doesn't need to be more than one. Also lots of similar tests in this file.

g.transform("some_arbitrary_name")

# method exists on the object, but is not a valid transformation/agg
assert hasattr(g, "aggregate") # make sure the method exists
with pytest.raises(ValueError, match="not a valid function name"):
g.transform("aggregate")

# Test SeriesGroupBy
g = df["a"].groupby(["a", "b", "b", "c"])
with pytest.raises(ValueError, match="not a valid function name"):
g.transform("some_arbitrary_name")

# method exists on the object, but is not a valid transformation/agg
with pytest.raises(ValueError, match="not a valid function name"):
g.transform("aggregate")


@pytest.mark.parametrize("func", sorted(reduction_functions))
def test_transform_agg_by_name(func):

df = DataFrame(dict(a=[0, 0, 0, 1, 1, 1], b=range(6)))
g = df.groupby(np.repeat([0, 1], 3))

if func == "ngroup": # GH#27468
pytest.xfail("TODO: g.transform('ngroup') doesn't work")
if func == "size": # GH#27469
pytest.xfail("TODO: g.transform('size') doesn't work")
if func == "corr":
pytest.xfail("corr returns multiindex, excluded from transform for now")

args = {"nth": [0], "quantile": [0.5]}.get(func, [])

print(func)
result = g.transform(func, *args)
tm.assert_index_equal(result.index, df.index)

# check values replicated broadcasted across group
assert len(set(result.iloc[-3:, 1])) == 1