From 9514ab0d5cacaae2b96589a39f092503df64ec17 Mon Sep 17 00:00:00 2001 From: Irv Lustig Date: Sun, 12 May 2024 14:37:08 -0400 Subject: [PATCH 1/3] try using Protocol --- pandas-stubs/io/formats/style.pyi | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pandas-stubs/io/formats/style.pyi b/pandas-stubs/io/formats/style.pyi index dc2b53829..85b4a45bc 100644 --- a/pandas-stubs/io/formats/style.pyi +++ b/pandas-stubs/io/formats/style.pyi @@ -5,6 +5,7 @@ from collections.abc import ( from typing import ( Any, Literal, + Protocol, overload, ) @@ -40,6 +41,9 @@ from pandas.io.formats.style_render import ( Subset, ) +class SeriesFunc(Protocol): + def __call__(self, series: Series, *args: Any, **kwargs: Any) -> list | Series: ... + class Styler(StylerRenderer): def __init__( self, @@ -198,7 +202,7 @@ class Styler(StylerRenderer): @overload def apply( self, - func: Callable[[Series], list | Series], + func: SeriesFunc, axis: Axis = ..., subset: Subset | None = ..., **kwargs: Any, From 1820f12bf24d82a062b4120ddc88a6a9256a5686 Mon Sep 17 00:00:00 2001 From: Irv Lustig Date: Sun, 12 May 2024 16:34:39 -0400 Subject: [PATCH 2/3] use protocol for styler.apply --- pandas-stubs/io/formats/style.pyi | 13 ++++++++++--- tests/test_styler.py | 8 ++++++++ 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/pandas-stubs/io/formats/style.pyi b/pandas-stubs/io/formats/style.pyi index 85b4a45bc..0d0c0a74e 100644 --- a/pandas-stubs/io/formats/style.pyi +++ b/pandas-stubs/io/formats/style.pyi @@ -42,7 +42,14 @@ from pandas.io.formats.style_render import ( ) class SeriesFunc(Protocol): - def __call__(self, series: Series, *args: Any, **kwargs: Any) -> list | Series: ... + def __call__( + self, series: Series, /, *args: Any, **kwargs: Any + ) -> list | Series: ... + +class DataFrameFunc(Protocol): + def __call__( + self, series: DataFrame, /, *args: Any, **kwargs: Any + ) -> npt.NDArray | DataFrame: ... class Styler(StylerRenderer): def __init__( @@ -202,7 +209,7 @@ class Styler(StylerRenderer): @overload def apply( self, - func: SeriesFunc, + func: SeriesFunc | Callable[[Series], list | Series], axis: Axis = ..., subset: Subset | None = ..., **kwargs: Any, @@ -210,7 +217,7 @@ class Styler(StylerRenderer): @overload def apply( self, - func: Callable[[DataFrame], npt.NDArray | DataFrame], + func: DataFrameFunc | Callable[[DataFrame], npt.NDArray | DataFrame], axis: None, subset: Subset | None = ..., **kwargs: Any, diff --git a/tests/test_styler.py b/tests/test_styler.py index 7dd4003ec..42f7830d9 100644 --- a/tests/test_styler.py +++ b/tests/test_styler.py @@ -52,6 +52,14 @@ def h(df: DataFrame) -> DataFrame: check(assert_type(DF.style.apply(h, axis=None), Styler), Styler) + # GH 919 + def highlight_max(x: Series[int], /, color: str) -> list[str]: + return [f"color: {color}" if val == x.max() else "" for val in x] + + check( + assert_type(DF.style.apply(highlight_max, color="red", axis=1), Styler), Styler + ) + def test_apply_index() -> None: def f(s: Series) -> npt.NDArray[np.str_]: From 63ea1465f2338af3c32db3b16fd7544b6195a6e9 Mon Sep 17 00:00:00 2001 From: Irv Lustig Date: Mon, 13 May 2024 09:41:13 -0400 Subject: [PATCH 3/3] hide protocol classes --- pandas-stubs/io/formats/style.pyi | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pandas-stubs/io/formats/style.pyi b/pandas-stubs/io/formats/style.pyi index 0d0c0a74e..4724726f8 100644 --- a/pandas-stubs/io/formats/style.pyi +++ b/pandas-stubs/io/formats/style.pyi @@ -41,12 +41,12 @@ from pandas.io.formats.style_render import ( Subset, ) -class SeriesFunc(Protocol): +class _SeriesFunc(Protocol): def __call__( self, series: Series, /, *args: Any, **kwargs: Any ) -> list | Series: ... -class DataFrameFunc(Protocol): +class _DataFrameFunc(Protocol): def __call__( self, series: DataFrame, /, *args: Any, **kwargs: Any ) -> npt.NDArray | DataFrame: ... @@ -209,7 +209,7 @@ class Styler(StylerRenderer): @overload def apply( self, - func: SeriesFunc | Callable[[Series], list | Series], + func: _SeriesFunc | Callable[[Series], list | Series], axis: Axis = ..., subset: Subset | None = ..., **kwargs: Any, @@ -217,7 +217,7 @@ class Styler(StylerRenderer): @overload def apply( self, - func: DataFrameFunc | Callable[[DataFrame], npt.NDArray | DataFrame], + func: _DataFrameFunc | Callable[[DataFrame], npt.NDArray | DataFrame], axis: None, subset: Subset | None = ..., **kwargs: Any,