Skip to content

Commit baad046

Browse files
TomAugspurgerjreback
authored andcommitted
BUG/Perf: Support ExtensionArrays in where (#24114)
Closes #24077
1 parent c5a4711 commit baad046

File tree

15 files changed

+323
-11
lines changed

15 files changed

+323
-11
lines changed

doc/source/whatsnew/v0.24.0.rst

+6
Original file line numberDiff line numberDiff line change
@@ -675,6 +675,7 @@ changes were made:
675675
- ``SparseDataFrame.combine`` and ``DataFrame.combine_first`` no longer supports combining a sparse column with a dense column while preserving the sparse subtype. The result will be an object-dtype SparseArray.
676676
- Setting :attr:`SparseArray.fill_value` to a fill value with a different dtype is now allowed.
677677
- ``DataFrame[column]`` is now a :class:`Series` with sparse values, rather than a :class:`SparseSeries`, when slicing a single column with sparse values (:issue:`23559`).
678+
- The result of :meth:`Series.where` is now a ``Series`` with sparse values, like with other extension arrays (:issue:`24077`)
678679

679680
Some new warnings are issued for operations that require or are likely to materialize a large dense array:
680681

@@ -1113,6 +1114,8 @@ Deprecations
11131114
- :func:`pandas.types.is_datetimetz` is deprecated in favor of `pandas.types.is_datetime64tz` (:issue:`23917`)
11141115
- Creating a :class:`TimedeltaIndex` or :class:`DatetimeIndex` by passing range arguments `start`, `end`, and `periods` is deprecated in favor of :func:`timedelta_range` and :func:`date_range` (:issue:`23919`)
11151116
- Passing a string alias like ``'datetime64[ns, UTC]'`` as the `unit` parameter to :class:`DatetimeTZDtype` is deprecated. Use :class:`DatetimeTZDtype.construct_from_string` instead (:issue:`23990`).
1117+
- In :meth:`Series.where` with Categorical data, providing an ``other`` that is not present in the categories is deprecated. Convert the categorical to a different dtype or add the ``other`` to the categories first (:issue:`24077`).
1118+
11161119

11171120
.. _whatsnew_0240.deprecations.datetimelike_int_ops:
11181121

@@ -1223,6 +1226,7 @@ Performance Improvements
12231226
- Improved performance of :meth:`DatetimeIndex.tz_localize` and various ``DatetimeIndex`` attributes with dateutil UTC timezone (:issue:`23772`)
12241227
- Fixed a performance regression on Windows with Python 3.7 of :func:`pd.read_csv` (:issue:`23516`)
12251228
- Improved performance of :class:`Categorical` constructor for `Series` objects (:issue:`23814`)
1229+
- Improved performance of :meth:`~DataFrame.where` for Categorical data (:issue:`24077`)
12261230

12271231
.. _whatsnew_0240.docs:
12281232

@@ -1249,6 +1253,7 @@ Categorical
12491253
- In meth:`Series.unstack`, specifying a ``fill_value`` not present in the categories now raises a ``TypeError`` rather than ignoring the ``fill_value`` (:issue:`23284`)
12501254
- Bug when resampling :meth:`Dataframe.resample()` and aggregating on categorical data, the categorical dtype was getting lost. (:issue:`23227`)
12511255
- Bug in many methods of the ``.str``-accessor, which always failed on calling the ``CategoricalIndex.str`` constructor (:issue:`23555`, :issue:`23556`)
1256+
- Bug in :meth:`Series.where` losing the categorical dtype for categorical data (:issue:`24077`)
12521257

12531258
Datetimelike
12541259
^^^^^^^^^^^^
@@ -1285,6 +1290,7 @@ Datetimelike
12851290
- Bug in :class:`DatetimeIndex` where calling ``np.array(dtindex, dtype=object)`` would incorrectly return an array of ``long`` objects (:issue:`23524`)
12861291
- Bug in :class:`Index` where passing a timezone-aware :class:`DatetimeIndex` and `dtype=object` would incorrectly raise a ``ValueError`` (:issue:`23524`)
12871292
- Bug in :class:`Index` where calling ``np.array(dtindex, dtype=object)`` on a timezone-naive :class:`DatetimeIndex` would return an array of ``datetime`` objects instead of :class:`Timestamp` objects, potentially losing nanosecond portions of the timestamps (:issue:`23524`)
1293+
- Bug in :class:`Categorical.__setitem__` not allowing setting with another ``Categorical`` when both are undordered and have the same categories, but in a different order (:issue:`24142`)
12881294
- Bug in :func:`date_range` where using dates with millisecond resolution or higher could return incorrect values or the wrong number of values in the index (:issue:`24110`)
12891295

12901296
Timedelta

pandas/core/arrays/base.py

+2
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,8 @@ def __setitem__(self, key, value):
220220
# example, a string like '2018-01-01' is coerced to a datetime
221221
# when setting on a datetime64ns array. In general, if the
222222
# __init__ method coerces that value, then so should __setitem__
223+
# Note, also, that Series/DataFrame.where internally use __setitem__
224+
# on a copy of the data.
223225
raise NotImplementedError(_not_implemented_message.format(
224226
type(self), '__setitem__')
225227
)

pandas/core/arrays/categorical.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -2078,11 +2078,21 @@ def __setitem__(self, key, value):
20782078
`Categorical` does not have the same categories
20792079
"""
20802080

2081+
if isinstance(value, (ABCIndexClass, ABCSeries)):
2082+
value = value.array
2083+
20812084
# require identical categories set
20822085
if isinstance(value, Categorical):
2083-
if not value.categories.equals(self.categories):
2086+
if not is_dtype_equal(self, value):
20842087
raise ValueError("Cannot set a Categorical with another, "
20852088
"without identical categories")
2089+
if not self.categories.equals(value.categories):
2090+
new_codes = _recode_for_categories(
2091+
value.codes, value.categories, self.categories
2092+
)
2093+
value = Categorical.from_codes(new_codes,
2094+
categories=self.categories,
2095+
ordered=self.ordered)
20862096

20872097
rvalue = value if is_list_like(value) else [value]
20882098

pandas/core/arrays/sparse.py

+2
Original file line numberDiff line numberDiff line change
@@ -706,6 +706,8 @@ def __array__(self, dtype=None, copy=True):
706706

707707
def __setitem__(self, key, value):
708708
# I suppose we could allow setting of non-fill_value elements.
709+
# TODO(SparseArray.__setitem__): remove special cases in
710+
# ExtensionBlock.where
709711
msg = "SparseArray does not support item assignment via setitem"
710712
raise TypeError(msg)
711713

pandas/core/indexes/category.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -501,10 +501,13 @@ def _can_reindex(self, indexer):
501501

502502
@Appender(_index_shared_docs['where'])
503503
def where(self, cond, other=None):
504+
# TODO: Investigate an alternative implementation with
505+
# 1. copy the underyling Categorical
506+
# 2. setitem with `cond` and `other`
507+
# 3. Rebuild CategoricalIndex.
504508
if other is None:
505509
other = self._na_value
506510
values = np.where(cond, self.values, other)
507-
508511
cat = Categorical(values, dtype=self.dtype)
509512
return self._shallow_copy(cat, **self._get_attributes_dict())
510513

pandas/core/internals/blocks.py

+80-2
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@
2828
from pandas.core.dtypes.dtypes import (
2929
CategoricalDtype, DatetimeTZDtype, ExtensionDtype, PandasExtensionDtype)
3030
from pandas.core.dtypes.generic import (
31-
ABCDatetimeIndex, ABCExtensionArray, ABCIndexClass, ABCSeries)
31+
ABCDataFrame, ABCDatetimeIndex, ABCExtensionArray, ABCIndexClass,
32+
ABCSeries)
3233
from pandas.core.dtypes.missing import (
3334
_isna_compat, array_equivalent, is_null_datelike_scalar, isna, notna)
3435

@@ -1886,7 +1887,6 @@ def take_nd(self, indexer, axis=0, new_mgr_locs=None, fill_tuple=None):
18861887
new_values = self.values.take(indexer, fill_value=fill_value,
18871888
allow_fill=True)
18881889

1889-
# if we are a 1-dim object, then always place at 0
18901890
if self.ndim == 1 and new_mgr_locs is None:
18911891
new_mgr_locs = [0]
18921892
else:
@@ -1967,6 +1967,57 @@ def shift(self, periods, axis=0):
19671967
placement=self.mgr_locs,
19681968
ndim=self.ndim)]
19691969

1970+
def where(self, other, cond, align=True, errors='raise',
1971+
try_cast=False, axis=0, transpose=False):
1972+
# Extract the underlying arrays.
1973+
if isinstance(other, (ABCIndexClass, ABCSeries)):
1974+
other = other.array
1975+
1976+
elif isinstance(other, ABCDataFrame):
1977+
# ExtensionArrays are 1-D, so if we get here then
1978+
# `other` should be a DataFrame with a single column.
1979+
assert other.shape[1] == 1
1980+
other = other.iloc[:, 0].array
1981+
1982+
if isinstance(cond, ABCDataFrame):
1983+
assert cond.shape[1] == 1
1984+
cond = cond.iloc[:, 0].array
1985+
1986+
elif isinstance(cond, (ABCIndexClass, ABCSeries)):
1987+
cond = cond.array
1988+
1989+
if lib.is_scalar(other) and isna(other):
1990+
# The default `other` for Series / Frame is np.nan
1991+
# we want to replace that with the correct NA value
1992+
# for the type
1993+
other = self.dtype.na_value
1994+
1995+
if is_sparse(self.values):
1996+
# TODO(SparseArray.__setitem__): remove this if condition
1997+
# We need to re-infer the type of the data after doing the
1998+
# where, for cases where the subtypes don't match
1999+
dtype = None
2000+
else:
2001+
dtype = self.dtype
2002+
2003+
try:
2004+
result = self.values.copy()
2005+
icond = ~cond
2006+
if lib.is_scalar(other):
2007+
result[icond] = other
2008+
else:
2009+
result[icond] = other[icond]
2010+
except (NotImplementedError, TypeError):
2011+
# NotImplementedError for class not implementing `__setitem__`
2012+
# TypeError for SparseArray, which implements just to raise
2013+
# a TypeError
2014+
result = self._holder._from_sequence(
2015+
np.where(cond, self.values, other),
2016+
dtype=dtype,
2017+
)
2018+
2019+
return self.make_block_same_class(result, placement=self.mgr_locs)
2020+
19702021
@property
19712022
def _ftype(self):
19722023
return getattr(self.values, '_pandas_ftype', Block._ftype)
@@ -2658,6 +2709,33 @@ def concat_same_type(self, to_concat, placement=None):
26582709
values, placement=placement or slice(0, len(values), 1),
26592710
ndim=self.ndim)
26602711

2712+
def where(self, other, cond, align=True, errors='raise',
2713+
try_cast=False, axis=0, transpose=False):
2714+
# TODO(CategoricalBlock.where):
2715+
# This can all be deleted in favor of ExtensionBlock.where once
2716+
# we enforce the deprecation.
2717+
object_msg = (
2718+
"Implicitly converting categorical to object-dtype ndarray. "
2719+
"One or more of the values in 'other' are not present in this "
2720+
"categorical's categories. A future version of pandas will raise "
2721+
"a ValueError when 'other' contains different categories.\n\n"
2722+
"To preserve the current behavior, add the new categories to "
2723+
"the categorical before calling 'where', or convert the "
2724+
"categorical to a different dtype."
2725+
)
2726+
try:
2727+
# Attempt to do preserve categorical dtype.
2728+
result = super(CategoricalBlock, self).where(
2729+
other, cond, align, errors, try_cast, axis, transpose
2730+
)
2731+
except (TypeError, ValueError):
2732+
warnings.warn(object_msg, FutureWarning, stacklevel=6)
2733+
result = self.astype(object).where(other, cond, align=align,
2734+
errors=errors,
2735+
try_cast=try_cast,
2736+
axis=axis, transpose=transpose)
2737+
return result
2738+
26612739

26622740
class DatetimeBlock(DatetimeLikeBlockMixin, Block):
26632741
__slots__ = ()

pandas/tests/arrays/categorical/test_indexing.py

+94
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import numpy as np
44
import pytest
55

6+
import pandas as pd
67
from pandas import Categorical, CategoricalIndex, Index, PeriodIndex, Series
78
import pandas.core.common as com
89
from pandas.tests.arrays.categorical.common import TestCategorical
@@ -43,6 +44,45 @@ def test_setitem(self):
4344

4445
tm.assert_categorical_equal(c, expected)
4546

47+
@pytest.mark.parametrize('other', [
48+
pd.Categorical(['b', 'a']),
49+
pd.Categorical(['b', 'a'], categories=['b', 'a']),
50+
])
51+
def test_setitem_same_but_unordered(self, other):
52+
# GH-24142
53+
target = pd.Categorical(['a', 'b'], categories=['a', 'b'])
54+
mask = np.array([True, False])
55+
target[mask] = other[mask]
56+
expected = pd.Categorical(['b', 'b'], categories=['a', 'b'])
57+
tm.assert_categorical_equal(target, expected)
58+
59+
@pytest.mark.parametrize('other', [
60+
pd.Categorical(['b', 'a'], categories=['b', 'a', 'c']),
61+
pd.Categorical(['b', 'a'], categories=['a', 'b', 'c']),
62+
pd.Categorical(['a', 'a'], categories=['a']),
63+
pd.Categorical(['b', 'b'], categories=['b']),
64+
])
65+
def test_setitem_different_unordered_raises(self, other):
66+
# GH-24142
67+
target = pd.Categorical(['a', 'b'], categories=['a', 'b'])
68+
mask = np.array([True, False])
69+
with pytest.raises(ValueError):
70+
target[mask] = other[mask]
71+
72+
@pytest.mark.parametrize('other', [
73+
pd.Categorical(['b', 'a']),
74+
pd.Categorical(['b', 'a'], categories=['b', 'a'], ordered=True),
75+
pd.Categorical(['b', 'a'], categories=['a', 'b', 'c'], ordered=True),
76+
])
77+
def test_setitem_same_ordered_rasies(self, other):
78+
# Gh-24142
79+
target = pd.Categorical(['a', 'b'], categories=['a', 'b'],
80+
ordered=True)
81+
mask = np.array([True, False])
82+
83+
with pytest.raises(ValueError):
84+
target[mask] = other[mask]
85+
4686

4787
class TestCategoricalIndexing(object):
4888

@@ -122,6 +162,60 @@ def test_get_indexer_non_unique(self, idx_values, key_values, key_class):
122162
tm.assert_numpy_array_equal(expected, result)
123163
tm.assert_numpy_array_equal(exp_miss, res_miss)
124164

165+
def test_where_unobserved_nan(self):
166+
ser = pd.Series(pd.Categorical(['a', 'b']))
167+
result = ser.where([True, False])
168+
expected = pd.Series(pd.Categorical(['a', None],
169+
categories=['a', 'b']))
170+
tm.assert_series_equal(result, expected)
171+
172+
# all NA
173+
ser = pd.Series(pd.Categorical(['a', 'b']))
174+
result = ser.where([False, False])
175+
expected = pd.Series(pd.Categorical([None, None],
176+
categories=['a', 'b']))
177+
tm.assert_series_equal(result, expected)
178+
179+
def test_where_unobserved_categories(self):
180+
ser = pd.Series(
181+
Categorical(['a', 'b', 'c'], categories=['d', 'c', 'b', 'a'])
182+
)
183+
result = ser.where([True, True, False], other='b')
184+
expected = pd.Series(
185+
Categorical(['a', 'b', 'b'], categories=ser.cat.categories)
186+
)
187+
tm.assert_series_equal(result, expected)
188+
189+
def test_where_other_categorical(self):
190+
ser = pd.Series(
191+
Categorical(['a', 'b', 'c'], categories=['d', 'c', 'b', 'a'])
192+
)
193+
other = Categorical(['b', 'c', 'a'], categories=['a', 'c', 'b', 'd'])
194+
result = ser.where([True, False, True], other)
195+
expected = pd.Series(Categorical(['a', 'c', 'c'], dtype=ser.dtype))
196+
tm.assert_series_equal(result, expected)
197+
198+
def test_where_warns(self):
199+
ser = pd.Series(Categorical(['a', 'b', 'c']))
200+
with tm.assert_produces_warning(FutureWarning):
201+
result = ser.where([True, False, True], 'd')
202+
203+
expected = pd.Series(np.array(['a', 'd', 'c'], dtype='object'))
204+
tm.assert_series_equal(result, expected)
205+
206+
def test_where_ordered_differs_rasies(self):
207+
ser = pd.Series(
208+
Categorical(['a', 'b', 'c'], categories=['d', 'c', 'b', 'a'],
209+
ordered=True)
210+
)
211+
other = Categorical(['b', 'c', 'a'], categories=['a', 'c', 'b', 'd'],
212+
ordered=True)
213+
with tm.assert_produces_warning(FutureWarning):
214+
result = ser.where([True, False, True], other)
215+
216+
expected = pd.Series(np.array(['a', 'c', 'c'], dtype=object))
217+
tm.assert_series_equal(result, expected)
218+
125219

126220
@pytest.mark.parametrize("index", [True, False])
127221
def test_mask_with_boolean(index):

pandas/tests/arrays/interval/test_interval.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
import numpy as np
33
import pytest
44

5-
from pandas import Index, IntervalIndex, date_range, timedelta_range
5+
import pandas as pd
6+
from pandas import Index, Interval, IntervalIndex, date_range, timedelta_range
67
from pandas.core.arrays import IntervalArray
78
import pandas.util.testing as tm
89

@@ -50,6 +51,17 @@ def test_set_closed(self, closed, new_closed):
5051
expected = IntervalArray.from_breaks(range(10), closed=new_closed)
5152
tm.assert_extension_array_equal(result, expected)
5253

54+
@pytest.mark.parametrize('other', [
55+
Interval(0, 1, closed='right'),
56+
IntervalArray.from_breaks([1, 2, 3, 4], closed='right'),
57+
])
58+
def test_where_raises(self, other):
59+
ser = pd.Series(IntervalArray.from_breaks([1, 2, 3, 4],
60+
closed='left'))
61+
match = "'value.closed' is 'right', expected 'left'."
62+
with pytest.raises(ValueError, match=match):
63+
ser.where([True, False, True], other=other)
64+
5365

5466
class TestSetitem(object):
5567

pandas/tests/arrays/sparse/test_array.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -357,10 +357,10 @@ def setitem():
357357
def setslice():
358358
self.arr[1:5] = 2
359359

360-
with pytest.raises(TypeError, match="item assignment"):
360+
with pytest.raises(TypeError, match="assignment via setitem"):
361361
setitem()
362362

363-
with pytest.raises(TypeError, match="item assignment"):
363+
with pytest.raises(TypeError, match="assignment via setitem"):
364364
setslice()
365365

366366
def test_constructor_from_too_large_array(self):

pandas/tests/arrays/test_period.py

+15
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,21 @@ def test_sub_period():
197197
arr - other
198198

199199

200+
# ----------------------------------------------------------------------------
201+
# Methods
202+
203+
@pytest.mark.parametrize('other', [
204+
pd.Period('2000', freq='H'),
205+
period_array(['2000', '2001', '2000'], freq='H')
206+
])
207+
def test_where_different_freq_raises(other):
208+
ser = pd.Series(period_array(['2000', '2001', '2002'], freq='D'))
209+
cond = np.array([True, False, True])
210+
with pytest.raises(IncompatibleFrequency,
211+
match="Input has different freq=H"):
212+
ser.where(cond, other)
213+
214+
200215
# ----------------------------------------------------------------------------
201216
# Printing
202217

0 commit comments

Comments
 (0)