Skip to content

Commit 40ab8f5

Browse files
committed
Cover all argument combinations and add a comprehensive set of test
cases for apply()
1 parent 2590294 commit 40ab8f5

File tree

3 files changed

+146
-9
lines changed

3 files changed

+146
-9
lines changed

pandas-stubs/_typing.pyi

+3-1
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,9 @@ num: TypeAlias = complex
152152
SeriesAxisType: TypeAlias = Literal[
153153
"index", 0
154154
] # Restricted subset of _AxisType for series
155-
AxisType: TypeAlias = Literal["columns", "index", 0, 1]
155+
AxisTypeIndex: TypeAlias = Literal["index", 0]
156+
AxisTypeColumn: TypeAlias = Literal["columns", 1]
157+
AxisType: TypeAlias = AxisTypeIndex | AxisTypeColumn
156158
DtypeNp = TypeVar("DtypeNp", bound=np.dtype[np.generic])
157159
KeysArgType: TypeAlias = Any
158160
ListLike = TypeVar("ListLike", Sequence, np.ndarray, "Series", "Index")

pandas-stubs/core/frame.pyi

+117-6
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ from pandas._typing import (
5959
Axes,
6060
Axis,
6161
AxisType,
62+
AxisTypeColumn,
63+
AxisTypeIndex,
6264
CalculationMethod,
6365
ColspaceArgType,
6466
CompressionOptions,
@@ -1087,11 +1089,13 @@ class DataFrame(NDFrame, OpsMixin):
10871089
*args,
10881090
**kwargs,
10891091
) -> DataFrame: ...
1092+
1093+
# First set of apply() overloads is with defaults
10901094
@overload
10911095
def apply(
10921096
self,
10931097
f: Callable[..., ListLikeExceptSeriesAndStr],
1094-
axis: AxisType = ...,
1098+
axis: AxisTypeIndex = ...,
10951099
raw: _bool = ...,
10961100
result_type: Literal[None] = ...,
10971101
args=...,
@@ -1101,7 +1105,7 @@ class DataFrame(NDFrame, OpsMixin):
11011105
def apply(
11021106
self,
11031107
f: Callable[..., Series],
1104-
axis: AxisType = ...,
1108+
axis: AxisTypeIndex = ...,
11051109
raw: _bool = ...,
11061110
result_type: Literal[None] = ...,
11071111
args=...,
@@ -1111,17 +1115,19 @@ class DataFrame(NDFrame, OpsMixin):
11111115
def apply(
11121116
self,
11131117
f: Callable[..., Scalar],
1114-
axis: AxisType = ...,
1118+
axis: AxisTypeIndex = ...,
11151119
raw: _bool = ...,
11161120
result_type: Literal[None] = ...,
11171121
args=...,
11181122
**kwargs,
11191123
) -> Series: ...
1124+
1125+
# Second set of apply() overloads is with keyword result_type
11201126
@overload
11211127
def apply(
11221128
self,
11231129
f: Callable[..., ListLikeExceptSeriesAndStr],
1124-
axis: AxisType = ...,
1130+
axis: AxisTypeIndex = ...,
11251131
raw: _bool = ...,
11261132
args=...,
11271133
*,
@@ -1132,7 +1138,7 @@ class DataFrame(NDFrame, OpsMixin):
11321138
def apply(
11331139
self,
11341140
f: Callable[..., ListLikeExceptSeriesAndStr | Series],
1135-
axis: AxisType = ...,
1141+
axis: AxisTypeIndex = ...,
11361142
raw: _bool = ...,
11371143
args=...,
11381144
*,
@@ -1143,13 +1149,118 @@ class DataFrame(NDFrame, OpsMixin):
11431149
def apply(
11441150
self,
11451151
f: Callable[..., ListLikeExceptSeriesAndStr | Series],
1146-
axis: AxisType = ...,
1152+
axis: AxisTypeIndex = ...,
11471153
raw: _bool = ...,
11481154
args=...,
11491155
*,
11501156
result_type: Literal["broadcast"],
11511157
**kwargs,
11521158
) -> DataFrame: ...
1159+
1160+
# Third set of apply() overloads is with keyword axis=1 only
1161+
@overload
1162+
def apply(
1163+
self,
1164+
f: Callable[..., ListLikeExceptSeriesAndStr],
1165+
raw: _bool = ...,
1166+
result_type: Literal[None] = ...,
1167+
args=...,
1168+
*,
1169+
axis: AxisTypeColumn,
1170+
**kwargs,
1171+
) -> Series: ...
1172+
@overload
1173+
def apply(
1174+
self,
1175+
f: Callable[..., Series],
1176+
raw: _bool = ...,
1177+
result_type: Literal[None] = ...,
1178+
args=...,
1179+
*,
1180+
axis: AxisTypeColumn,
1181+
**kwargs,
1182+
) -> DataFrame: ...
1183+
@overload
1184+
def apply(
1185+
self,
1186+
f: Callable[..., Scalar],
1187+
raw: _bool = ...,
1188+
result_type: Literal[None] = ...,
1189+
args=...,
1190+
*,
1191+
axis: AxisTypeColumn,
1192+
**kwargs,
1193+
) -> Series: ...
1194+
1195+
# Fourth set of apply() overloads is with keyword axis=1 and keyword result_type
1196+
@overload
1197+
def apply(
1198+
self,
1199+
f: Callable[..., ListLikeExceptSeriesAndStr],
1200+
raw: _bool = ...,
1201+
args=...,
1202+
*,
1203+
axis: AxisTypeColumn = ...,
1204+
result_type: Literal[None] = ...,
1205+
**kwargs,
1206+
) -> Series: ...
1207+
@overload
1208+
def apply(
1209+
self,
1210+
f: Callable[..., Series],
1211+
raw: _bool = ...,
1212+
args=...,
1213+
*,
1214+
axis: AxisTypeColumn = ...,
1215+
result_type: Literal[None] = ...,
1216+
**kwargs,
1217+
) -> DataFrame: ...
1218+
@overload
1219+
def apply(
1220+
self,
1221+
f: Callable[..., Scalar],
1222+
raw: _bool = ...,
1223+
args=...,
1224+
*,
1225+
axis: AxisTypeColumn = ...,
1226+
result_type: Literal[None] = ...,
1227+
**kwargs,
1228+
) -> Series: ...
1229+
@overload
1230+
def apply(
1231+
self,
1232+
f: Callable[..., ListLikeExceptSeriesAndStr],
1233+
raw: _bool = ...,
1234+
args=...,
1235+
*,
1236+
axis: AxisTypeColumn = ...,
1237+
result_type: Literal["reduce"],
1238+
**kwargs,
1239+
) -> Series: ...
1240+
@overload
1241+
def apply(
1242+
self,
1243+
f: Callable[..., ListLikeExceptSeriesAndStr | Series],
1244+
raw: _bool = ...,
1245+
args=...,
1246+
*,
1247+
axis: AxisTypeColumn = ...,
1248+
result_type: Literal["expand"],
1249+
**kwargs,
1250+
) -> DataFrame: ...
1251+
@overload
1252+
def apply(
1253+
self,
1254+
f: Callable[..., ListLikeExceptSeriesAndStr | Series],
1255+
raw: _bool = ...,
1256+
args=...,
1257+
*,
1258+
axis: AxisTypeColumn = ...,
1259+
result_type: Literal["broadcast"],
1260+
**kwargs,
1261+
) -> DataFrame: ...
1262+
1263+
# Add spacing between apply() overloads and remaining annotations
11531264
def applymap(
11541265
self, func: Callable, na_action: Literal["ignore"] | None = ..., **kwargs
11551266
) -> DataFrame: ...

tests/test_frame.py

+26-2
Original file line numberDiff line numberDiff line change
@@ -569,8 +569,8 @@ def gethead(s: pd.Series, y: int) -> pd.Series:
569569
pd.DataFrame,
570570
)
571571
check(
572-
assert_type(df.apply(returns_listlike_of_3, axis=1), pd.DataFrame),
573-
pd.DataFrame,
572+
assert_type(df.apply(returns_listlike_of_3, axis=1), pd.Series),
573+
pd.Series,
574574
)
575575

576576
# While this call works in reality, it errors in the type checker, because this should never be called
@@ -638,6 +638,30 @@ def gethead(s: pd.Series, y: int) -> pd.Series:
638638
pd.DataFrame,
639639
)
640640

641+
# Test various other argument combinations to ensure all overloads are supported
642+
check(
643+
assert_type(df.apply(returns_scalar, axis=0), pd.Series),
644+
pd.Series,
645+
)
646+
check(
647+
assert_type(df.apply(returns_scalar, axis=0, result_type=None), pd.Series),
648+
pd.Series,
649+
)
650+
check(
651+
assert_type(df.apply(returns_scalar, 0, False, None), pd.Series),
652+
pd.Series,
653+
)
654+
check(
655+
assert_type(df.apply(returns_scalar, 0, False, result_type=None), pd.Series),
656+
pd.Series,
657+
)
658+
check(
659+
assert_type(
660+
df.apply(returns_scalar, 0, raw=False, result_type=None), pd.Series
661+
),
662+
pd.Series,
663+
)
664+
641665

642666
def test_types_applymap() -> None:
643667
df = pd.DataFrame(data={"col1": [2, 1], "col2": [3, 4]})

0 commit comments

Comments
 (0)