-
-
Notifications
You must be signed in to change notification settings - Fork 18.5k
BUG: DataFrameGroupBy.transform and ngroup do not work with cumcount #27858
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
7e461a1
1314059
8bcb313
1bcf325
5c96549
42fdb0b
9b1af14
d520bcc
c298eaf
e2c51aa
c66ec83
4dc07eb
9d60bbb
0378a74
2a8e1ed
9920344
9012e53
8d360a2
8a371c6
e54f024
5a33608
45c5339
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -100,7 +100,9 @@ def _gotitem(self, key, ndim, subset=None): | |
|
||
# cythonized transformations or canned "agg+broadcast", which do not | ||
# require postprocessing of the result by transform. | ||
cythonized_kernels = frozenset(["cumprod", "cumsum", "shift", "cummin", "cummax"]) | ||
cythonized_kernels = frozenset( | ||
["cumprod", "cumsum", "shift", "cummin", "cummax", "cumcount"] | ||
) | ||
|
||
cython_cast_blacklist = frozenset(["rank", "count", "size", "idxmin", "idxmax"]) | ||
|
||
|
@@ -120,7 +122,6 @@ def _gotitem(self, key, ndim, subset=None): | |
"mean", | ||
"median", | ||
"min", | ||
"ngroup", | ||
"nth", | ||
"nunique", | ||
"prod", | ||
|
@@ -158,6 +159,7 @@ def _gotitem(self, key, ndim, subset=None): | |
"rank", | ||
"shift", | ||
"tshift", | ||
"ngroup", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you put it in the alphabetical order |
||
] | ||
) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -583,7 +583,9 @@ def transform(self, func, *args, **kwargs): | |
if not (func in base.transform_kernel_whitelist): | ||
msg = "'{func}' is not a valid function name for transform(name)" | ||
raise ValueError(msg.format(func=func)) | ||
if func in base.cythonized_kernels: | ||
|
||
# transformation are added as well since they are broadcasted already | ||
if func in base.cythonized_kernels or func in base.transformation_kernels: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of this can't you just add cumcount to the transformation list? This somewhat blurs the line between |
||
# cythonized transformation or canned "reduction+broadcast" | ||
TomAugspurger marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return getattr(self, func)(*args, **kwargs) | ||
else: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,7 +20,11 @@ | |
) | ||
from pandas.core.groupby.groupby import DataError | ||
from pandas.util import testing as tm | ||
from pandas.util.testing import assert_frame_equal, assert_series_equal | ||
from pandas.util.testing import ( | ||
assert_frame_equal, | ||
assert_index_equal, | ||
assert_series_equal, | ||
) | ||
|
||
|
||
def assert_fp_equal(a, b): | ||
|
@@ -1034,8 +1038,6 @@ def test_transform_agg_by_name(reduction_func, obj): | |
func = reduction_func | ||
g = obj.groupby(np.repeat([0, 1], 3)) | ||
|
||
if func == "ngroup": # GH#27468 | ||
pytest.xfail("TODO: g.transform('ngroup') doesn't work") | ||
if func == "size": # GH#27469 | ||
pytest.xfail("TODO: g.transform('size') doesn't work") | ||
|
||
|
@@ -1074,3 +1076,58 @@ def test_transform_lambda_with_datetimetz(): | |
name="time", | ||
) | ||
assert_series_equal(result, expected) | ||
|
||
|
||
def test_transform_cumcount_ngroup(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this not already covered by test below? Ideally we can rely on fixtures rather than one-off tests like this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's a bit different than below... the test below is just to ensure after But you are right, this should not be in one-off tests... will look for if there is fixture for this already @WillAyd |
||
df = DataFrame(dict(a=[0, 0, 0, 1, 1, 1], b=range(6))) | ||
g = df.groupby(np.repeat([0, 1], 3)) | ||
|
||
# GH 27472 | ||
result = g.transform("cumcount") | ||
expected = g.cumcount() | ||
assert_series_equal(result, expected) | ||
|
||
# GH 27468 | ||
result = g.transform("ngroup") | ||
expected = g.ngroup() | ||
assert_series_equal(result, expected) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"func", | ||
[ | ||
"backfill", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there not already a fixture for these we can leverage? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok, will change! thanks for review! |
||
"bfill", | ||
"cumcount", | ||
"cummax", | ||
"cummin", | ||
"cumprod", | ||
"cumsum", | ||
"diff", | ||
"ffill", | ||
"pad", | ||
"pct_change", | ||
"rank", | ||
"shift", | ||
"ngroup", | ||
pytest.param( | ||
"fillna", | ||
marks=pytest.mark.xfail(reason="GH27905: 'fillna' get empty DataFrame now"), | ||
), | ||
pytest.param( | ||
"tshift", marks=pytest.mark.xfail(reason="GH27905: Should apply to ts data") | ||
), | ||
pytest.param( | ||
"corrwith", | ||
marks=pytest.mark.xfail(reason="GH27905: Inapplicable to the data"), | ||
), | ||
], | ||
) | ||
def test_transformation_kernels_length(func): | ||
# This test is to evaluate if after transformation, the index | ||
# of transformed data is still the same with original DataFrame | ||
df = DataFrame(dict(a=[0, 0, 0, 1, 1, 1], b=range(6))) | ||
g = df.groupby(np.repeat([0, 1], 3)) | ||
|
||
result = g.transform(func) | ||
assert_index_equal(result.index, df.index) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
leave this one