Skip to content

Backport PR #56059 on branch 2.2.x (ENH: Add case_when method) #56800

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/reference/series.rst
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ Reindexing / selection / label manipulation
:toctree: api/

Series.align
Series.case_when
Series.drop
Series.droplevel
Series.drop_duplicates
Expand Down
20 changes: 20 additions & 0 deletions doc/source/whatsnew/v2.2.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,26 @@ For a full list of ADBC drivers and their development status, see the `ADBC Driv
Implementation Status <https://arrow.apache.org/adbc/current/driver/status.html>`_
documentation.

.. _whatsnew_220.enhancements.case_when:

Create a pandas Series based on one or more conditions
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

The :meth:`Series.case_when` function has been added to create a Series object based on one or more conditions. (:issue:`39154`)

.. ipython:: python

import pandas as pd

df = pd.DataFrame(dict(a=[1, 2, 3], b=[4, 5, 6]))
default=pd.Series('default', index=df.index)
default.case_when(
caselist=[
(df.a == 1, 'first'), # condition, replacement
(df.a.gt(1) & df.b.eq(5), 'second'), # condition, replacement
],
)

.. _whatsnew_220.enhancements.to_numpy_ea:

``to_numpy`` for NumPy nullable and Arrow types converts to suitable NumPy dtype
Expand Down
124 changes: 123 additions & 1 deletion pandas/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@
from pandas.core.dtypes.astype import astype_is_view
from pandas.core.dtypes.cast import (
LossySetitemError,
construct_1d_arraylike_from_scalar,
find_common_type,
infer_dtype_from,
maybe_box_native,
maybe_cast_pointwise_result,
)
Expand All @@ -84,7 +87,10 @@
CategoricalDtype,
ExtensionDtype,
)
from pandas.core.dtypes.generic import ABCDataFrame
from pandas.core.dtypes.generic import (
ABCDataFrame,
ABCSeries,
)
from pandas.core.dtypes.inference import is_hashable
from pandas.core.dtypes.missing import (
isna,
Expand Down Expand Up @@ -113,6 +119,7 @@
from pandas.core.arrays.sparse import SparseAccessor
from pandas.core.arrays.string_ import StringDtype
from pandas.core.construction import (
array as pd_array,
extract_array,
sanitize_array,
)
Expand Down Expand Up @@ -5627,6 +5634,121 @@ def between(

return lmask & rmask

def case_when(
self,
caselist: list[
tuple[
ArrayLike | Callable[[Series], Series | np.ndarray | Sequence[bool]],
ArrayLike | Scalar | Callable[[Series], Series | np.ndarray],
],
],
) -> Series:
"""
Replace values where the conditions are True.

Parameters
----------
caselist : A list of tuples of conditions and expected replacements
Takes the form: ``(condition0, replacement0)``,
``(condition1, replacement1)``, ... .
``condition`` should be a 1-D boolean array-like object
or a callable. If ``condition`` is a callable,
it is computed on the Series
and should return a boolean Series or array.
The callable must not change the input Series
(though pandas doesn`t check it). ``replacement`` should be a
1-D array-like object, a scalar or a callable.
If ``replacement`` is a callable, it is computed on the Series
and should return a scalar or Series. The callable
must not change the input Series
(though pandas doesn`t check it).

.. versionadded:: 2.2.0

Returns
-------
Series

See Also
--------
Series.mask : Replace values where the condition is True.

Examples
--------
>>> c = pd.Series([6, 7, 8, 9], name='c')
>>> a = pd.Series([0, 0, 1, 2])
>>> b = pd.Series([0, 3, 4, 5])

>>> c.case_when(caselist=[(a.gt(0), a), # condition, replacement
... (b.gt(0), b)])
0 6
1 3
2 1
3 2
Name: c, dtype: int64
"""
if not isinstance(caselist, list):
raise TypeError(
f"The caselist argument should be a list; instead got {type(caselist)}"
)

if not caselist:
raise ValueError(
"provide at least one boolean condition, "
"with a corresponding replacement."
)

for num, entry in enumerate(caselist):
if not isinstance(entry, tuple):
raise TypeError(
f"Argument {num} must be a tuple; instead got {type(entry)}."
)
if len(entry) != 2:
raise ValueError(
f"Argument {num} must have length 2; "
"a condition and replacement; "
f"instead got length {len(entry)}."
)
caselist = [
(
com.apply_if_callable(condition, self),
com.apply_if_callable(replacement, self),
)
for condition, replacement in caselist
]
default = self.copy()
conditions, replacements = zip(*caselist)
common_dtypes = [infer_dtype_from(arg)[0] for arg in [*replacements, default]]
if len(set(common_dtypes)) > 1:
common_dtype = find_common_type(common_dtypes)
updated_replacements = []
for condition, replacement in zip(conditions, replacements):
if is_scalar(replacement):
replacement = construct_1d_arraylike_from_scalar(
value=replacement, length=len(condition), dtype=common_dtype
)
elif isinstance(replacement, ABCSeries):
replacement = replacement.astype(common_dtype)
else:
replacement = pd_array(replacement, dtype=common_dtype)
updated_replacements.append(replacement)
replacements = updated_replacements
default = default.astype(common_dtype)

counter = reversed(range(len(conditions)))
for position, condition, replacement in zip(
counter, conditions[::-1], replacements[::-1]
):
try:
default = default.mask(
condition, other=replacement, axis=0, inplace=False, level=None
)
except Exception as error:
raise ValueError(
f"Failed to apply condition{position} and replacement{position}."
) from error
return default

# error: Cannot determine type of 'isna'
@doc(NDFrame.isna, klass=_shared_doc_kwargs["klass"]) # type: ignore[has-type]
def isna(self) -> Series:
Expand Down
148 changes: 148 additions & 0 deletions pandas/tests/series/methods/test_case_when.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import numpy as np
import pytest

from pandas import (
DataFrame,
Series,
array as pd_array,
date_range,
)
import pandas._testing as tm


@pytest.fixture
def df():
"""
base dataframe for testing
"""
return DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})


def test_case_when_caselist_is_not_a_list(df):
"""
Raise ValueError if caselist is not a list.
"""
msg = "The caselist argument should be a list; "
msg += "instead got.+"
with pytest.raises(TypeError, match=msg): # GH39154
df["a"].case_when(caselist=())


def test_case_when_no_caselist(df):
"""
Raise ValueError if no caselist is provided.
"""
msg = "provide at least one boolean condition, "
msg += "with a corresponding replacement."
with pytest.raises(ValueError, match=msg): # GH39154
df["a"].case_when([])


def test_case_when_odd_caselist(df):
"""
Raise ValueError if no of caselist is odd.
"""
msg = "Argument 0 must have length 2; "
msg += "a condition and replacement; instead got length 3."

with pytest.raises(ValueError, match=msg):
df["a"].case_when([(df["a"].eq(1), 1, df.a.gt(1))])


def test_case_when_raise_error_from_mask(df):
"""
Raise Error from within Series.mask
"""
msg = "Failed to apply condition0 and replacement0."
with pytest.raises(ValueError, match=msg):
df["a"].case_when([(df["a"].eq(1), [1, 2])])


def test_case_when_single_condition(df):
"""
Test output on a single condition.
"""
result = Series([np.nan, np.nan, np.nan]).case_when([(df.a.eq(1), 1)])
expected = Series([1, np.nan, np.nan])
tm.assert_series_equal(result, expected)


def test_case_when_multiple_conditions(df):
"""
Test output when booleans are derived from a computation
"""
result = Series([np.nan, np.nan, np.nan]).case_when(
[(df.a.eq(1), 1), (Series([False, True, False]), 2)]
)
expected = Series([1, 2, np.nan])
tm.assert_series_equal(result, expected)


def test_case_when_multiple_conditions_replacement_list(df):
"""
Test output when replacement is a list
"""
result = Series([np.nan, np.nan, np.nan]).case_when(
[([True, False, False], 1), (df["a"].gt(1) & df["b"].eq(5), [1, 2, 3])]
)
expected = Series([1, 2, np.nan])
tm.assert_series_equal(result, expected)


def test_case_when_multiple_conditions_replacement_extension_dtype(df):
"""
Test output when replacement has an extension dtype
"""
result = Series([np.nan, np.nan, np.nan]).case_when(
[
([True, False, False], 1),
(df["a"].gt(1) & df["b"].eq(5), pd_array([1, 2, 3], dtype="Int64")),
],
)
expected = Series([1, 2, np.nan], dtype="Float64")
tm.assert_series_equal(result, expected)


def test_case_when_multiple_conditions_replacement_series(df):
"""
Test output when replacement is a Series
"""
result = Series([np.nan, np.nan, np.nan]).case_when(
[
(np.array([True, False, False]), 1),
(df["a"].gt(1) & df["b"].eq(5), Series([1, 2, 3])),
],
)
expected = Series([1, 2, np.nan])
tm.assert_series_equal(result, expected)


def test_case_when_non_range_index():
"""
Test output if index is not RangeIndex
"""
rng = np.random.default_rng(seed=123)
dates = date_range("1/1/2000", periods=8)
df = DataFrame(
rng.standard_normal(size=(8, 4)), index=dates, columns=["A", "B", "C", "D"]
)
result = Series(5, index=df.index, name="A").case_when([(df.A.gt(0), df.B)])
expected = df.A.mask(df.A.gt(0), df.B).where(df.A.gt(0), 5)
tm.assert_series_equal(result, expected)


def test_case_when_callable():
"""
Test output on a callable
"""
# https://numpy.org/doc/stable/reference/generated/numpy.piecewise.html
x = np.linspace(-2.5, 2.5, 6)
ser = Series(x)
result = ser.case_when(
caselist=[
(lambda df: df < 0, lambda df: -df),
(lambda df: df >= 0, lambda df: df),
]
)
expected = np.piecewise(x, [x < 0, x >= 0], [lambda x: -x, lambda x: x])
tm.assert_series_equal(result, Series(expected))