diff --git a/pandas-stubs/io/formats/style.pyi b/pandas-stubs/io/formats/style.pyi index dc2b53829..4724726f8 100644 --- a/pandas-stubs/io/formats/style.pyi +++ b/pandas-stubs/io/formats/style.pyi @@ -5,6 +5,7 @@ from collections.abc import ( from typing import ( Any, Literal, + Protocol, overload, ) @@ -40,6 +41,16 @@ from pandas.io.formats.style_render import ( Subset, ) +class _SeriesFunc(Protocol): + def __call__( + self, series: Series, /, *args: Any, **kwargs: Any + ) -> list | Series: ... + +class _DataFrameFunc(Protocol): + def __call__( + self, series: DataFrame, /, *args: Any, **kwargs: Any + ) -> npt.NDArray | DataFrame: ... + class Styler(StylerRenderer): def __init__( self, @@ -198,7 +209,7 @@ class Styler(StylerRenderer): @overload def apply( self, - func: Callable[[Series], list | Series], + func: _SeriesFunc | Callable[[Series], list | Series], axis: Axis = ..., subset: Subset | None = ..., **kwargs: Any, @@ -206,7 +217,7 @@ class Styler(StylerRenderer): @overload def apply( self, - func: Callable[[DataFrame], npt.NDArray | DataFrame], + func: _DataFrameFunc | Callable[[DataFrame], npt.NDArray | DataFrame], axis: None, subset: Subset | None = ..., **kwargs: Any, diff --git a/tests/test_styler.py b/tests/test_styler.py index 7dd4003ec..42f7830d9 100644 --- a/tests/test_styler.py +++ b/tests/test_styler.py @@ -52,6 +52,14 @@ def h(df: DataFrame) -> DataFrame: check(assert_type(DF.style.apply(h, axis=None), Styler), Styler) + # GH 919 + def highlight_max(x: Series[int], /, color: str) -> list[str]: + return [f"color: {color}" if val == x.max() else "" for val in x] + + check( + assert_type(DF.style.apply(highlight_max, color="red", axis=1), Styler), Styler + ) + def test_apply_index() -> None: def f(s: Series) -> npt.NDArray[np.str_]: