diff --git a/doc/source/whatsnew/v0.25.1.rst b/doc/source/whatsnew/v0.25.1.rst index 34b149a6b8261..6bf9a3b705527 100644 --- a/doc/source/whatsnew/v0.25.1.rst +++ b/doc/source/whatsnew/v0.25.1.rst @@ -120,7 +120,6 @@ Groupby/resample/rolling - Bug in :meth:`pandas.core.groupby.DataFrameGroupBy.transform` where applying a timezone conversion lambda function would drop timezone information (:issue:`27496`) - Bug in windowing over read-only arrays (:issue:`27766`) - Fixed segfault in `pandas.core.groupby.DataFrameGroupBy.quantile` when an invalid quantile was passed (:issue:`27470`) -- Reshaping ^^^^^^^^^ diff --git a/pandas/core/groupby/base.py b/pandas/core/groupby/base.py index fc3bb69afd0cb..889fb2fbb75ae 100644 --- a/pandas/core/groupby/base.py +++ b/pandas/core/groupby/base.py @@ -100,7 +100,9 @@ def _gotitem(self, key, ndim, subset=None): # cythonized transformations or canned "agg+broadcast", which do not # require postprocessing of the result by transform. -cythonized_kernels = frozenset(["cumprod", "cumsum", "shift", "cummin", "cummax"]) +cythonized_kernels = frozenset( + ["cumprod", "cumsum", "shift", "cummin", "cummax", "cumcount"] +) cython_cast_blacklist = frozenset(["rank", "count", "size", "idxmin", "idxmax"]) @@ -120,7 +122,6 @@ def _gotitem(self, key, ndim, subset=None): "mean", "median", "min", - "ngroup", "nth", "nunique", "prod", @@ -158,6 +159,7 @@ def _gotitem(self, key, ndim, subset=None): "rank", "shift", "tshift", + "ngroup", ] ) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index ea2bd22cccc3d..9f350411084f3 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -583,7 +583,9 @@ def transform(self, func, *args, **kwargs): if not (func in base.transform_kernel_whitelist): msg = "'{func}' is not a valid function name for transform(name)" raise ValueError(msg.format(func=func)) - if func in base.cythonized_kernels: + + # transformation are added as well since they are broadcasted already + if func in base.cythonized_kernels or func in base.transformation_kernels: # cythonized transformation or canned "reduction+broadcast" return getattr(self, func)(*args, **kwargs) else: diff --git a/pandas/tests/groupby/test_transform.py b/pandas/tests/groupby/test_transform.py index d3972e6ba9008..63104e978839f 100644 --- a/pandas/tests/groupby/test_transform.py +++ b/pandas/tests/groupby/test_transform.py @@ -20,7 +20,11 @@ ) from pandas.core.groupby.groupby import DataError from pandas.util import testing as tm -from pandas.util.testing import assert_frame_equal, assert_series_equal +from pandas.util.testing import ( + assert_frame_equal, + assert_index_equal, + assert_series_equal, +) def assert_fp_equal(a, b): @@ -1034,8 +1038,6 @@ def test_transform_agg_by_name(reduction_func, obj): func = reduction_func g = obj.groupby(np.repeat([0, 1], 3)) - if func == "ngroup": # GH#27468 - pytest.xfail("TODO: g.transform('ngroup') doesn't work") if func == "size": # GH#27469 pytest.xfail("TODO: g.transform('size') doesn't work") @@ -1074,3 +1076,58 @@ def test_transform_lambda_with_datetimetz(): name="time", ) assert_series_equal(result, expected) + + +def test_transform_cumcount_ngroup(): + df = DataFrame(dict(a=[0, 0, 0, 1, 1, 1], b=range(6))) + g = df.groupby(np.repeat([0, 1], 3)) + + # GH 27472 + result = g.transform("cumcount") + expected = g.cumcount() + assert_series_equal(result, expected) + + # GH 27468 + result = g.transform("ngroup") + expected = g.ngroup() + assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "func", + [ + "backfill", + "bfill", + "cumcount", + "cummax", + "cummin", + "cumprod", + "cumsum", + "diff", + "ffill", + "pad", + "pct_change", + "rank", + "shift", + "ngroup", + pytest.param( + "fillna", + marks=pytest.mark.xfail(reason="GH27905: 'fillna' get empty DataFrame now"), + ), + pytest.param( + "tshift", marks=pytest.mark.xfail(reason="GH27905: Should apply to ts data") + ), + pytest.param( + "corrwith", + marks=pytest.mark.xfail(reason="GH27905: Inapplicable to the data"), + ), + ], +) +def test_transformation_kernels_length(func): + # This test is to evaluate if after transformation, the index + # of transformed data is still the same with original DataFrame + df = DataFrame(dict(a=[0, 0, 0, 1, 1, 1], b=range(6))) + g = df.groupby(np.repeat([0, 1], 3)) + + result = g.transform(func) + assert_index_equal(result.index, df.index)