@@ -394,35 +394,39 @@ def _aggregate_named(self, func, *args, **kwargs):
394
394
def transform (self , func , * args , ** kwargs ):
395
395
func = self ._get_cython_func (func ) or func
396
396
397
- if isinstance (func , str ):
398
- if not (func in base .transform_kernel_whitelist ):
399
- msg = "'{func}' is not a valid function name for transform(name)"
400
- raise ValueError (msg .format (func = func ))
401
- if func in base .cythonized_kernels :
402
- # cythonized transform or canned "agg+broadcast"
403
- return getattr (self , func )(* args , ** kwargs )
404
- else :
405
- # If func is a reduction, we need to broadcast the
406
- # result to the whole group. Compute func result
407
- # and deal with possible broadcasting below.
408
- return self ._transform_fast (
409
- lambda : getattr (self , func )(* args , ** kwargs ), func
410
- )
397
+ if not isinstance (func , str ):
398
+ return self ._transform_general (func , * args , ** kwargs )
399
+
400
+ elif func not in base .transform_kernel_whitelist :
401
+ msg = f"'{ func } ' is not a valid function name for transform(name)"
402
+ raise ValueError (msg )
403
+ elif func in base .cythonized_kernels :
404
+ # cythonized transform or canned "agg+broadcast"
405
+ return getattr (self , func )(* args , ** kwargs )
411
406
412
- # reg transform
407
+ # If func is a reduction, we need to broadcast the
408
+ # result to the whole group. Compute func result
409
+ # and deal with possible broadcasting below.
410
+ result = getattr (self , func )(* args , ** kwargs )
411
+ return self ._transform_fast (result , func )
412
+
413
+ def _transform_general (self , func , * args , ** kwargs ):
414
+ """
415
+ Transform with a non-str `func`.
416
+ """
413
417
klass = self ._selected_obj .__class__
418
+
414
419
results = []
415
- wrapper = lambda x : func (x , * args , ** kwargs )
416
420
for name , group in self :
417
421
object .__setattr__ (group , "name" , name )
418
- res = wrapper (group )
422
+ res = func (group , * args , ** kwargs )
419
423
420
424
if isinstance (res , (ABCDataFrame , ABCSeries )):
421
425
res = res ._values
422
426
423
427
indexer = self ._get_index (name )
424
- s = klass (res , indexer )
425
- results .append (s )
428
+ ser = klass (res , indexer )
429
+ results .append (ser )
426
430
427
431
# check for empty "results" to avoid concat ValueError
428
432
if results :
@@ -433,7 +437,7 @@ def transform(self, func, *args, **kwargs):
433
437
result = Series ()
434
438
435
439
# we will only try to coerce the result type if
436
- # we have a numeric dtype, as these are *always* udfs
440
+ # we have a numeric dtype, as these are *always* user-defined funcs
437
441
# the cython take a different path (and casting)
438
442
dtype = self ._selected_obj .dtype
439
443
if is_numeric_dtype (dtype ):
@@ -443,17 +447,14 @@ def transform(self, func, *args, **kwargs):
443
447
result .index = self ._selected_obj .index
444
448
return result
445
449
446
- def _transform_fast (self , func , func_nm ) -> Series :
450
+ def _transform_fast (self , result , func_nm : str ) -> Series :
447
451
"""
448
452
fast version of transform, only applicable to
449
453
builtin/cythonizable functions
450
454
"""
451
- if isinstance (func , str ):
452
- func = getattr (self , func )
453
-
454
455
ids , _ , ngroup = self .grouper .group_info
455
456
cast = self ._transform_should_cast (func_nm )
456
- out = algorithms .take_1d (func () ._values , ids )
457
+ out = algorithms .take_1d (result ._values , ids )
457
458
if cast :
458
459
out = self ._try_cast (out , self .obj )
459
460
return Series (out , index = self .obj .index , name = self .obj .name )
@@ -1333,21 +1334,21 @@ def transform(self, func, *args, **kwargs):
1333
1334
# optimized transforms
1334
1335
func = self ._get_cython_func (func ) or func
1335
1336
1336
- if isinstance (func , str ):
1337
- if not (func in base .transform_kernel_whitelist ):
1338
- msg = "'{func}' is not a valid function name for transform(name)"
1339
- raise ValueError (msg .format (func = func ))
1340
- if func in base .cythonized_kernels :
1341
- # cythonized transformation or canned "reduction+broadcast"
1342
- return getattr (self , func )(* args , ** kwargs )
1343
- else :
1344
- # If func is a reduction, we need to broadcast the
1345
- # result to the whole group. Compute func result
1346
- # and deal with possible broadcasting below.
1347
- result = getattr (self , func )(* args , ** kwargs )
1348
- else :
1337
+ if not isinstance (func , str ):
1349
1338
return self ._transform_general (func , * args , ** kwargs )
1350
1339
1340
+ elif func not in base .transform_kernel_whitelist :
1341
+ msg = f"'{ func } ' is not a valid function name for transform(name)"
1342
+ raise ValueError (msg )
1343
+ elif func in base .cythonized_kernels :
1344
+ # cythonized transformation or canned "reduction+broadcast"
1345
+ return getattr (self , func )(* args , ** kwargs )
1346
+
1347
+ # If func is a reduction, we need to broadcast the
1348
+ # result to the whole group. Compute func result
1349
+ # and deal with possible broadcasting below.
1350
+ result = getattr (self , func )(* args , ** kwargs )
1351
+
1351
1352
# a reduction transform
1352
1353
if not isinstance (result , DataFrame ):
1353
1354
return self ._transform_general (func , * args , ** kwargs )
@@ -1358,16 +1359,18 @@ def transform(self, func, *args, **kwargs):
1358
1359
if not result .columns .equals (obj .columns ):
1359
1360
return self ._transform_general (func , * args , ** kwargs )
1360
1361
1361
- return self ._transform_fast (result , obj , func )
1362
+ return self ._transform_fast (result , func )
1362
1363
1363
- def _transform_fast (self , result : DataFrame , obj : DataFrame , func_nm ) -> DataFrame :
1364
+ def _transform_fast (self , result : DataFrame , func_nm : str ) -> DataFrame :
1364
1365
"""
1365
1366
Fast transform path for aggregations
1366
1367
"""
1367
1368
# if there were groups with no observations (Categorical only?)
1368
1369
# try casting data to original dtype
1369
1370
cast = self ._transform_should_cast (func_nm )
1370
1371
1372
+ obj = self ._obj_with_exclusions
1373
+
1371
1374
# for each col, reshape to to size of original frame
1372
1375
# by take operation
1373
1376
ids , _ , ngroup = self .grouper .group_info
0 commit comments