Skip to content

Commit 7ec7351

Browse files
committed
Squashed commit of the following:
commit 56470c3 Author: Tom Augspurger <[email protected]> Date: Wed Dec 5 11:39:48 2018 -0600 Fixups: * Ensure data generated OK. * Remove erroneous comments about alignment. That was user error. commit c4604df Author: Tom Augspurger <[email protected]> Date: Mon Dec 3 14:23:25 2018 -0600 API: Added ExtensionArray.where We need some way to do `.where` on EA object for DatetimeArray. Adding it to the interface is, I think, the easiest way. Initially I started to write a version on ExtensionBlock, but it proved to be unwieldy. to write a version that performed well for all types. It *may* be possible to do using `_ndarray_values` but we'd need a few more things around that (missing values, converting an arbitrary array to the "same' ndarary_values, error handling, re-constructing). It seemed easier to push this down to the array. The implementation on ExtensionArray is readable, but likely slow since it'll involve a conversion to object-dtype. Closes pandas-dev#24077
1 parent 165f3fd commit 7ec7351

File tree

17 files changed

+302
-21
lines changed

17 files changed

+302
-21
lines changed

doc/source/whatsnew/v0.24.0.rst

+3
Original file line numberDiff line numberDiff line change
@@ -994,6 +994,7 @@ update the ``ExtensionDtype._metadata`` tuple to match the signature of your
994994
- :meth:`Series.astype` and :meth:`DataFrame.astype` now dispatch to :meth:`ExtensionArray.astype` (:issue:`21185:`).
995995
- Slicing a single row of a ``DataFrame`` with multiple ExtensionArrays of the same type now preserves the dtype, rather than coercing to object (:issue:`22784`)
996996
- Added :meth:`pandas.api.types.register_extension_dtype` to register an extension type with pandas (:issue:`22664`)
997+
- Added :meth:`pandas.api.extensions.ExtensionArray.where` (:issue:`24077`)
997998
- Bug when concatenating multiple ``Series`` with different extension dtypes not casting to object dtype (:issue:`22994`)
998999
- Series backed by an ``ExtensionArray`` now work with :func:`util.hash_pandas_object` (:issue:`23066`)
9991000
- Updated the ``.type`` attribute for ``PeriodDtype``, ``DatetimeTZDtype``, and ``IntervalDtype`` to be instances of the dtype (``Period``, ``Timestamp``, and ``Interval`` respectively) (:issue:`22938`)
@@ -1236,6 +1237,7 @@ Performance Improvements
12361237
- Improved performance of :meth:`DatetimeIndex.normalize` and :meth:`Timestamp.normalize` for timezone naive or UTC datetimes (:issue:`23634`)
12371238
- Improved performance of :meth:`DatetimeIndex.tz_localize` and various ``DatetimeIndex`` attributes with dateutil UTC timezone (:issue:`23772`)
12381239
- Improved performance of :class:`Categorical` constructor for `Series` objects (:issue:`23814`)
1240+
- Improved performance of :meth:`~DataFrame.where` for Categorical data (:issue:`24077`)
12391241

12401242
.. _whatsnew_0240.docs:
12411243

@@ -1262,6 +1264,7 @@ Categorical
12621264
- In meth:`Series.unstack`, specifying a ``fill_value`` not present in the categories now raises a ``TypeError`` rather than ignoring the ``fill_value`` (:issue:`23284`)
12631265
- Bug when resampling :meth:`Dataframe.resample()` and aggregating on categorical data, the categorical dtype was getting lost. (:issue:`23227`)
12641266
- Bug in many methods of the ``.str``-accessor, which always failed on calling the ``CategoricalIndex.str`` constructor (:issue:`23555`, :issue:`23556`)
1267+
- Bug in :meth:`Series.where` losing the categorical dtype for categorical data (:issue:`24077`)
12651268

12661269
Datetimelike
12671270
^^^^^^^^^^^^

pandas/core/arrays/base.py

+35
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ class ExtensionArray(object):
6464
* unique
6565
* factorize / _values_for_factorize
6666
* argsort / _values_for_argsort
67+
* where
6768
6869
The remaining methods implemented on this class should be performant,
6970
as they only compose abstract methods. Still, a more efficient
@@ -661,6 +662,40 @@ def take(self, indices, allow_fill=False, fill_value=None):
661662
# pandas.api.extensions.take
662663
raise AbstractMethodError(self)
663664

665+
def where(self, cond, other):
666+
"""
667+
Replace values where the condition is False.
668+
669+
Parameters
670+
----------
671+
cond : ndarray or ExtensionArray
672+
The mask indicating which values should be kept (True)
673+
or replaced from `other` (False).
674+
675+
other : ndarray, ExtensionArray, or scalar
676+
Entries where `cond` is False are replaced with
677+
corresponding value from `other`.
678+
679+
Notes
680+
-----
681+
Note that `cond` and `other` *cannot* be a Series, Index, or callable.
682+
When used from, e.g., :meth:`Series.where`, pandas will unbox
683+
Series and Indexes, and will apply callables before they arrive here.
684+
685+
Returns
686+
-------
687+
ExtensionArray
688+
Same dtype as the original.
689+
690+
See Also
691+
--------
692+
Series.where : Similar method for Series.
693+
DataFrame.where : Similar method for DataFrame.
694+
"""
695+
return type(self)._from_sequence(np.where(cond, self, other),
696+
dtype=self.dtype,
697+
copy=False)
698+
664699
def copy(self, deep=False):
665700
# type: (bool) -> ExtensionArray
666701
"""

pandas/core/arrays/categorical.py

+43
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# pylint: disable=E1101,W0232
22

3+
import reprlib
34
import textwrap
45
from warnings import warn
56

@@ -1906,6 +1907,48 @@ def take_nd(self, indexer, allow_fill=None, fill_value=None):
19061907

19071908
take = take_nd
19081909

1910+
def where(self, cond, other):
1911+
# n.b. this now preserves the type
1912+
codes = self._codes
1913+
object_msg = (
1914+
"Implicitly converting categorical to object-dtype ndarray. "
1915+
"The values `{}' are not present in this categorical's "
1916+
"categories. A future version of pandas will raise a ValueError "
1917+
"when 'other' contains different categories.\n\n"
1918+
"To preserve the current behavior, add the new categories to "
1919+
"the categorical before calling 'where', or convert the "
1920+
"categorical to a different dtype."
1921+
)
1922+
1923+
if is_scalar(other) and isna(other):
1924+
other = -1
1925+
elif is_scalar(other):
1926+
item = self.categories.get_indexer([other]).item()
1927+
1928+
if item == -1:
1929+
# note: when removing this, also remove CategoricalBlock.where
1930+
warn(object_msg.format(other), FutureWarning, stacklevel=2)
1931+
return np.where(cond, self, other)
1932+
1933+
other = item
1934+
1935+
elif is_categorical_dtype(other):
1936+
if not is_dtype_equal(self, other):
1937+
extra = list(other.categories.difference(self.categories))
1938+
warn(object_msg.format(reprlib.repr(extra)), FutureWarning,
1939+
stacklevel=2)
1940+
return np.where(cond, self, other)
1941+
other = _get_codes_for_values(other, self.categories)
1942+
# get the codes from other that match our categories
1943+
pass
1944+
else:
1945+
other = np.where(isna(other), -1, other)
1946+
1947+
new_codes = np.where(cond, codes, other)
1948+
return type(self).from_codes(new_codes,
1949+
categories=self.categories,
1950+
ordered=self.ordered)
1951+
19091952
def _slice(self, slicer):
19101953
"""
19111954
Return a slice of myself.

pandas/core/arrays/interval.py

+11
Original file line numberDiff line numberDiff line change
@@ -777,6 +777,17 @@ def take(self, indices, allow_fill=False, fill_value=None, axis=None,
777777

778778
return self._shallow_copy(left_take, right_take)
779779

780+
def where(self, cond, other):
781+
if is_scalar(other) and isna(other):
782+
lother = rother = other
783+
else:
784+
self._check_closed_matches(other, name='other')
785+
lother = other.left
786+
rother = other.right
787+
left = np.where(cond, self.left, lother)
788+
right = np.where(cond, self.right, rother)
789+
return self._shallow_copy(left, right)
790+
780791
def value_counts(self, dropna=True):
781792
"""
782793
Returns a Series containing counts of each interval.

pandas/core/arrays/period.py

+17-10
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import numpy as np
66

7+
from pandas._libs import lib
78
from pandas._libs.tslibs import NaT, iNaT, period as libperiod
89
from pandas._libs.tslibs.fields import isleapyear_arr
910
from pandas._libs.tslibs.period import (
@@ -242,16 +243,6 @@ def _generate_range(cls, start, end, periods, freq, fields):
242243

243244
return subarr, freq
244245

245-
# -----------------------------------------------------------------
246-
# DatetimeLike Interface
247-
def _unbox_scalar(self, value):
248-
assert isinstance(value, self._scalar_type), value
249-
return value.ordinal
250-
251-
def _scalar_from_string(self, value):
252-
assert isinstance(value, self._scalar_type), value
253-
return Period(value, freq=self.freq)
254-
255246
def _check_compatible_with(self, other):
256247
if self.freqstr != other.freqstr:
257248
msg = DIFFERENT_FREQ_INDEX.format(self.freqstr, other.freqstr)
@@ -357,6 +348,22 @@ def to_timestamp(self, freq=None, how='start'):
357348
# --------------------------------------------------------------------
358349
# Array-like / EA-Interface Methods
359350

351+
def where(self, cond, other):
352+
# TODO(DatetimeArray): move to DatetimeLikeArrayMixin
353+
# n.b. _ndarray_values candidate.
354+
i8 = self.asi8
355+
if lib.is_scalar(other):
356+
if isna(other):
357+
other = iNaT
358+
elif isinstance(other, Period):
359+
self._check_compatible_with(other)
360+
other = other.ordinal
361+
elif isinstance(other, type(self)):
362+
self._check_compatible_with(other)
363+
other = other.asi8
364+
result = np.where(cond, i8, other)
365+
return type(self)._simple_new(result, dtype=self.dtype)
366+
360367
def _formatter(self, boxed=False):
361368
if boxed:
362369
return str

pandas/core/arrays/sparse.py

+14
Original file line numberDiff line numberDiff line change
@@ -1063,6 +1063,20 @@ def take(self, indices, allow_fill=False, fill_value=None):
10631063
return type(self)(result, fill_value=self.fill_value, kind=self.kind,
10641064
**kwargs)
10651065

1066+
def where(self, cond, other):
1067+
if is_scalar(other):
1068+
result_dtype = np.result_type(self.dtype.subtype, other)
1069+
elif isinstance(other, type(self)):
1070+
result_dtype = np.result_type(self.dtype.subtype,
1071+
other.dtype.subtype)
1072+
else:
1073+
result_dtype = np.result_type(self.dtype.subtype, other.dtype)
1074+
1075+
dtype = self.dtype.update_dtype(result_dtype)
1076+
# TODO: avoid converting to dense.
1077+
values = np.where(cond, self, other)
1078+
return type(self)(values, dtype=dtype)
1079+
10661080
def _take_with_fill(self, indices, fill_value=None):
10671081
if fill_value is None:
10681082
fill_value = self.dtype.na_value

pandas/core/dtypes/base.py

+5
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@ class _DtypeOpsMixin(object):
2626
na_value = np.nan
2727
_metadata = ()
2828

29+
@property
30+
def _ndarray_na_value(self):
31+
"""Private method internal to pandas"""
32+
raise AbstractMethodError(self)
33+
2934
def __eq__(self, other):
3035
"""Check whether 'other' is equal to self.
3136

pandas/core/indexes/category.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -501,11 +501,7 @@ def _can_reindex(self, indexer):
501501

502502
@Appender(_index_shared_docs['where'])
503503
def where(self, cond, other=None):
504-
if other is None:
505-
other = self._na_value
506-
values = np.where(cond, self.values, other)
507-
508-
cat = Categorical(values, dtype=self.dtype)
504+
cat = self.values.where(cond, other=other)
509505
return self._shallow_copy(cat, **self._get_attributes_dict())
510506

511507
def reindex(self, target, method=None, level=None, limit=None,

pandas/core/internals/blocks.py

+37-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@
2929
from pandas.core.dtypes.dtypes import (
3030
CategoricalDtype, DatetimeTZDtype, ExtensionDtype, PandasExtensionDtype)
3131
from pandas.core.dtypes.generic import (
32-
ABCDatetimeIndex, ABCExtensionArray, ABCIndexClass, ABCSeries)
32+
ABCDataFrame, ABCDatetimeIndex, ABCExtensionArray, ABCIndexClass,
33+
ABCSeries)
3334
from pandas.core.dtypes.inference import is_scalar
3435
from pandas.core.dtypes.missing import (
3536
_isna_compat, array_equivalent, is_null_datelike_scalar, isna, notna)
@@ -1970,6 +1971,30 @@ def shift(self, periods, axis=0):
19701971
placement=self.mgr_locs,
19711972
ndim=self.ndim)]
19721973

1974+
def where(self, other, cond, align=True, errors='raise',
1975+
try_cast=False, axis=0, transpose=False):
1976+
if isinstance(other, (ABCIndexClass, ABCSeries)):
1977+
other = other.array
1978+
1979+
if isinstance(cond, ABCDataFrame):
1980+
assert cond.shape[1] == 1
1981+
cond = cond.iloc[:, 0].array
1982+
1983+
if isinstance(other, ABCDataFrame):
1984+
assert other.shape[1] == 1
1985+
other = other.iloc[:, 0].array
1986+
1987+
if isinstance(cond, (ABCIndexClass, ABCSeries)):
1988+
cond = cond.array
1989+
1990+
if lib.is_scalar(other) and isna(other):
1991+
# The default `other` for Series / Frame is np.nan
1992+
# we want to replace that with the correct NA value
1993+
# for the type
1994+
other = self.dtype.na_value
1995+
result = self.values.where(cond, other)
1996+
return self.make_block_same_class(result, placement=self.mgr_locs)
1997+
19731998
@property
19741999
def _ftype(self):
19752000
return getattr(self.values, '_pandas_ftype', Block._ftype)
@@ -2675,6 +2700,17 @@ def concat_same_type(self, to_concat, placement=None):
26752700
values, placement=placement or slice(0, len(values), 1),
26762701
ndim=self.ndim)
26772702

2703+
def where(self, other, cond, align=True, errors='raise',
2704+
try_cast=False, axis=0, transpose=False):
2705+
result = super(CategoricalBlock, self).where(
2706+
other, cond, align, errors, try_cast, axis, transpose
2707+
)
2708+
if result.values.dtype != self.values.dtype:
2709+
# For backwards compatability, we allow upcasting to object.
2710+
# This fallback will be removed in the future.
2711+
result = result.astype(object)
2712+
return result
2713+
26782714

26792715
class DatetimeBlock(DatetimeLikeBlockMixin, Block):
26802716
__slots__ = ()

pandas/tests/arrays/categorical/test_indexing.py

+32
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,38 @@ def test_get_indexer_non_unique(self, idx_values, key_values, key_class):
122122
tm.assert_numpy_array_equal(expected, result)
123123
tm.assert_numpy_array_equal(exp_miss, res_miss)
124124

125+
def test_where_unobserved_categories(self):
126+
arr = Categorical(['a', 'b', 'c'], categories=['d', 'c', 'b', 'a'])
127+
result = arr.where([True, True, False], other='b')
128+
expected = Categorical(['a', 'b', 'b'], categories=arr.categories)
129+
tm.assert_categorical_equal(result, expected)
130+
131+
def test_where_other_categorical(self):
132+
arr = Categorical(['a', 'b', 'c'], categories=['d', 'c', 'b', 'a'])
133+
other = Categorical(['b', 'c', 'a'], categories=['a', 'c', 'b', 'd'])
134+
result = arr.where([True, False, True], other)
135+
expected = Categorical(['a', 'c', 'c'], dtype=arr.dtype)
136+
tm.assert_categorical_equal(result, expected)
137+
138+
def test_where_warns(self):
139+
arr = Categorical(['a', 'b', 'c'])
140+
with tm.assert_produces_warning(FutureWarning):
141+
result = arr.where([True, False, True], 'd')
142+
143+
expected = np.array(['a', 'd', 'c'], dtype='object')
144+
tm.assert_numpy_array_equal(result, expected)
145+
146+
def test_where_ordered_differs_rasies(self):
147+
arr = Categorical(['a', 'b', 'c'], categories=['d', 'c', 'b', 'a'],
148+
ordered=True)
149+
other = Categorical(['b', 'c', 'a'], categories=['a', 'c', 'b', 'd'],
150+
ordered=True)
151+
with tm.assert_produces_warning(FutureWarning):
152+
result = arr.where([True, False, True], other)
153+
154+
expected = np.array(['a', 'c', 'c'], dtype=object)
155+
tm.assert_numpy_array_equal(result, expected)
156+
125157

126158
@pytest.mark.parametrize("index", [True, False])
127159
def test_mask_with_boolean(index):

pandas/tests/arrays/interval/test_interval.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import numpy as np
33
import pytest
44

5-
from pandas import Index, IntervalIndex, date_range, timedelta_range
5+
from pandas import Index, Interval, IntervalIndex, date_range, timedelta_range
66
from pandas.core.arrays import IntervalArray
77
import pandas.util.testing as tm
88

@@ -50,6 +50,16 @@ def test_set_closed(self, closed, new_closed):
5050
expected = IntervalArray.from_breaks(range(10), closed=new_closed)
5151
tm.assert_extension_array_equal(result, expected)
5252

53+
@pytest.mark.parametrize('other', [
54+
Interval(0, 1, closed='right'),
55+
IntervalArray.from_breaks([1, 2, 3, 4], closed='right'),
56+
])
57+
def test_where_raises(self, other):
58+
arr = IntervalArray.from_breaks([1, 2, 3, 4], closed='left')
59+
match = "'other.closed' is 'right', expected 'left'."
60+
with pytest.raises(ValueError, match=match):
61+
arr.where([True, False, True], other=other)
62+
5363

5464
class TestSetitem(object):
5565

pandas/tests/arrays/test_period.py

+15
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,21 @@ def test_sub_period():
199199
arr - other
200200

201201

202+
# ----------------------------------------------------------------------------
203+
# Methods
204+
205+
@pytest.mark.parametrize('other', [
206+
pd.Period('2000', freq='H'),
207+
period_array(['2000', '2001', '2000'], freq='H')
208+
])
209+
def test_where_different_freq_raises(other):
210+
arr = period_array(['2000', '2001', '2002'], freq='D')
211+
cond = np.array([True, False, True])
212+
with pytest.raises(IncompatibleFrequency,
213+
match="Input has different freq=H"):
214+
arr.where(cond, other)
215+
216+
202217
# ----------------------------------------------------------------------------
203218
# Printing
204219

0 commit comments

Comments
 (0)