Skip to content

Commit adc5f8b

Browse files
authored
ENH: Groupby.transform support string input with engine=numba (pandas-dev#53579)
* ENH: Groupby.transform support string input with engine=numba * remove xfail
1 parent 7e66493 commit adc5f8b

File tree

4 files changed

+30
-16
lines changed

4 files changed

+30
-16
lines changed

doc/source/whatsnew/v2.1.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ Other enhancements
103103
- :meth:`DataFrame.unstack` gained the ``sort`` keyword to dictate whether the resulting :class:`MultiIndex` levels are sorted (:issue:`15105`)
104104
- :meth:`DataFrameGroupby.agg` and :meth:`DataFrameGroupby.transform` now support grouping by multiple keys when the index is not a :class:`MultiIndex` for ``engine="numba"`` (:issue:`53486`)
105105
- :meth:`SeriesGroupby.agg` and :meth:`DataFrameGroupby.agg` now support passing in multiple functions for ``engine="numba"`` (:issue:`53486`)
106+
- :meth:`SeriesGroupby.transform` and :meth:`DataFrameGroupby.transform` now support passing in a string as the function for ``engine="numba"`` (:issue:`53579`)
106107
- Added ``engine_kwargs`` parameter to :meth:`DataFrame.to_excel` (:issue:`53220`)
107108
- Added a new parameter ``by_row`` to :meth:`Series.apply`. When set to ``False`` the supplied callables will always operate on the whole Series (:issue:`53400`).
108109
- Many read/to_* functions, such as :meth:`DataFrame.to_pickle` and :func:`read_csv`, support forwarding compression arguments to lzma.LZMAFile (:issue:`52979`)

pandas/core/groupby/generic.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -523,10 +523,16 @@ def _cython_transform(
523523

524524
return obj._constructor(result, index=self.obj.index, name=obj.name)
525525

526-
def _transform_general(self, func: Callable, *args, **kwargs) -> Series:
526+
def _transform_general(
527+
self, func: Callable, engine, engine_kwargs, *args, **kwargs
528+
) -> Series:
527529
"""
528530
Transform with a callable `func`.
529531
"""
532+
if maybe_use_numba(engine):
533+
return self._transform_with_numba(
534+
func, *args, engine_kwargs=engine_kwargs, **kwargs
535+
)
530536
assert callable(func)
531537
klass = type(self.obj)
532538

@@ -1654,7 +1660,11 @@ def arr_func(bvalues: ArrayLike) -> ArrayLike:
16541660
res_df = self._maybe_transpose_result(res_df)
16551661
return res_df
16561662

1657-
def _transform_general(self, func, *args, **kwargs):
1663+
def _transform_general(self, func, engine, engine_kwargs, *args, **kwargs):
1664+
if maybe_use_numba(engine):
1665+
return self._transform_with_numba(
1666+
func, *args, engine_kwargs=engine_kwargs, **kwargs
1667+
)
16581668
from pandas.core.reshape.concat import concat
16591669

16601670
applied = []

pandas/core/groupby/groupby.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -1787,22 +1787,20 @@ def _cython_transform(
17871787

17881788
@final
17891789
def _transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs):
1790-
if maybe_use_numba(engine):
1791-
return self._transform_with_numba(
1792-
func, *args, engine_kwargs=engine_kwargs, **kwargs
1793-
)
1794-
17951790
# optimized transforms
17961791
func = com.get_cython_func(func) or func
17971792

17981793
if not isinstance(func, str):
1799-
return self._transform_general(func, *args, **kwargs)
1794+
return self._transform_general(func, engine, engine_kwargs, *args, **kwargs)
18001795

18011796
elif func not in base.transform_kernel_allowlist:
18021797
msg = f"'{func}' is not a valid function name for transform(name)"
18031798
raise ValueError(msg)
18041799
elif func in base.cythonized_kernels or func in base.transformation_kernels:
18051800
# cythonized transform or canned "agg+broadcast"
1801+
if engine is not None:
1802+
kwargs["engine"] = engine
1803+
kwargs["engine_kwargs"] = engine_kwargs
18061804
return getattr(self, func)(*args, **kwargs)
18071805

18081806
else:
@@ -1817,6 +1815,9 @@ def _transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs):
18171815
with com.temp_setattr(self, "as_index", True):
18181816
# GH#49834 - result needs groups in the index for
18191817
# _wrap_transform_fast_result
1818+
if engine is not None:
1819+
kwargs["engine"] = engine
1820+
kwargs["engine_kwargs"] = engine_kwargs
18201821
result = getattr(self, func)(*args, **kwargs)
18211822

18221823
return self._wrap_transform_fast_result(result)

pandas/tests/groupby/transform/test_numba.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -132,20 +132,25 @@ def func_1(values, index):
132132
tm.assert_frame_equal(expected, result)
133133

134134

135+
# TODO: Test more than just reductions (e.g. actually test transformations once we have
135136
@td.skip_if_no("numba")
136137
@pytest.mark.parametrize(
137138
"agg_func", [["min", "max"], "min", {"B": ["min", "max"], "C": "sum"}]
138139
)
139-
def test_multifunc_notimplimented(agg_func):
140+
def test_string_cython_vs_numba(agg_func, numba_supported_reductions):
141+
agg_func, kwargs = numba_supported_reductions
140142
data = DataFrame(
141143
{0: ["a", "a", "b", "b", "a"], 1: [1.0, 2.0, 3.0, 4.0, 5.0]}, columns=[0, 1]
142144
)
143145
grouped = data.groupby(0)
144-
with pytest.raises(NotImplementedError, match="Numba engine can"):
145-
grouped.transform(agg_func, engine="numba")
146146

147-
with pytest.raises(NotImplementedError, match="Numba engine can"):
148-
grouped[1].transform(agg_func, engine="numba")
147+
result = grouped.transform(agg_func, engine="numba", **kwargs)
148+
expected = grouped.transform(agg_func, engine="cython", **kwargs)
149+
tm.assert_frame_equal(result, expected)
150+
151+
result = grouped[1].transform(agg_func, engine="numba", **kwargs)
152+
expected = grouped[1].transform(agg_func, engine="cython", **kwargs)
153+
tm.assert_series_equal(result, expected)
149154

150155

151156
@td.skip_if_no("numba")
@@ -234,9 +239,6 @@ def numba_func(values, index):
234239

235240

236241
@td.skip_if_no("numba")
237-
@pytest.mark.xfail(
238-
reason="Groupby transform doesn't support strings as function inputs yet with numba"
239-
)
240242
def test_multilabel_numba_vs_cython(numba_supported_reductions):
241243
reduction, kwargs = numba_supported_reductions
242244
df = DataFrame(

0 commit comments

Comments
 (0)