Skip to content

Commit 27db005

Browse files
committed
API: Added ExtensionArray.where
We need some way to do `.where` on EA object for DatetimeArray. Adding it to the interface is, I think, the easiest way. Initially I started to write a version on ExtensionBlock, but it proved to be unwieldy. to write a version that performed well for all types. It *may* be possible to do using `_ndarray_values` but we'd need a few more things around that (missing values, converting an arbitrary array to the "same' ndarary_values, error handling, re-constructing). It seemed easier to push this down to the array. The implementation on ExtensionArray is readable, but likely slow since it'll involve a conversion to object-dtype. Closes pandas-dev#24077
1 parent 669cb27 commit 27db005

File tree

15 files changed

+289
-7
lines changed

15 files changed

+289
-7
lines changed

doc/source/whatsnew/v0.24.0.rst

+2
Original file line numberDiff line numberDiff line change
@@ -994,6 +994,7 @@ update the ``ExtensionDtype._metadata`` tuple to match the signature of your
994994
- :meth:`Series.astype` and :meth:`DataFrame.astype` now dispatch to :meth:`ExtensionArray.astype` (:issue:`21185:`).
995995
- Slicing a single row of a ``DataFrame`` with multiple ExtensionArrays of the same type now preserves the dtype, rather than coercing to object (:issue:`22784`)
996996
- Added :meth:`pandas.api.types.register_extension_dtype` to register an extension type with pandas (:issue:`22664`)
997+
- Added :meth:`pandas.api.extensions.ExtensionArray.where` (:issue:`24077`)
997998
- Bug when concatenating multiple ``Series`` with different extension dtypes not casting to object dtype (:issue:`22994`)
998999
- Series backed by an ``ExtensionArray`` now work with :func:`util.hash_pandas_object` (:issue:`23066`)
9991000
- Updated the ``.type`` attribute for ``PeriodDtype``, ``DatetimeTZDtype``, and ``IntervalDtype`` to be instances of the dtype (``Period``, ``Timestamp``, and ``Interval`` respectively) (:issue:`22938`)
@@ -1236,6 +1237,7 @@ Performance Improvements
12361237
- Improved performance of :meth:`DatetimeIndex.normalize` and :meth:`Timestamp.normalize` for timezone naive or UTC datetimes (:issue:`23634`)
12371238
- Improved performance of :meth:`DatetimeIndex.tz_localize` and various ``DatetimeIndex`` attributes with dateutil UTC timezone (:issue:`23772`)
12381239
- Improved performance of :class:`Categorical` constructor for `Series` objects (:issue:`23814`)
1240+
- Improved performance of :meth:`~DataFrame.where` for Categorical data (:issue:`24077`)
12391241

12401242
.. _whatsnew_0240.docs:
12411243

pandas/core/arrays/base.py

+35
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ class ExtensionArray(object):
6464
* unique
6565
* factorize / _values_for_factorize
6666
* argsort / _values_for_argsort
67+
* where
6768
6869
The remaining methods implemented on this class should be performant,
6970
as they only compose abstract methods. Still, a more efficient
@@ -661,6 +662,40 @@ def take(self, indices, allow_fill=False, fill_value=None):
661662
# pandas.api.extensions.take
662663
raise AbstractMethodError(self)
663664

665+
def where(self, cond, other):
666+
"""
667+
Replace values where the condition is False.
668+
669+
Parameters
670+
----------
671+
cond : ndarray or ExtensionArray
672+
The mask indicating which values should be kept (True)
673+
or replaced from `other` (False).
674+
675+
other : ndarray, ExtensionArray, or scalar
676+
Entries where `cond` is False are replaced with
677+
corresponding value from `other`.
678+
679+
Notes
680+
-----
681+
Note that `cond` and `other` *cannot* be a Series, Index, or callable.
682+
When used from, e.g., :meth:`Series.where`, pandas will unbox
683+
Series and Indexes, and will apply callables before they arrive here.
684+
685+
Returns
686+
-------
687+
ExtensionArray
688+
Same dtype as the original.
689+
690+
See Also
691+
--------
692+
Series.where : Similar method for Series.
693+
DataFrame.where : Similar method for DataFrame.
694+
"""
695+
return type(self)._from_sequence(np.where(cond, self, other),
696+
dtype=self.dtype,
697+
copy=False)
698+
664699
def copy(self, deep=False):
665700
# type: (bool) -> ExtensionArray
666701
"""

pandas/core/arrays/categorical.py

+28
Original file line numberDiff line numberDiff line change
@@ -1906,6 +1906,34 @@ def take_nd(self, indexer, allow_fill=None, fill_value=None):
19061906

19071907
take = take_nd
19081908

1909+
def where(self, cond, other):
1910+
# n.b. this now preserves the type
1911+
codes = self._codes
1912+
1913+
if is_scalar(other) and isna(other):
1914+
other = -1
1915+
elif is_scalar(other):
1916+
item = self.categories.get_indexer([other]).item()
1917+
1918+
if item == -1:
1919+
raise ValueError("The value '{}' is not present in "
1920+
"this Categorical's categories".format(other))
1921+
other = item
1922+
1923+
elif is_categorical_dtype(other):
1924+
if not is_dtype_equal(self, other):
1925+
raise TypeError("The type of 'other' does not match.")
1926+
other = _get_codes_for_values(other, self.categories)
1927+
# get the codes from other that match our categories
1928+
pass
1929+
else:
1930+
other = np.where(isna(other), -1, other)
1931+
1932+
new_codes = np.where(cond, codes, other)
1933+
return type(self).from_codes(new_codes,
1934+
categories=self.categories,
1935+
ordered=self.ordered)
1936+
19091937
def _slice(self, slicer):
19101938
"""
19111939
Return a slice of myself.

pandas/core/arrays/interval.py

+12
Original file line numberDiff line numberDiff line change
@@ -777,6 +777,18 @@ def take(self, indices, allow_fill=False, fill_value=None, axis=None,
777777

778778
return self._shallow_copy(left_take, right_take)
779779

780+
def where(self, cond, other):
781+
if is_scalar(other) and isna(other):
782+
lother = other
783+
rother = other
784+
else:
785+
self._check_closed_matches(other, name='other')
786+
lother = other.left
787+
rother = other.right
788+
left = np.where(cond, self.left, lother)
789+
right = np.where(cond, self.right, rother)
790+
return self._shallow_copy(left, right)
791+
780792
def value_counts(self, dropna=True):
781793
"""
782794
Returns a Series containing counts of each interval.

pandas/core/arrays/period.py

+22
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import numpy as np
66

7+
from pandas._libs import lib
78
from pandas._libs.tslibs import NaT, iNaT, period as libperiod
89
from pandas._libs.tslibs.fields import isleapyear_arr
910
from pandas._libs.tslibs.period import (
@@ -241,6 +242,11 @@ def _generate_range(cls, start, end, periods, freq, fields):
241242

242243
return subarr, freq
243244

245+
def _check_compatible_with(self, other):
246+
if self.freqstr != other.freqstr:
247+
msg = DIFFERENT_FREQ_INDEX.format(self.freqstr, other.freqstr)
248+
raise IncompatibleFrequency(msg)
249+
244250
# --------------------------------------------------------------------
245251
# Data / Attributes
246252

@@ -341,6 +347,22 @@ def to_timestamp(self, freq=None, how='start'):
341347
# --------------------------------------------------------------------
342348
# Array-like / EA-Interface Methods
343349

350+
def where(self, cond, other):
351+
# TODO(DatetimeArray): move to DatetimeLikeArrayMixin
352+
# n.b. _ndarray_values candidate.
353+
i8 = self.asi8
354+
if lib.is_scalar(other):
355+
if isna(other):
356+
other = iNaT
357+
elif isinstance(other, Period):
358+
self._check_compatible_with(other)
359+
other = other.ordinal
360+
elif isinstance(other, type(self)):
361+
self._check_compatible_with(other)
362+
other = other.asi8
363+
result = np.where(cond, i8, other)
364+
return type(self)._simple_new(result, dtype=self.dtype)
365+
344366
def _formatter(self, boxed=False):
345367
if boxed:
346368
return str

pandas/core/arrays/sparse.py

+14
Original file line numberDiff line numberDiff line change
@@ -1063,6 +1063,20 @@ def take(self, indices, allow_fill=False, fill_value=None):
10631063
return type(self)(result, fill_value=self.fill_value, kind=self.kind,
10641064
**kwargs)
10651065

1066+
def where(self, cond, other):
1067+
if is_scalar(other):
1068+
result_dtype = np.result_type(self.dtype.subtype, other)
1069+
elif isinstance(other, type(self)):
1070+
result_dtype = np.result_type(self.dtype.subtype,
1071+
other.dtype.subtype)
1072+
else:
1073+
result_dtype = np.result_type(self.dtype.subtype, other.dtype)
1074+
1075+
dtype = self.dtype.update_dtype(result_dtype)
1076+
# TODO: avoid converting to dense.
1077+
values = np.where(cond, self, other)
1078+
return type(self)(values, dtype=dtype)
1079+
10661080
def _take_with_fill(self, indices, fill_value=None):
10671081
if fill_value is None:
10681082
fill_value = self.dtype.na_value

pandas/core/dtypes/base.py

+5
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@ class _DtypeOpsMixin(object):
2626
na_value = np.nan
2727
_metadata = ()
2828

29+
@property
30+
def _ndarray_na_value(self):
31+
"""Private method internal to pandas"""
32+
raise AbstractMethodError(self)
33+
2934
def __eq__(self, other):
3035
"""Check whether 'other' is equal to self.
3136

pandas/core/indexes/category.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -501,11 +501,7 @@ def _can_reindex(self, indexer):
501501

502502
@Appender(_index_shared_docs['where'])
503503
def where(self, cond, other=None):
504-
if other is None:
505-
other = self._na_value
506-
values = np.where(cond, self.values, other)
507-
508-
cat = Categorical(values, dtype=self.dtype)
504+
cat = self.values.where(cond, other=other)
509505
return self._shallow_copy(cat, **self._get_attributes_dict())
510506

511507
def reindex(self, target, method=None, level=None, limit=None,

pandas/core/internals/blocks.py

+26-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@
2828
from pandas.core.dtypes.dtypes import (
2929
CategoricalDtype, DatetimeTZDtype, ExtensionDtype, PandasExtensionDtype)
3030
from pandas.core.dtypes.generic import (
31-
ABCDatetimeIndex, ABCExtensionArray, ABCIndexClass, ABCSeries)
31+
ABCDataFrame, ABCDatetimeIndex, ABCExtensionArray, ABCIndexClass,
32+
ABCSeries)
3233
from pandas.core.dtypes.missing import (
3334
_isna_compat, array_equivalent, is_null_datelike_scalar, isna, notna)
3435

@@ -1967,6 +1968,30 @@ def shift(self, periods, axis=0):
19671968
placement=self.mgr_locs,
19681969
ndim=self.ndim)]
19691970

1971+
def where(self, other, cond, align=True, errors='raise',
1972+
try_cast=False, axis=0, transpose=False):
1973+
if isinstance(other, (ABCIndexClass, ABCSeries)):
1974+
other = other.array
1975+
1976+
if isinstance(cond, ABCDataFrame):
1977+
assert cond.shape[1] == 1
1978+
cond = cond.iloc[:, 0].array
1979+
1980+
if isinstance(other, ABCDataFrame):
1981+
assert other.shape[1] == 1
1982+
other = other.iloc[:, 0].array
1983+
1984+
if isinstance(cond, (ABCIndexClass, ABCSeries)):
1985+
cond = cond.array
1986+
1987+
if lib.is_scalar(other) and isna(other):
1988+
# The default `other` for Series / Frame is np.nan
1989+
# we want to replace that with the correct NA value
1990+
# for the type
1991+
other = self.dtype.na_value
1992+
result = self.values.where(cond, other)
1993+
return self.make_block_same_class(result, placement=self.mgr_locs)
1994+
19701995
@property
19711996
def _ftype(self):
19721997
return getattr(self.values, '_pandas_ftype', Block._ftype)

pandas/tests/arrays/categorical/test_indexing.py

+26
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,32 @@ def test_get_indexer_non_unique(self, idx_values, key_values, key_class):
122122
tm.assert_numpy_array_equal(expected, result)
123123
tm.assert_numpy_array_equal(exp_miss, res_miss)
124124

125+
def test_where_raises(self):
126+
arr = Categorical(['a', 'b', 'c'])
127+
with pytest.raises(ValueError, match="The value 'd'"):
128+
arr.where([True, False, True], 'd')
129+
130+
def test_where_unobserved_categories(self):
131+
arr = Categorical(['a', 'b', 'c'], categories=['d', 'c', 'b', 'a'])
132+
result = arr.where([True, True, False], other='b')
133+
expected = Categorical(['a', 'b', 'b'], categories=arr.categories)
134+
tm.assert_categorical_equal(result, expected)
135+
136+
def test_where_other_categorical(self):
137+
arr = Categorical(['a', 'b', 'c'], categories=['d', 'c', 'b', 'a'])
138+
other = Categorical(['b', 'c', 'a'], categories=['a', 'c', 'b', 'd'])
139+
result = arr.where([True, False, True], other)
140+
expected = Categorical(['a', 'c', 'c'], dtype=arr.dtype)
141+
tm.assert_categorical_equal(result, expected)
142+
143+
def test_where_ordered_differs_rasies(self):
144+
arr = Categorical(['a', 'b', 'c'], categories=['d', 'c', 'b', 'a'],
145+
ordered=True)
146+
other = Categorical(['b', 'c', 'a'], categories=['a', 'c', 'b', 'd'],
147+
ordered=True)
148+
with pytest.raises(TypeError, match="The type of"):
149+
arr.where([True, False, True], other)
150+
125151

126152
@pytest.mark.parametrize("index", [True, False])
127153
def test_mask_with_boolean(index):

pandas/tests/arrays/interval/test_interval.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import numpy as np
33
import pytest
44

5-
from pandas import Index, IntervalIndex, date_range, timedelta_range
5+
from pandas import Index, Interval, IntervalIndex, date_range, timedelta_range
66
from pandas.core.arrays import IntervalArray
77
import pandas.util.testing as tm
88

@@ -50,6 +50,16 @@ def test_set_closed(self, closed, new_closed):
5050
expected = IntervalArray.from_breaks(range(10), closed=new_closed)
5151
tm.assert_extension_array_equal(result, expected)
5252

53+
@pytest.mark.parametrize('other', [
54+
Interval(0, 1, closed='right'),
55+
IntervalArray.from_breaks([1, 2, 3, 4], closed='right'),
56+
])
57+
def test_where_raises(self, other):
58+
arr = IntervalArray.from_breaks([1, 2, 3, 4], closed='left')
59+
match = "'other.closed' is 'right', expected 'left'."
60+
with pytest.raises(ValueError, match=match):
61+
arr.where([True, False, True], other=other)
62+
5363

5464
class TestSetitem(object):
5565

pandas/tests/arrays/test_period.py

+15
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,21 @@ def test_sub_period():
197197
arr - other
198198

199199

200+
# ----------------------------------------------------------------------------
201+
# Methods
202+
203+
@pytest.mark.parametrize('other', [
204+
pd.Period('2000', freq='H'),
205+
period_array(['2000', '2001', '2000'], freq='H')
206+
])
207+
def test_where_different_freq_raises(other):
208+
arr = period_array(['2000', '2001', '2002'], freq='D')
209+
cond = np.array([True, False, True])
210+
with pytest.raises(IncompatibleFrequency,
211+
match="Input has different freq=H"):
212+
arr.where(cond, other)
213+
214+
200215
# ----------------------------------------------------------------------------
201216
# Printing
202217

pandas/tests/extension/base/methods.py

+63
Original file line numberDiff line numberDiff line change
@@ -198,3 +198,66 @@ def test_hash_pandas_object_works(self, data, as_frame):
198198
a = pd.util.hash_pandas_object(data)
199199
b = pd.util.hash_pandas_object(data)
200200
self.assert_equal(a, b)
201+
202+
@pytest.mark.parametrize("as_frame", [True, False])
203+
def test_where_series(self, data, na_value, as_frame):
204+
assert data[0] != data[1]
205+
cls = type(data)
206+
a, b = data[:2]
207+
208+
ser = pd.Series(cls._from_sequence([a, a, b, b], dtype=data.dtype))
209+
cond = np.array([True, True, False, False])
210+
211+
if as_frame:
212+
ser = ser.to_frame(name='a')
213+
# TODO: alignment is broken for ndarray `cond`
214+
cond = pd.DataFrame({"a": cond})
215+
216+
result = ser.where(cond)
217+
expected = pd.Series(cls._from_sequence([a, a, na_value, na_value],
218+
dtype=data.dtype))
219+
220+
if as_frame:
221+
expected = expected.to_frame(name='a')
222+
self.assert_equal(result, expected)
223+
224+
# array other
225+
cond = np.array([True, False, True, True])
226+
other = cls._from_sequence([a, b, a, b], dtype=data.dtype)
227+
if as_frame:
228+
# TODO: alignment is broken for ndarray `cond`
229+
other = pd.DataFrame({"a": other})
230+
# TODO: alignment is broken for array `other`
231+
cond = pd.DataFrame({"a": cond})
232+
result = ser.where(cond, other)
233+
expected = pd.Series(cls._from_sequence([a, b, b, b],
234+
dtype=data.dtype))
235+
if as_frame:
236+
expected = expected.to_frame(name='a')
237+
self.assert_equal(result, expected)
238+
#
239+
# def test_where_frame(self, data, na_value):
240+
# assert data[0] != data[1]
241+
# cls = type(data)
242+
# a, b = data[:2]
243+
#
244+
# df = pd.DataFrame({
245+
# "A": cls._from_sequence([a, a, b, b], dtype=data.dtype)
246+
# })
247+
#
248+
# cond = np.array([True, True, False, False]).reshape(-1, 1)
249+
# result = df.where(cond)
250+
# expected = pd.DataFrame({
251+
# "A": cls._from_sequence([a, a, na_value, na_value],
252+
# dtype=data.dtype)
253+
# })
254+
# self.assert_frame_equal(result, expected)
255+
#
256+
# other = cls._from_sequence([a, b, a, b])
257+
# cond = np.array([True, False, True, True])
258+
# result = ser.where(cond, other)
259+
# expected = pd.Series(cls._from_sequence([a, b, b, b]))
260+
# self.assert_series_equal(result, expected)
261+
#
262+
# # df = ser.to_frame(name='A')
263+
# # result = terr

0 commit comments

Comments
 (0)