Skip to content

Commit 66df0bd

Browse files
phoflsamukweku
andauthored
Backport PR pandas-dev#56059 on branch 2.2.x (ENH: Add case_when method) (pandas-dev#56800)
ENH: Add case_when method (pandas-dev#56059) (cherry picked from commit e3a55a4) Co-authored-by: Samuel Oranyeli <[email protected]>
1 parent 2ddeb45 commit 66df0bd

File tree

4 files changed

+292
-1
lines changed

4 files changed

+292
-1
lines changed

doc/source/reference/series.rst

+1
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ Reindexing / selection / label manipulation
177177
:toctree: api/
178178

179179
Series.align
180+
Series.case_when
180181
Series.drop
181182
Series.droplevel
182183
Series.drop_duplicates

doc/source/whatsnew/v2.2.0.rst

+20
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,26 @@ For a full list of ADBC drivers and their development status, see the `ADBC Driv
188188
Implementation Status <https://arrow.apache.org/adbc/current/driver/status.html>`_
189189
documentation.
190190

191+
.. _whatsnew_220.enhancements.case_when:
192+
193+
Create a pandas Series based on one or more conditions
194+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
195+
196+
The :meth:`Series.case_when` function has been added to create a Series object based on one or more conditions. (:issue:`39154`)
197+
198+
.. ipython:: python
199+
200+
import pandas as pd
201+
202+
df = pd.DataFrame(dict(a=[1, 2, 3], b=[4, 5, 6]))
203+
default=pd.Series('default', index=df.index)
204+
default.case_when(
205+
caselist=[
206+
(df.a == 1, 'first'), # condition, replacement
207+
(df.a.gt(1) & df.b.eq(5), 'second'), # condition, replacement
208+
],
209+
)
210+
191211
.. _whatsnew_220.enhancements.to_numpy_ea:
192212

193213
``to_numpy`` for NumPy nullable and Arrow types converts to suitable NumPy dtype

pandas/core/series.py

+123-1
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@
6767
from pandas.core.dtypes.astype import astype_is_view
6868
from pandas.core.dtypes.cast import (
6969
LossySetitemError,
70+
construct_1d_arraylike_from_scalar,
71+
find_common_type,
72+
infer_dtype_from,
7073
maybe_box_native,
7174
maybe_cast_pointwise_result,
7275
)
@@ -84,7 +87,10 @@
8487
CategoricalDtype,
8588
ExtensionDtype,
8689
)
87-
from pandas.core.dtypes.generic import ABCDataFrame
90+
from pandas.core.dtypes.generic import (
91+
ABCDataFrame,
92+
ABCSeries,
93+
)
8894
from pandas.core.dtypes.inference import is_hashable
8995
from pandas.core.dtypes.missing import (
9096
isna,
@@ -113,6 +119,7 @@
113119
from pandas.core.arrays.sparse import SparseAccessor
114120
from pandas.core.arrays.string_ import StringDtype
115121
from pandas.core.construction import (
122+
array as pd_array,
116123
extract_array,
117124
sanitize_array,
118125
)
@@ -5627,6 +5634,121 @@ def between(
56275634

56285635
return lmask & rmask
56295636

5637+
def case_when(
5638+
self,
5639+
caselist: list[
5640+
tuple[
5641+
ArrayLike | Callable[[Series], Series | np.ndarray | Sequence[bool]],
5642+
ArrayLike | Scalar | Callable[[Series], Series | np.ndarray],
5643+
],
5644+
],
5645+
) -> Series:
5646+
"""
5647+
Replace values where the conditions are True.
5648+
5649+
Parameters
5650+
----------
5651+
caselist : A list of tuples of conditions and expected replacements
5652+
Takes the form: ``(condition0, replacement0)``,
5653+
``(condition1, replacement1)``, ... .
5654+
``condition`` should be a 1-D boolean array-like object
5655+
or a callable. If ``condition`` is a callable,
5656+
it is computed on the Series
5657+
and should return a boolean Series or array.
5658+
The callable must not change the input Series
5659+
(though pandas doesn`t check it). ``replacement`` should be a
5660+
1-D array-like object, a scalar or a callable.
5661+
If ``replacement`` is a callable, it is computed on the Series
5662+
and should return a scalar or Series. The callable
5663+
must not change the input Series
5664+
(though pandas doesn`t check it).
5665+
5666+
.. versionadded:: 2.2.0
5667+
5668+
Returns
5669+
-------
5670+
Series
5671+
5672+
See Also
5673+
--------
5674+
Series.mask : Replace values where the condition is True.
5675+
5676+
Examples
5677+
--------
5678+
>>> c = pd.Series([6, 7, 8, 9], name='c')
5679+
>>> a = pd.Series([0, 0, 1, 2])
5680+
>>> b = pd.Series([0, 3, 4, 5])
5681+
5682+
>>> c.case_when(caselist=[(a.gt(0), a), # condition, replacement
5683+
... (b.gt(0), b)])
5684+
0 6
5685+
1 3
5686+
2 1
5687+
3 2
5688+
Name: c, dtype: int64
5689+
"""
5690+
if not isinstance(caselist, list):
5691+
raise TypeError(
5692+
f"The caselist argument should be a list; instead got {type(caselist)}"
5693+
)
5694+
5695+
if not caselist:
5696+
raise ValueError(
5697+
"provide at least one boolean condition, "
5698+
"with a corresponding replacement."
5699+
)
5700+
5701+
for num, entry in enumerate(caselist):
5702+
if not isinstance(entry, tuple):
5703+
raise TypeError(
5704+
f"Argument {num} must be a tuple; instead got {type(entry)}."
5705+
)
5706+
if len(entry) != 2:
5707+
raise ValueError(
5708+
f"Argument {num} must have length 2; "
5709+
"a condition and replacement; "
5710+
f"instead got length {len(entry)}."
5711+
)
5712+
caselist = [
5713+
(
5714+
com.apply_if_callable(condition, self),
5715+
com.apply_if_callable(replacement, self),
5716+
)
5717+
for condition, replacement in caselist
5718+
]
5719+
default = self.copy()
5720+
conditions, replacements = zip(*caselist)
5721+
common_dtypes = [infer_dtype_from(arg)[0] for arg in [*replacements, default]]
5722+
if len(set(common_dtypes)) > 1:
5723+
common_dtype = find_common_type(common_dtypes)
5724+
updated_replacements = []
5725+
for condition, replacement in zip(conditions, replacements):
5726+
if is_scalar(replacement):
5727+
replacement = construct_1d_arraylike_from_scalar(
5728+
value=replacement, length=len(condition), dtype=common_dtype
5729+
)
5730+
elif isinstance(replacement, ABCSeries):
5731+
replacement = replacement.astype(common_dtype)
5732+
else:
5733+
replacement = pd_array(replacement, dtype=common_dtype)
5734+
updated_replacements.append(replacement)
5735+
replacements = updated_replacements
5736+
default = default.astype(common_dtype)
5737+
5738+
counter = reversed(range(len(conditions)))
5739+
for position, condition, replacement in zip(
5740+
counter, conditions[::-1], replacements[::-1]
5741+
):
5742+
try:
5743+
default = default.mask(
5744+
condition, other=replacement, axis=0, inplace=False, level=None
5745+
)
5746+
except Exception as error:
5747+
raise ValueError(
5748+
f"Failed to apply condition{position} and replacement{position}."
5749+
) from error
5750+
return default
5751+
56305752
# error: Cannot determine type of 'isna'
56315753
@doc(NDFrame.isna, klass=_shared_doc_kwargs["klass"]) # type: ignore[has-type]
56325754
def isna(self) -> Series:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
import numpy as np
2+
import pytest
3+
4+
from pandas import (
5+
DataFrame,
6+
Series,
7+
array as pd_array,
8+
date_range,
9+
)
10+
import pandas._testing as tm
11+
12+
13+
@pytest.fixture
14+
def df():
15+
"""
16+
base dataframe for testing
17+
"""
18+
return DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
19+
20+
21+
def test_case_when_caselist_is_not_a_list(df):
22+
"""
23+
Raise ValueError if caselist is not a list.
24+
"""
25+
msg = "The caselist argument should be a list; "
26+
msg += "instead got.+"
27+
with pytest.raises(TypeError, match=msg): # GH39154
28+
df["a"].case_when(caselist=())
29+
30+
31+
def test_case_when_no_caselist(df):
32+
"""
33+
Raise ValueError if no caselist is provided.
34+
"""
35+
msg = "provide at least one boolean condition, "
36+
msg += "with a corresponding replacement."
37+
with pytest.raises(ValueError, match=msg): # GH39154
38+
df["a"].case_when([])
39+
40+
41+
def test_case_when_odd_caselist(df):
42+
"""
43+
Raise ValueError if no of caselist is odd.
44+
"""
45+
msg = "Argument 0 must have length 2; "
46+
msg += "a condition and replacement; instead got length 3."
47+
48+
with pytest.raises(ValueError, match=msg):
49+
df["a"].case_when([(df["a"].eq(1), 1, df.a.gt(1))])
50+
51+
52+
def test_case_when_raise_error_from_mask(df):
53+
"""
54+
Raise Error from within Series.mask
55+
"""
56+
msg = "Failed to apply condition0 and replacement0."
57+
with pytest.raises(ValueError, match=msg):
58+
df["a"].case_when([(df["a"].eq(1), [1, 2])])
59+
60+
61+
def test_case_when_single_condition(df):
62+
"""
63+
Test output on a single condition.
64+
"""
65+
result = Series([np.nan, np.nan, np.nan]).case_when([(df.a.eq(1), 1)])
66+
expected = Series([1, np.nan, np.nan])
67+
tm.assert_series_equal(result, expected)
68+
69+
70+
def test_case_when_multiple_conditions(df):
71+
"""
72+
Test output when booleans are derived from a computation
73+
"""
74+
result = Series([np.nan, np.nan, np.nan]).case_when(
75+
[(df.a.eq(1), 1), (Series([False, True, False]), 2)]
76+
)
77+
expected = Series([1, 2, np.nan])
78+
tm.assert_series_equal(result, expected)
79+
80+
81+
def test_case_when_multiple_conditions_replacement_list(df):
82+
"""
83+
Test output when replacement is a list
84+
"""
85+
result = Series([np.nan, np.nan, np.nan]).case_when(
86+
[([True, False, False], 1), (df["a"].gt(1) & df["b"].eq(5), [1, 2, 3])]
87+
)
88+
expected = Series([1, 2, np.nan])
89+
tm.assert_series_equal(result, expected)
90+
91+
92+
def test_case_when_multiple_conditions_replacement_extension_dtype(df):
93+
"""
94+
Test output when replacement has an extension dtype
95+
"""
96+
result = Series([np.nan, np.nan, np.nan]).case_when(
97+
[
98+
([True, False, False], 1),
99+
(df["a"].gt(1) & df["b"].eq(5), pd_array([1, 2, 3], dtype="Int64")),
100+
],
101+
)
102+
expected = Series([1, 2, np.nan], dtype="Float64")
103+
tm.assert_series_equal(result, expected)
104+
105+
106+
def test_case_when_multiple_conditions_replacement_series(df):
107+
"""
108+
Test output when replacement is a Series
109+
"""
110+
result = Series([np.nan, np.nan, np.nan]).case_when(
111+
[
112+
(np.array([True, False, False]), 1),
113+
(df["a"].gt(1) & df["b"].eq(5), Series([1, 2, 3])),
114+
],
115+
)
116+
expected = Series([1, 2, np.nan])
117+
tm.assert_series_equal(result, expected)
118+
119+
120+
def test_case_when_non_range_index():
121+
"""
122+
Test output if index is not RangeIndex
123+
"""
124+
rng = np.random.default_rng(seed=123)
125+
dates = date_range("1/1/2000", periods=8)
126+
df = DataFrame(
127+
rng.standard_normal(size=(8, 4)), index=dates, columns=["A", "B", "C", "D"]
128+
)
129+
result = Series(5, index=df.index, name="A").case_when([(df.A.gt(0), df.B)])
130+
expected = df.A.mask(df.A.gt(0), df.B).where(df.A.gt(0), 5)
131+
tm.assert_series_equal(result, expected)
132+
133+
134+
def test_case_when_callable():
135+
"""
136+
Test output on a callable
137+
"""
138+
# https://numpy.org/doc/stable/reference/generated/numpy.piecewise.html
139+
x = np.linspace(-2.5, 2.5, 6)
140+
ser = Series(x)
141+
result = ser.case_when(
142+
caselist=[
143+
(lambda df: df < 0, lambda df: -df),
144+
(lambda df: df >= 0, lambda df: df),
145+
]
146+
)
147+
expected = np.piecewise(x, [x < 0, x >= 0], [lambda x: -x, lambda x: x])
148+
tm.assert_series_equal(result, Series(expected))

0 commit comments

Comments
 (0)