Skip to content

Commit 67dfcaa

Browse files
committed
improve typing and add test for callable
1 parent f6cf725 commit 67dfcaa

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

pandas/core/series.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5638,7 +5638,12 @@ def between(
56385638

56395639
def case_when(
56405640
self,
5641-
caselist: list[tuple[ArrayLike, ArrayLike | Scalar]],
5641+
caselist: list[
5642+
tuple[
5643+
ArrayLike | Callable[[Series | np.ndarray | Sequence[bool]]],
5644+
ArrayLike | Callable[[Series | np.ndarray]] | Scalar,
5645+
]
5646+
],
56425647
) -> Series:
56435648
"""
56445649
Replace values where the conditions are True.

pandas/tests/series/methods/test_case_when.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,3 +129,20 @@ def test_case_when_non_range_index():
129129
result = Series(5, index=df.index, name="A").case_when([(df.A.gt(0), df.B)])
130130
expected = df.A.mask(df.A.gt(0), df.B).where(df.A.gt(0), 5)
131131
tm.assert_series_equal(result, expected)
132+
133+
134+
def test_case_when_callable():
135+
"""
136+
Test output on a callable
137+
"""
138+
# https://numpy.org/doc/stable/reference/generated/numpy.piecewise.html
139+
x = np.linspace(-2.5, 2.5, 6)
140+
ser = Series(x)
141+
result = ser.case_when(
142+
caselist=[
143+
(lambda df: df < 0, lambda df: -df),
144+
(lambda df: df >= 0, lambda df: df),
145+
]
146+
)
147+
expected = np.piecewise(x, [x < 0, x >= 0], [lambda x: -x, lambda x: x])
148+
tm.assert_series_equal(result, Series(expected))

0 commit comments

Comments
 (0)