Skip to content

Commit 55a1a0c

Browse files
authored
Support Callable conditions for DataFrame.where() and Series.where() (#310)
1 parent 1f597d3 commit 55a1a0c

File tree

4 files changed

+44
-2
lines changed

4 files changed

+44
-2
lines changed

pandas-stubs/core/frame.pyi

+5-1
Original file line numberDiff line numberDiff line change
@@ -2017,7 +2017,11 @@ class DataFrame(NDFrame, OpsMixin):
20172017
) -> Series: ...
20182018
def where(
20192019
self,
2020-
cond: Series | DataFrame | np.ndarray,
2020+
cond: Series
2021+
| DataFrame
2022+
| np.ndarray
2023+
| Callable[[DataFrame], DataFrame]
2024+
| Callable[[Any], _bool],
20212025
other=...,
20222026
inplace: _bool = ...,
20232027
axis: AxisType | None = ...,

pandas-stubs/core/series.pyi

+5-1
Original file line numberDiff line numberDiff line change
@@ -1087,7 +1087,11 @@ class Series(IndexOpsMixin, NDFrame, Generic[S1]):
10871087
) -> Series: ...
10881088
def where(
10891089
self,
1090-
cond: Series[S1] | Series[_bool] | np.ndarray,
1090+
cond: Series[S1]
1091+
| Series[_bool]
1092+
| np.ndarray
1093+
| Callable[[Series[S1]], Series[bool]]
1094+
| Callable[[S1], bool],
10911095
other=...,
10921096
inplace: _bool = ...,
10931097
axis: SeriesAxisType | None = ...,

tests/test_frame.py

+17
Original file line numberDiff line numberDiff line change
@@ -1764,3 +1764,20 @@ def test_loc_slice() -> None:
17641764
index=pd.MultiIndex.from_product([[1, 2], ["a", "b"]], names=["num", "let"]),
17651765
)
17661766
check(assert_type(df1.loc[1, :], pd.DataFrame), pd.DataFrame)
1767+
1768+
1769+
def test_where() -> None:
1770+
df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
1771+
1772+
def cond1(x: int) -> bool:
1773+
return x % 2 == 0
1774+
1775+
check(assert_type(df.where(cond1), pd.DataFrame), pd.DataFrame)
1776+
1777+
def cond2(x: pd.DataFrame) -> pd.DataFrame:
1778+
return x > 1
1779+
1780+
check(assert_type(df.where(cond2), pd.DataFrame), pd.DataFrame)
1781+
1782+
cond3 = pd.DataFrame({"a": [True, True, False], "b": [False, False, False]})
1783+
check(assert_type(df.where(cond3), pd.DataFrame), pd.DataFrame)

tests/test_series.py

+17
Original file line numberDiff line numberDiff line change
@@ -1192,3 +1192,20 @@ def test_types_to_numpy() -> None:
11921192
check(assert_type(s.to_numpy(), np.ndarray), np.ndarray)
11931193
check(assert_type(s.to_numpy(dtype="str", copy=True), np.ndarray), np.ndarray)
11941194
check(assert_type(s.to_numpy(na_value=0), np.ndarray), np.ndarray)
1195+
1196+
1197+
def test_where() -> None:
1198+
s = pd.Series([1, 2, 3], dtype=int)
1199+
1200+
def cond1(x: int) -> bool:
1201+
return x % 2 == 0
1202+
1203+
check(assert_type(s.where(cond1, other=0), "pd.Series[int]"), pd.Series, int)
1204+
1205+
def cond2(x: pd.Series[int]) -> pd.Series[bool]:
1206+
return x > 1
1207+
1208+
check(assert_type(s.where(cond2, other=0), "pd.Series[int]"), pd.Series, int)
1209+
1210+
cond3 = pd.Series([False, True, True])
1211+
check(assert_type(s.where(cond3, other=0), "pd.Series[int]"), pd.Series, int)

0 commit comments

Comments
 (0)