Skip to content

Commit ef57808

Browse files
mroeschkeKevin D Smith
authored and
Kevin D Smith
committed
PERF: Allow groupby transform with numba engine to be fully parallelizable (pandas-dev#36240)
1 parent a21211e commit ef57808

File tree

6 files changed

+187
-62
lines changed

6 files changed

+187
-62
lines changed

asv_bench/benchmarks/groupby.py

+32-14
Original file line numberDiff line numberDiff line change
@@ -627,49 +627,63 @@ def time_first(self):
627627

628628

629629
class TransformEngine:
630-
def setup(self):
630+
631+
param_names = ["parallel"]
632+
params = [[True, False]]
633+
634+
def setup(self, parallel):
631635
N = 10 ** 3
632636
data = DataFrame(
633637
{0: [str(i) for i in range(100)] * N, 1: list(range(100)) * N},
634638
columns=[0, 1],
635639
)
640+
self.parallel = parallel
636641
self.grouper = data.groupby(0)
637642

638-
def time_series_numba(self):
643+
def time_series_numba(self, parallel):
639644
def function(values, index):
640645
return values * 5
641646

642-
self.grouper[1].transform(function, engine="numba")
647+
self.grouper[1].transform(
648+
function, engine="numba", engine_kwargs={"parallel": self.parallel}
649+
)
643650

644-
def time_series_cython(self):
651+
def time_series_cython(self, parallel):
645652
def function(values):
646653
return values * 5
647654

648655
self.grouper[1].transform(function, engine="cython")
649656

650-
def time_dataframe_numba(self):
657+
def time_dataframe_numba(self, parallel):
651658
def function(values, index):
652659
return values * 5
653660

654-
self.grouper.transform(function, engine="numba")
661+
self.grouper.transform(
662+
function, engine="numba", engine_kwargs={"parallel": self.parallel}
663+
)
655664

656-
def time_dataframe_cython(self):
665+
def time_dataframe_cython(self, parallel):
657666
def function(values):
658667
return values * 5
659668

660669
self.grouper.transform(function, engine="cython")
661670

662671

663672
class AggEngine:
664-
def setup(self):
673+
674+
param_names = ["parallel"]
675+
params = [[True, False]]
676+
677+
def setup(self, parallel):
665678
N = 10 ** 3
666679
data = DataFrame(
667680
{0: [str(i) for i in range(100)] * N, 1: list(range(100)) * N},
668681
columns=[0, 1],
669682
)
683+
self.parallel = parallel
670684
self.grouper = data.groupby(0)
671685

672-
def time_series_numba(self):
686+
def time_series_numba(self, parallel):
673687
def function(values, index):
674688
total = 0
675689
for i, value in enumerate(values):
@@ -679,9 +693,11 @@ def function(values, index):
679693
total += value * 2
680694
return total
681695

682-
self.grouper[1].agg(function, engine="numba")
696+
self.grouper[1].agg(
697+
function, engine="numba", engine_kwargs={"parallel": self.parallel}
698+
)
683699

684-
def time_series_cython(self):
700+
def time_series_cython(self, parallel):
685701
def function(values):
686702
total = 0
687703
for i, value in enumerate(values):
@@ -693,7 +709,7 @@ def function(values):
693709

694710
self.grouper[1].agg(function, engine="cython")
695711

696-
def time_dataframe_numba(self):
712+
def time_dataframe_numba(self, parallel):
697713
def function(values, index):
698714
total = 0
699715
for i, value in enumerate(values):
@@ -703,9 +719,11 @@ def function(values, index):
703719
total += value * 2
704720
return total
705721

706-
self.grouper.agg(function, engine="numba")
722+
self.grouper.agg(
723+
function, engine="numba", engine_kwargs={"parallel": self.parallel}
724+
)
707725

708-
def time_dataframe_cython(self):
726+
def time_dataframe_cython(self, parallel):
709727
def function(values):
710728
total = 0
711729
for i, value in enumerate(values):

doc/source/whatsnew/v1.2.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ Performance improvements
207207

208208
- Performance improvements when creating Series with dtype `str` or :class:`StringDtype` from array with many string elements (:issue:`36304`, :issue:`36317`)
209209
- Performance improvement in :meth:`GroupBy.agg` with the ``numba`` engine (:issue:`35759`)
210-
-
210+
- Performance improvement in :meth:`GroupBy.transform` with the ``numba`` engine (:issue:`36240`)
211211

212212
.. ---------------------------------------------------------------------------
213213

pandas/core/groupby/generic.py

+32-44
Original file line numberDiff line numberDiff line change
@@ -226,10 +226,6 @@ def apply(self, func, *args, **kwargs):
226226
def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs):
227227

228228
if maybe_use_numba(engine):
229-
if not callable(func):
230-
raise NotImplementedError(
231-
"Numba engine can only be used with a single function."
232-
)
233229
with group_selection_context(self):
234230
data = self._selected_obj
235231
result, index = self._aggregate_with_numba(
@@ -489,12 +485,21 @@ def _aggregate_named(self, func, *args, **kwargs):
489485
@Substitution(klass="Series")
490486
@Appender(_transform_template)
491487
def transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs):
488+
489+
if maybe_use_numba(engine):
490+
with group_selection_context(self):
491+
data = self._selected_obj
492+
result = self._transform_with_numba(
493+
data.to_frame(), func, *args, engine_kwargs=engine_kwargs, **kwargs
494+
)
495+
return self.obj._constructor(
496+
result.ravel(), index=data.index, name=data.name
497+
)
498+
492499
func = self._get_cython_func(func) or func
493500

494501
if not isinstance(func, str):
495-
return self._transform_general(
496-
func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs
497-
)
502+
return self._transform_general(func, *args, **kwargs)
498503

499504
elif func not in base.transform_kernel_allowlist:
500505
msg = f"'{func}' is not a valid function name for transform(name)"
@@ -938,10 +943,6 @@ class DataFrameGroupBy(GroupBy[DataFrame]):
938943
def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs):
939944

940945
if maybe_use_numba(engine):
941-
if not callable(func):
942-
raise NotImplementedError(
943-
"Numba engine can only be used with a single function."
944-
)
945946
with group_selection_context(self):
946947
data = self._selected_obj
947948
result, index = self._aggregate_with_numba(
@@ -1290,42 +1291,25 @@ def _wrap_applied_output(self, keys, values, not_indexed_same=False):
12901291

12911292
return self._reindex_output(result)
12921293

1293-
def _transform_general(
1294-
self, func, *args, engine="cython", engine_kwargs=None, **kwargs
1295-
):
1294+
def _transform_general(self, func, *args, **kwargs):
12961295
from pandas.core.reshape.concat import concat
12971296

12981297
applied = []
12991298
obj = self._obj_with_exclusions
13001299
gen = self.grouper.get_iterator(obj, axis=self.axis)
1301-
if maybe_use_numba(engine):
1302-
numba_func, cache_key = generate_numba_func(
1303-
func, engine_kwargs, kwargs, "groupby_transform"
1304-
)
1305-
else:
1306-
fast_path, slow_path = self._define_paths(func, *args, **kwargs)
1300+
fast_path, slow_path = self._define_paths(func, *args, **kwargs)
13071301

13081302
for name, group in gen:
13091303
object.__setattr__(group, "name", name)
13101304

1311-
if maybe_use_numba(engine):
1312-
values, index = split_for_numba(group)
1313-
res = numba_func(values, index, *args)
1314-
if cache_key not in NUMBA_FUNC_CACHE:
1315-
NUMBA_FUNC_CACHE[cache_key] = numba_func
1316-
# Return the result as a DataFrame for concatenation later
1317-
res = self.obj._constructor(
1318-
res, index=group.index, columns=group.columns
1319-
)
1320-
else:
1321-
# Try slow path and fast path.
1322-
try:
1323-
path, res = self._choose_path(fast_path, slow_path, group)
1324-
except TypeError:
1325-
return self._transform_item_by_item(obj, fast_path)
1326-
except ValueError as err:
1327-
msg = "transform must return a scalar value for each group"
1328-
raise ValueError(msg) from err
1305+
# Try slow path and fast path.
1306+
try:
1307+
path, res = self._choose_path(fast_path, slow_path, group)
1308+
except TypeError:
1309+
return self._transform_item_by_item(obj, fast_path)
1310+
except ValueError as err:
1311+
msg = "transform must return a scalar value for each group"
1312+
raise ValueError(msg) from err
13291313

13301314
if isinstance(res, Series):
13311315

@@ -1361,13 +1345,19 @@ def _transform_general(
13611345
@Appender(_transform_template)
13621346
def transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs):
13631347

1348+
if maybe_use_numba(engine):
1349+
with group_selection_context(self):
1350+
data = self._selected_obj
1351+
result = self._transform_with_numba(
1352+
data, func, *args, engine_kwargs=engine_kwargs, **kwargs
1353+
)
1354+
return self.obj._constructor(result, index=data.index, columns=data.columns)
1355+
13641356
# optimized transforms
13651357
func = self._get_cython_func(func) or func
13661358

13671359
if not isinstance(func, str):
1368-
return self._transform_general(
1369-
func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs
1370-
)
1360+
return self._transform_general(func, *args, **kwargs)
13711361

13721362
elif func not in base.transform_kernel_allowlist:
13731363
msg = f"'{func}' is not a valid function name for transform(name)"
@@ -1393,9 +1383,7 @@ def transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs):
13931383
):
13941384
return self._transform_fast(result)
13951385

1396-
return self._transform_general(
1397-
func, engine=engine, engine_kwargs=engine_kwargs, *args, **kwargs
1398-
)
1386+
return self._transform_general(func, *args, **kwargs)
13991387

14001388
def _transform_fast(self, result: DataFrame) -> DataFrame:
14011389
"""

pandas/core/groupby/groupby.py

+39-1
Original file line numberDiff line numberDiff line change
@@ -1056,6 +1056,41 @@ def _cython_agg_general(
10561056

10571057
return self._wrap_aggregated_output(output, index=self.grouper.result_index)
10581058

1059+
def _transform_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs):
1060+
"""
1061+
Perform groupby transform routine with the numba engine.
1062+
1063+
This routine mimics the data splitting routine of the DataSplitter class
1064+
to generate the indices of each group in the sorted data and then passes the
1065+
data and indices into a Numba jitted function.
1066+
"""
1067+
if not callable(func):
1068+
raise NotImplementedError(
1069+
"Numba engine can only be used with a single function."
1070+
)
1071+
group_keys = self.grouper._get_group_keys()
1072+
labels, _, n_groups = self.grouper.group_info
1073+
sorted_index = get_group_index_sorter(labels, n_groups)
1074+
sorted_labels = algorithms.take_nd(labels, sorted_index, allow_fill=False)
1075+
sorted_data = data.take(sorted_index, axis=self.axis).to_numpy()
1076+
starts, ends = lib.generate_slices(sorted_labels, n_groups)
1077+
cache_key = (func, "groupby_transform")
1078+
if cache_key in NUMBA_FUNC_CACHE:
1079+
numba_transform_func = NUMBA_FUNC_CACHE[cache_key]
1080+
else:
1081+
numba_transform_func = numba_.generate_numba_transform_func(
1082+
tuple(args), kwargs, func, engine_kwargs
1083+
)
1084+
result = numba_transform_func(
1085+
sorted_data, sorted_index, starts, ends, len(group_keys), len(data.columns)
1086+
)
1087+
if cache_key not in NUMBA_FUNC_CACHE:
1088+
NUMBA_FUNC_CACHE[cache_key] = numba_transform_func
1089+
1090+
# result values needs to be resorted to their original positions since we
1091+
# evaluated the data sorted by group
1092+
return result.take(np.argsort(sorted_index), axis=0)
1093+
10591094
def _aggregate_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs):
10601095
"""
10611096
Perform groupby aggregation routine with the numba engine.
@@ -1064,6 +1099,10 @@ def _aggregate_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs)
10641099
to generate the indices of each group in the sorted data and then passes the
10651100
data and indices into a Numba jitted function.
10661101
"""
1102+
if not callable(func):
1103+
raise NotImplementedError(
1104+
"Numba engine can only be used with a single function."
1105+
)
10671106
group_keys = self.grouper._get_group_keys()
10681107
labels, _, n_groups = self.grouper.group_info
10691108
sorted_index = get_group_index_sorter(labels, n_groups)
@@ -1072,7 +1111,6 @@ def _aggregate_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs)
10721111
starts, ends = lib.generate_slices(sorted_labels, n_groups)
10731112
cache_key = (func, "groupby_agg")
10741113
if cache_key in NUMBA_FUNC_CACHE:
1075-
# Return an already compiled version of roll_apply if available
10761114
numba_agg_func = NUMBA_FUNC_CACHE[cache_key]
10771115
else:
10781116
numba_agg_func = numba_.generate_numba_agg_func(

pandas/core/groupby/numba_.py

+67-2
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def generate_numba_agg_func(
153153
loop_range = range
154154

155155
@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
156-
def group_apply(
156+
def group_agg(
157157
values: np.ndarray,
158158
index: np.ndarray,
159159
begin: np.ndarray,
@@ -169,4 +169,69 @@ def group_apply(
169169
result[i, j] = numba_func(group, group_index, *args)
170170
return result
171171

172-
return group_apply
172+
return group_agg
173+
174+
175+
def generate_numba_transform_func(
176+
args: Tuple,
177+
kwargs: Dict[str, Any],
178+
func: Callable[..., Scalar],
179+
engine_kwargs: Optional[Dict[str, bool]],
180+
) -> Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, int], np.ndarray]:
181+
"""
182+
Generate a numba jitted transform function specified by values from engine_kwargs.
183+
184+
1. jit the user's function
185+
2. Return a groupby agg function with the jitted function inline
186+
187+
Configurations specified in engine_kwargs apply to both the user's
188+
function _AND_ the rolling apply function.
189+
190+
Parameters
191+
----------
192+
args : tuple
193+
*args to be passed into the function
194+
kwargs : dict
195+
**kwargs to be passed into the function
196+
func : function
197+
function to be applied to each window and will be JITed
198+
engine_kwargs : dict
199+
dictionary of arguments to be passed into numba.jit
200+
201+
Returns
202+
-------
203+
Numba function
204+
"""
205+
nopython, nogil, parallel = get_jit_arguments(engine_kwargs)
206+
207+
check_kwargs_and_nopython(kwargs, nopython)
208+
209+
validate_udf(func)
210+
211+
numba_func = jit_user_function(func, nopython, nogil, parallel)
212+
213+
numba = import_optional_dependency("numba")
214+
215+
if parallel:
216+
loop_range = numba.prange
217+
else:
218+
loop_range = range
219+
220+
@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
221+
def group_transform(
222+
values: np.ndarray,
223+
index: np.ndarray,
224+
begin: np.ndarray,
225+
end: np.ndarray,
226+
num_groups: int,
227+
num_columns: int,
228+
) -> np.ndarray:
229+
result = np.empty((len(values), num_columns))
230+
for i in loop_range(num_groups):
231+
group_index = index[begin[i] : end[i]]
232+
for j in loop_range(num_columns):
233+
group = values[begin[i] : end[i], j]
234+
result[begin[i] : end[i], j] = numba_func(group, group_index, *args)
235+
return result
236+
237+
return group_transform

0 commit comments

Comments
 (0)