diff --git a/pandas/core/aggregation.py b/pandas/core/aggregation.py index 744a1ffa5fea1..0a4e03fa97402 100644 --- a/pandas/core/aggregation.py +++ b/pandas/core/aggregation.py @@ -20,28 +20,19 @@ Sequence, Tuple, Union, - cast, ) from pandas._typing import ( AggFuncType, - AggFuncTypeBase, - AggFuncTypeDict, - Axis, FrameOrSeries, - FrameOrSeriesUnion, ) from pandas.core.dtypes.common import ( is_dict_like, is_list_like, ) -from pandas.core.dtypes.generic import ( - ABCDataFrame, - ABCSeries, -) +from pandas.core.dtypes.generic import ABCSeries -from pandas.core.algorithms import safe_sort from pandas.core.base import SpecificationError import pandas.core.common as com from pandas.core.indexes.api import Index @@ -405,134 +396,3 @@ def validate_func_kwargs( no_arg_message = "Must provide 'func' or named aggregation **kwargs." raise TypeError(no_arg_message) return columns, func - - -def transform( - obj: FrameOrSeries, func: AggFuncType, axis: Axis, *args, **kwargs -) -> FrameOrSeriesUnion: - """ - Transform a DataFrame or Series - - Parameters - ---------- - obj : DataFrame or Series - Object to compute the transform on. - func : string, function, list, or dictionary - Function(s) to compute the transform with. - axis : {0 or 'index', 1 or 'columns'} - Axis along which the function is applied: - - * 0 or 'index': apply function to each column. - * 1 or 'columns': apply function to each row. - - Returns - ------- - DataFrame or Series - Result of applying ``func`` along the given axis of the - Series or DataFrame. - - Raises - ------ - ValueError - If the transform function fails or does not transform. - """ - is_series = obj.ndim == 1 - - if obj._get_axis_number(axis) == 1: - assert not is_series - return transform(obj.T, func, 0, *args, **kwargs).T - - if is_list_like(func) and not is_dict_like(func): - func = cast(List[AggFuncTypeBase], func) - # Convert func equivalent dict - if is_series: - func = {com.get_callable_name(v) or v: v for v in func} - else: - func = {col: func for col in obj} - - if is_dict_like(func): - func = cast(AggFuncTypeDict, func) - return transform_dict_like(obj, func, *args, **kwargs) - - # func is either str or callable - func = cast(AggFuncTypeBase, func) - try: - result = transform_str_or_callable(obj, func, *args, **kwargs) - except Exception: - raise ValueError("Transform function failed") - - # Functions that transform may return empty Series/DataFrame - # when the dtype is not appropriate - if isinstance(result, (ABCSeries, ABCDataFrame)) and result.empty and not obj.empty: - raise ValueError("Transform function failed") - if not isinstance(result, (ABCSeries, ABCDataFrame)) or not result.index.equals( - obj.index - ): - raise ValueError("Function did not transform") - - return result - - -def transform_dict_like( - obj: FrameOrSeries, - func: AggFuncTypeDict, - *args, - **kwargs, -): - """ - Compute transform in the case of a dict-like func - """ - from pandas.core.reshape.concat import concat - - if len(func) == 0: - raise ValueError("No transform functions were provided") - - if obj.ndim != 1: - # Check for missing columns on a frame - cols = set(func.keys()) - set(obj.columns) - if len(cols) > 0: - cols_sorted = list(safe_sort(list(cols))) - raise SpecificationError(f"Column(s) {cols_sorted} do not exist") - - # Can't use func.values(); wouldn't work for a Series - if any(is_dict_like(v) for _, v in func.items()): - # GH 15931 - deprecation of renaming keys - raise SpecificationError("nested renamer is not supported") - - results: Dict[Hashable, FrameOrSeriesUnion] = {} - for name, how in func.items(): - colg = obj._gotitem(name, ndim=1) - try: - results[name] = transform(colg, how, 0, *args, **kwargs) - except Exception as err: - if str(err) in { - "Function did not transform", - "No transform functions were provided", - }: - raise err - - # combine results - if not results: - raise ValueError("Transform function failed") - return concat(results, axis=1) - - -def transform_str_or_callable( - obj: FrameOrSeries, func: AggFuncTypeBase, *args, **kwargs -) -> FrameOrSeriesUnion: - """ - Compute transform in the case of a string or callable func - """ - if isinstance(func, str): - return obj._try_aggregate_string_function(func, *args, **kwargs) - - if not args and not kwargs: - f = obj._get_cython_func(func) - if f: - return getattr(obj, f)() - - # Two possible ways to use a UDF - apply or call directly - try: - return obj.apply(func, args=args, **kwargs) - except Exception: - return func(obj, *args, **kwargs) diff --git a/pandas/core/apply.py b/pandas/core/apply.py index b41c432dff172..e43e9dadda033 100644 --- a/pandas/core/apply.py +++ b/pandas/core/apply.py @@ -6,6 +6,7 @@ TYPE_CHECKING, Any, Dict, + Hashable, Iterator, List, Optional, @@ -151,6 +152,7 @@ def f(x): else: f = func + self.orig_f: AggFuncType = func self.f: AggFuncType = f @property @@ -197,6 +199,131 @@ def agg(self) -> Optional[FrameOrSeriesUnion]: # caller can react return None + def transform(self) -> FrameOrSeriesUnion: + """ + Transform a DataFrame or Series. + + Returns + ------- + DataFrame or Series + Result of applying ``func`` along the given axis of the + Series or DataFrame. + + Raises + ------ + ValueError + If the transform function fails or does not transform. + """ + obj = self.obj + func = self.orig_f + axis = self.axis + args = self.args + kwargs = self.kwargs + + is_series = obj.ndim == 1 + + if obj._get_axis_number(axis) == 1: + assert not is_series + return obj.T.transform(func, 0, *args, **kwargs).T + + if is_list_like(func) and not is_dict_like(func): + func = cast(List[AggFuncTypeBase], func) + # Convert func equivalent dict + if is_series: + func = {com.get_callable_name(v) or v: v for v in func} + else: + func = {col: func for col in obj} + + if is_dict_like(func): + func = cast(AggFuncTypeDict, func) + return self.transform_dict_like(func) + + # func is either str or callable + func = cast(AggFuncTypeBase, func) + try: + result = self.transform_str_or_callable(func) + except Exception: + raise ValueError("Transform function failed") + + # Functions that transform may return empty Series/DataFrame + # when the dtype is not appropriate + if ( + isinstance(result, (ABCSeries, ABCDataFrame)) + and result.empty + and not obj.empty + ): + raise ValueError("Transform function failed") + if not isinstance(result, (ABCSeries, ABCDataFrame)) or not result.index.equals( + obj.index + ): + raise ValueError("Function did not transform") + + return result + + def transform_dict_like(self, func): + """ + Compute transform in the case of a dict-like func + """ + from pandas.core.reshape.concat import concat + + obj = self.obj + args = self.args + kwargs = self.kwargs + + if len(func) == 0: + raise ValueError("No transform functions were provided") + + if obj.ndim != 1: + # Check for missing columns on a frame + cols = set(func.keys()) - set(obj.columns) + if len(cols) > 0: + cols_sorted = list(safe_sort(list(cols))) + raise SpecificationError(f"Column(s) {cols_sorted} do not exist") + + # Can't use func.values(); wouldn't work for a Series + if any(is_dict_like(v) for _, v in func.items()): + # GH 15931 - deprecation of renaming keys + raise SpecificationError("nested renamer is not supported") + + results: Dict[Hashable, FrameOrSeriesUnion] = {} + for name, how in func.items(): + colg = obj._gotitem(name, ndim=1) + try: + results[name] = colg.transform(how, 0, *args, **kwargs) + except Exception as err: + if str(err) in { + "Function did not transform", + "No transform functions were provided", + }: + raise err + + # combine results + if not results: + raise ValueError("Transform function failed") + return concat(results, axis=1) + + def transform_str_or_callable(self, func) -> FrameOrSeriesUnion: + """ + Compute transform in the case of a string or callable func + """ + obj = self.obj + args = self.args + kwargs = self.kwargs + + if isinstance(func, str): + return obj._try_aggregate_string_function(func, *args, **kwargs) + + if not args and not kwargs: + f = obj._get_cython_func(func) + if f: + return getattr(obj, f)() + + # Two possible ways to use a UDF - apply or call directly + try: + return obj.apply(func, args=args, **kwargs) + except Exception: + return func(obj, *args, **kwargs) + def agg_list_like(self, _axis: int) -> FrameOrSeriesUnion: """ Compute aggregation in the case of a list-like argument. @@ -901,6 +1028,9 @@ def __init__( def apply(self): raise NotImplementedError + def transform(self): + raise NotImplementedError + class ResamplerWindowApply(Apply): axis = 0 @@ -924,3 +1054,6 @@ def __init__( def apply(self): raise NotImplementedError + + def transform(self): + raise NotImplementedError diff --git a/pandas/core/frame.py b/pandas/core/frame.py index 3fe330f659513..070f3dae7ae1c 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -139,7 +139,6 @@ from pandas.core.aggregation import ( reconstruct_func, relabel_result, - transform, ) from pandas.core.arraylike import OpsMixin from pandas.core.arrays import ExtensionArray @@ -7786,7 +7785,10 @@ def _aggregate(self, arg, axis: Axis = 0, *args, **kwargs): def transform( self, func: AggFuncType, axis: Axis = 0, *args, **kwargs ) -> DataFrame: - result = transform(self, func, axis, *args, **kwargs) + from pandas.core.apply import frame_apply + + op = frame_apply(self, func=func, axis=axis, args=args, kwargs=kwargs) + result = op.transform() assert isinstance(result, DataFrame) return result diff --git a/pandas/core/series.py b/pandas/core/series.py index cbb66918a661b..c2c0f0384ed71 100644 --- a/pandas/core/series.py +++ b/pandas/core/series.py @@ -94,7 +94,6 @@ ops, ) from pandas.core.accessor import CachedAccessor -from pandas.core.aggregation import transform from pandas.core.apply import series_apply from pandas.core.arrays import ExtensionArray from pandas.core.arrays.categorical import CategoricalAccessor @@ -4035,7 +4034,10 @@ def aggregate(self, func=None, axis=0, *args, **kwargs): def transform( self, func: AggFuncType, axis: Axis = 0, *args, **kwargs ) -> FrameOrSeriesUnion: - return transform(self, func, axis, *args, **kwargs) + # Validate axis argument + self._get_axis_number(axis) + result = series_apply(self, func=func, args=args, kwargs=kwargs).transform() + return result def apply( self,