-
-
Notifications
You must be signed in to change notification settings - Fork 18.5k
ENH: Add support for multiple conditions assign statement #50343
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 12 commits
1f5156e
9e0238f
45352da
0e053ef
3e6dc99
cdec601
df1845d
fd12a30
b897b8e
d01d6e8
da79aa6
eb70eaf
f2bf4d8
a5d678f
072bc0c
747b379
69de3f1
7d19fb8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Any | ||
import warnings | ||
|
||
from pandas.util._exceptions import find_stack_level | ||
|
||
import pandas as pd | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't import all of pandas, only what you need from various modules. |
||
import pandas.core.common as com | ||
|
||
|
||
def warn_and_override_index(series, series_type, index): | ||
warnings.warn( | ||
f"Series {series_type} will be reindexed to match obj index.", | ||
UserWarning, | ||
stacklevel=find_stack_level(), | ||
) | ||
return pd.Series(series.values, index=index) | ||
|
||
|
||
def case_when(obj: pd.DataFrame | pd.Series, *args, default: Any) -> pd.Series: | ||
""" | ||
Returns a Series based on multiple conditions assignment. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we say "Construct" instead of Returns. Also, "multiple conditions assignment" sounds off to me, I would recommend just "multiple conditions". |
||
|
||
This is useful when you want to assign a column based on multiple conditions. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This can be used independently of assigning a column, I suggest this be removed. |
||
Uses `Series.mask` to perform the assignment. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is an implementation detail. |
||
|
||
The returned Series have the same index as `obj`. | ||
|
||
Parameters | ||
---------- | ||
obj : Dataframe or Series on which the conditions will be applied. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The first line should just be the dtype; a description of the argument should appear on the subsequent line. |
||
args : Variable argument of conditions and expected values. | ||
Takes the form: | ||
`condition0`, `value0`, `condition1`, `value1`, ... | ||
`condition` can be a 1-D boolean array/series or a callable | ||
that evaluate to a 1-D boolean array/series. See examples below. | ||
rhshadrach marked this conversation as resolved.
Show resolved
Hide resolved
|
||
default : Any | ||
The default value to be used if all conditions evaluate False. This value | ||
will be used to create the `Series` on which `Series.mask` will be called. | ||
Comment on lines
+42
to
+43
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Series.mask is an implementation detail. |
||
If this value is not already an array like (i.e. it is not of type `Series`, | ||
`np.array` or `list`) it will be repeated `obj.shape[0]` times in order to | ||
create an array like object from it and then apply the `Series.mask`. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here |
||
|
||
Returns | ||
------- | ||
Series | ||
Series with the corresponding values based on the conditions, their values | ||
and the default value. | ||
rhshadrach marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
Examples | ||
-------- | ||
>>> df = pd.DataFrame(dict(a=[1, 2, 3], b=[4, 5, 6])) | ||
>>> df | ||
a b | ||
0 1 4 | ||
1 2 5 | ||
2 3 6 | ||
|
||
>>> pd.case_when( | ||
... df, | ||
... lambda x: x.a == 1, | ||
... 'first', | ||
... lambda x: (x.a == 2) & (x.b == 5), | ||
... 'second', | ||
... default='default', | ||
... ) | ||
0 first | ||
1 second | ||
2 default | ||
dtype: object | ||
|
||
>>> pd.case_when( | ||
... df, | ||
... lambda x: (x.a == 1) & (x.b == 4), | ||
... df.b, | ||
... default=0, | ||
... ) | ||
0 4 | ||
1 0 | ||
2 0 | ||
dtype: int64 | ||
|
||
>>> pd.case_when( | ||
... df, | ||
... lambda x: (x.a > 1) & (x.b > 1), | ||
... -1, | ||
... default=df.a, | ||
... ) | ||
0 1 | ||
1 -1 | ||
2 -1 | ||
Name: a, dtype: int64 | ||
|
||
>>> pd.case_when( | ||
... df.a, | ||
... lambda x: x == 1, | ||
... -1, | ||
... default=df.a, | ||
... ) | ||
0 -1 | ||
1 2 | ||
2 3 | ||
Name: a, dtype: int64 | ||
|
||
>>> pd.case_when( | ||
... df.a, | ||
... df.a > 1, | ||
... -1, | ||
... default=df.a, | ||
... ) | ||
0 1 | ||
1 -1 | ||
2 -1 | ||
Name: a, dtype: int64 | ||
|
||
The index will always follow that of `obj`. For example: | ||
>>> df = pd.DataFrame( | ||
... dict(a=[1, 2, 3], b=[4, 5, 6]), | ||
... index=['index 1', 'index 2', 'index 3'] | ||
... ) | ||
>>> df | ||
a b | ||
index 1 1 4 | ||
index 2 2 5 | ||
index 3 3 6 | ||
|
||
>>> pd.case_when( | ||
... df, | ||
... lambda x: (x.a == 1) & (x.b == 4), | ||
... df.b, | ||
... default=0, | ||
... ) | ||
index 1 4 | ||
index 2 0 | ||
index 3 0 | ||
dtype: int64 | ||
""" | ||
len_args = len(args) | ||
|
||
if len_args < 2: | ||
raise ValueError("At least two arguments are required for `case_when`") | ||
if len_args % 2: | ||
raise ValueError( | ||
"The number of conditions and values do not match. " | ||
f"There are {len_args - len_args//2} conditions " | ||
f"and {len_args//2} values." | ||
) | ||
|
||
# construct series on which we will apply `Series.mask` | ||
series = pd.Series(default, index=obj.index) | ||
|
||
for i in range(0, len_args, 2): | ||
# get conditions | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This comment is essentially repeating the code; can you remove. |
||
if callable(args[i]): | ||
conditions = com.apply_if_callable(args[i], obj) | ||
else: | ||
conditions = args[i] | ||
|
||
# get replacements | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same. |
||
replacements = args[i + 1] | ||
|
||
# `Series.mask` call | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same |
||
series = series.mask(conditions, replacements) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe we want to only modify the value on the first There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point. Updated (f2bf4d8). Now, if multiple conditions are met, the value of the first one is used. For example:
|
||
|
||
return series |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -99,6 +99,7 @@ class TestPDApi(Base): | |
funcs = [ | ||
"array", | ||
"bdate_range", | ||
"case_when", | ||
"concat", | ||
"crosstab", | ||
"cut", | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import numpy as np | ||
import pytest # noqa | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this needed? Include the error code when using noqa. |
||
|
||
import pandas as pd | ||
import pandas._testing as tm | ||
|
||
|
||
class TestCaseWhen: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If not making use of the class, then use test functions instead of methods. |
||
def test_case_when_multiple_conditions_callable(self): | ||
df = pd.DataFrame(dict(a=[1, 2, 3], b=[4, 5, 6])) | ||
result = df.assign( | ||
new_column=pd.case_when( | ||
lambda x: x.a == 1, | ||
1, | ||
lambda x: (x.a > 1) & (x.b == 5), | ||
2, | ||
) | ||
) | ||
expected = df.assign(new_column=[1, 2, np.nan]) | ||
tm.assert_frame_equal(result, expected) | ||
|
||
def test_case_when_multiple_conditions_array_series(self): | ||
df = pd.DataFrame(dict(a=[1, 2, 3], b=[4, 5, 6])) | ||
result = df.assign( | ||
new_column=pd.case_when( | ||
[True, False, False], | ||
1, | ||
pd.Series([False, True, False]), | ||
2, | ||
) | ||
) | ||
expected = df.assign(new_column=[1, 2, np.nan]) | ||
tm.assert_frame_equal(result, expected) | ||
|
||
def test_case_when_multiple_conditions_callable_default(self): | ||
df = pd.DataFrame(dict(a=[1, 2, 3], b=[4, 5, 6])) | ||
result = df.assign( | ||
new_column=pd.case_when( | ||
lambda x: x.a == 1, | ||
1, | ||
lambda x: (x.a > 1) & (x.b == 5), | ||
2, | ||
default=-1, | ||
) | ||
) | ||
expected = df.assign(new_column=[1, 2, -1]) | ||
tm.assert_frame_equal(result, expected) | ||
|
||
def test_case_when_multiple_conditions_callable_default_series(self): | ||
df = pd.DataFrame(dict(a=[1, 2, 3], b=[4, 5, 6])) | ||
result = df.assign( | ||
new_column=pd.case_when( | ||
lambda x: x.a == 1, | ||
1, | ||
lambda x: (x.a > 1) & (x.b == 5), | ||
2, | ||
default=df.b, | ||
) | ||
) | ||
expected = df.assign(new_column=[1, 2, 6]) | ||
tm.assert_frame_equal(result, expected) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a proper link to the API documentation:
:func:`case_when`
Also, can you refer to this as a function rather than an API.