75
75
import pandas .core .indexes .base as ibase
76
76
from pandas .core .internals import BlockManager , make_block
77
77
from pandas .core .series import Series
78
+ from pandas .core .util .numba_ import (
79
+ check_kwargs_and_nopython ,
80
+ get_jit_arguments ,
81
+ jit_user_function ,
82
+ split_for_numba ,
83
+ validate_udf ,
84
+ )
78
85
79
86
from pandas .plotting import boxplot_frame_groupby
80
87
@@ -154,6 +161,8 @@ def pinner(cls):
154
161
class SeriesGroupBy (GroupBy [Series ]):
155
162
_apply_whitelist = base .series_apply_whitelist
156
163
164
+ _numba_func_cache : Dict [Callable , Callable ] = {}
165
+
157
166
def _iterate_slices (self ) -> Iterable [Series ]:
158
167
yield self ._selected_obj
159
168
@@ -463,11 +472,13 @@ def _aggregate_named(self, func, *args, **kwargs):
463
472
464
473
@Substitution (klass = "Series" , selected = "A." )
465
474
@Appender (_transform_template )
466
- def transform (self , func , * args , ** kwargs ):
475
+ def transform (self , func , * args , engine = "cython" , engine_kwargs = None , ** kwargs ):
467
476
func = self ._get_cython_func (func ) or func
468
477
469
478
if not isinstance (func , str ):
470
- return self ._transform_general (func , * args , ** kwargs )
479
+ return self ._transform_general (
480
+ func , * args , engine = engine , engine_kwargs = engine_kwargs , ** kwargs
481
+ )
471
482
472
483
elif func not in base .transform_kernel_whitelist :
473
484
msg = f"'{ func } ' is not a valid function name for transform(name)"
@@ -482,16 +493,33 @@ def transform(self, func, *args, **kwargs):
482
493
result = getattr (self , func )(* args , ** kwargs )
483
494
return self ._transform_fast (result , func )
484
495
485
- def _transform_general (self , func , * args , ** kwargs ):
496
+ def _transform_general (
497
+ self , func , * args , engine = "cython" , engine_kwargs = None , ** kwargs
498
+ ):
486
499
"""
487
500
Transform with a non-str `func`.
488
501
"""
502
+
503
+ if engine == "numba" :
504
+ nopython , nogil , parallel = get_jit_arguments (engine_kwargs )
505
+ check_kwargs_and_nopython (kwargs , nopython )
506
+ validate_udf (func )
507
+ numba_func = self ._numba_func_cache .get (
508
+ func , jit_user_function (func , nopython , nogil , parallel )
509
+ )
510
+
489
511
klass = type (self ._selected_obj )
490
512
491
513
results = []
492
514
for name , group in self :
493
515
object .__setattr__ (group , "name" , name )
494
- res = func (group , * args , ** kwargs )
516
+ if engine == "numba" :
517
+ values , index = split_for_numba (group )
518
+ res = numba_func (values , index , * args )
519
+ if func not in self ._numba_func_cache :
520
+ self ._numba_func_cache [func ] = numba_func
521
+ else :
522
+ res = func (group , * args , ** kwargs )
495
523
496
524
if isinstance (res , (ABCDataFrame , ABCSeries )):
497
525
res = res ._values
@@ -819,6 +847,8 @@ class DataFrameGroupBy(GroupBy[DataFrame]):
819
847
820
848
_apply_whitelist = base .dataframe_apply_whitelist
821
849
850
+ _numba_func_cache : Dict [Callable , Callable ] = {}
851
+
822
852
_agg_see_also_doc = dedent (
823
853
"""
824
854
See Also
@@ -1355,19 +1385,35 @@ def _wrap_applied_output(self, keys, values, not_indexed_same=False):
1355
1385
# Handle cases like BinGrouper
1356
1386
return self ._concat_objects (keys , values , not_indexed_same = not_indexed_same )
1357
1387
1358
- def _transform_general (self , func , * args , ** kwargs ):
1388
+ def _transform_general (
1389
+ self , func , * args , engine = "cython" , engine_kwargs = None , ** kwargs
1390
+ ):
1359
1391
from pandas .core .reshape .concat import concat
1360
1392
1361
1393
applied = []
1362
1394
obj = self ._obj_with_exclusions
1363
1395
gen = self .grouper .get_iterator (obj , axis = self .axis )
1364
- fast_path , slow_path = self ._define_paths (func , * args , ** kwargs )
1396
+ if engine == "numba" :
1397
+ nopython , nogil , parallel = get_jit_arguments (engine_kwargs )
1398
+ check_kwargs_and_nopython (kwargs , nopython )
1399
+ validate_udf (func )
1400
+ numba_func = self ._numba_func_cache .get (
1401
+ func , jit_user_function (func , nopython , nogil , parallel )
1402
+ )
1403
+ else :
1404
+ fast_path , slow_path = self ._define_paths (func , * args , ** kwargs )
1365
1405
1366
- path = None
1367
1406
for name , group in gen :
1368
1407
object .__setattr__ (group , "name" , name )
1369
1408
1370
- if path is None :
1409
+ if engine == "numba" :
1410
+ values , index = split_for_numba (group )
1411
+ res = numba_func (values , index , * args )
1412
+ if func not in self ._numba_func_cache :
1413
+ self ._numba_func_cache [func ] = numba_func
1414
+ # Return the result as a DataFrame for concatenation later
1415
+ res = DataFrame (res , index = group .index , columns = group .columns )
1416
+ else :
1371
1417
# Try slow path and fast path.
1372
1418
try :
1373
1419
path , res = self ._choose_path (fast_path , slow_path , group )
@@ -1376,8 +1422,6 @@ def _transform_general(self, func, *args, **kwargs):
1376
1422
except ValueError as err :
1377
1423
msg = "transform must return a scalar value for each group"
1378
1424
raise ValueError (msg ) from err
1379
- else :
1380
- res = path (group )
1381
1425
1382
1426
if isinstance (res , Series ):
1383
1427
@@ -1411,13 +1455,15 @@ def _transform_general(self, func, *args, **kwargs):
1411
1455
1412
1456
@Substitution (klass = "DataFrame" , selected = "" )
1413
1457
@Appender (_transform_template )
1414
- def transform (self , func , * args , ** kwargs ):
1458
+ def transform (self , func , * args , engine = "cython" , engine_kwargs = None , ** kwargs ):
1415
1459
1416
1460
# optimized transforms
1417
1461
func = self ._get_cython_func (func ) or func
1418
1462
1419
1463
if not isinstance (func , str ):
1420
- return self ._transform_general (func , * args , ** kwargs )
1464
+ return self ._transform_general (
1465
+ func , * args , engine = engine , engine_kwargs = engine_kwargs , ** kwargs
1466
+ )
1421
1467
1422
1468
elif func not in base .transform_kernel_whitelist :
1423
1469
msg = f"'{ func } ' is not a valid function name for transform(name)"
@@ -1439,7 +1485,9 @@ def transform(self, func, *args, **kwargs):
1439
1485
):
1440
1486
return self ._transform_fast (result , func )
1441
1487
1442
- return self ._transform_general (func , * args , ** kwargs )
1488
+ return self ._transform_general (
1489
+ func , engine = engine , engine_kwargs = engine_kwargs , * args , ** kwargs
1490
+ )
1443
1491
1444
1492
def _transform_fast (self , result : DataFrame , func_nm : str ) -> DataFrame :
1445
1493
"""
0 commit comments