Skip to content

REF: Move transform into apply #39957

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 1 addition & 141 deletions pandas/core/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,28 +20,19 @@
Sequence,
Tuple,
Union,
cast,
)

from pandas._typing import (
AggFuncType,
AggFuncTypeBase,
AggFuncTypeDict,
Axis,
FrameOrSeries,
FrameOrSeriesUnion,
)

from pandas.core.dtypes.common import (
is_dict_like,
is_list_like,
)
from pandas.core.dtypes.generic import (
ABCDataFrame,
ABCSeries,
)
from pandas.core.dtypes.generic import ABCSeries

from pandas.core.algorithms import safe_sort
from pandas.core.base import SpecificationError
import pandas.core.common as com
from pandas.core.indexes.api import Index
Expand Down Expand Up @@ -405,134 +396,3 @@ def validate_func_kwargs(
no_arg_message = "Must provide 'func' or named aggregation **kwargs."
raise TypeError(no_arg_message)
return columns, func


def transform(
obj: FrameOrSeries, func: AggFuncType, axis: Axis, *args, **kwargs
) -> FrameOrSeriesUnion:
"""
Transform a DataFrame or Series

Parameters
----------
obj : DataFrame or Series
Object to compute the transform on.
func : string, function, list, or dictionary
Function(s) to compute the transform with.
axis : {0 or 'index', 1 or 'columns'}
Axis along which the function is applied:

* 0 or 'index': apply function to each column.
* 1 or 'columns': apply function to each row.

Returns
-------
DataFrame or Series
Result of applying ``func`` along the given axis of the
Series or DataFrame.

Raises
------
ValueError
If the transform function fails or does not transform.
"""
is_series = obj.ndim == 1

if obj._get_axis_number(axis) == 1:
assert not is_series
return transform(obj.T, func, 0, *args, **kwargs).T

if is_list_like(func) and not is_dict_like(func):
func = cast(List[AggFuncTypeBase], func)
# Convert func equivalent dict
if is_series:
func = {com.get_callable_name(v) or v: v for v in func}
else:
func = {col: func for col in obj}

if is_dict_like(func):
func = cast(AggFuncTypeDict, func)
return transform_dict_like(obj, func, *args, **kwargs)

# func is either str or callable
func = cast(AggFuncTypeBase, func)
try:
result = transform_str_or_callable(obj, func, *args, **kwargs)
except Exception:
raise ValueError("Transform function failed")

# Functions that transform may return empty Series/DataFrame
# when the dtype is not appropriate
if isinstance(result, (ABCSeries, ABCDataFrame)) and result.empty and not obj.empty:
raise ValueError("Transform function failed")
if not isinstance(result, (ABCSeries, ABCDataFrame)) or not result.index.equals(
obj.index
):
raise ValueError("Function did not transform")

return result


def transform_dict_like(
obj: FrameOrSeries,
func: AggFuncTypeDict,
*args,
**kwargs,
):
"""
Compute transform in the case of a dict-like func
"""
from pandas.core.reshape.concat import concat

if len(func) == 0:
raise ValueError("No transform functions were provided")

if obj.ndim != 1:
# Check for missing columns on a frame
cols = set(func.keys()) - set(obj.columns)
if len(cols) > 0:
cols_sorted = list(safe_sort(list(cols)))
raise SpecificationError(f"Column(s) {cols_sorted} do not exist")

# Can't use func.values(); wouldn't work for a Series
if any(is_dict_like(v) for _, v in func.items()):
# GH 15931 - deprecation of renaming keys
raise SpecificationError("nested renamer is not supported")

results: Dict[Hashable, FrameOrSeriesUnion] = {}
for name, how in func.items():
colg = obj._gotitem(name, ndim=1)
try:
results[name] = transform(colg, how, 0, *args, **kwargs)
except Exception as err:
if str(err) in {
"Function did not transform",
"No transform functions were provided",
}:
raise err

# combine results
if not results:
raise ValueError("Transform function failed")
return concat(results, axis=1)


def transform_str_or_callable(
obj: FrameOrSeries, func: AggFuncTypeBase, *args, **kwargs
) -> FrameOrSeriesUnion:
"""
Compute transform in the case of a string or callable func
"""
if isinstance(func, str):
return obj._try_aggregate_string_function(func, *args, **kwargs)

if not args and not kwargs:
f = obj._get_cython_func(func)
if f:
return getattr(obj, f)()

# Two possible ways to use a UDF - apply or call directly
try:
return obj.apply(func, args=args, **kwargs)
except Exception:
return func(obj, *args, **kwargs)
133 changes: 133 additions & 0 deletions pandas/core/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
TYPE_CHECKING,
Any,
Dict,
Hashable,
Iterator,
List,
Optional,
Expand Down Expand Up @@ -151,6 +152,7 @@ def f(x):
else:
f = func

self.orig_f: AggFuncType = func
self.f: AggFuncType = f

@property
Expand Down Expand Up @@ -197,6 +199,131 @@ def agg(self) -> Optional[FrameOrSeriesUnion]:
# caller can react
return None

def transform(self) -> FrameOrSeriesUnion:
"""
Transform a DataFrame or Series.

Returns
-------
DataFrame or Series
Result of applying ``func`` along the given axis of the
Series or DataFrame.

Raises
------
ValueError
If the transform function fails or does not transform.
"""
obj = self.obj
func = self.orig_f
axis = self.axis
args = self.args
kwargs = self.kwargs

is_series = obj.ndim == 1

if obj._get_axis_number(axis) == 1:
assert not is_series
return obj.T.transform(func, 0, *args, **kwargs).T

if is_list_like(func) and not is_dict_like(func):
func = cast(List[AggFuncTypeBase], func)
# Convert func equivalent dict
if is_series:
func = {com.get_callable_name(v) or v: v for v in func}
else:
func = {col: func for col in obj}

if is_dict_like(func):
func = cast(AggFuncTypeDict, func)
return self.transform_dict_like(func)

# func is either str or callable
func = cast(AggFuncTypeBase, func)
try:
result = self.transform_str_or_callable(func)
except Exception:
raise ValueError("Transform function failed")

# Functions that transform may return empty Series/DataFrame
# when the dtype is not appropriate
if (
isinstance(result, (ABCSeries, ABCDataFrame))
and result.empty
and not obj.empty
):
raise ValueError("Transform function failed")
if not isinstance(result, (ABCSeries, ABCDataFrame)) or not result.index.equals(
obj.index
):
raise ValueError("Function did not transform")

return result

def transform_dict_like(self, func):
"""
Compute transform in the case of a dict-like func
"""
from pandas.core.reshape.concat import concat

obj = self.obj
args = self.args
kwargs = self.kwargs

if len(func) == 0:
raise ValueError("No transform functions were provided")

if obj.ndim != 1:
# Check for missing columns on a frame
cols = set(func.keys()) - set(obj.columns)
if len(cols) > 0:
cols_sorted = list(safe_sort(list(cols)))
raise SpecificationError(f"Column(s) {cols_sorted} do not exist")

# Can't use func.values(); wouldn't work for a Series
if any(is_dict_like(v) for _, v in func.items()):
# GH 15931 - deprecation of renaming keys
raise SpecificationError("nested renamer is not supported")

results: Dict[Hashable, FrameOrSeriesUnion] = {}
for name, how in func.items():
colg = obj._gotitem(name, ndim=1)
try:
results[name] = colg.transform(how, 0, *args, **kwargs)
except Exception as err:
if str(err) in {
"Function did not transform",
"No transform functions were provided",
}:
raise err

# combine results
if not results:
raise ValueError("Transform function failed")
return concat(results, axis=1)

def transform_str_or_callable(self, func) -> FrameOrSeriesUnion:
"""
Compute transform in the case of a string or callable func
"""
obj = self.obj
args = self.args
kwargs = self.kwargs

if isinstance(func, str):
return obj._try_aggregate_string_function(func, *args, **kwargs)

if not args and not kwargs:
f = obj._get_cython_func(func)
if f:
return getattr(obj, f)()

# Two possible ways to use a UDF - apply or call directly
try:
return obj.apply(func, args=args, **kwargs)
except Exception:
return func(obj, *args, **kwargs)

def agg_list_like(self, _axis: int) -> FrameOrSeriesUnion:
"""
Compute aggregation in the case of a list-like argument.
Expand Down Expand Up @@ -901,6 +1028,9 @@ def __init__(
def apply(self):
raise NotImplementedError

def transform(self):
raise NotImplementedError


class ResamplerWindowApply(Apply):
axis = 0
Expand All @@ -924,3 +1054,6 @@ def __init__(

def apply(self):
raise NotImplementedError

def transform(self):
raise NotImplementedError
6 changes: 4 additions & 2 deletions pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,6 @@
from pandas.core.aggregation import (
reconstruct_func,
relabel_result,
transform,
)
from pandas.core.arraylike import OpsMixin
from pandas.core.arrays import ExtensionArray
Expand Down Expand Up @@ -7786,7 +7785,10 @@ def _aggregate(self, arg, axis: Axis = 0, *args, **kwargs):
def transform(
self, func: AggFuncType, axis: Axis = 0, *args, **kwargs
) -> DataFrame:
result = transform(self, func, axis, *args, **kwargs)
from pandas.core.apply import frame_apply

op = frame_apply(self, func=func, axis=axis, args=args, kwargs=kwargs)
result = op.transform()
assert isinstance(result, DataFrame)
return result

Expand Down
6 changes: 4 additions & 2 deletions pandas/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@
ops,
)
from pandas.core.accessor import CachedAccessor
from pandas.core.aggregation import transform
from pandas.core.apply import series_apply
from pandas.core.arrays import ExtensionArray
from pandas.core.arrays.categorical import CategoricalAccessor
Expand Down Expand Up @@ -4035,7 +4034,10 @@ def aggregate(self, func=None, axis=0, *args, **kwargs):
def transform(
self, func: AggFuncType, axis: Axis = 0, *args, **kwargs
) -> FrameOrSeriesUnion:
return transform(self, func, axis, *args, **kwargs)
# Validate axis argument
self._get_axis_number(axis)
result = series_apply(self, func=func, args=args, kwargs=kwargs).transform()
return result

def apply(
self,
Expand Down