Skip to content

ENH: case_when function #55306

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/wheels.yml
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ jobs:
run: echo "sdist_name=$(cd ./dist && ls -d */)" >> "$GITHUB_ENV"

- name: Build wheels
uses: pypa/[email protected].1
uses: pypa/[email protected].0
with:
package-dir: ./dist/${{ matrix.buildplat[1] == 'macosx_*' && env.sdist_name || needs.build_sdist.outputs.sdist_file }}
env:
Expand Down
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,10 @@ doc/source/index.rst
doc/build/html/index.html
# Windows specific leftover:
doc/tmp.sv
doc/tmp.csv
env/
doc/source/savefig/
doc/source/_build

# Interactive terminal generated files #
########################################
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ repos:
'--filter=-readability/casting,-runtime/int,-build/include_subdir,-readability/fn_size'
]
- repo: https://github.com/pylint-dev/pylint
rev: v3.0.0b0
rev: v3.0.0a7
hooks:
- id: pylint
stages: [manual]
Expand Down
13 changes: 13 additions & 0 deletions doc/source/reference/arrays.rst
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,11 @@ is the missing value for datetime data.

Timestamp

.. autosummary::
:toctree: api/

NaT

Properties
~~~~~~~~~~
.. autosummary::
Expand Down Expand Up @@ -252,6 +257,11 @@ is the missing value for timedelta data.

Timedelta

.. autosummary::
:toctree: api/

NaT

Properties
~~~~~~~~~~
.. autosummary::
Expand Down Expand Up @@ -455,6 +465,7 @@ pandas provides this through :class:`arrays.IntegerArray`.
UInt16Dtype
UInt32Dtype
UInt64Dtype
NA

.. _api.arrays.float_na:

Expand All @@ -473,6 +484,7 @@ Nullable float

Float32Dtype
Float64Dtype
NA

.. _api.arrays.categorical:

Expand Down Expand Up @@ -609,6 +621,7 @@ with a bool :class:`numpy.ndarray`.
:template: autosummary/class_without_autosummary.rst

BooleanDtype
NA


.. Dtype attributes which are manually listed in their docstrings: including
Expand Down
1 change: 0 additions & 1 deletion doc/source/reference/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ are mentioned in the documentation.
options
extensions
testing
missing_value

.. This is to prevent warnings in the doc build. We don't want to encourage
.. these methods.
Expand Down
24 changes: 0 additions & 24 deletions doc/source/reference/missing_value.rst

This file was deleted.

2 changes: 1 addition & 1 deletion doc/source/whatsnew/v2.1.2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Fixed regressions
~~~~~~~~~~~~~~~~~
- Fixed bug in :meth:`DataFrame.resample` where bin edges were not correct for :class:`~pandas.tseries.offsets.MonthBegin` (:issue:`55271`)
- Fixed bug where PDEP-6 warning about setting an item of an incompatible dtype was being shown when creating a new conditional column (:issue:`55025`)
- Fixed regression in :meth:`DataFrame.join` where result has missing values and dtype is arrow backed string (:issue:`55348`)
-

.. ---------------------------------------------------------------------------
.. _whatsnew_212.bug_fixes:
Expand Down
27 changes: 23 additions & 4 deletions doc/source/whatsnew/v2.2.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,27 @@ including other versions of pandas.
Enhancements
~~~~~~~~~~~~


.. _whatsnew_220.enhancements.case_when:

Create Series based on one or more conditions
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

The :func:`case_when` function has been added to create a Series object based on one or more conditions. (:issue:`39154`)

.. ipython:: python

import pandas as pd

df = pd.DataFrame(dict(a=[1, 2, 3], b=[4, 5, 6]))
df.assign(
new_column=pd.case_when(
df.a == 1, 'first', # condition, replacement
df.a.gt(1) & df.b.eq(5), 'second',
default='default', # optional
)
)

.. _whatsnew_220.enhancements.calamine:

Calamine engine for :func:`read_excel`
Expand Down Expand Up @@ -78,7 +99,6 @@ Other enhancements
- :meth:`ExtensionArray._explode` interface method added to allow extension type implementations of the ``explode`` method (:issue:`54833`)
- :meth:`ExtensionArray.duplicated` added to allow extension type implementations of the ``duplicated`` method (:issue:`55255`)
- DataFrame.apply now allows the usage of numba (via ``engine="numba"``) to JIT compile the passed function, allowing for potential speedups (:issue:`54666`)
- Implement masked algorithms for :meth:`Series.value_counts` (:issue:`54984`)
-

.. ---------------------------------------------------------------------------
Expand Down Expand Up @@ -220,7 +240,6 @@ Other Deprecations
- Deprecated allowing non-keyword arguments in :meth:`DataFrame.to_parquet` except ``path``. (:issue:`54229`)
- Deprecated allowing non-keyword arguments in :meth:`DataFrame.to_pickle` except ``path``. (:issue:`54229`)
- Deprecated allowing non-keyword arguments in :meth:`DataFrame.to_string` except ``buf``. (:issue:`54229`)
- Deprecated allowing non-keyword arguments in :meth:`DataFrame.to_xml` except ``path_or_buffer``. (:issue:`54229`)
- Deprecated automatic downcasting of object-dtype results in :meth:`Series.replace` and :meth:`DataFrame.replace`, explicitly call ``result = result.infer_objects(copy=False)`` instead. To opt in to the future version, use ``pd.set_option("future.no_silent_downcasting", True)`` (:issue:`54710`)
- Deprecated downcasting behavior in :meth:`Series.where`, :meth:`DataFrame.where`, :meth:`Series.mask`, :meth:`DataFrame.mask`, :meth:`Series.clip`, :meth:`DataFrame.clip`; in a future version these will not infer object-dtype columns to non-object dtype, or all-round floats to integer dtype. Call ``result.infer_objects(copy=False)`` on the result for object inference, or explicitly cast floats to ints. To opt in to the future version, use ``pd.set_option("future.no_silent_downcasting", True)`` (:issue:`53656`)
- Deprecated including the groups in computations when using :meth:`DataFrameGroupBy.apply` and :meth:`DataFrameGroupBy.resample`; pass ``include_groups=False`` to exclude the groups (:issue:`7155`)
Expand Down Expand Up @@ -283,7 +302,7 @@ Numeric

Conversion
^^^^^^^^^^
- Bug in :meth:`Series.convert_dtypes` not converting all NA column to ``null[pyarrow]`` (:issue:`55346`)
-
-

Strings
Expand Down Expand Up @@ -312,7 +331,7 @@ Missing

MultiIndex
^^^^^^^^^^
- Bug in :meth:`MultiIndex.get_indexer` not raising ``ValueError`` when ``method`` provided and index is non-monotonic (:issue:`53452`)
-
-

I/O
Expand Down
2 changes: 2 additions & 0 deletions pandas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
notnull,
# indexes
Index,
case_when,
CategoricalIndex,
RangeIndex,
MultiIndex,
Expand Down Expand Up @@ -252,6 +253,7 @@
__all__ = [
"ArrowDtype",
"BooleanDtype",
"case_when",
"Categorical",
"CategoricalDtype",
"CategoricalIndex",
Expand Down
2 changes: 1 addition & 1 deletion pandas/_libs/hashtable.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def value_count(
values: np.ndarray,
dropna: bool,
mask: npt.NDArray[np.bool_] | None = ...,
) -> tuple[np.ndarray, npt.NDArray[np.int64], int]: ... # np.ndarray[same-as-values]
) -> tuple[np.ndarray, npt.NDArray[np.int64]]: ... # np.ndarray[same-as-values]

# arr and values should have same dtype
def ismember(
Expand Down
41 changes: 17 additions & 24 deletions pandas/_libs/hashtable_func_helper.pxi.in
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ cdef value_count_{{dtype}}(ndarray[{{dtype}}] values, bint dropna, const uint8_t
cdef value_count_{{dtype}}(const {{dtype}}_t[:] values, bint dropna, const uint8_t[:] mask=None):
{{endif}}
cdef:
Py_ssize_t i = 0, na_counter = 0, na_add = 0
Py_ssize_t i = 0
Py_ssize_t n = len(values)
kh_{{ttype}}_t *table

Expand All @@ -49,6 +49,9 @@ cdef value_count_{{dtype}}(const {{dtype}}_t[:] values, bint dropna, const uint8
bint uses_mask = mask is not None
bint isna_entry = False

if uses_mask and not dropna:
raise NotImplementedError("uses_mask not implemented with dropna=False")

# we track the order in which keys are first seen (GH39009),
# khash-map isn't insertion-ordered, thus:
# table maps keys to counts
Expand Down Expand Up @@ -79,31 +82,25 @@ cdef value_count_{{dtype}}(const {{dtype}}_t[:] values, bint dropna, const uint8
for i in range(n):
val = {{to_c_type}}(values[i])

if uses_mask:
isna_entry = mask[i]

if dropna:
if not uses_mask:
if uses_mask:
isna_entry = mask[i]
else:
isna_entry = is_nan_{{c_type}}(val)

if not dropna or not isna_entry:
if uses_mask and isna_entry:
na_counter += 1
k = kh_get_{{ttype}}(table, val)
if k != table.n_buckets:
table.vals[k] += 1
else:
k = kh_get_{{ttype}}(table, val)
if k != table.n_buckets:
table.vals[k] += 1
else:
k = kh_put_{{ttype}}(table, val, &ret)
table.vals[k] = 1
result_keys.append(val)
k = kh_put_{{ttype}}(table, val, &ret)
table.vals[k] = 1
result_keys.append(val)
{{endif}}

# collect counts in the order corresponding to result_keys:
if na_counter > 0:
na_add = 1
cdef:
int64_t[::1] result_counts = np.empty(table.size + na_add, dtype=np.int64)
int64_t[::1] result_counts = np.empty(table.size, dtype=np.int64)

for i in range(table.size):
{{if dtype == 'object'}}
Expand All @@ -113,13 +110,9 @@ cdef value_count_{{dtype}}(const {{dtype}}_t[:] values, bint dropna, const uint8
{{endif}}
result_counts[i] = table.vals[k]

if na_counter > 0:
result_counts[table.size] = na_counter
result_keys.append(val)

kh_destroy_{{ttype}}(table)

return result_keys.to_array(), result_counts.base, na_counter
return result_keys.to_array(), result_counts.base


@cython.wraparound(False)
Expand Down Expand Up @@ -406,10 +399,10 @@ def mode(ndarray[htfunc_t] values, bint dropna, const uint8_t[:] mask=None):
ndarray[htfunc_t] modes

int64_t[::1] counts
int64_t count, _, max_count = -1
int64_t count, max_count = -1
Py_ssize_t nkeys, k, j = 0

keys, counts, _ = value_count(values, dropna, mask=mask)
keys, counts = value_count(values, dropna, mask=mask)
nkeys = len(keys)

modes = np.empty(nkeys, dtype=values.dtype)
Expand Down
10 changes: 7 additions & 3 deletions pandas/_libs/tslib.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -455,18 +455,18 @@ cpdef array_to_datetime(
set out_tzoffset_vals = set()
tzinfo tz_out = None
bint found_tz = False, found_naive = False
cnp.flatiter it = cnp.PyArray_IterNew(values)
cnp.broadcast mi

# specify error conditions
assert is_raise or is_ignore or is_coerce

result = np.empty((<object>values).shape, dtype="M8[ns]")
mi = cnp.PyArray_MultiIterNew2(result, values)
iresult = result.view("i8").ravel()

for i in range(n):
# Analogous to `val = values[i]`
val = cnp.PyArray_GETITEM(values, cnp.PyArray_ITER_DATA(it))
cnp.PyArray_ITER_NEXT(it)
val = <object>(<PyObject**>cnp.PyArray_MultiIter_DATA(mi, 1))[0]

try:
if checknull_with_nat_and_na(val):
Expand Down Expand Up @@ -511,6 +511,7 @@ cpdef array_to_datetime(
if parse_today_now(val, &iresult[i], utc):
# We can't _quite_ dispatch this to convert_str_to_tsobject
# bc there isn't a nice way to pass "utc"
cnp.PyArray_MultiIter_NEXT(mi)
continue

_ts = convert_str_to_tsobject(
Expand Down Expand Up @@ -539,10 +540,13 @@ cpdef array_to_datetime(
else:
raise TypeError(f"{type(val)} is not convertible to datetime")

cnp.PyArray_MultiIter_NEXT(mi)

except (TypeError, OverflowError, ValueError) as ex:
ex.args = (f"{ex}, at position {i}",)
if is_coerce:
iresult[i] = NPY_NAT
cnp.PyArray_MultiIter_NEXT(mi)
continue
elif is_raise:
raise
Expand Down
Loading