1
+ import abc
1
2
import inspect
3
+ from typing import TYPE_CHECKING , Iterator , Type
2
4
3
5
import numpy as np
4
6
13
15
)
14
16
from pandas .core .dtypes .generic import ABCSeries
15
17
18
+ if TYPE_CHECKING :
19
+ from pandas import DataFrame , Series , Index
20
+
16
21
17
22
def frame_apply (
18
- obj ,
23
+ obj : "DataFrame" ,
19
24
func ,
20
25
axis = 0 ,
21
- raw = False ,
26
+ raw : bool = False ,
22
27
result_type = None ,
23
- ignore_failures = False ,
28
+ ignore_failures : bool = False ,
24
29
args = None ,
25
30
kwds = None ,
26
31
):
27
32
""" construct and return a row or column based frame apply object """
28
33
29
34
axis = obj ._get_axis_number (axis )
30
35
if axis == 0 :
31
- klass = FrameRowApply
36
+ klass = FrameRowApply # type: Type[FrameApply]
32
37
elif axis == 1 :
33
38
klass = FrameColumnApply
34
39
@@ -43,8 +48,38 @@ def frame_apply(
43
48
)
44
49
45
50
46
- class FrameApply :
47
- def __init__ (self , obj , func , raw , result_type , ignore_failures , args , kwds ):
51
+ class FrameApply (metaclass = abc .ABCMeta ):
52
+
53
+ # ---------------------------------------------------------------
54
+ # Abstract Methods
55
+ axis : int
56
+
57
+ @property
58
+ @abc .abstractmethod
59
+ def result_index (self ) -> "Index" :
60
+ pass
61
+
62
+ @property
63
+ @abc .abstractmethod
64
+ def result_columns (self ) -> "Index" :
65
+ pass
66
+
67
+ @abc .abstractmethod
68
+ def series_generator (self ) -> Iterator ["Series" ]:
69
+ pass
70
+
71
+ # ---------------------------------------------------------------
72
+
73
+ def __init__ (
74
+ self ,
75
+ obj : "DataFrame" ,
76
+ func ,
77
+ raw : bool ,
78
+ result_type ,
79
+ ignore_failures : bool ,
80
+ args ,
81
+ kwds ,
82
+ ):
48
83
self .obj = obj
49
84
self .raw = raw
50
85
self .ignore_failures = ignore_failures
@@ -76,23 +111,23 @@ def f(x):
76
111
self .res_columns = None
77
112
78
113
@property
79
- def columns (self ):
114
+ def columns (self ) -> "Index" :
80
115
return self .obj .columns
81
116
82
117
@property
83
- def index (self ):
118
+ def index (self ) -> "Index" :
84
119
return self .obj .index
85
120
86
121
@cache_readonly
87
122
def values (self ):
88
123
return self .obj .values
89
124
90
125
@cache_readonly
91
- def dtypes (self ):
126
+ def dtypes (self ) -> "Series" :
92
127
return self .obj .dtypes
93
128
94
129
@property
95
- def agg_axis (self ):
130
+ def agg_axis (self ) -> "Index" :
96
131
return self .obj ._get_agg_axis (self .axis )
97
132
98
133
def get_result (self ):
@@ -127,7 +162,7 @@ def get_result(self):
127
162
128
163
# broadcasting
129
164
if self .result_type == "broadcast" :
130
- return self .apply_broadcast ()
165
+ return self .apply_broadcast (self . obj )
131
166
132
167
# one axis empty
133
168
elif not all (self .obj .shape ):
@@ -191,7 +226,7 @@ def apply_raw(self):
191
226
else :
192
227
return self .obj ._constructor_sliced (result , index = self .agg_axis )
193
228
194
- def apply_broadcast (self , target ) :
229
+ def apply_broadcast (self , target : "DataFrame" ) -> "DataFrame" :
195
230
result_values = np .empty_like (target .values )
196
231
197
232
# axis which we want to compare compliance
@@ -317,19 +352,19 @@ def wrap_results(self):
317
352
class FrameRowApply (FrameApply ):
318
353
axis = 0
319
354
320
- def apply_broadcast (self ) :
321
- return super ().apply_broadcast (self . obj )
355
+ def apply_broadcast (self , target : "DataFrame" ) -> "DataFrame" :
356
+ return super ().apply_broadcast (target )
322
357
323
358
@property
324
359
def series_generator (self ):
325
360
return (self .obj ._ixs (i , axis = 1 ) for i in range (len (self .columns )))
326
361
327
362
@property
328
- def result_index (self ):
363
+ def result_index (self ) -> "Index" :
329
364
return self .columns
330
365
331
366
@property
332
- def result_columns (self ):
367
+ def result_columns (self ) -> "Index" :
333
368
return self .index
334
369
335
370
def wrap_results_for_axis (self ):
@@ -351,8 +386,8 @@ def wrap_results_for_axis(self):
351
386
class FrameColumnApply (FrameApply ):
352
387
axis = 1
353
388
354
- def apply_broadcast (self ) :
355
- result = super ().apply_broadcast (self . obj .T )
389
+ def apply_broadcast (self , target : "DataFrame" ) -> "DataFrame" :
390
+ result = super ().apply_broadcast (target .T )
356
391
return result .T
357
392
358
393
@property
@@ -364,11 +399,11 @@ def series_generator(self):
364
399
)
365
400
366
401
@property
367
- def result_index (self ):
402
+ def result_index (self ) -> "Index" :
368
403
return self .index
369
404
370
405
@property
371
- def result_columns (self ):
406
+ def result_columns (self ) -> "Index" :
372
407
return self .columns
373
408
374
409
def wrap_results_for_axis (self ):
@@ -392,7 +427,7 @@ def wrap_results_for_axis(self):
392
427
393
428
return result
394
429
395
- def infer_to_same_shape (self ):
430
+ def infer_to_same_shape (self ) -> "DataFrame" :
396
431
""" infer the results to the same shape as the input object """
397
432
results = self .results
398
433
0 commit comments