Skip to content

ENH: Add support for multiple conditions assign statement #50343

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 18 commits into from
Closed
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
1f5156e
Add case_when API
ELHoussineT Dec 19, 2022
9e0238f
fixup! Add case_when API * Used to support conditional assignment ope…
ELHoussineT Jan 6, 2023
45352da
fixup! Add case_when API * Used to support conditional assignment ope…
ELHoussineT Jan 7, 2023
0e053ef
fixup! Add case_when API * Used to support conditional assignment ope…
ELHoussineT Mar 8, 2023
3e6dc99
fixup! Add case_when API * Used to support conditional assignment ope…
ELHoussineT Mar 9, 2023
cdec601
fixup! Add case_when API * Used to support conditional assignment ope…
ELHoussineT Mar 9, 2023
df1845d
fixup! Add case_when API * Used to support conditional assignment ope…
ELHoussineT Mar 9, 2023
fd12a30
fixup! Add case_when API * Used to support conditional assignment ope…
ELHoussineT Mar 16, 2023
b897b8e
fixup! Add case_when API * Used to support conditional assignment ope…
ELHoussineT Mar 16, 2023
d01d6e8
fixup! Add case_when API * Used to support conditional assignment ope…
ELHoussineT Mar 16, 2023
da79aa6
fixup! Add case_when API * Used to support conditional assignment ope…
ELHoussineT Mar 19, 2023
eb70eaf
fixup! Add case_when API * Used to support conditional assignment ope…
ELHoussineT Mar 19, 2023
f2bf4d8
fixup! Add case_when API * Used to support conditional assignment ope…
ELHoussineT Apr 2, 2023
a5d678f
fixup! Add case_when API * Used to support conditional assignment ope…
ELHoussineT Apr 2, 2023
072bc0c
Revert "fixup! Add case_when API * Used to support conditional assign…
ELHoussineT Apr 13, 2023
747b379
fixup! Add case_when API * Used to support conditional assignment ope…
ELHoussineT Apr 13, 2023
69de3f1
Merge branch 'pandas-dev:main' into conditional-assignment
ELHoussineT May 31, 2023
7d19fb8
Merge branch 'pandas-dev:main' into conditional-assignment
ELHoussineT Jul 19, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions doc/source/whatsnew/v2.0.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,26 @@ including other versions of pandas.
Enhancements
~~~~~~~~~~~~

.. _whatsnew_200.enhancements.case_when:

Assignment based on multiple conditions
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

The ``pd.case_when`` API has now been added to support assignment based on multiple conditions.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a proper link to the API documentation: :func:`case_when`

Also, can you refer to this as a function rather than an API.


.. ipython:: python

import pandas as pd

df = pd.DataFrame(dict(a=[1, 2, 3], b=[4, 5, 6]))
df.assign(
new_column=pd.case_when(
lambda x: x.a == 1, 'first',
lambda x: (x.a > 1) & (x.b == 5), 'second',
default='default',
)
)

.. _whatsnew_200.enhancements.optional_dependency_management_pip:

Installing optional dependencies with pip extras
Expand Down
2 changes: 2 additions & 0 deletions pandas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
notnull,
# indexes
Index,
case_when,
CategoricalIndex,
RangeIndex,
MultiIndex,
Expand Down Expand Up @@ -231,6 +232,7 @@
__all__ = [
"ArrowDtype",
"BooleanDtype",
"case_when",
"Categorical",
"CategoricalDtype",
"CategoricalIndex",
Expand Down
3 changes: 3 additions & 0 deletions pandas/core/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
UInt64Dtype,
)
from pandas.core.arrays.string_ import StringDtype
from pandas.core.case_when import case_when
from pandas.core.construction import array
from pandas.core.flags import Flags
from pandas.core.groupby import (
Expand Down Expand Up @@ -80,11 +81,13 @@
# DataFrame needs to be imported after NamedAgg to avoid a circular import
from pandas.core.frame import DataFrame # isort:skip


__all__ = [
"array",
"ArrowDtype",
"bdate_range",
"BooleanDtype",
"case_when",
"Categorical",
"CategoricalDtype",
"CategoricalIndex",
Expand Down
167 changes: 167 additions & 0 deletions pandas/core/case_when.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
from __future__ import annotations

from typing import Any
import warnings

from pandas.util._exceptions import find_stack_level

import pandas as pd
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't import all of pandas, only what you need from various modules.

import pandas.core.common as com


def warn_and_override_index(series, series_type, index):
warnings.warn(
f"Series {series_type} will be reindexed to match obj index.",
UserWarning,
stacklevel=find_stack_level(),
)
return pd.Series(series.values, index=index)


def case_when(obj: pd.DataFrame | pd.Series, *args, default: Any) -> pd.Series:
"""
Returns a Series based on multiple conditions assignment.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we say "Construct" instead of Returns. Also, "multiple conditions assignment" sounds off to me, I would recommend just "multiple conditions".


This is useful when you want to assign a column based on multiple conditions.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be used independently of assigning a column, I suggest this be removed.

Uses `Series.mask` to perform the assignment.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is an implementation detail.


The returned Series have the same index as `obj`.

Parameters
----------
obj : Dataframe or Series on which the conditions will be applied.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first line should just be the dtype; a description of the argument should appear on the subsequent line.

args : Variable argument of conditions and expected values.
Takes the form:
`condition0`, `value0`, `condition1`, `value1`, ...
`condition` can be a 1-D boolean array/series or a callable
that evaluate to a 1-D boolean array/series. See examples below.
default : Any
The default value to be used if all conditions evaluate False. This value
will be used to create the `Series` on which `Series.mask` will be called.
Comment on lines +42 to +43
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Series.mask is an implementation detail.

If this value is not already an array like (i.e. it is not of type `Series`,
`np.array` or `list`) it will be repeated `obj.shape[0]` times in order to
create an array like object from it and then apply the `Series.mask`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here


Returns
-------
Series
Series with the corresponding values based on the conditions, their values
and the default value.


Examples
--------
>>> df = pd.DataFrame(dict(a=[1, 2, 3], b=[4, 5, 6]))
>>> df
a b
0 1 4
1 2 5
2 3 6

>>> pd.case_when(
... df,
... lambda x: x.a == 1,
... 'first',
... lambda x: (x.a == 2) & (x.b == 5),
... 'second',
... default='default',
... )
0 first
1 second
2 default
dtype: object

>>> pd.case_when(
... df,
... lambda x: (x.a == 1) & (x.b == 4),
... df.b,
... default=0,
... )
0 4
1 0
2 0
dtype: int64

>>> pd.case_when(
... df,
... lambda x: (x.a > 1) & (x.b > 1),
... -1,
... default=df.a,
... )
0 1
1 -1
2 -1
Name: a, dtype: int64

>>> pd.case_when(
... df.a,
... lambda x: x == 1,
... -1,
... default=df.a,
... )
0 -1
1 2
2 3
Name: a, dtype: int64

>>> pd.case_when(
... df.a,
... df.a > 1,
... -1,
... default=df.a,
... )
0 1
1 -1
2 -1
Name: a, dtype: int64

The index will always follow that of `obj`. For example:
>>> df = pd.DataFrame(
... dict(a=[1, 2, 3], b=[4, 5, 6]),
... index=['index 1', 'index 2', 'index 3']
... )
>>> df
a b
index 1 1 4
index 2 2 5
index 3 3 6

>>> pd.case_when(
... df,
... lambda x: (x.a == 1) & (x.b == 4),
... df.b,
... default=0,
... )
index 1 4
index 2 0
index 3 0
dtype: int64
"""
len_args = len(args)

if len_args < 2:
raise ValueError("At least two arguments are required for `case_when`")
if len_args % 2:
raise ValueError(
"The number of conditions and values do not match. "
f"There are {len_args - len_args//2} conditions "
f"and {len_args//2} values."
)

# construct series on which we will apply `Series.mask`
series = pd.Series(default, index=obj.index)

for i in range(0, len_args, 2):
# get conditions
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is essentially repeating the code; can you remove.

if callable(args[i]):
conditions = com.apply_if_callable(args[i], obj)
else:
conditions = args[i]

# get replacements
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same.

replacements = args[i + 1]

# `Series.mask` call
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same

series = series.mask(conditions, replacements)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe we want to only modify the value on the first conditions that evaluates to True (correct me if you think this is wrong!); this will modify it for all conditions. One approach is to maintain a modified that starts as pd.Series(False, index=self.index). Then this line can become series = series.mask(~modified & conditions, replacements). After this line, we also need to update modified |= conditions.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Updated (f2bf4d8).

Now, if multiple conditions are met, the value of the first one is used. For example:

>>> df 
         a  b
index 1  1  4
index 2  2  5
index 3  3  6

>>> pd.case_when(
... df,
... lambda x: x.a > 0,
... 1,
... lambda x: x.a == 1, 
... -1,
... default='default'
)
index 1    1
index 2    1
index 3    1
dtype: object


return series
1 change: 1 addition & 0 deletions pandas/tests/api/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ class TestPDApi(Base):
funcs = [
"array",
"bdate_range",
"case_when",
"concat",
"crosstab",
"cut",
Expand Down
61 changes: 61 additions & 0 deletions pandas/tests/test_case_when.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import numpy as np
import pytest # noqa
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed? Include the error code when using noqa.


import pandas as pd
import pandas._testing as tm


class TestCaseWhen:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If not making use of the class, then use test functions instead of methods.

def test_case_when_multiple_conditions_callable(self):
df = pd.DataFrame(dict(a=[1, 2, 3], b=[4, 5, 6]))
result = df.assign(
new_column=pd.case_when(
lambda x: x.a == 1,
1,
lambda x: (x.a > 1) & (x.b == 5),
2,
)
)
expected = df.assign(new_column=[1, 2, np.nan])
tm.assert_frame_equal(result, expected)

def test_case_when_multiple_conditions_array_series(self):
df = pd.DataFrame(dict(a=[1, 2, 3], b=[4, 5, 6]))
result = df.assign(
new_column=pd.case_when(
[True, False, False],
1,
pd.Series([False, True, False]),
2,
)
)
expected = df.assign(new_column=[1, 2, np.nan])
tm.assert_frame_equal(result, expected)

def test_case_when_multiple_conditions_callable_default(self):
df = pd.DataFrame(dict(a=[1, 2, 3], b=[4, 5, 6]))
result = df.assign(
new_column=pd.case_when(
lambda x: x.a == 1,
1,
lambda x: (x.a > 1) & (x.b == 5),
2,
default=-1,
)
)
expected = df.assign(new_column=[1, 2, -1])
tm.assert_frame_equal(result, expected)

def test_case_when_multiple_conditions_callable_default_series(self):
df = pd.DataFrame(dict(a=[1, 2, 3], b=[4, 5, 6]))
result = df.assign(
new_column=pd.case_when(
lambda x: x.a == 1,
1,
lambda x: (x.a > 1) & (x.b == 5),
2,
default=df.b,
)
)
expected = df.assign(new_column=[1, 2, 6])
tm.assert_frame_equal(result, expected)