Skip to content

Commit 6f8ae81

Browse files
authored
REF: Move transform into apply (#39957)
1 parent 1b6b581 commit 6f8ae81

File tree

4 files changed

+141
-145
lines changed

4 files changed

+141
-145
lines changed

pandas/core/aggregation.py

+1-141
Original file line numberDiff line numberDiff line change
@@ -20,28 +20,19 @@
2020
Sequence,
2121
Tuple,
2222
Union,
23-
cast,
2423
)
2524

2625
from pandas._typing import (
2726
AggFuncType,
28-
AggFuncTypeBase,
29-
AggFuncTypeDict,
30-
Axis,
3127
FrameOrSeries,
32-
FrameOrSeriesUnion,
3328
)
3429

3530
from pandas.core.dtypes.common import (
3631
is_dict_like,
3732
is_list_like,
3833
)
39-
from pandas.core.dtypes.generic import (
40-
ABCDataFrame,
41-
ABCSeries,
42-
)
34+
from pandas.core.dtypes.generic import ABCSeries
4335

44-
from pandas.core.algorithms import safe_sort
4536
from pandas.core.base import SpecificationError
4637
import pandas.core.common as com
4738
from pandas.core.indexes.api import Index
@@ -405,134 +396,3 @@ def validate_func_kwargs(
405396
no_arg_message = "Must provide 'func' or named aggregation **kwargs."
406397
raise TypeError(no_arg_message)
407398
return columns, func
408-
409-
410-
def transform(
411-
obj: FrameOrSeries, func: AggFuncType, axis: Axis, *args, **kwargs
412-
) -> FrameOrSeriesUnion:
413-
"""
414-
Transform a DataFrame or Series
415-
416-
Parameters
417-
----------
418-
obj : DataFrame or Series
419-
Object to compute the transform on.
420-
func : string, function, list, or dictionary
421-
Function(s) to compute the transform with.
422-
axis : {0 or 'index', 1 or 'columns'}
423-
Axis along which the function is applied:
424-
425-
* 0 or 'index': apply function to each column.
426-
* 1 or 'columns': apply function to each row.
427-
428-
Returns
429-
-------
430-
DataFrame or Series
431-
Result of applying ``func`` along the given axis of the
432-
Series or DataFrame.
433-
434-
Raises
435-
------
436-
ValueError
437-
If the transform function fails or does not transform.
438-
"""
439-
is_series = obj.ndim == 1
440-
441-
if obj._get_axis_number(axis) == 1:
442-
assert not is_series
443-
return transform(obj.T, func, 0, *args, **kwargs).T
444-
445-
if is_list_like(func) and not is_dict_like(func):
446-
func = cast(List[AggFuncTypeBase], func)
447-
# Convert func equivalent dict
448-
if is_series:
449-
func = {com.get_callable_name(v) or v: v for v in func}
450-
else:
451-
func = {col: func for col in obj}
452-
453-
if is_dict_like(func):
454-
func = cast(AggFuncTypeDict, func)
455-
return transform_dict_like(obj, func, *args, **kwargs)
456-
457-
# func is either str or callable
458-
func = cast(AggFuncTypeBase, func)
459-
try:
460-
result = transform_str_or_callable(obj, func, *args, **kwargs)
461-
except Exception:
462-
raise ValueError("Transform function failed")
463-
464-
# Functions that transform may return empty Series/DataFrame
465-
# when the dtype is not appropriate
466-
if isinstance(result, (ABCSeries, ABCDataFrame)) and result.empty and not obj.empty:
467-
raise ValueError("Transform function failed")
468-
if not isinstance(result, (ABCSeries, ABCDataFrame)) or not result.index.equals(
469-
obj.index
470-
):
471-
raise ValueError("Function did not transform")
472-
473-
return result
474-
475-
476-
def transform_dict_like(
477-
obj: FrameOrSeries,
478-
func: AggFuncTypeDict,
479-
*args,
480-
**kwargs,
481-
):
482-
"""
483-
Compute transform in the case of a dict-like func
484-
"""
485-
from pandas.core.reshape.concat import concat
486-
487-
if len(func) == 0:
488-
raise ValueError("No transform functions were provided")
489-
490-
if obj.ndim != 1:
491-
# Check for missing columns on a frame
492-
cols = set(func.keys()) - set(obj.columns)
493-
if len(cols) > 0:
494-
cols_sorted = list(safe_sort(list(cols)))
495-
raise SpecificationError(f"Column(s) {cols_sorted} do not exist")
496-
497-
# Can't use func.values(); wouldn't work for a Series
498-
if any(is_dict_like(v) for _, v in func.items()):
499-
# GH 15931 - deprecation of renaming keys
500-
raise SpecificationError("nested renamer is not supported")
501-
502-
results: Dict[Hashable, FrameOrSeriesUnion] = {}
503-
for name, how in func.items():
504-
colg = obj._gotitem(name, ndim=1)
505-
try:
506-
results[name] = transform(colg, how, 0, *args, **kwargs)
507-
except Exception as err:
508-
if str(err) in {
509-
"Function did not transform",
510-
"No transform functions were provided",
511-
}:
512-
raise err
513-
514-
# combine results
515-
if not results:
516-
raise ValueError("Transform function failed")
517-
return concat(results, axis=1)
518-
519-
520-
def transform_str_or_callable(
521-
obj: FrameOrSeries, func: AggFuncTypeBase, *args, **kwargs
522-
) -> FrameOrSeriesUnion:
523-
"""
524-
Compute transform in the case of a string or callable func
525-
"""
526-
if isinstance(func, str):
527-
return obj._try_aggregate_string_function(func, *args, **kwargs)
528-
529-
if not args and not kwargs:
530-
f = obj._get_cython_func(func)
531-
if f:
532-
return getattr(obj, f)()
533-
534-
# Two possible ways to use a UDF - apply or call directly
535-
try:
536-
return obj.apply(func, args=args, **kwargs)
537-
except Exception:
538-
return func(obj, *args, **kwargs)

pandas/core/apply.py

+132
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
TYPE_CHECKING,
77
Any,
88
Dict,
9+
Hashable,
910
Iterator,
1011
List,
1112
Optional,
@@ -198,6 +199,131 @@ def agg(self) -> Optional[FrameOrSeriesUnion]:
198199
# caller can react
199200
return None
200201

202+
def transform(self) -> FrameOrSeriesUnion:
203+
"""
204+
Transform a DataFrame or Series.
205+
206+
Returns
207+
-------
208+
DataFrame or Series
209+
Result of applying ``func`` along the given axis of the
210+
Series or DataFrame.
211+
212+
Raises
213+
------
214+
ValueError
215+
If the transform function fails or does not transform.
216+
"""
217+
obj = self.obj
218+
func = self.orig_f
219+
axis = self.axis
220+
args = self.args
221+
kwargs = self.kwargs
222+
223+
is_series = obj.ndim == 1
224+
225+
if obj._get_axis_number(axis) == 1:
226+
assert not is_series
227+
return obj.T.transform(func, 0, *args, **kwargs).T
228+
229+
if is_list_like(func) and not is_dict_like(func):
230+
func = cast(List[AggFuncTypeBase], func)
231+
# Convert func equivalent dict
232+
if is_series:
233+
func = {com.get_callable_name(v) or v: v for v in func}
234+
else:
235+
func = {col: func for col in obj}
236+
237+
if is_dict_like(func):
238+
func = cast(AggFuncTypeDict, func)
239+
return self.transform_dict_like(func)
240+
241+
# func is either str or callable
242+
func = cast(AggFuncTypeBase, func)
243+
try:
244+
result = self.transform_str_or_callable(func)
245+
except Exception:
246+
raise ValueError("Transform function failed")
247+
248+
# Functions that transform may return empty Series/DataFrame
249+
# when the dtype is not appropriate
250+
if (
251+
isinstance(result, (ABCSeries, ABCDataFrame))
252+
and result.empty
253+
and not obj.empty
254+
):
255+
raise ValueError("Transform function failed")
256+
if not isinstance(result, (ABCSeries, ABCDataFrame)) or not result.index.equals(
257+
obj.index
258+
):
259+
raise ValueError("Function did not transform")
260+
261+
return result
262+
263+
def transform_dict_like(self, func):
264+
"""
265+
Compute transform in the case of a dict-like func
266+
"""
267+
from pandas.core.reshape.concat import concat
268+
269+
obj = self.obj
270+
args = self.args
271+
kwargs = self.kwargs
272+
273+
if len(func) == 0:
274+
raise ValueError("No transform functions were provided")
275+
276+
if obj.ndim != 1:
277+
# Check for missing columns on a frame
278+
cols = set(func.keys()) - set(obj.columns)
279+
if len(cols) > 0:
280+
cols_sorted = list(safe_sort(list(cols)))
281+
raise SpecificationError(f"Column(s) {cols_sorted} do not exist")
282+
283+
# Can't use func.values(); wouldn't work for a Series
284+
if any(is_dict_like(v) for _, v in func.items()):
285+
# GH 15931 - deprecation of renaming keys
286+
raise SpecificationError("nested renamer is not supported")
287+
288+
results: Dict[Hashable, FrameOrSeriesUnion] = {}
289+
for name, how in func.items():
290+
colg = obj._gotitem(name, ndim=1)
291+
try:
292+
results[name] = colg.transform(how, 0, *args, **kwargs)
293+
except Exception as err:
294+
if str(err) in {
295+
"Function did not transform",
296+
"No transform functions were provided",
297+
}:
298+
raise err
299+
300+
# combine results
301+
if not results:
302+
raise ValueError("Transform function failed")
303+
return concat(results, axis=1)
304+
305+
def transform_str_or_callable(self, func) -> FrameOrSeriesUnion:
306+
"""
307+
Compute transform in the case of a string or callable func
308+
"""
309+
obj = self.obj
310+
args = self.args
311+
kwargs = self.kwargs
312+
313+
if isinstance(func, str):
314+
return obj._try_aggregate_string_function(func, *args, **kwargs)
315+
316+
if not args and not kwargs:
317+
f = obj._get_cython_func(func)
318+
if f:
319+
return getattr(obj, f)()
320+
321+
# Two possible ways to use a UDF - apply or call directly
322+
try:
323+
return obj.apply(func, args=args, **kwargs)
324+
except Exception:
325+
return func(obj, *args, **kwargs)
326+
201327
def agg_list_like(self, _axis: int) -> FrameOrSeriesUnion:
202328
"""
203329
Compute aggregation in the case of a list-like argument.
@@ -961,6 +1087,9 @@ def __init__(
9611087
def apply(self):
9621088
raise NotImplementedError
9631089

1090+
def transform(self):
1091+
raise NotImplementedError
1092+
9641093

9651094
class ResamplerWindowApply(Apply):
9661095
axis = 0
@@ -984,3 +1113,6 @@ def __init__(
9841113

9851114
def apply(self):
9861115
raise NotImplementedError
1116+
1117+
def transform(self):
1118+
raise NotImplementedError

pandas/core/frame.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,6 @@
139139
from pandas.core.aggregation import (
140140
reconstruct_func,
141141
relabel_result,
142-
transform,
143142
)
144143
from pandas.core.arraylike import OpsMixin
145144
from pandas.core.arrays import ExtensionArray
@@ -7761,7 +7760,10 @@ def aggregate(self, func=None, axis: Axis = 0, *args, **kwargs):
77617760
def transform(
77627761
self, func: AggFuncType, axis: Axis = 0, *args, **kwargs
77637762
) -> DataFrame:
7764-
result = transform(self, func, axis, *args, **kwargs)
7763+
from pandas.core.apply import frame_apply
7764+
7765+
op = frame_apply(self, func=func, axis=axis, args=args, kwargs=kwargs)
7766+
result = op.transform()
77657767
assert isinstance(result, DataFrame)
77667768
return result
77677769

pandas/core/series.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,6 @@
9494
ops,
9595
)
9696
from pandas.core.accessor import CachedAccessor
97-
from pandas.core.aggregation import transform
9897
from pandas.core.apply import series_apply
9998
from pandas.core.arrays import ExtensionArray
10099
from pandas.core.arrays.categorical import CategoricalAccessor
@@ -4015,7 +4014,10 @@ def aggregate(self, func=None, axis=0, *args, **kwargs):
40154014
def transform(
40164015
self, func: AggFuncType, axis: Axis = 0, *args, **kwargs
40174016
) -> FrameOrSeriesUnion:
4018-
return transform(self, func, axis, *args, **kwargs)
4017+
# Validate axis argument
4018+
self._get_axis_number(axis)
4019+
result = series_apply(self, func=func, args=args, kwargs=kwargs).transform()
4020+
return result
40194021

40204022
def apply(
40214023
self,

0 commit comments

Comments
 (0)