Skip to content

Commit 3572e06

Browse files
committed
ENH: Groupby.transform support string input with engine=numba
1 parent 1cae7a3 commit 3572e06

File tree

4 files changed

+34
-12
lines changed

4 files changed

+34
-12
lines changed

doc/source/whatsnew/v2.1.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ Other enhancements
100100
- :meth:`DataFrame.stack` gained the ``sort`` keyword to dictate whether the resulting :class:`MultiIndex` levels are sorted (:issue:`15105`)
101101
- :meth:`DataFrame.unstack` gained the ``sort`` keyword to dictate whether the resulting :class:`MultiIndex` levels are sorted (:issue:`15105`)
102102
- :meth:`SeriesGroupby.agg` and :meth:`DataFrameGroupby.agg` now support passing in multiple functions for ``engine="numba"`` (:issue:`53486`)
103+
- :meth:`SeriesGroupby.transform` and :meth:`DataFrameGroupby.transform` now support passing in a string as the function for ``engine="numba"`` (:issue:`53579`)
103104
- Added ``engine_kwargs`` parameter to :meth:`DataFrame.to_excel` (:issue:`53220`)
104105
- Added a new parameter ``by_row`` to :meth:`Series.apply`. When set to ``False`` the supplied callables will always operate on the whole Series (:issue:`53400`).
105106
- Many read/to_* functions, such as :meth:`DataFrame.to_pickle` and :func:`read_csv`, support forwarding compression arguments to lzma.LZMAFile (:issue:`52979`)

pandas/core/groupby/generic.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -523,10 +523,16 @@ def _cython_transform(
523523

524524
return obj._constructor(result, index=self.obj.index, name=obj.name)
525525

526-
def _transform_general(self, func: Callable, *args, **kwargs) -> Series:
526+
def _transform_general(
527+
self, func: Callable, engine, engine_kwargs, *args, **kwargs
528+
) -> Series:
527529
"""
528530
Transform with a callable `func`.
529531
"""
532+
if maybe_use_numba(engine):
533+
return self._transform_with_numba(
534+
func, *args, engine_kwargs=engine_kwargs, **kwargs
535+
)
530536
assert callable(func)
531537
klass = type(self.obj)
532538

@@ -1677,7 +1683,11 @@ def arr_func(bvalues: ArrayLike) -> ArrayLike:
16771683
res_df = self._maybe_transpose_result(res_df)
16781684
return res_df
16791685

1680-
def _transform_general(self, func, *args, **kwargs):
1686+
def _transform_general(self, func, engine, engine_kwargs, *args, **kwargs):
1687+
if maybe_use_numba(engine):
1688+
return self._transform_with_numba(
1689+
func, *args, engine_kwargs=engine_kwargs, **kwargs
1690+
)
16811691
from pandas.core.reshape.concat import concat
16821692

16831693
applied = []

pandas/core/groupby/groupby.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -1782,22 +1782,25 @@ def _cython_transform(
17821782

17831783
@final
17841784
def _transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs):
1785-
if maybe_use_numba(engine):
1786-
return self._transform_with_numba(
1787-
func, *args, engine_kwargs=engine_kwargs, **kwargs
1788-
)
1785+
# if maybe_use_numba(engine):
1786+
# return self._transform_with_numba(
1787+
# func, *args, engine_kwargs=engine_kwargs, **kwargs
1788+
# )
17891789

17901790
# optimized transforms
17911791
func = com.get_cython_func(func) or func
17921792

17931793
if not isinstance(func, str):
1794-
return self._transform_general(func, *args, **kwargs)
1794+
return self._transform_general(func, engine, engine_kwargs, *args, **kwargs)
17951795

17961796
elif func not in base.transform_kernel_allowlist:
17971797
msg = f"'{func}' is not a valid function name for transform(name)"
17981798
raise ValueError(msg)
17991799
elif func in base.cythonized_kernels or func in base.transformation_kernels:
18001800
# cythonized transform or canned "agg+broadcast"
1801+
if engine is not None:
1802+
kwargs["engine"] = engine
1803+
kwargs["engine_kwargs"] = engine_kwargs
18011804
return getattr(self, func)(*args, **kwargs)
18021805

18031806
else:
@@ -1812,6 +1815,9 @@ def _transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs):
18121815
with com.temp_setattr(self, "as_index", True):
18131816
# GH#49834 - result needs groups in the index for
18141817
# _wrap_transform_fast_result
1818+
if engine is not None:
1819+
kwargs["engine"] = engine
1820+
kwargs["engine_kwargs"] = engine_kwargs
18151821
result = getattr(self, func)(*args, **kwargs)
18161822

18171823
return self._wrap_transform_fast_result(result)

pandas/tests/groupby/transform/test_numba.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -129,20 +129,25 @@ def func_1(values, index):
129129
tm.assert_frame_equal(expected, result)
130130

131131

132+
# TODO: Test more than just reductions (e.g. actually test transformations once we have
132133
@td.skip_if_no("numba")
133134
@pytest.mark.parametrize(
134135
"agg_func", [["min", "max"], "min", {"B": ["min", "max"], "C": "sum"}]
135136
)
136-
def test_multifunc_notimplimented(agg_func):
137+
def test_string_cython_vs_numba(agg_func, numba_supported_reductions):
138+
agg_func, kwargs = numba_supported_reductions
137139
data = DataFrame(
138140
{0: ["a", "a", "b", "b", "a"], 1: [1.0, 2.0, 3.0, 4.0, 5.0]}, columns=[0, 1]
139141
)
140142
grouped = data.groupby(0)
141-
with pytest.raises(NotImplementedError, match="Numba engine can"):
142-
grouped.transform(agg_func, engine="numba")
143143

144-
with pytest.raises(NotImplementedError, match="Numba engine can"):
145-
grouped[1].transform(agg_func, engine="numba")
144+
result = grouped.transform(agg_func, engine="numba", **kwargs)
145+
expected = grouped.transform(agg_func, engine="cython", **kwargs)
146+
tm.assert_frame_equal(result, expected)
147+
148+
result = grouped[1].transform(agg_func, engine="numba", **kwargs)
149+
expected = grouped[1].transform(agg_func, engine="cython", **kwargs)
150+
tm.assert_series_equal(result, expected)
146151

147152

148153
@td.skip_if_no("numba")

0 commit comments

Comments
 (0)