From 45a2721bb2a5ada60b4c701e21084372ab0be189 Mon Sep 17 00:00:00 2001 From: Thomas Li <47963215+lithomas1@users.noreply.github.com> Date: Fri, 9 Jun 2023 14:00:04 -0700 Subject: [PATCH 1/2] ENH: Groupby.transform support string input with engine=numba --- doc/source/whatsnew/v2.1.0.rst | 1 + pandas/core/groupby/generic.py | 14 ++++++++++++-- pandas/core/groupby/groupby.py | 13 +++++++------ pandas/tests/groupby/transform/test_numba.py | 15 ++++++++++----- 4 files changed, 30 insertions(+), 13 deletions(-) diff --git a/doc/source/whatsnew/v2.1.0.rst b/doc/source/whatsnew/v2.1.0.rst index b9ad494172bdf..4dffb085e4a89 100644 --- a/doc/source/whatsnew/v2.1.0.rst +++ b/doc/source/whatsnew/v2.1.0.rst @@ -100,6 +100,7 @@ Other enhancements - :meth:`DataFrame.stack` gained the ``sort`` keyword to dictate whether the resulting :class:`MultiIndex` levels are sorted (:issue:`15105`) - :meth:`DataFrame.unstack` gained the ``sort`` keyword to dictate whether the resulting :class:`MultiIndex` levels are sorted (:issue:`15105`) - :meth:`SeriesGroupby.agg` and :meth:`DataFrameGroupby.agg` now support passing in multiple functions for ``engine="numba"`` (:issue:`53486`) +- :meth:`SeriesGroupby.transform` and :meth:`DataFrameGroupby.transform` now support passing in a string as the function for ``engine="numba"`` (:issue:`53579`) - Added ``engine_kwargs`` parameter to :meth:`DataFrame.to_excel` (:issue:`53220`) - Added a new parameter ``by_row`` to :meth:`Series.apply`. When set to ``False`` the supplied callables will always operate on the whole Series (:issue:`53400`). - Many read/to_* functions, such as :meth:`DataFrame.to_pickle` and :func:`read_csv`, support forwarding compression arguments to lzma.LZMAFile (:issue:`52979`) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index 1d5fb0fb3873d..2359f34ec4018 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -523,10 +523,16 @@ def _cython_transform( return obj._constructor(result, index=self.obj.index, name=obj.name) - def _transform_general(self, func: Callable, *args, **kwargs) -> Series: + def _transform_general( + self, func: Callable, engine, engine_kwargs, *args, **kwargs + ) -> Series: """ Transform with a callable `func`. """ + if maybe_use_numba(engine): + return self._transform_with_numba( + func, *args, engine_kwargs=engine_kwargs, **kwargs + ) assert callable(func) klass = type(self.obj) @@ -1677,7 +1683,11 @@ def arr_func(bvalues: ArrayLike) -> ArrayLike: res_df = self._maybe_transpose_result(res_df) return res_df - def _transform_general(self, func, *args, **kwargs): + def _transform_general(self, func, engine, engine_kwargs, *args, **kwargs): + if maybe_use_numba(engine): + return self._transform_with_numba( + func, *args, engine_kwargs=engine_kwargs, **kwargs + ) from pandas.core.reshape.concat import concat applied = [] diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 5d15be19f34f7..9501758bbbc8d 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -1782,22 +1782,20 @@ def _cython_transform( @final def _transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs): - if maybe_use_numba(engine): - return self._transform_with_numba( - func, *args, engine_kwargs=engine_kwargs, **kwargs - ) - # optimized transforms func = com.get_cython_func(func) or func if not isinstance(func, str): - return self._transform_general(func, *args, **kwargs) + return self._transform_general(func, engine, engine_kwargs, *args, **kwargs) elif func not in base.transform_kernel_allowlist: msg = f"'{func}' is not a valid function name for transform(name)" raise ValueError(msg) elif func in base.cythonized_kernels or func in base.transformation_kernels: # cythonized transform or canned "agg+broadcast" + if engine is not None: + kwargs["engine"] = engine + kwargs["engine_kwargs"] = engine_kwargs return getattr(self, func)(*args, **kwargs) else: @@ -1812,6 +1810,9 @@ def _transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs): with com.temp_setattr(self, "as_index", True): # GH#49834 - result needs groups in the index for # _wrap_transform_fast_result + if engine is not None: + kwargs["engine"] = engine + kwargs["engine_kwargs"] = engine_kwargs result = getattr(self, func)(*args, **kwargs) return self._wrap_transform_fast_result(result) diff --git a/pandas/tests/groupby/transform/test_numba.py b/pandas/tests/groupby/transform/test_numba.py index 0264d2a09778f..6fdbf18db9e81 100644 --- a/pandas/tests/groupby/transform/test_numba.py +++ b/pandas/tests/groupby/transform/test_numba.py @@ -129,20 +129,25 @@ def func_1(values, index): tm.assert_frame_equal(expected, result) +# TODO: Test more than just reductions (e.g. actually test transformations once we have @td.skip_if_no("numba") @pytest.mark.parametrize( "agg_func", [["min", "max"], "min", {"B": ["min", "max"], "C": "sum"}] ) -def test_multifunc_notimplimented(agg_func): +def test_string_cython_vs_numba(agg_func, numba_supported_reductions): + agg_func, kwargs = numba_supported_reductions data = DataFrame( {0: ["a", "a", "b", "b", "a"], 1: [1.0, 2.0, 3.0, 4.0, 5.0]}, columns=[0, 1] ) grouped = data.groupby(0) - with pytest.raises(NotImplementedError, match="Numba engine can"): - grouped.transform(agg_func, engine="numba") - with pytest.raises(NotImplementedError, match="Numba engine can"): - grouped[1].transform(agg_func, engine="numba") + result = grouped.transform(agg_func, engine="numba", **kwargs) + expected = grouped.transform(agg_func, engine="cython", **kwargs) + tm.assert_frame_equal(result, expected) + + result = grouped[1].transform(agg_func, engine="numba", **kwargs) + expected = grouped[1].transform(agg_func, engine="cython", **kwargs) + tm.assert_series_equal(result, expected) @td.skip_if_no("numba") From ec1da0cbc79c75861c5e0b54016b592ed8094d23 Mon Sep 17 00:00:00 2001 From: Thomas Li <47963215+lithomas1@users.noreply.github.com> Date: Fri, 9 Jun 2023 20:09:04 -0700 Subject: [PATCH 2/2] remove xfail --- pandas/tests/groupby/transform/test_numba.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/pandas/tests/groupby/transform/test_numba.py b/pandas/tests/groupby/transform/test_numba.py index 0a66781b23d26..8da6b23f5ac57 100644 --- a/pandas/tests/groupby/transform/test_numba.py +++ b/pandas/tests/groupby/transform/test_numba.py @@ -237,9 +237,6 @@ def numba_func(values, index): @td.skip_if_no("numba") -@pytest.mark.xfail( - reason="Groupby transform doesn't support strings as function inputs yet with numba" -) def test_multilabel_numba_vs_cython(numba_supported_reductions): reduction, kwargs = numba_supported_reductions df = DataFrame(