Skip to content

Commit 07e6b9d

Browse files
jbrockmendeljreback
authored andcommitted
REF: align transform logic flow (#29672)
1 parent 7a331c9 commit 07e6b9d

File tree

1 file changed

+43
-40
lines changed

1 file changed

+43
-40
lines changed

pandas/core/groupby/generic.py

+43-40
Original file line numberDiff line numberDiff line change
@@ -394,35 +394,39 @@ def _aggregate_named(self, func, *args, **kwargs):
394394
def transform(self, func, *args, **kwargs):
395395
func = self._get_cython_func(func) or func
396396

397-
if isinstance(func, str):
398-
if not (func in base.transform_kernel_whitelist):
399-
msg = "'{func}' is not a valid function name for transform(name)"
400-
raise ValueError(msg.format(func=func))
401-
if func in base.cythonized_kernels:
402-
# cythonized transform or canned "agg+broadcast"
403-
return getattr(self, func)(*args, **kwargs)
404-
else:
405-
# If func is a reduction, we need to broadcast the
406-
# result to the whole group. Compute func result
407-
# and deal with possible broadcasting below.
408-
return self._transform_fast(
409-
lambda: getattr(self, func)(*args, **kwargs), func
410-
)
397+
if not isinstance(func, str):
398+
return self._transform_general(func, *args, **kwargs)
399+
400+
elif func not in base.transform_kernel_whitelist:
401+
msg = f"'{func}' is not a valid function name for transform(name)"
402+
raise ValueError(msg)
403+
elif func in base.cythonized_kernels:
404+
# cythonized transform or canned "agg+broadcast"
405+
return getattr(self, func)(*args, **kwargs)
411406

412-
# reg transform
407+
# If func is a reduction, we need to broadcast the
408+
# result to the whole group. Compute func result
409+
# and deal with possible broadcasting below.
410+
result = getattr(self, func)(*args, **kwargs)
411+
return self._transform_fast(result, func)
412+
413+
def _transform_general(self, func, *args, **kwargs):
414+
"""
415+
Transform with a non-str `func`.
416+
"""
413417
klass = self._selected_obj.__class__
418+
414419
results = []
415-
wrapper = lambda x: func(x, *args, **kwargs)
416420
for name, group in self:
417421
object.__setattr__(group, "name", name)
418-
res = wrapper(group)
422+
res = func(group, *args, **kwargs)
419423

420424
if isinstance(res, (ABCDataFrame, ABCSeries)):
421425
res = res._values
422426

423427
indexer = self._get_index(name)
424-
s = klass(res, indexer)
425-
results.append(s)
428+
ser = klass(res, indexer)
429+
results.append(ser)
426430

427431
# check for empty "results" to avoid concat ValueError
428432
if results:
@@ -433,7 +437,7 @@ def transform(self, func, *args, **kwargs):
433437
result = Series()
434438

435439
# we will only try to coerce the result type if
436-
# we have a numeric dtype, as these are *always* udfs
440+
# we have a numeric dtype, as these are *always* user-defined funcs
437441
# the cython take a different path (and casting)
438442
dtype = self._selected_obj.dtype
439443
if is_numeric_dtype(dtype):
@@ -443,17 +447,14 @@ def transform(self, func, *args, **kwargs):
443447
result.index = self._selected_obj.index
444448
return result
445449

446-
def _transform_fast(self, func, func_nm) -> Series:
450+
def _transform_fast(self, result, func_nm: str) -> Series:
447451
"""
448452
fast version of transform, only applicable to
449453
builtin/cythonizable functions
450454
"""
451-
if isinstance(func, str):
452-
func = getattr(self, func)
453-
454455
ids, _, ngroup = self.grouper.group_info
455456
cast = self._transform_should_cast(func_nm)
456-
out = algorithms.take_1d(func()._values, ids)
457+
out = algorithms.take_1d(result._values, ids)
457458
if cast:
458459
out = self._try_cast(out, self.obj)
459460
return Series(out, index=self.obj.index, name=self.obj.name)
@@ -1333,21 +1334,21 @@ def transform(self, func, *args, **kwargs):
13331334
# optimized transforms
13341335
func = self._get_cython_func(func) or func
13351336

1336-
if isinstance(func, str):
1337-
if not (func in base.transform_kernel_whitelist):
1338-
msg = "'{func}' is not a valid function name for transform(name)"
1339-
raise ValueError(msg.format(func=func))
1340-
if func in base.cythonized_kernels:
1341-
# cythonized transformation or canned "reduction+broadcast"
1342-
return getattr(self, func)(*args, **kwargs)
1343-
else:
1344-
# If func is a reduction, we need to broadcast the
1345-
# result to the whole group. Compute func result
1346-
# and deal with possible broadcasting below.
1347-
result = getattr(self, func)(*args, **kwargs)
1348-
else:
1337+
if not isinstance(func, str):
13491338
return self._transform_general(func, *args, **kwargs)
13501339

1340+
elif func not in base.transform_kernel_whitelist:
1341+
msg = f"'{func}' is not a valid function name for transform(name)"
1342+
raise ValueError(msg)
1343+
elif func in base.cythonized_kernels:
1344+
# cythonized transformation or canned "reduction+broadcast"
1345+
return getattr(self, func)(*args, **kwargs)
1346+
1347+
# If func is a reduction, we need to broadcast the
1348+
# result to the whole group. Compute func result
1349+
# and deal with possible broadcasting below.
1350+
result = getattr(self, func)(*args, **kwargs)
1351+
13511352
# a reduction transform
13521353
if not isinstance(result, DataFrame):
13531354
return self._transform_general(func, *args, **kwargs)
@@ -1358,16 +1359,18 @@ def transform(self, func, *args, **kwargs):
13581359
if not result.columns.equals(obj.columns):
13591360
return self._transform_general(func, *args, **kwargs)
13601361

1361-
return self._transform_fast(result, obj, func)
1362+
return self._transform_fast(result, func)
13621363

1363-
def _transform_fast(self, result: DataFrame, obj: DataFrame, func_nm) -> DataFrame:
1364+
def _transform_fast(self, result: DataFrame, func_nm: str) -> DataFrame:
13641365
"""
13651366
Fast transform path for aggregations
13661367
"""
13671368
# if there were groups with no observations (Categorical only?)
13681369
# try casting data to original dtype
13691370
cast = self._transform_should_cast(func_nm)
13701371

1372+
obj = self._obj_with_exclusions
1373+
13711374
# for each col, reshape to to size of original frame
13721375
# by take operation
13731376
ids, _, ngroup = self.grouper.group_info

0 commit comments

Comments
 (0)