Skip to content

Commit 9dd17f2

Browse files
authored
REF: Move aggregation into apply (#38867)
1 parent 8d923c9 commit 9dd17f2

File tree

2 files changed

+90
-13
lines changed

2 files changed

+90
-13
lines changed

pandas/core/apply.py

+74-4
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
11
import abc
22
import inspect
3-
from typing import TYPE_CHECKING, Any, Dict, Iterator, Optional, Tuple, Type
3+
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Tuple, Type, cast
44

55
import numpy as np
66

77
from pandas._config import option_context
88

9-
from pandas._typing import AggFuncType, Axis, FrameOrSeriesUnion
9+
from pandas._typing import (
10+
AggFuncType,
11+
AggFuncTypeBase,
12+
AggFuncTypeDict,
13+
Axis,
14+
FrameOrSeriesUnion,
15+
)
1016
from pandas.util._decorators import cache_readonly
1117

1218
from pandas.core.dtypes.common import (
@@ -17,6 +23,7 @@
1723
)
1824
from pandas.core.dtypes.generic import ABCSeries
1925

26+
from pandas.core.aggregation import agg_dict_like, agg_list_like
2027
from pandas.core.construction import create_series_with_explicit_dtype
2128

2229
if TYPE_CHECKING:
@@ -27,6 +34,7 @@
2734

2835
def frame_apply(
2936
obj: "DataFrame",
37+
how: str,
3038
func: AggFuncType,
3139
axis: Axis = 0,
3240
raw: bool = False,
@@ -44,6 +52,7 @@ def frame_apply(
4452

4553
return klass(
4654
obj,
55+
how,
4756
func,
4857
raw=raw,
4958
result_type=result_type,
@@ -84,13 +93,16 @@ def wrap_results_for_axis(
8493
def __init__(
8594
self,
8695
obj: "DataFrame",
96+
how: str,
8797
func,
8898
raw: bool,
8999
result_type: Optional[str],
90100
args,
91101
kwds,
92102
):
103+
assert how in ("apply", "agg")
93104
self.obj = obj
105+
self.how = how
94106
self.raw = raw
95107
self.args = args or ()
96108
self.kwds = kwds or {}
@@ -104,15 +116,19 @@ def __init__(
104116
self.result_type = result_type
105117

106118
# curry if needed
107-
if (kwds or args) and not isinstance(func, (np.ufunc, str)):
119+
if (
120+
(kwds or args)
121+
and not isinstance(func, (np.ufunc, str))
122+
and not is_list_like(func)
123+
):
108124

109125
def f(x):
110126
return func(x, *args, **kwds)
111127

112128
else:
113129
f = func
114130

115-
self.f = f
131+
self.f: AggFuncType = f
116132

117133
@property
118134
def res_columns(self) -> "Index":
@@ -139,6 +155,54 @@ def agg_axis(self) -> "Index":
139155
return self.obj._get_agg_axis(self.axis)
140156

141157
def get_result(self):
158+
if self.how == "apply":
159+
return self.apply()
160+
else:
161+
return self.agg()
162+
163+
def agg(self) -> Tuple[Optional[FrameOrSeriesUnion], Optional[bool]]:
164+
"""
165+
Provide an implementation for the aggregators.
166+
167+
Returns
168+
-------
169+
tuple of result, how.
170+
171+
Notes
172+
-----
173+
how can be a string describe the required post-processing, or
174+
None if not required.
175+
"""
176+
obj = self.obj
177+
arg = self.f
178+
args = self.args
179+
kwargs = self.kwds
180+
181+
_axis = kwargs.pop("_axis", None)
182+
if _axis is None:
183+
_axis = getattr(obj, "axis", 0)
184+
185+
if isinstance(arg, str):
186+
return obj._try_aggregate_string_function(arg, *args, **kwargs), None
187+
elif is_dict_like(arg):
188+
arg = cast(AggFuncTypeDict, arg)
189+
return agg_dict_like(obj, arg, _axis), True
190+
elif is_list_like(arg):
191+
# we require a list, but not a 'str'
192+
arg = cast(List[AggFuncTypeBase], arg)
193+
return agg_list_like(obj, arg, _axis=_axis), None
194+
else:
195+
result = None
196+
197+
if callable(arg):
198+
f = obj._get_cython_func(arg)
199+
if f and not args and not kwargs:
200+
return getattr(obj, f)(), None
201+
202+
# caller can react
203+
return result, True
204+
205+
def apply(self) -> FrameOrSeriesUnion:
142206
""" compute the results """
143207
# dispatch to agg
144208
if is_list_like(self.f) or is_dict_like(self.f):
@@ -191,6 +255,8 @@ def apply_empty_result(self):
191255
we will try to apply the function to an empty
192256
series in order to see if this is a reduction function
193257
"""
258+
assert callable(self.f)
259+
194260
# we are not asked to reduce or infer reduction
195261
# so just return a copy of the existing object
196262
if self.result_type not in ["reduce", None]:
@@ -246,6 +312,8 @@ def wrapper(*args, **kwargs):
246312
return self.obj._constructor_sliced(result, index=self.agg_axis)
247313

248314
def apply_broadcast(self, target: "DataFrame") -> "DataFrame":
315+
assert callable(self.f)
316+
249317
result_values = np.empty_like(target.values)
250318

251319
# axis which we want to compare compliance
@@ -279,6 +347,8 @@ def apply_standard(self):
279347
return self.wrap_results(results, res_index)
280348

281349
def apply_series_generator(self) -> Tuple[ResType, "Index"]:
350+
assert callable(self.f)
351+
282352
series_gen = self.series_generator
283353
res_index = self.result_index
284354

pandas/core/frame.py

+16-9
Original file line numberDiff line numberDiff line change
@@ -121,12 +121,7 @@
121121

122122
from pandas.core import algorithms, common as com, generic, nanops, ops
123123
from pandas.core.accessor import CachedAccessor
124-
from pandas.core.aggregation import (
125-
aggregate,
126-
reconstruct_func,
127-
relabel_result,
128-
transform,
129-
)
124+
from pandas.core.aggregation import reconstruct_func, relabel_result, transform
130125
from pandas.core.arraylike import OpsMixin
131126
from pandas.core.arrays import ExtensionArray
132127
from pandas.core.arrays.sparse import SparseFrameAccessor
@@ -7623,13 +7618,24 @@ def aggregate(self, func=None, axis: Axis = 0, *args, **kwargs):
76237618
return result
76247619

76257620
def _aggregate(self, arg, axis: Axis = 0, *args, **kwargs):
7621+
from pandas.core.apply import frame_apply
7622+
7623+
op = frame_apply(
7624+
self if axis == 0 else self.T,
7625+
how="agg",
7626+
func=arg,
7627+
axis=0,
7628+
args=args,
7629+
kwds=kwargs,
7630+
)
7631+
result, how = op.get_result()
7632+
76267633
if axis == 1:
76277634
# NDFrame.aggregate returns a tuple, and we need to transpose
76287635
# only result
7629-
result, how = aggregate(self.T, arg, *args, **kwargs)
76307636
result = result.T if result is not None else result
7631-
return result, how
7632-
return aggregate(self, arg, *args, **kwargs)
7637+
7638+
return result, how
76337639

76347640
agg = aggregate
76357641

@@ -7789,6 +7795,7 @@ def apply(
77897795

77907796
op = frame_apply(
77917797
self,
7798+
how="apply",
77927799
func=func,
77937800
axis=axis,
77947801
raw=raw,

0 commit comments

Comments
 (0)