Skip to content

Commit d89767d

Browse files
authored
Fix annotations for DataFrame.apply() including additional overloads (#448)
* Fix annotations for DataFrame.apply() including additional overloads * Expand test suite and eliminate edge cases * Cover all argument combinations and add a comprehensive set of test cases for apply() * Try to remove unnecessary overloads * Refine the full set of overloads to cover value type of Series
1 parent bc61543 commit d89767d

File tree

3 files changed

+320
-17
lines changed

3 files changed

+320
-17
lines changed

pandas-stubs/_typing.pyi

+7-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ from typing import (
88
Iterator,
99
Literal,
1010
Mapping,
11+
MutableSequence,
1112
Optional,
1213
Protocol,
1314
Sequence,
@@ -151,10 +152,15 @@ num: TypeAlias = complex
151152
SeriesAxisType: TypeAlias = Literal[
152153
"index", 0
153154
] # Restricted subset of _AxisType for series
154-
AxisType: TypeAlias = Literal["columns", "index", 0, 1]
155+
AxisTypeIndex: TypeAlias = Literal["index", 0]
156+
AxisTypeColumn: TypeAlias = Literal["columns", 1]
157+
AxisType: TypeAlias = AxisTypeIndex | AxisTypeColumn
155158
DtypeNp = TypeVar("DtypeNp", bound=np.dtype[np.generic])
156159
KeysArgType: TypeAlias = Any
157160
ListLike = TypeVar("ListLike", Sequence, np.ndarray, "Series", "Index")
161+
ListLikeExceptSeriesAndStr = TypeVar(
162+
"ListLikeExceptSeriesAndStr", MutableSequence, np.ndarray, tuple, "Index"
163+
)
158164
ListLikeU: TypeAlias = Union[Sequence, np.ndarray, Series, Index]
159165
StrLike: TypeAlias = Union[str, np.str_]
160166
Scalar: TypeAlias = Union[

pandas-stubs/core/frame.pyi

+110-6
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ from pandas._typing import (
6060
Axes,
6161
Axis,
6262
AxisType,
63+
AxisTypeColumn,
64+
AxisTypeIndex,
6365
CalculationMethod,
6466
ColspaceArgType,
6567
CompressionOptions,
@@ -83,6 +85,7 @@ from pandas._typing import (
8385
Label,
8486
Level,
8587
ListLike,
88+
ListLikeExceptSeriesAndStr,
8689
ListLikeU,
8790
MaskType,
8891
MergeHow,
@@ -1085,36 +1088,137 @@ class DataFrame(NDFrame, OpsMixin):
10851088
*args,
10861089
**kwargs,
10871090
) -> DataFrame: ...
1091+
1092+
# apply() overloads with default result_type of None, and is indifferent to axis
10881093
@overload
10891094
def apply(
10901095
self,
1091-
f: Callable[..., Series],
1096+
f: Callable[..., ListLikeExceptSeriesAndStr | Series],
1097+
axis: AxisTypeIndex = ...,
1098+
raw: _bool = ...,
1099+
result_type: Literal[None] = ...,
1100+
args=...,
1101+
**kwargs,
1102+
) -> DataFrame: ...
1103+
@overload
1104+
def apply(
1105+
self,
1106+
f: Callable[..., S1],
1107+
axis: AxisTypeIndex = ...,
1108+
raw: _bool = ...,
1109+
result_type: Literal[None] = ...,
1110+
args=...,
1111+
**kwargs,
1112+
) -> Series[S1]: ...
1113+
1114+
# apply() overloads with keyword result_type, and axis does not matter
1115+
@overload
1116+
def apply(
1117+
self,
1118+
f: Callable[..., S1],
1119+
axis: AxisType = ...,
1120+
raw: _bool = ...,
1121+
args=...,
1122+
*,
1123+
result_type: Literal["expand", "reduce"],
1124+
**kwargs,
1125+
) -> Series[S1]: ...
1126+
@overload
1127+
def apply(
1128+
self,
1129+
f: Callable[..., ListLikeExceptSeriesAndStr | Series],
10921130
axis: AxisType = ...,
10931131
raw: _bool = ...,
1094-
result_type: Literal["expand", "reduce", "broadcast"] | None = ...,
10951132
args=...,
1133+
*,
1134+
result_type: Literal["expand"],
10961135
**kwargs,
10971136
) -> DataFrame: ...
10981137
@overload
10991138
def apply(
11001139
self,
1101-
f: Callable[..., Scalar],
1140+
f: Callable[..., ListLikeExceptSeriesAndStr],
11021141
axis: AxisType = ...,
11031142
raw: _bool = ...,
1104-
result_type: Literal["expand", "reduce"] | None = ...,
11051143
args=...,
1144+
*,
1145+
result_type: Literal["reduce"],
11061146
**kwargs,
11071147
) -> Series: ...
11081148
@overload
11091149
def apply(
11101150
self,
1111-
f: Callable[..., Scalar],
1112-
result_type: Literal["broadcast"],
1151+
f: Callable[..., ListLikeExceptSeriesAndStr | Series | Scalar],
11131152
axis: AxisType = ...,
11141153
raw: _bool = ...,
11151154
args=...,
1155+
*,
1156+
result_type: Literal["broadcast"],
11161157
**kwargs,
11171158
) -> DataFrame: ...
1159+
1160+
# apply() overloads with keyword result_type, and axis does matter
1161+
@overload
1162+
def apply(
1163+
self,
1164+
f: Callable[..., Series],
1165+
axis: AxisTypeIndex = ...,
1166+
raw: _bool = ...,
1167+
args=...,
1168+
*,
1169+
result_type: Literal["reduce"],
1170+
**kwargs,
1171+
) -> Series: ...
1172+
1173+
# apply() overloads with default result_type of None, and keyword axis=1 matters
1174+
@overload
1175+
def apply(
1176+
self,
1177+
f: Callable[..., S1],
1178+
raw: _bool = ...,
1179+
result_type: Literal[None] = ...,
1180+
args=...,
1181+
*,
1182+
axis: AxisTypeColumn,
1183+
**kwargs,
1184+
) -> Series[S1]: ...
1185+
@overload
1186+
def apply(
1187+
self,
1188+
f: Callable[..., ListLikeExceptSeriesAndStr],
1189+
raw: _bool = ...,
1190+
result_type: Literal[None] = ...,
1191+
args=...,
1192+
*,
1193+
axis: AxisTypeColumn,
1194+
**kwargs,
1195+
) -> Series: ...
1196+
@overload
1197+
def apply(
1198+
self,
1199+
f: Callable[..., Series],
1200+
raw: _bool = ...,
1201+
result_type: Literal[None] = ...,
1202+
args=...,
1203+
*,
1204+
axis: AxisTypeColumn,
1205+
**kwargs,
1206+
) -> DataFrame: ...
1207+
1208+
# apply() overloads with keyword axis=1 and keyword result_type
1209+
@overload
1210+
def apply(
1211+
self,
1212+
f: Callable[..., Series],
1213+
raw: _bool = ...,
1214+
args=...,
1215+
*,
1216+
axis: AxisTypeColumn,
1217+
result_type: Literal["reduce"],
1218+
**kwargs,
1219+
) -> DataFrame: ...
1220+
1221+
# Add spacing between apply() overloads and remaining annotations
11181222
def applymap(
11191223
self, func: Callable, na_action: Literal["ignore"] | None = ..., **kwargs
11201224
) -> DataFrame: ...

0 commit comments

Comments
 (0)