From 85161418810b9dd956162db76573364d6f99c7ae Mon Sep 17 00:00:00 2001 From: George Sittas Date: Thu, 9 May 2024 04:41:19 +0300 Subject: [PATCH 1/2] Add callable condition type in Series.mask, DataFrame.mask --- pandas-stubs/core/frame.pyi | 8 +++++++- pandas-stubs/core/series.pyi | 8 +++++++- tests/test_frame.py | 9 +++++++++ tests/test_series.py | 6 ++++++ 4 files changed, 29 insertions(+), 2 deletions(-) diff --git a/pandas-stubs/core/frame.pyi b/pandas-stubs/core/frame.pyi index 783842476..f2deecaa8 100644 --- a/pandas-stubs/core/frame.pyi +++ b/pandas-stubs/core/frame.pyi @@ -1833,7 +1833,13 @@ class DataFrame(NDFrame, OpsMixin): def lt(self, other, axis: Axis = ..., level: Level | None = ...) -> DataFrame: ... def mask( self, - cond: Series | DataFrame | np.ndarray, + cond: ( + Series + | DataFrame + | np.ndarray + | Callable[[DataFrame], DataFrame] + | Callable[[Any], _bool] + ), other=..., *, inplace: _bool = ..., diff --git a/pandas-stubs/core/series.pyi b/pandas-stubs/core/series.pyi index 1d48feba3..b6111db82 100644 --- a/pandas-stubs/core/series.pyi +++ b/pandas-stubs/core/series.pyi @@ -1433,7 +1433,13 @@ class Series(IndexOpsMixin[S1], NDFrame): ) -> Series[S1]: ... def mask( self, - cond: MaskType, + cond: ( + Series[S1] + | Series[_bool] + | np.ndarray + | Callable[[Series[S1]], Series[bool]] + | Callable[[S1], bool] + ), other: Scalar | Series[S1] | DataFrame | Callable | NAType | None = ..., *, inplace: _bool = ..., diff --git a/tests/test_frame.py b/tests/test_frame.py index f156a8e37..ed42dad84 100644 --- a/tests/test_frame.py +++ b/tests/test_frame.py @@ -2755,6 +2755,15 @@ def cond2(x: pd.DataFrame) -> pd.DataFrame: check(assert_type(df.where(cond3), pd.DataFrame), pd.DataFrame) +def test_mask() -> None: + df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + + def cond1(x: int) -> bool: + return x % 2 == 0 + + check(assert_type(df.mask(cond1), pd.DataFrame), pd.DataFrame) + + def test_setitem_loc() -> None: # GH 254 df = pd.DataFrame.from_dict( diff --git a/tests/test_series.py b/tests/test_series.py index 203f81700..d2a8af78a 100644 --- a/tests/test_series.py +++ b/tests/test_series.py @@ -2823,6 +2823,12 @@ def test_types_mask() -> None: # Test case with a boolean condition and a scalar value check(assert_type(s.mask(s > 3, 10), "pd.Series[int]"), pd.Series, np.integer) + def cond(x: int) -> bool: + return x % 2 == 0 + + # Test case with a callable condition and a scalar value + check(assert_type(s.mask(cond, 10), "pd.Series[int]"), pd.Series, np.integer) + # Test case with a boolean condition and a callable def double(x): return x * 2 From 83ee9f40a1fa8ded1f0b678303b2fe35d4971b49 Mon Sep 17 00:00:00 2001 From: George Sittas Date: Thu, 9 May 2024 17:41:43 +0300 Subject: [PATCH 2/2] Get rid of ignored rule (pyright) --- pandas-stubs/core/groupby/generic.pyi | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas-stubs/core/groupby/generic.pyi b/pandas-stubs/core/groupby/generic.pyi index 715899fca..37f3c0dcc 100644 --- a/pandas-stubs/core/groupby/generic.pyi +++ b/pandas-stubs/core/groupby/generic.pyi @@ -199,7 +199,7 @@ class DataFrameGroupBy(GroupBy[DataFrame], Generic[ByT]): **kwargs, ) -> DataFrame: ... @overload - def apply( # pyright: ignore[reportOverlappingOverload,reportIncompatibleMethodOverride] + def apply( # pyright: ignore[reportOverlappingOverload] self, func: Callable[[Iterable], float], *args,