diff --git a/pandas-stubs/core/frame.pyi b/pandas-stubs/core/frame.pyi index fcd6b996a..a65e6e1d9 100644 --- a/pandas-stubs/core/frame.pyi +++ b/pandas-stubs/core/frame.pyi @@ -1088,14 +1088,32 @@ class DataFrame(NDFrame, OpsMixin): **kwargs, ) -> DataFrame: ... @overload - def apply(self, f: Callable) -> Series: ... + def apply( + self, + f: Callable[..., Series], + axis: AxisType = ..., + raw: _bool = ..., + result_type: Literal["expand", "reduce", "broadcast"] | None = ..., + args=..., + **kwargs, + ) -> DataFrame: ... + @overload + def apply( + self, + f: Callable[..., Scalar], + axis: AxisType = ..., + raw: _bool = ..., + result_type: Literal["expand", "reduce"] | None = ..., + args=..., + **kwargs, + ) -> Series: ... @overload def apply( self, - f: Callable, - axis: AxisType, + f: Callable[..., Scalar], + result_type: Literal["broadcast"], + axis: AxisType = ..., raw: _bool = ..., - result_type: _str | None = ..., args=..., **kwargs, ) -> DataFrame: ... diff --git a/pandas-stubs/core/series.pyi b/pandas-stubs/core/series.pyi index 189054673..6697f7924 100644 --- a/pandas-stubs/core/series.pyi +++ b/pandas-stubs/core/series.pyi @@ -685,7 +685,7 @@ class Series(IndexOpsMixin, NDFrame, Generic[S1]): @overload def apply( self, - func: Callable[..., Hashable], + func: Callable[..., Scalar | Sequence | Mapping], convertDType: _bool = ..., args: tuple = ..., **kwds, @@ -1193,7 +1193,7 @@ class Series(IndexOpsMixin, NDFrame, Generic[S1]): ascending: _bool = ..., bins: int | None = ..., dropna: _bool = ..., - ) -> Series[S1]: ... + ) -> Series[int]: ... def transpose(self, *args, **kwargs) -> Series[S1]: ... @property def T(self) -> Series[S1]: ... diff --git a/tests/test_frame.py b/tests/test_frame.py index 3f0a4ad52..977dd7f2c 100644 --- a/tests/test_frame.py +++ b/tests/test_frame.py @@ -460,9 +460,28 @@ def test_types_unique() -> None: def test_types_apply() -> None: df = pd.DataFrame(data={"col1": [2, 1], "col2": [3, 4]}) - df.apply(lambda x: x**2) - df.apply(np.exp) - df.apply(str) + + def returns_series(x: pd.Series) -> pd.Series: + return x**2 + + check(assert_type(df.apply(returns_series), pd.DataFrame), pd.DataFrame) + + def returns_scalar(x: pd.Series) -> float: + return 2 + + check(assert_type(df.apply(returns_scalar), pd.Series), pd.Series) + check( + assert_type(df.apply(returns_scalar, result_type="broadcast"), pd.DataFrame), + pd.DataFrame, + ) + check(assert_type(df.apply(np.exp), pd.DataFrame), pd.DataFrame) + check(assert_type(df.apply(str), pd.Series), pd.Series) + + # GH 393 + def gethead(s: pd.Series, y: int) -> pd.Series: + return s.head(y) + + check(assert_type(df.apply(gethead, args=(4,)), pd.DataFrame), pd.DataFrame) def test_types_applymap() -> None: diff --git a/tests/test_series.py b/tests/test_series.py index 67f768d81..e6136f503 100644 --- a/tests/test_series.py +++ b/tests/test_series.py @@ -367,8 +367,8 @@ def test_types_idxmax() -> None: def test_types_value_counts() -> None: - s = pd.Series([1, 2]) - s.value_counts() + s = pd.Series(["a", "b"]) + check(assert_type(s.value_counts(), "pd.Series[int]"), pd.Series, int) def test_types_unique() -> None: @@ -398,6 +398,11 @@ def retseries(x: float) -> float: check(assert_type(s.apply(retseries).tolist(), list), list) + def retlist(x: float) -> list: + return [x] + + check(assert_type(s.apply(retlist), pd.Series), pd.Series, list) + def get_depth(url: str) -> int: return len(url)