Skip to content

Commit d194237

Browse files
committed
Cover all argument combinations and add a comprehensive set of test
cases for apply()
1 parent 2590294 commit d194237

File tree

3 files changed

+142
-9
lines changed

3 files changed

+142
-9
lines changed

pandas-stubs/_typing.pyi

+3-1
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,9 @@ num: TypeAlias = complex
152152
SeriesAxisType: TypeAlias = Literal[
153153
"index", 0
154154
] # Restricted subset of _AxisType for series
155-
AxisType: TypeAlias = Literal["columns", "index", 0, 1]
155+
AxisTypeIndex: TypeAlias = Literal["index", 0]
156+
AxisTypeColumn: TypeAlias = Literal["columns", 1]
157+
AxisType: TypeAlias = AxisTypeIndex | AxisTypeColumn
156158
DtypeNp = TypeVar("DtypeNp", bound=np.dtype[np.generic])
157159
KeysArgType: TypeAlias = Any
158160
ListLike = TypeVar("ListLike", Sequence, np.ndarray, "Series", "Index")

pandas-stubs/core/frame.pyi

+113-6
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ from pandas._typing import (
5959
Axes,
6060
Axis,
6161
AxisType,
62+
AxisTypeColumn,
63+
AxisTypeIndex,
6264
CalculationMethod,
6365
ColspaceArgType,
6466
CompressionOptions,
@@ -1087,11 +1089,13 @@ class DataFrame(NDFrame, OpsMixin):
10871089
*args,
10881090
**kwargs,
10891091
) -> DataFrame: ...
1092+
1093+
# First set of apply is with defaults, and also with keyword result_type
10901094
@overload
10911095
def apply(
10921096
self,
10931097
f: Callable[..., ListLikeExceptSeriesAndStr],
1094-
axis: AxisType = ...,
1098+
axis: AxisTypeIndex = ...,
10951099
raw: _bool = ...,
10961100
result_type: Literal[None] = ...,
10971101
args=...,
@@ -1101,7 +1105,7 @@ class DataFrame(NDFrame, OpsMixin):
11011105
def apply(
11021106
self,
11031107
f: Callable[..., Series],
1104-
axis: AxisType = ...,
1108+
axis: AxisTypeIndex = ...,
11051109
raw: _bool = ...,
11061110
result_type: Literal[None] = ...,
11071111
args=...,
@@ -1111,7 +1115,7 @@ class DataFrame(NDFrame, OpsMixin):
11111115
def apply(
11121116
self,
11131117
f: Callable[..., Scalar],
1114-
axis: AxisType = ...,
1118+
axis: AxisTypeIndex = ...,
11151119
raw: _bool = ...,
11161120
result_type: Literal[None] = ...,
11171121
args=...,
@@ -1121,7 +1125,7 @@ class DataFrame(NDFrame, OpsMixin):
11211125
def apply(
11221126
self,
11231127
f: Callable[..., ListLikeExceptSeriesAndStr],
1124-
axis: AxisType = ...,
1128+
axis: AxisTypeIndex = ...,
11251129
raw: _bool = ...,
11261130
args=...,
11271131
*,
@@ -1132,7 +1136,7 @@ class DataFrame(NDFrame, OpsMixin):
11321136
def apply(
11331137
self,
11341138
f: Callable[..., ListLikeExceptSeriesAndStr | Series],
1135-
axis: AxisType = ...,
1139+
axis: AxisTypeIndex = ...,
11361140
raw: _bool = ...,
11371141
args=...,
11381142
*,
@@ -1143,10 +1147,113 @@ class DataFrame(NDFrame, OpsMixin):
11431147
def apply(
11441148
self,
11451149
f: Callable[..., ListLikeExceptSeriesAndStr | Series],
1146-
axis: AxisType = ...,
1150+
axis: AxisTypeIndex = ...,
1151+
raw: _bool = ...,
1152+
args=...,
1153+
*,
1154+
result_type: Literal["broadcast"],
1155+
**kwargs,
1156+
) -> DataFrame: ...
1157+
1158+
# Second set of apply is with keyword axis=1 only
1159+
@overload
1160+
def apply(
1161+
self,
1162+
f: Callable[..., ListLikeExceptSeriesAndStr],
1163+
raw: _bool = ...,
1164+
result_type: Literal[None] = ...,
1165+
args=...,
1166+
*,
1167+
axis: AxisTypeColumn,
1168+
**kwargs,
1169+
) -> Series: ...
1170+
@overload
1171+
def apply(
1172+
self,
1173+
f: Callable[..., Series],
1174+
raw: _bool = ...,
1175+
result_type: Literal[None] = ...,
1176+
args=...,
1177+
*,
1178+
axis: AxisTypeColumn,
1179+
**kwargs,
1180+
) -> DataFrame: ...
1181+
@overload
1182+
def apply(
1183+
self,
1184+
f: Callable[..., Scalar],
1185+
raw: _bool = ...,
1186+
result_type: Literal[None] = ...,
1187+
args=...,
1188+
*,
1189+
axis: AxisTypeColumn,
1190+
**kwargs,
1191+
) -> Series: ...
1192+
1193+
# Third set of apply is with keyword axis=1 and keyword result_type
1194+
@overload
1195+
def apply(
1196+
self,
1197+
f: Callable[..., ListLikeExceptSeriesAndStr],
1198+
raw: _bool = ...,
1199+
args=...,
1200+
*,
1201+
axis: AxisTypeColumn = ...,
1202+
result_type: Literal[None] = ...,
1203+
**kwargs,
1204+
) -> Series: ...
1205+
@overload
1206+
def apply(
1207+
self,
1208+
f: Callable[..., Series],
1209+
raw: _bool = ...,
1210+
args=...,
1211+
*,
1212+
axis: AxisTypeColumn = ...,
1213+
result_type: Literal[None] = ...,
1214+
**kwargs,
1215+
) -> DataFrame: ...
1216+
@overload
1217+
def apply(
1218+
self,
1219+
f: Callable[..., Scalar],
1220+
raw: _bool = ...,
1221+
args=...,
1222+
*,
1223+
axis: AxisTypeColumn = ...,
1224+
result_type: Literal[None] = ...,
1225+
**kwargs,
1226+
) -> Series: ...
1227+
@overload
1228+
def apply(
1229+
self,
1230+
f: Callable[..., ListLikeExceptSeriesAndStr],
1231+
raw: _bool = ...,
1232+
args=...,
1233+
*,
1234+
axis: AxisTypeColumn = ...,
1235+
result_type: Literal["reduce"],
1236+
**kwargs,
1237+
) -> Series: ...
1238+
@overload
1239+
def apply(
1240+
self,
1241+
f: Callable[..., ListLikeExceptSeriesAndStr | Series],
1242+
raw: _bool = ...,
1243+
args=...,
1244+
*,
1245+
axis: AxisTypeColumn = ...,
1246+
result_type: Literal["expand"],
1247+
**kwargs,
1248+
) -> DataFrame: ...
1249+
@overload
1250+
def apply(
1251+
self,
1252+
f: Callable[..., ListLikeExceptSeriesAndStr | Series],
11471253
raw: _bool = ...,
11481254
args=...,
11491255
*,
1256+
axis: AxisTypeColumn = ...,
11501257
result_type: Literal["broadcast"],
11511258
**kwargs,
11521259
) -> DataFrame: ...

tests/test_frame.py

+26-2
Original file line numberDiff line numberDiff line change
@@ -569,8 +569,8 @@ def gethead(s: pd.Series, y: int) -> pd.Series:
569569
pd.DataFrame,
570570
)
571571
check(
572-
assert_type(df.apply(returns_listlike_of_3, axis=1), pd.DataFrame),
573-
pd.DataFrame,
572+
assert_type(df.apply(returns_listlike_of_3, axis=1), pd.Series),
573+
pd.Series,
574574
)
575575

576576
# While this call works in reality, it errors in the type checker, because this should never be called
@@ -638,6 +638,30 @@ def gethead(s: pd.Series, y: int) -> pd.Series:
638638
pd.DataFrame,
639639
)
640640

641+
# Test various other argument combinations to ensure all overloads are supported
642+
check(
643+
assert_type(df.apply(returns_scalar, axis=0), pd.Series),
644+
pd.Series,
645+
)
646+
check(
647+
assert_type(df.apply(returns_scalar, axis=0, result_type=None), pd.Series),
648+
pd.Series,
649+
)
650+
check(
651+
assert_type(df.apply(returns_scalar, 0, False, None), pd.Series),
652+
pd.Series,
653+
)
654+
check(
655+
assert_type(df.apply(returns_scalar, 0, False, result_type=None), pd.Series),
656+
pd.Series,
657+
)
658+
check(
659+
assert_type(
660+
df.apply(returns_scalar, 0, raw=False, result_type=None), pd.Series
661+
),
662+
pd.Series,
663+
)
664+
641665

642666
def test_types_applymap() -> None:
643667
df = pd.DataFrame(data={"col1": [2, 1], "col2": [3, 4]})

0 commit comments

Comments
 (0)