Skip to content

Commit d89824b

Browse files
authored
Add callable condition type in Series.mask, DataFrame.mask (#918)
* Add callable condition type in Series.mask, DataFrame.mask * Get rid of ignored rule (pyright)
1 parent e63d1cb commit d89824b

File tree

5 files changed

+30
-3
lines changed

5 files changed

+30
-3
lines changed

pandas-stubs/core/frame.pyi

+7-1
Original file line numberDiff line numberDiff line change
@@ -1833,7 +1833,13 @@ class DataFrame(NDFrame, OpsMixin):
18331833
def lt(self, other, axis: Axis = ..., level: Level | None = ...) -> DataFrame: ...
18341834
def mask(
18351835
self,
1836-
cond: Series | DataFrame | np.ndarray,
1836+
cond: (
1837+
Series
1838+
| DataFrame
1839+
| np.ndarray
1840+
| Callable[[DataFrame], DataFrame]
1841+
| Callable[[Any], _bool]
1842+
),
18371843
other=...,
18381844
*,
18391845
inplace: _bool = ...,

pandas-stubs/core/groupby/generic.pyi

+1-1
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ class DataFrameGroupBy(GroupBy[DataFrame], Generic[ByT]):
199199
**kwargs,
200200
) -> DataFrame: ...
201201
@overload
202-
def apply( # pyright: ignore[reportOverlappingOverload,reportIncompatibleMethodOverride]
202+
def apply( # pyright: ignore[reportOverlappingOverload]
203203
self,
204204
func: Callable[[Iterable], float],
205205
*args,

pandas-stubs/core/series.pyi

+7-1
Original file line numberDiff line numberDiff line change
@@ -1433,7 +1433,13 @@ class Series(IndexOpsMixin[S1], NDFrame):
14331433
) -> Series[S1]: ...
14341434
def mask(
14351435
self,
1436-
cond: MaskType,
1436+
cond: (
1437+
Series[S1]
1438+
| Series[_bool]
1439+
| np.ndarray
1440+
| Callable[[Series[S1]], Series[bool]]
1441+
| Callable[[S1], bool]
1442+
),
14371443
other: Scalar | Series[S1] | DataFrame | Callable | NAType | None = ...,
14381444
*,
14391445
inplace: _bool = ...,

tests/test_frame.py

+9
Original file line numberDiff line numberDiff line change
@@ -2755,6 +2755,15 @@ def cond2(x: pd.DataFrame) -> pd.DataFrame:
27552755
check(assert_type(df.where(cond3), pd.DataFrame), pd.DataFrame)
27562756

27572757

2758+
def test_mask() -> None:
2759+
df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
2760+
2761+
def cond1(x: int) -> bool:
2762+
return x % 2 == 0
2763+
2764+
check(assert_type(df.mask(cond1), pd.DataFrame), pd.DataFrame)
2765+
2766+
27582767
def test_setitem_loc() -> None:
27592768
# GH 254
27602769
df = pd.DataFrame.from_dict(

tests/test_series.py

+6
Original file line numberDiff line numberDiff line change
@@ -2823,6 +2823,12 @@ def test_types_mask() -> None:
28232823
# Test case with a boolean condition and a scalar value
28242824
check(assert_type(s.mask(s > 3, 10), "pd.Series[int]"), pd.Series, np.integer)
28252825

2826+
def cond(x: int) -> bool:
2827+
return x % 2 == 0
2828+
2829+
# Test case with a callable condition and a scalar value
2830+
check(assert_type(s.mask(cond, 10), "pd.Series[int]"), pd.Series, np.integer)
2831+
28262832
# Test case with a boolean condition and a callable
28272833
def double(x):
28282834
return x * 2

0 commit comments

Comments
 (0)