Skip to content

Commit 1820f12

Browse files
committed
use protocol for styler.apply
1 parent 9514ab0 commit 1820f12

File tree

2 files changed

+18
-3
lines changed

2 files changed

+18
-3
lines changed

pandas-stubs/io/formats/style.pyi

+10-3
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,14 @@ from pandas.io.formats.style_render import (
4242
)
4343

4444
class SeriesFunc(Protocol):
45-
def __call__(self, series: Series, *args: Any, **kwargs: Any) -> list | Series: ...
45+
def __call__(
46+
self, series: Series, /, *args: Any, **kwargs: Any
47+
) -> list | Series: ...
48+
49+
class DataFrameFunc(Protocol):
50+
def __call__(
51+
self, series: DataFrame, /, *args: Any, **kwargs: Any
52+
) -> npt.NDArray | DataFrame: ...
4653

4754
class Styler(StylerRenderer):
4855
def __init__(
@@ -202,15 +209,15 @@ class Styler(StylerRenderer):
202209
@overload
203210
def apply(
204211
self,
205-
func: SeriesFunc,
212+
func: SeriesFunc | Callable[[Series], list | Series],
206213
axis: Axis = ...,
207214
subset: Subset | None = ...,
208215
**kwargs: Any,
209216
) -> Styler: ...
210217
@overload
211218
def apply(
212219
self,
213-
func: Callable[[DataFrame], npt.NDArray | DataFrame],
220+
func: DataFrameFunc | Callable[[DataFrame], npt.NDArray | DataFrame],
214221
axis: None,
215222
subset: Subset | None = ...,
216223
**kwargs: Any,

tests/test_styler.py

+8
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,14 @@ def h(df: DataFrame) -> DataFrame:
5252

5353
check(assert_type(DF.style.apply(h, axis=None), Styler), Styler)
5454

55+
# GH 919
56+
def highlight_max(x: Series[int], /, color: str) -> list[str]:
57+
return [f"color: {color}" if val == x.max() else "" for val in x]
58+
59+
check(
60+
assert_type(DF.style.apply(highlight_max, color="red", axis=1), Styler), Styler
61+
)
62+
5563

5664
def test_apply_index() -> None:
5765
def f(s: Series) -> npt.NDArray[np.str_]:

0 commit comments

Comments
 (0)