Skip to content

Commit ad4066a

Browse files
authored
Fix Styler.apply() to accept keyword arguments (#921)
* try using Protocol * use protocol for styler.apply * hide protocol classes
1 parent 6dfa03e commit ad4066a

File tree

2 files changed

+21
-2
lines changed

2 files changed

+21
-2
lines changed

pandas-stubs/io/formats/style.pyi

+13-2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ from collections.abc import (
55
from typing import (
66
Any,
77
Literal,
8+
Protocol,
89
overload,
910
)
1011

@@ -40,6 +41,16 @@ from pandas.io.formats.style_render import (
4041
Subset,
4142
)
4243

44+
class _SeriesFunc(Protocol):
45+
def __call__(
46+
self, series: Series, /, *args: Any, **kwargs: Any
47+
) -> list | Series: ...
48+
49+
class _DataFrameFunc(Protocol):
50+
def __call__(
51+
self, series: DataFrame, /, *args: Any, **kwargs: Any
52+
) -> npt.NDArray | DataFrame: ...
53+
4354
class Styler(StylerRenderer):
4455
def __init__(
4556
self,
@@ -198,15 +209,15 @@ class Styler(StylerRenderer):
198209
@overload
199210
def apply(
200211
self,
201-
func: Callable[[Series], list | Series],
212+
func: _SeriesFunc | Callable[[Series], list | Series],
202213
axis: Axis = ...,
203214
subset: Subset | None = ...,
204215
**kwargs: Any,
205216
) -> Styler: ...
206217
@overload
207218
def apply(
208219
self,
209-
func: Callable[[DataFrame], npt.NDArray | DataFrame],
220+
func: _DataFrameFunc | Callable[[DataFrame], npt.NDArray | DataFrame],
210221
axis: None,
211222
subset: Subset | None = ...,
212223
**kwargs: Any,

tests/test_styler.py

+8
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,14 @@ def h(df: DataFrame) -> DataFrame:
5252

5353
check(assert_type(DF.style.apply(h, axis=None), Styler), Styler)
5454

55+
# GH 919
56+
def highlight_max(x: Series[int], /, color: str) -> list[str]:
57+
return [f"color: {color}" if val == x.max() else "" for val in x]
58+
59+
check(
60+
assert_type(DF.style.apply(highlight_max, color="red", axis=1), Styler), Styler
61+
)
62+
5563

5664
def test_apply_index() -> None:
5765
def f(s: Series) -> npt.NDArray[np.str_]:

0 commit comments

Comments
 (0)