1
1
import abc
2
2
import inspect
3
- from typing import TYPE_CHECKING , Any , Dict , Iterator , Optional , Tuple , Type
3
+ from typing import TYPE_CHECKING , Any , Dict , Iterator , List , Optional , Tuple , Type , cast
4
4
5
5
import numpy as np
6
6
7
7
from pandas ._config import option_context
8
8
9
- from pandas ._typing import AggFuncType , Axis , FrameOrSeriesUnion
9
+ from pandas ._typing import (
10
+ AggFuncType ,
11
+ AggFuncTypeBase ,
12
+ AggFuncTypeDict ,
13
+ Axis ,
14
+ FrameOrSeriesUnion ,
15
+ )
10
16
from pandas .util ._decorators import cache_readonly
11
17
12
18
from pandas .core .dtypes .common import (
17
23
)
18
24
from pandas .core .dtypes .generic import ABCSeries
19
25
26
+ from pandas .core .aggregation import agg_dict_like , agg_list_like
20
27
from pandas .core .construction import create_series_with_explicit_dtype
21
28
22
29
if TYPE_CHECKING :
27
34
28
35
def frame_apply (
29
36
obj : "DataFrame" ,
37
+ how : str ,
30
38
func : AggFuncType ,
31
39
axis : Axis = 0 ,
32
40
raw : bool = False ,
@@ -44,6 +52,7 @@ def frame_apply(
44
52
45
53
return klass (
46
54
obj ,
55
+ how ,
47
56
func ,
48
57
raw = raw ,
49
58
result_type = result_type ,
@@ -84,13 +93,16 @@ def wrap_results_for_axis(
84
93
def __init__ (
85
94
self ,
86
95
obj : "DataFrame" ,
96
+ how : str ,
87
97
func ,
88
98
raw : bool ,
89
99
result_type : Optional [str ],
90
100
args ,
91
101
kwds ,
92
102
):
103
+ assert how in ("apply" , "agg" )
93
104
self .obj = obj
105
+ self .how = how
94
106
self .raw = raw
95
107
self .args = args or ()
96
108
self .kwds = kwds or {}
@@ -104,15 +116,19 @@ def __init__(
104
116
self .result_type = result_type
105
117
106
118
# curry if needed
107
- if (kwds or args ) and not isinstance (func , (np .ufunc , str )):
119
+ if (
120
+ (kwds or args )
121
+ and not isinstance (func , (np .ufunc , str ))
122
+ and not is_list_like (func )
123
+ ):
108
124
109
125
def f (x ):
110
126
return func (x , * args , ** kwds )
111
127
112
128
else :
113
129
f = func
114
130
115
- self .f = f
131
+ self .f : AggFuncType = f
116
132
117
133
@property
118
134
def res_columns (self ) -> "Index" :
@@ -139,6 +155,54 @@ def agg_axis(self) -> "Index":
139
155
return self .obj ._get_agg_axis (self .axis )
140
156
141
157
def get_result (self ):
158
+ if self .how == "apply" :
159
+ return self .apply ()
160
+ else :
161
+ return self .agg ()
162
+
163
+ def agg (self ) -> Tuple [Optional [FrameOrSeriesUnion ], Optional [bool ]]:
164
+ """
165
+ Provide an implementation for the aggregators.
166
+
167
+ Returns
168
+ -------
169
+ tuple of result, how.
170
+
171
+ Notes
172
+ -----
173
+ how can be a string describe the required post-processing, or
174
+ None if not required.
175
+ """
176
+ obj = self .obj
177
+ arg = self .f
178
+ args = self .args
179
+ kwargs = self .kwds
180
+
181
+ _axis = kwargs .pop ("_axis" , None )
182
+ if _axis is None :
183
+ _axis = getattr (obj , "axis" , 0 )
184
+
185
+ if isinstance (arg , str ):
186
+ return obj ._try_aggregate_string_function (arg , * args , ** kwargs ), None
187
+ elif is_dict_like (arg ):
188
+ arg = cast (AggFuncTypeDict , arg )
189
+ return agg_dict_like (obj , arg , _axis ), True
190
+ elif is_list_like (arg ):
191
+ # we require a list, but not a 'str'
192
+ arg = cast (List [AggFuncTypeBase ], arg )
193
+ return agg_list_like (obj , arg , _axis = _axis ), None
194
+ else :
195
+ result = None
196
+
197
+ if callable (arg ):
198
+ f = obj ._get_cython_func (arg )
199
+ if f and not args and not kwargs :
200
+ return getattr (obj , f )(), None
201
+
202
+ # caller can react
203
+ return result , True
204
+
205
+ def apply (self ) -> FrameOrSeriesUnion :
142
206
""" compute the results """
143
207
# dispatch to agg
144
208
if is_list_like (self .f ) or is_dict_like (self .f ):
@@ -191,6 +255,8 @@ def apply_empty_result(self):
191
255
we will try to apply the function to an empty
192
256
series in order to see if this is a reduction function
193
257
"""
258
+ assert callable (self .f )
259
+
194
260
# we are not asked to reduce or infer reduction
195
261
# so just return a copy of the existing object
196
262
if self .result_type not in ["reduce" , None ]:
@@ -246,6 +312,8 @@ def wrapper(*args, **kwargs):
246
312
return self .obj ._constructor_sliced (result , index = self .agg_axis )
247
313
248
314
def apply_broadcast (self , target : "DataFrame" ) -> "DataFrame" :
315
+ assert callable (self .f )
316
+
249
317
result_values = np .empty_like (target .values )
250
318
251
319
# axis which we want to compare compliance
@@ -279,6 +347,8 @@ def apply_standard(self):
279
347
return self .wrap_results (results , res_index )
280
348
281
349
def apply_series_generator (self ) -> Tuple [ResType , "Index" ]:
350
+ assert callable (self .f )
351
+
282
352
series_gen = self .series_generator
283
353
res_index = self .result_index
284
354
0 commit comments