Skip to content

Commit acbd4a9

Browse files
committed
More detailed typing for DataFrame.apply.
Whether it returns a Series or a DataFrame depends on the return type of the callable. In the case of the callable returning a scalar, the result is a Series unless the result_type is "broadcast".
1 parent 78d8215 commit acbd4a9

File tree

2 files changed

+38
-7
lines changed

2 files changed

+38
-7
lines changed

pandas-stubs/core/frame.pyi

+22-4
Original file line numberDiff line numberDiff line change
@@ -1088,14 +1088,32 @@ class DataFrame(NDFrame, OpsMixin):
10881088
**kwargs,
10891089
) -> DataFrame: ...
10901090
@overload
1091-
def apply(self, f: Callable) -> Series: ...
1091+
def apply(
1092+
self,
1093+
f: Callable[..., Series],
1094+
axis: AxisType = ...,
1095+
raw: _bool = ...,
1096+
result_type: Literal["expand", "reduce", "broadcast"] | None = ...,
1097+
args=...,
1098+
**kwargs,
1099+
) -> DataFrame: ...
1100+
@overload
1101+
def apply(
1102+
self,
1103+
f: Callable[..., Scalar],
1104+
axis: AxisType = ...,
1105+
raw: _bool = ...,
1106+
result_type: Literal["expand", "reduce"] | None = ...,
1107+
args=...,
1108+
**kwargs,
1109+
) -> Series: ...
10921110
@overload
10931111
def apply(
10941112
self,
1095-
f: Callable,
1096-
axis: AxisType,
1113+
f: Callable[..., Scalar],
1114+
result_type: Literal["broadcast"],
1115+
axis: AxisType = ...,
10971116
raw: _bool = ...,
1098-
result_type: _str | None = ...,
10991117
args=...,
11001118
**kwargs,
11011119
) -> DataFrame: ...

tests/test_frame.py

+16-3
Original file line numberDiff line numberDiff line change
@@ -460,9 +460,22 @@ def test_types_unique() -> None:
460460

461461
def test_types_apply() -> None:
462462
df = pd.DataFrame(data={"col1": [2, 1], "col2": [3, 4]})
463-
df.apply(lambda x: x**2)
464-
df.apply(np.exp)
465-
df.apply(str)
463+
464+
def returns_series(x: pd.Series) -> pd.Series:
465+
return x**2
466+
467+
check(assert_type(df.apply(returns_series), pd.DataFrame), pd.DataFrame)
468+
469+
def returns_scalar(x: pd.Series) -> float:
470+
return 2
471+
472+
check(assert_type(df.apply(returns_scalar), pd.Series), pd.Series)
473+
check(
474+
assert_type(df.apply(returns_scalar, result_type="broadcast"), pd.DataFrame),
475+
pd.DataFrame,
476+
)
477+
check(assert_type(df.apply(np.exp), pd.DataFrame), pd.DataFrame)
478+
check(assert_type(df.apply(str), pd.Series), pd.Series)
466479

467480

468481
def test_types_applymap() -> None:

0 commit comments

Comments
 (0)