Skip to content

Commit b8b6471

Browse files
authored
ENH: Add numba engine to groupby.transform (#32854)
1 parent 3d4f9dc commit b8b6471

File tree

9 files changed

+360
-18
lines changed

9 files changed

+360
-18
lines changed

asv_bench/benchmarks/groupby.py

+34
Original file line numberDiff line numberDiff line change
@@ -626,4 +626,38 @@ def time_first(self):
626626
self.df_nans.groupby("key").transform("first")
627627

628628

629+
class TransformEngine:
630+
def setup(self):
631+
N = 10 ** 3
632+
data = DataFrame(
633+
{0: [str(i) for i in range(100)] * N, 1: list(range(100)) * N},
634+
columns=[0, 1],
635+
)
636+
self.grouper = data.groupby(0)
637+
638+
def time_series_numba(self):
639+
def function(values, index):
640+
return values * 5
641+
642+
self.grouper[1].transform(function, engine="numba")
643+
644+
def time_series_cython(self):
645+
def function(values):
646+
return values * 5
647+
648+
self.grouper[1].transform(function, engine="cython")
649+
650+
def time_dataframe_numba(self):
651+
def function(values, index):
652+
return values * 5
653+
654+
self.grouper.transform(function, engine="numba")
655+
656+
def time_dataframe_cython(self):
657+
def function(values):
658+
return values * 5
659+
660+
self.grouper.transform(function, engine="cython")
661+
662+
629663
from .pandas_vb_common import setup # noqa: F401 isort:skip

doc/source/whatsnew/v1.1.0.rst

+2
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ Other enhancements
9898
This can be used to set a custom compression level, e.g.,
9999
``df.to_csv(path, compression={'method': 'gzip', 'compresslevel': 1}``
100100
(:issue:`33196`)
101+
- :meth:`~pandas.core.groupby.GroupBy.transform` has gained ``engine`` and ``engine_kwargs`` arguments that supports executing functions with ``Numba`` (:issue:`32854`)
102+
-
101103

102104
.. ---------------------------------------------------------------------------
103105

pandas/core/groupby/generic.py

+61-13
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,13 @@
7575
import pandas.core.indexes.base as ibase
7676
from pandas.core.internals import BlockManager, make_block
7777
from pandas.core.series import Series
78+
from pandas.core.util.numba_ import (
79+
check_kwargs_and_nopython,
80+
get_jit_arguments,
81+
jit_user_function,
82+
split_for_numba,
83+
validate_udf,
84+
)
7885

7986
from pandas.plotting import boxplot_frame_groupby
8087

@@ -154,6 +161,8 @@ def pinner(cls):
154161
class SeriesGroupBy(GroupBy[Series]):
155162
_apply_whitelist = base.series_apply_whitelist
156163

164+
_numba_func_cache: Dict[Callable, Callable] = {}
165+
157166
def _iterate_slices(self) -> Iterable[Series]:
158167
yield self._selected_obj
159168

@@ -463,11 +472,13 @@ def _aggregate_named(self, func, *args, **kwargs):
463472

464473
@Substitution(klass="Series", selected="A.")
465474
@Appender(_transform_template)
466-
def transform(self, func, *args, **kwargs):
475+
def transform(self, func, *args, engine="cython", engine_kwargs=None, **kwargs):
467476
func = self._get_cython_func(func) or func
468477

469478
if not isinstance(func, str):
470-
return self._transform_general(func, *args, **kwargs)
479+
return self._transform_general(
480+
func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs
481+
)
471482

472483
elif func not in base.transform_kernel_whitelist:
473484
msg = f"'{func}' is not a valid function name for transform(name)"
@@ -482,16 +493,33 @@ def transform(self, func, *args, **kwargs):
482493
result = getattr(self, func)(*args, **kwargs)
483494
return self._transform_fast(result, func)
484495

485-
def _transform_general(self, func, *args, **kwargs):
496+
def _transform_general(
497+
self, func, *args, engine="cython", engine_kwargs=None, **kwargs
498+
):
486499
"""
487500
Transform with a non-str `func`.
488501
"""
502+
503+
if engine == "numba":
504+
nopython, nogil, parallel = get_jit_arguments(engine_kwargs)
505+
check_kwargs_and_nopython(kwargs, nopython)
506+
validate_udf(func)
507+
numba_func = self._numba_func_cache.get(
508+
func, jit_user_function(func, nopython, nogil, parallel)
509+
)
510+
489511
klass = type(self._selected_obj)
490512

491513
results = []
492514
for name, group in self:
493515
object.__setattr__(group, "name", name)
494-
res = func(group, *args, **kwargs)
516+
if engine == "numba":
517+
values, index = split_for_numba(group)
518+
res = numba_func(values, index, *args)
519+
if func not in self._numba_func_cache:
520+
self._numba_func_cache[func] = numba_func
521+
else:
522+
res = func(group, *args, **kwargs)
495523

496524
if isinstance(res, (ABCDataFrame, ABCSeries)):
497525
res = res._values
@@ -819,6 +847,8 @@ class DataFrameGroupBy(GroupBy[DataFrame]):
819847

820848
_apply_whitelist = base.dataframe_apply_whitelist
821849

850+
_numba_func_cache: Dict[Callable, Callable] = {}
851+
822852
_agg_see_also_doc = dedent(
823853
"""
824854
See Also
@@ -1355,19 +1385,35 @@ def _wrap_applied_output(self, keys, values, not_indexed_same=False):
13551385
# Handle cases like BinGrouper
13561386
return self._concat_objects(keys, values, not_indexed_same=not_indexed_same)
13571387

1358-
def _transform_general(self, func, *args, **kwargs):
1388+
def _transform_general(
1389+
self, func, *args, engine="cython", engine_kwargs=None, **kwargs
1390+
):
13591391
from pandas.core.reshape.concat import concat
13601392

13611393
applied = []
13621394
obj = self._obj_with_exclusions
13631395
gen = self.grouper.get_iterator(obj, axis=self.axis)
1364-
fast_path, slow_path = self._define_paths(func, *args, **kwargs)
1396+
if engine == "numba":
1397+
nopython, nogil, parallel = get_jit_arguments(engine_kwargs)
1398+
check_kwargs_and_nopython(kwargs, nopython)
1399+
validate_udf(func)
1400+
numba_func = self._numba_func_cache.get(
1401+
func, jit_user_function(func, nopython, nogil, parallel)
1402+
)
1403+
else:
1404+
fast_path, slow_path = self._define_paths(func, *args, **kwargs)
13651405

1366-
path = None
13671406
for name, group in gen:
13681407
object.__setattr__(group, "name", name)
13691408

1370-
if path is None:
1409+
if engine == "numba":
1410+
values, index = split_for_numba(group)
1411+
res = numba_func(values, index, *args)
1412+
if func not in self._numba_func_cache:
1413+
self._numba_func_cache[func] = numba_func
1414+
# Return the result as a DataFrame for concatenation later
1415+
res = DataFrame(res, index=group.index, columns=group.columns)
1416+
else:
13711417
# Try slow path and fast path.
13721418
try:
13731419
path, res = self._choose_path(fast_path, slow_path, group)
@@ -1376,8 +1422,6 @@ def _transform_general(self, func, *args, **kwargs):
13761422
except ValueError as err:
13771423
msg = "transform must return a scalar value for each group"
13781424
raise ValueError(msg) from err
1379-
else:
1380-
res = path(group)
13811425

13821426
if isinstance(res, Series):
13831427

@@ -1411,13 +1455,15 @@ def _transform_general(self, func, *args, **kwargs):
14111455

14121456
@Substitution(klass="DataFrame", selected="")
14131457
@Appender(_transform_template)
1414-
def transform(self, func, *args, **kwargs):
1458+
def transform(self, func, *args, engine="cython", engine_kwargs=None, **kwargs):
14151459

14161460
# optimized transforms
14171461
func = self._get_cython_func(func) or func
14181462

14191463
if not isinstance(func, str):
1420-
return self._transform_general(func, *args, **kwargs)
1464+
return self._transform_general(
1465+
func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs
1466+
)
14211467

14221468
elif func not in base.transform_kernel_whitelist:
14231469
msg = f"'{func}' is not a valid function name for transform(name)"
@@ -1439,7 +1485,9 @@ def transform(self, func, *args, **kwargs):
14391485
):
14401486
return self._transform_fast(result, func)
14411487

1442-
return self._transform_general(func, *args, **kwargs)
1488+
return self._transform_general(
1489+
func, engine=engine, engine_kwargs=engine_kwargs, *args, **kwargs
1490+
)
14431491

14441492
def _transform_fast(self, result: DataFrame, func_nm: str) -> DataFrame:
14451493
"""

pandas/core/groupby/groupby.py

+30-1
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,36 @@ class providing the base-class of operations.
254254
Parameters
255255
----------
256256
f : function
257-
Function to apply to each group
257+
Function to apply to each group.
258+
259+
Can also accept a Numba JIT function with
260+
``engine='numba'`` specified.
261+
262+
If the ``'numba'`` engine is chosen, the function must be
263+
a user defined function with ``values`` and ``index`` as the
264+
first and second arguments respectively in the function signature.
265+
Each group's index will be passed to the user defined function
266+
and optionally available for use.
267+
268+
.. versionchanged:: 1.1.0
269+
*args
270+
Positional arguments to pass to func
271+
engine : str, default 'cython'
272+
* ``'cython'`` : Runs the function through C-extensions from cython.
273+
* ``'numba'`` : Runs the function through JIT compiled code from numba.
274+
275+
.. versionadded:: 1.1.0
276+
engine_kwargs : dict, default None
277+
* For ``'cython'`` engine, there are no accepted ``engine_kwargs``
278+
* For ``'numba'`` engine, the engine can accept ``nopython``, ``nogil``
279+
and ``parallel`` dictionary keys. The values must either be ``True`` or
280+
``False``. The default ``engine_kwargs`` for the ``'numba'`` engine is
281+
``{'nopython': True, 'nogil': False, 'parallel': False}`` and will be
282+
applied to the function
283+
284+
.. versionadded:: 1.1.0
285+
**kwargs
286+
Keyword arguments to be passed into func.
258287
259288
Returns
260289
-------

pandas/core/util/numba_.py

+103-4
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,58 @@
11
"""Common utilities for Numba operations"""
2+
import inspect
23
import types
3-
from typing import Callable, Dict, Optional
4+
from typing import Callable, Dict, Optional, Tuple
45

56
import numpy as np
67

8+
from pandas._typing import FrameOrSeries
79
from pandas.compat._optional import import_optional_dependency
810

911

1012
def check_kwargs_and_nopython(
1113
kwargs: Optional[Dict] = None, nopython: Optional[bool] = None
12-
):
14+
) -> None:
15+
"""
16+
Validate that **kwargs and nopython=True was passed
17+
https://github.com/numba/numba/issues/2916
18+
19+
Parameters
20+
----------
21+
kwargs : dict, default None
22+
user passed keyword arguments to pass into the JITed function
23+
nopython : bool, default None
24+
nopython parameter
25+
26+
Returns
27+
-------
28+
None
29+
30+
Raises
31+
------
32+
ValueError
33+
"""
1334
if kwargs and nopython:
1435
raise ValueError(
1536
"numba does not support kwargs with nopython=True: "
1637
"https://github.com/numba/numba/issues/2916"
1738
)
1839

1940

20-
def get_jit_arguments(engine_kwargs: Optional[Dict[str, bool]] = None):
41+
def get_jit_arguments(
42+
engine_kwargs: Optional[Dict[str, bool]] = None
43+
) -> Tuple[bool, bool, bool]:
2144
"""
2245
Return arguments to pass to numba.JIT, falling back on pandas default JIT settings.
46+
47+
Parameters
48+
----------
49+
engine_kwargs : dict, default None
50+
user passed keyword arguments for numba.JIT
51+
52+
Returns
53+
-------
54+
(bool, bool, bool)
55+
nopython, nogil, parallel
2356
"""
2457
if engine_kwargs is None:
2558
engine_kwargs = {}
@@ -30,9 +63,28 @@ def get_jit_arguments(engine_kwargs: Optional[Dict[str, bool]] = None):
3063
return nopython, nogil, parallel
3164

3265

33-
def jit_user_function(func: Callable, nopython: bool, nogil: bool, parallel: bool):
66+
def jit_user_function(
67+
func: Callable, nopython: bool, nogil: bool, parallel: bool
68+
) -> Callable:
3469
"""
3570
JIT the user's function given the configurable arguments.
71+
72+
Parameters
73+
----------
74+
func : function
75+
user defined function
76+
77+
nopython : bool
78+
nopython parameter for numba.JIT
79+
nogil : bool
80+
nogil parameter for numba.JIT
81+
parallel : bool
82+
parallel parameter for numba.JIT
83+
84+
Returns
85+
-------
86+
function
87+
Numba JITed function
3688
"""
3789
numba = import_optional_dependency("numba")
3890

@@ -56,3 +108,50 @@ def impl(data, *_args):
56108
return impl
57109

58110
return numba_func
111+
112+
113+
def split_for_numba(arg: FrameOrSeries) -> Tuple[np.ndarray, np.ndarray]:
114+
"""
115+
Split pandas object into its components as numpy arrays for numba functions.
116+
117+
Parameters
118+
----------
119+
arg : Series or DataFrame
120+
121+
Returns
122+
-------
123+
(ndarray, ndarray)
124+
values, index
125+
"""
126+
return arg.to_numpy(), arg.index.to_numpy()
127+
128+
129+
def validate_udf(func: Callable) -> None:
130+
"""
131+
Validate user defined function for ops when using Numba.
132+
133+
The first signature arguments should include:
134+
135+
def f(values, index, ...):
136+
...
137+
138+
Parameters
139+
----------
140+
func : function, default False
141+
user defined function
142+
143+
Returns
144+
-------
145+
None
146+
"""
147+
udf_signature = list(inspect.signature(func).parameters.keys())
148+
expected_args = ["values", "index"]
149+
min_number_args = len(expected_args)
150+
if (
151+
len(udf_signature) < min_number_args
152+
or udf_signature[:min_number_args] != expected_args
153+
):
154+
raise ValueError(
155+
f"The first {min_number_args} arguments to {func.__name__} must be "
156+
f"{expected_args}"
157+
)

0 commit comments

Comments
 (0)