Skip to content

Commit 57e1b34

Browse files
jbrockmendeljreback
authored andcommitted
CLN: annotations in core.apply (#29477)
1 parent 808f482 commit 57e1b34

File tree

1 file changed

+56
-21
lines changed

1 file changed

+56
-21
lines changed

pandas/core/apply.py

+56-21
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import abc
12
import inspect
3+
from typing import TYPE_CHECKING, Iterator, Type
24

35
import numpy as np
46

@@ -13,22 +15,25 @@
1315
)
1416
from pandas.core.dtypes.generic import ABCSeries
1517

18+
if TYPE_CHECKING:
19+
from pandas import DataFrame, Series, Index
20+
1621

1722
def frame_apply(
18-
obj,
23+
obj: "DataFrame",
1924
func,
2025
axis=0,
21-
raw=False,
26+
raw: bool = False,
2227
result_type=None,
23-
ignore_failures=False,
28+
ignore_failures: bool = False,
2429
args=None,
2530
kwds=None,
2631
):
2732
""" construct and return a row or column based frame apply object """
2833

2934
axis = obj._get_axis_number(axis)
3035
if axis == 0:
31-
klass = FrameRowApply
36+
klass = FrameRowApply # type: Type[FrameApply]
3237
elif axis == 1:
3338
klass = FrameColumnApply
3439

@@ -43,8 +48,38 @@ def frame_apply(
4348
)
4449

4550

46-
class FrameApply:
47-
def __init__(self, obj, func, raw, result_type, ignore_failures, args, kwds):
51+
class FrameApply(metaclass=abc.ABCMeta):
52+
53+
# ---------------------------------------------------------------
54+
# Abstract Methods
55+
axis: int
56+
57+
@property
58+
@abc.abstractmethod
59+
def result_index(self) -> "Index":
60+
pass
61+
62+
@property
63+
@abc.abstractmethod
64+
def result_columns(self) -> "Index":
65+
pass
66+
67+
@abc.abstractmethod
68+
def series_generator(self) -> Iterator["Series"]:
69+
pass
70+
71+
# ---------------------------------------------------------------
72+
73+
def __init__(
74+
self,
75+
obj: "DataFrame",
76+
func,
77+
raw: bool,
78+
result_type,
79+
ignore_failures: bool,
80+
args,
81+
kwds,
82+
):
4883
self.obj = obj
4984
self.raw = raw
5085
self.ignore_failures = ignore_failures
@@ -76,23 +111,23 @@ def f(x):
76111
self.res_columns = None
77112

78113
@property
79-
def columns(self):
114+
def columns(self) -> "Index":
80115
return self.obj.columns
81116

82117
@property
83-
def index(self):
118+
def index(self) -> "Index":
84119
return self.obj.index
85120

86121
@cache_readonly
87122
def values(self):
88123
return self.obj.values
89124

90125
@cache_readonly
91-
def dtypes(self):
126+
def dtypes(self) -> "Series":
92127
return self.obj.dtypes
93128

94129
@property
95-
def agg_axis(self):
130+
def agg_axis(self) -> "Index":
96131
return self.obj._get_agg_axis(self.axis)
97132

98133
def get_result(self):
@@ -127,7 +162,7 @@ def get_result(self):
127162

128163
# broadcasting
129164
if self.result_type == "broadcast":
130-
return self.apply_broadcast()
165+
return self.apply_broadcast(self.obj)
131166

132167
# one axis empty
133168
elif not all(self.obj.shape):
@@ -191,7 +226,7 @@ def apply_raw(self):
191226
else:
192227
return self.obj._constructor_sliced(result, index=self.agg_axis)
193228

194-
def apply_broadcast(self, target):
229+
def apply_broadcast(self, target: "DataFrame") -> "DataFrame":
195230
result_values = np.empty_like(target.values)
196231

197232
# axis which we want to compare compliance
@@ -317,19 +352,19 @@ def wrap_results(self):
317352
class FrameRowApply(FrameApply):
318353
axis = 0
319354

320-
def apply_broadcast(self):
321-
return super().apply_broadcast(self.obj)
355+
def apply_broadcast(self, target: "DataFrame") -> "DataFrame":
356+
return super().apply_broadcast(target)
322357

323358
@property
324359
def series_generator(self):
325360
return (self.obj._ixs(i, axis=1) for i in range(len(self.columns)))
326361

327362
@property
328-
def result_index(self):
363+
def result_index(self) -> "Index":
329364
return self.columns
330365

331366
@property
332-
def result_columns(self):
367+
def result_columns(self) -> "Index":
333368
return self.index
334369

335370
def wrap_results_for_axis(self):
@@ -351,8 +386,8 @@ def wrap_results_for_axis(self):
351386
class FrameColumnApply(FrameApply):
352387
axis = 1
353388

354-
def apply_broadcast(self):
355-
result = super().apply_broadcast(self.obj.T)
389+
def apply_broadcast(self, target: "DataFrame") -> "DataFrame":
390+
result = super().apply_broadcast(target.T)
356391
return result.T
357392

358393
@property
@@ -364,11 +399,11 @@ def series_generator(self):
364399
)
365400

366401
@property
367-
def result_index(self):
402+
def result_index(self) -> "Index":
368403
return self.index
369404

370405
@property
371-
def result_columns(self):
406+
def result_columns(self) -> "Index":
372407
return self.columns
373408

374409
def wrap_results_for_axis(self):
@@ -392,7 +427,7 @@ def wrap_results_for_axis(self):
392427

393428
return result
394429

395-
def infer_to_same_shape(self):
430+
def infer_to_same_shape(self) -> "DataFrame":
396431
""" infer the results to the same shape as the input object """
397432
results = self.results
398433

0 commit comments

Comments
 (0)