diff --git a/doc/source/whatsnew/v2.1.0.rst b/doc/source/whatsnew/v2.1.0.rst index baacc8c421414..e346870969741 100644 --- a/doc/source/whatsnew/v2.1.0.rst +++ b/doc/source/whatsnew/v2.1.0.rst @@ -103,6 +103,7 @@ Other enhancements - :meth:`DataFrame.unstack` gained the ``sort`` keyword to dictate whether the resulting :class:`MultiIndex` levels are sorted (:issue:`15105`) - :meth:`DataFrameGroupby.agg` and :meth:`DataFrameGroupby.transform` now support grouping by multiple keys when the index is not a :class:`MultiIndex` for ``engine="numba"`` (:issue:`53486`) - :meth:`SeriesGroupby.agg` and :meth:`DataFrameGroupby.agg` now support passing in multiple functions for ``engine="numba"`` (:issue:`53486`) +- :meth:`SeriesGroupby.transform` and :meth:`DataFrameGroupby.transform` now support passing in a string as the function for ``engine="numba"`` (:issue:`53579`) - Added ``engine_kwargs`` parameter to :meth:`DataFrame.to_excel` (:issue:`53220`) - Added a new parameter ``by_row`` to :meth:`Series.apply`. When set to ``False`` the supplied callables will always operate on the whole Series (:issue:`53400`). - Many read/to_* functions, such as :meth:`DataFrame.to_pickle` and :func:`read_csv`, support forwarding compression arguments to lzma.LZMAFile (:issue:`52979`) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index eef3de3e61f29..2b1ff05f18d5e 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -523,10 +523,16 @@ def _cython_transform( return obj._constructor(result, index=self.obj.index, name=obj.name) - def _transform_general(self, func: Callable, *args, **kwargs) -> Series: + def _transform_general( + self, func: Callable, engine, engine_kwargs, *args, **kwargs + ) -> Series: """ Transform with a callable `func`. """ + if maybe_use_numba(engine): + return self._transform_with_numba( + func, *args, engine_kwargs=engine_kwargs, **kwargs + ) assert callable(func) klass = type(self.obj) @@ -1654,7 +1660,11 @@ def arr_func(bvalues: ArrayLike) -> ArrayLike: res_df = self._maybe_transpose_result(res_df) return res_df - def _transform_general(self, func, *args, **kwargs): + def _transform_general(self, func, engine, engine_kwargs, *args, **kwargs): + if maybe_use_numba(engine): + return self._transform_with_numba( + func, *args, engine_kwargs=engine_kwargs, **kwargs + ) from pandas.core.reshape.concat import concat applied = [] diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index b15b5b11c3d5e..e447377db9e55 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -1787,22 +1787,20 @@ def _cython_transform( @final def _transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs): - if maybe_use_numba(engine): - return self._transform_with_numba( - func, *args, engine_kwargs=engine_kwargs, **kwargs - ) - # optimized transforms func = com.get_cython_func(func) or func if not isinstance(func, str): - return self._transform_general(func, *args, **kwargs) + return self._transform_general(func, engine, engine_kwargs, *args, **kwargs) elif func not in base.transform_kernel_allowlist: msg = f"'{func}' is not a valid function name for transform(name)" raise ValueError(msg) elif func in base.cythonized_kernels or func in base.transformation_kernels: # cythonized transform or canned "agg+broadcast" + if engine is not None: + kwargs["engine"] = engine + kwargs["engine_kwargs"] = engine_kwargs return getattr(self, func)(*args, **kwargs) else: @@ -1817,6 +1815,9 @@ def _transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs): with com.temp_setattr(self, "as_index", True): # GH#49834 - result needs groups in the index for # _wrap_transform_fast_result + if engine is not None: + kwargs["engine"] = engine + kwargs["engine_kwargs"] = engine_kwargs result = getattr(self, func)(*args, **kwargs) return self._wrap_transform_fast_result(result) diff --git a/pandas/tests/groupby/transform/test_numba.py b/pandas/tests/groupby/transform/test_numba.py index 00ff391199652..8da6b23f5ac57 100644 --- a/pandas/tests/groupby/transform/test_numba.py +++ b/pandas/tests/groupby/transform/test_numba.py @@ -130,20 +130,25 @@ def func_1(values, index): tm.assert_frame_equal(expected, result) +# TODO: Test more than just reductions (e.g. actually test transformations once we have @td.skip_if_no("numba") @pytest.mark.parametrize( "agg_func", [["min", "max"], "min", {"B": ["min", "max"], "C": "sum"}] ) -def test_multifunc_notimplimented(agg_func): +def test_string_cython_vs_numba(agg_func, numba_supported_reductions): + agg_func, kwargs = numba_supported_reductions data = DataFrame( {0: ["a", "a", "b", "b", "a"], 1: [1.0, 2.0, 3.0, 4.0, 5.0]}, columns=[0, 1] ) grouped = data.groupby(0) - with pytest.raises(NotImplementedError, match="Numba engine can"): - grouped.transform(agg_func, engine="numba") - with pytest.raises(NotImplementedError, match="Numba engine can"): - grouped[1].transform(agg_func, engine="numba") + result = grouped.transform(agg_func, engine="numba", **kwargs) + expected = grouped.transform(agg_func, engine="cython", **kwargs) + tm.assert_frame_equal(result, expected) + + result = grouped[1].transform(agg_func, engine="numba", **kwargs) + expected = grouped[1].transform(agg_func, engine="cython", **kwargs) + tm.assert_series_equal(result, expected) @td.skip_if_no("numba") @@ -232,9 +237,6 @@ def numba_func(values, index): @td.skip_if_no("numba") -@pytest.mark.xfail( - reason="Groupby transform doesn't support strings as function inputs yet with numba" -) def test_multilabel_numba_vs_cython(numba_supported_reductions): reduction, kwargs = numba_supported_reductions df = DataFrame(