@@ -226,10 +226,6 @@ def apply(self, func, *args, **kwargs):
226
226
def aggregate (self , func = None , * args , engine = None , engine_kwargs = None , ** kwargs ):
227
227
228
228
if maybe_use_numba (engine ):
229
- if not callable (func ):
230
- raise NotImplementedError (
231
- "Numba engine can only be used with a single function."
232
- )
233
229
with group_selection_context (self ):
234
230
data = self ._selected_obj
235
231
result , index = self ._aggregate_with_numba (
@@ -489,12 +485,21 @@ def _aggregate_named(self, func, *args, **kwargs):
489
485
@Substitution (klass = "Series" )
490
486
@Appender (_transform_template )
491
487
def transform (self , func , * args , engine = None , engine_kwargs = None , ** kwargs ):
488
+
489
+ if maybe_use_numba (engine ):
490
+ with group_selection_context (self ):
491
+ data = self ._selected_obj
492
+ result = self ._transform_with_numba (
493
+ data .to_frame (), func , * args , engine_kwargs = engine_kwargs , ** kwargs
494
+ )
495
+ return self .obj ._constructor (
496
+ result .ravel (), index = data .index , name = data .name
497
+ )
498
+
492
499
func = self ._get_cython_func (func ) or func
493
500
494
501
if not isinstance (func , str ):
495
- return self ._transform_general (
496
- func , * args , engine = engine , engine_kwargs = engine_kwargs , ** kwargs
497
- )
502
+ return self ._transform_general (func , * args , ** kwargs )
498
503
499
504
elif func not in base .transform_kernel_allowlist :
500
505
msg = f"'{ func } ' is not a valid function name for transform(name)"
@@ -938,10 +943,6 @@ class DataFrameGroupBy(GroupBy[DataFrame]):
938
943
def aggregate (self , func = None , * args , engine = None , engine_kwargs = None , ** kwargs ):
939
944
940
945
if maybe_use_numba (engine ):
941
- if not callable (func ):
942
- raise NotImplementedError (
943
- "Numba engine can only be used with a single function."
944
- )
945
946
with group_selection_context (self ):
946
947
data = self ._selected_obj
947
948
result , index = self ._aggregate_with_numba (
@@ -1290,42 +1291,25 @@ def _wrap_applied_output(self, keys, values, not_indexed_same=False):
1290
1291
1291
1292
return self ._reindex_output (result )
1292
1293
1293
- def _transform_general (
1294
- self , func , * args , engine = "cython" , engine_kwargs = None , ** kwargs
1295
- ):
1294
+ def _transform_general (self , func , * args , ** kwargs ):
1296
1295
from pandas .core .reshape .concat import concat
1297
1296
1298
1297
applied = []
1299
1298
obj = self ._obj_with_exclusions
1300
1299
gen = self .grouper .get_iterator (obj , axis = self .axis )
1301
- if maybe_use_numba (engine ):
1302
- numba_func , cache_key = generate_numba_func (
1303
- func , engine_kwargs , kwargs , "groupby_transform"
1304
- )
1305
- else :
1306
- fast_path , slow_path = self ._define_paths (func , * args , ** kwargs )
1300
+ fast_path , slow_path = self ._define_paths (func , * args , ** kwargs )
1307
1301
1308
1302
for name , group in gen :
1309
1303
object .__setattr__ (group , "name" , name )
1310
1304
1311
- if maybe_use_numba (engine ):
1312
- values , index = split_for_numba (group )
1313
- res = numba_func (values , index , * args )
1314
- if cache_key not in NUMBA_FUNC_CACHE :
1315
- NUMBA_FUNC_CACHE [cache_key ] = numba_func
1316
- # Return the result as a DataFrame for concatenation later
1317
- res = self .obj ._constructor (
1318
- res , index = group .index , columns = group .columns
1319
- )
1320
- else :
1321
- # Try slow path and fast path.
1322
- try :
1323
- path , res = self ._choose_path (fast_path , slow_path , group )
1324
- except TypeError :
1325
- return self ._transform_item_by_item (obj , fast_path )
1326
- except ValueError as err :
1327
- msg = "transform must return a scalar value for each group"
1328
- raise ValueError (msg ) from err
1305
+ # Try slow path and fast path.
1306
+ try :
1307
+ path , res = self ._choose_path (fast_path , slow_path , group )
1308
+ except TypeError :
1309
+ return self ._transform_item_by_item (obj , fast_path )
1310
+ except ValueError as err :
1311
+ msg = "transform must return a scalar value for each group"
1312
+ raise ValueError (msg ) from err
1329
1313
1330
1314
if isinstance (res , Series ):
1331
1315
@@ -1361,13 +1345,19 @@ def _transform_general(
1361
1345
@Appender (_transform_template )
1362
1346
def transform (self , func , * args , engine = None , engine_kwargs = None , ** kwargs ):
1363
1347
1348
+ if maybe_use_numba (engine ):
1349
+ with group_selection_context (self ):
1350
+ data = self ._selected_obj
1351
+ result = self ._transform_with_numba (
1352
+ data , func , * args , engine_kwargs = engine_kwargs , ** kwargs
1353
+ )
1354
+ return self .obj ._constructor (result , index = data .index , columns = data .columns )
1355
+
1364
1356
# optimized transforms
1365
1357
func = self ._get_cython_func (func ) or func
1366
1358
1367
1359
if not isinstance (func , str ):
1368
- return self ._transform_general (
1369
- func , * args , engine = engine , engine_kwargs = engine_kwargs , ** kwargs
1370
- )
1360
+ return self ._transform_general (func , * args , ** kwargs )
1371
1361
1372
1362
elif func not in base .transform_kernel_allowlist :
1373
1363
msg = f"'{ func } ' is not a valid function name for transform(name)"
@@ -1393,9 +1383,7 @@ def transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs):
1393
1383
):
1394
1384
return self ._transform_fast (result )
1395
1385
1396
- return self ._transform_general (
1397
- func , engine = engine , engine_kwargs = engine_kwargs , * args , ** kwargs
1398
- )
1386
+ return self ._transform_general (func , * args , ** kwargs )
1399
1387
1400
1388
def _transform_fast (self , result : DataFrame ) -> DataFrame :
1401
1389
"""
0 commit comments