Skip to content

BUG/Perf: Support ExtensionArrays in where #24114

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 26 commits into from
Dec 10, 2018
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
c4604df
API: Added ExtensionArray.where
TomAugspurger Dec 3, 2018
56470c3
Fixups:
TomAugspurger Dec 5, 2018
6f79282
32-bit compat
TomAugspurger Dec 5, 2018
a69dbb3
warn for categorical
TomAugspurger Dec 5, 2018
911a2da
debug 32-bit issue
TomAugspurger Dec 5, 2018
badb5be
compat, revert
TomAugspurger Dec 6, 2018
edff47e
32-bit compat
TomAugspurger Dec 6, 2018
4715ef6
Merge remote-tracking branch 'upstream/master' into ea-where
TomAugspurger Dec 6, 2018
d90f384
deprecation note for categorical
TomAugspurger Dec 6, 2018
5e14414
where versionadded
TomAugspurger Dec 6, 2018
e9665b8
Merge remote-tracking branch 'upstream/master' into ea-where
TomAugspurger Dec 7, 2018
033ac9c
Setitem-based where
TomAugspurger Dec 7, 2018
1271d3d
Merge remote-tracking branch 'upstream/master' into ea-where
TomAugspurger Dec 7, 2018
9e0d87d
update docs, cleanup
TomAugspurger Dec 7, 2018
e05a597
wip
TomAugspurger Dec 7, 2018
796332c
cleanup
TomAugspurger Dec 7, 2018
cad0c4c
Merge remote-tracking branch 'upstream/master' into ea-where
TomAugspurger Dec 7, 2018
6edd286
py2 compat
TomAugspurger Dec 7, 2018
30775f0
Merge remote-tracking branch 'upstream/master' into ea-where
TomAugspurger Dec 7, 2018
4de8bb5
Updated
TomAugspurger Dec 7, 2018
ce04a75
Merge remote-tracking branch 'upstream/master' into ea-where
TomAugspurger Dec 9, 2018
f98a82c
Clarify
TomAugspurger Dec 9, 2018
bcfb8f8
Merge remote-tracking branch 'upstream/master' into ea-where
TomAugspurger Dec 10, 2018
8d9b20b
Simplify error message
TomAugspurger Dec 10, 2018
c0351fd
sparse whatsnew
TomAugspurger Dec 10, 2018
539d3cb
updates
TomAugspurger Dec 10, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions doc/source/whatsnew/v0.24.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1136,6 +1136,8 @@ Deprecations
- :func:`pandas.types.is_datetimetz` is deprecated in favor of `pandas.types.is_datetime64tz` (:issue:`23917`)
- Creating a :class:`TimedeltaIndex` or :class:`DatetimeIndex` by passing range arguments `start`, `end`, and `periods` is deprecated in favor of :func:`timedelta_range` and :func:`date_range` (:issue:`23919`)
- Passing a string alias like ``'datetime64[ns, UTC]'`` as the `unit` parameter to :class:`DatetimeTZDtype` is deprecated. Use :class:`DatetimeTZDtype.construct_from_string` instead (:issue:`23990`).
- In :meth:`Series.where` with Categorical data, providing an ``other`` that is not present in the categories is deprecated. Convert the categorical to a different dtype or add the ``other`` to the categories first (:issue:`24077`).


.. _whatsnew_0240.deprecations.datetimelike_int_ops:

Expand Down Expand Up @@ -1244,6 +1246,7 @@ Performance Improvements
- Improved performance of :meth:`DatetimeIndex.tz_localize` and various ``DatetimeIndex`` attributes with dateutil UTC timezone (:issue:`23772`)
- Fixed a performance regression on Windows with Python 3.7 of :func:`pd.read_csv` (:issue:`23516`)
- Improved performance of :class:`Categorical` constructor for `Series` objects (:issue:`23814`)
- Improved performance of :meth:`~DataFrame.where` for Categorical data (:issue:`24077`)

.. _whatsnew_0240.docs:

Expand All @@ -1270,6 +1273,7 @@ Categorical
- In meth:`Series.unstack`, specifying a ``fill_value`` not present in the categories now raises a ``TypeError`` rather than ignoring the ``fill_value`` (:issue:`23284`)
- Bug when resampling :meth:`Dataframe.resample()` and aggregating on categorical data, the categorical dtype was getting lost. (:issue:`23227`)
- Bug in many methods of the ``.str``-accessor, which always failed on calling the ``CategoricalIndex.str`` constructor (:issue:`23555`, :issue:`23556`)
- Bug in :meth:`Series.where` losing the categorical dtype for categorical data (:issue:`24077`)

Datetimelike
^^^^^^^^^^^^
Expand Down Expand Up @@ -1305,6 +1309,7 @@ Datetimelike
- Bug in :class:`DatetimeIndex` where calling ``np.array(dtindex, dtype=object)`` would incorrectly return an array of ``long`` objects (:issue:`23524`)
- Bug in :class:`Index` where passing a timezone-aware :class:`DatetimeIndex` and `dtype=object` would incorrectly raise a ``ValueError`` (:issue:`23524`)
- Bug in :class:`Index` where calling ``np.array(dtindex, dtype=object)`` on a timezone-naive :class:`DatetimeIndex` would return an array of ``datetime`` objects instead of :class:`Timestamp` objects, potentially losing nanosecond portions of the timestamps (:issue:`23524`)
- Bug in :class:`Categorical.__setitem__` not allowing setting with another ``Categorical`` when both are undordered and have the same categories, but in a different order (:issue:`24142`)

Timedelta
^^^^^^^^^
Expand Down
3 changes: 3 additions & 0 deletions pandas/compat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def get_range_parameters(data):
reduce = functools.reduce
long = int
unichr = chr
import reprlib

# This was introduced in Python 3.3, but we don't support
# Python 3.x < 3.5, so checking PY3 is safe.
Expand Down Expand Up @@ -271,6 +272,7 @@ class to receive bound method
class_types = type,
text_type = str
binary_type = bytes
import reprlib

def u(s):
return s
Expand Down Expand Up @@ -323,6 +325,7 @@ def set_function_name(f, name, cls):
class_types = (type, types.ClassType)
text_type = unicode
binary_type = str
import repr as reprlib

def u(s):
return unicode(s, "unicode_escape")
Expand Down
2 changes: 2 additions & 0 deletions pandas/core/arrays/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,8 @@ def __setitem__(self, key, value):
# example, a string like '2018-01-01' is coerced to a datetime
# when setting on a datetime64ns array. In general, if the
# __init__ method coerces that value, then so should __setitem__
# Note, also, that Series/DataFrame.where internally use __setitem__
# on a copy of the data.
raise NotImplementedError(_not_implemented_message.format(
type(self), '__setitem__')
)
Expand Down
12 changes: 11 additions & 1 deletion pandas/core/arrays/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -2078,11 +2078,21 @@ def __setitem__(self, key, value):
`Categorical` does not have the same categories
"""

if isinstance(value, (ABCIndexClass, ABCSeries)):
value = value.array

# require identical categories set
if isinstance(value, Categorical):
if not value.categories.equals(self.categories):
if not is_dtype_equal(self, value):
raise ValueError("Cannot set a Categorical with another, "
"without identical categories")
if not self.categories.equals(value.categories):
new_codes = _recode_for_categories(
value.codes, value.categories, self.categories
)
value = Categorical.from_codes(new_codes,
categories=self.categories,
ordered=self.ordered)

rvalue = value if is_list_like(value) else [value]

Expand Down
2 changes: 2 additions & 0 deletions pandas/core/arrays/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,8 @@ def __array__(self, dtype=None, copy=True):

def __setitem__(self, key, value):
# I suppose we could allow setting of non-fill_value elements.
# TODO(SparseArray.__setitem__): remove special cases in
# ExtensionBlock.where
msg = "SparseArray does not support item assignment via setitem"
raise TypeError(msg)

Expand Down
5 changes: 4 additions & 1 deletion pandas/core/indexes/category.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,10 +501,13 @@ def _can_reindex(self, indexer):

@Appender(_index_shared_docs['where'])
def where(self, cond, other=None):
# TODO: Investigate an alternative implementation with
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# 1. copy the underyling Categorical
# 2. setitem with `cond` and `other`
# 3. Rebuild CategoricalIndex.
if other is None:
other = self._na_value
values = np.where(cond, self.values, other)

cat = Categorical(values, dtype=self.dtype)
return self._shallow_copy(cat, **self._get_attributes_dict())

Expand Down
87 changes: 85 additions & 2 deletions pandas/core/internals/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,13 @@
from pandas.core.dtypes.dtypes import (
CategoricalDtype, DatetimeTZDtype, ExtensionDtype, PandasExtensionDtype)
from pandas.core.dtypes.generic import (
ABCDatetimeIndex, ABCExtensionArray, ABCIndexClass, ABCSeries)
ABCDataFrame, ABCDatetimeIndex, ABCExtensionArray, ABCIndexClass,
ABCSeries)
from pandas.core.dtypes.missing import (
_isna_compat, array_equivalent, is_null_datelike_scalar, isna, notna)

import pandas.core.algorithms as algos
from pandas.core.arrays import Categorical, ExtensionArray
from pandas.core.arrays import Categorical, ExtensionArray, SparseArray
from pandas.core.base import PandasObject
import pandas.core.common as com
from pandas.core.indexes.datetimes import DatetimeIndex
Expand Down Expand Up @@ -1967,6 +1968,57 @@ def shift(self, periods, axis=0):
placement=self.mgr_locs,
ndim=self.ndim)]

def where(self, other, cond, align=True, errors='raise',
try_cast=False, axis=0, transpose=False):
if isinstance(other, (ABCIndexClass, ABCSeries)):
other = other.array

if isinstance(cond, ABCDataFrame):
assert cond.shape[1] == 1
cond = cond.iloc[:, 0].array

if isinstance(other, ABCDataFrame):
assert other.shape[1] == 1
other = other.iloc[:, 0].array

if isinstance(cond, (ABCIndexClass, ABCSeries)):
cond = cond.array

if lib.is_scalar(other) and isna(other):
# The default `other` for Series / Frame is np.nan
# we want to replace that with the correct NA value
# for the type
other = self.dtype.na_value

if is_sparse(self.values):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without this, we fail in the

            result = self._holder._from_sequence(
                 np.where(cond, self.values, other),
                 dtype=dtype,

since the where may change the dtype, if NaN is introduced.

Implementing SparseArray.__setitem__ would allow us to remove this block.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be an overriding method in Sparse then, not here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't have a SparseBlock anymore. I can add one back if you want, but I figured it'd be easier not to since implementing SparseArray.__setitem__ will remove the need for this, and we'd just have to remove SparseBlock again.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is pretty hacky. This was why we had originally a .get_values() methon on Sparse to do things like this. We need something to give back the underlying type of the object, which is useful for Categorical as well. Would rather create a generalized soln than hack it like this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, we don't need this. I think we can just re-infer the dtype from the output of np.where.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so is this changing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changing from master? Yes, in the sense that it'll return a SparseArray. But it still densifies when np.where is called.

If you mean "is this changing in the future", yes it'll be removed when SparseArray.__setitem__ is implemented.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh ok, can you add a TODO comment

# ugly workaround for ensure that the dtype is OK
# after we insert NaNs.
if is_sparse(other):
otype = other.dtype.subtype
else:
otype = other
dtype = self.dtype.update_dtype(
np.result_type(self.values.dtype.subtype, otype)
)
else:
dtype = self.dtype

# rough heuristic to see if the other array implements setitem
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

again you don't actually need to do this here, rather override in the appropriate class

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will still need the check for extension, even if we create SparseBlock again.

if (self._holder.__setitem__ == ExtensionArray.__setitem__
or self._holder.__setitem__ == SparseArray.__setitem__):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what the heck is this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The general block is to check if the block implements __setitem__. That specific line is backwards compat for SparseArray, which implements __setitem__ to raise a TypeError instead of a NotImplementedError.

I suppose it'd be cleaner to do this in a try / except block...

result = self._holder._from_sequence(
np.where(cond, self.values, other),
dtype=dtype,
)
else:
result = self.values.copy()
icond = ~cond
if lib.is_scalar(other):
result[icond] = other
else:
result[icond] = other[icond]
return self.make_block_same_class(result, placement=self.mgr_locs)

@property
def _ftype(self):
return getattr(self.values, '_pandas_ftype', Block._ftype)
Expand Down Expand Up @@ -2646,6 +2698,37 @@ def concat_same_type(self, to_concat, placement=None):
values, placement=placement or slice(0, len(values), 1),
ndim=self.ndim)

def where(self, other, cond, align=True, errors='raise',
try_cast=False, axis=0, transpose=False):
# This can all be deleted in favor of ExtensionBlock.where once
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add TODO(EA) or someting here so we know to remove this

# we enforce the deprecation.
object_msg = (
"Implicitly converting categorical to object-dtype ndarray. "
"The values `{}' are not present in this categorical's "
"categories. A future version of pandas will raise a ValueError "
"when 'other' contains different categories.\n\n"
"To preserve the current behavior, add the new categories to "
"the categorical before calling 'where', or convert the "
"categorical to a different dtype."
)
try:
# Attempt to do preserve categorical dtype.
result = super(CategoricalBlock, self).where(
other, cond, align, errors, try_cast, axis, transpose
)
except (TypeError, ValueError):
if lib.is_scalar(other):
msg = object_msg.format(other)
else:
msg = compat.reprlib.repr(other)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we don't blow up with a long message for large categoricals. I messed it up though, one sec.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've removed all this stuff and just print out the text of the message.

With a bit of effort we could figure out exactly which of the new values is causing the fallback of object, but that'd take some work (we don't know the exact type /dtype of other here, so there will be a lot of conditions). Not a high priority.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

k cool


warnings.warn(msg, FutureWarning, stacklevel=6)
result = self.astype(object).where(other, cond, align=align,
errors=errors,
try_cast=try_cast,
axis=axis, transpose=transpose)
return result


class DatetimeBlock(DatetimeLikeBlockMixin, Block):
__slots__ = ()
Expand Down
94 changes: 94 additions & 0 deletions pandas/tests/arrays/categorical/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import pytest

import pandas as pd
from pandas import Categorical, CategoricalIndex, Index, PeriodIndex, Series
import pandas.core.common as com
from pandas.tests.arrays.categorical.common import TestCategorical
Expand Down Expand Up @@ -43,6 +44,45 @@ def test_setitem(self):

tm.assert_categorical_equal(c, expected)

@pytest.mark.parametrize('other', [
pd.Categorical(['b', 'a']),
pd.Categorical(['b', 'a'], categories=['b', 'a']),
])
def test_setitem_same_but_unordered(self, other):
# GH-24142
target = pd.Categorical(['a', 'b'], categories=['a', 'b'])
mask = np.array([True, False])
target[mask] = other[mask]
expected = pd.Categorical(['b', 'b'], categories=['a', 'b'])
tm.assert_categorical_equal(target, expected)

@pytest.mark.parametrize('other', [
pd.Categorical(['b', 'a'], categories=['b', 'a', 'c']),
pd.Categorical(['b', 'a'], categories=['a', 'b', 'c']),
pd.Categorical(['a', 'a'], categories=['a']),
pd.Categorical(['b', 'b'], categories=['b']),
])
def test_setitem_different_unordered_raises(self, other):
# GH-24142
target = pd.Categorical(['a', 'b'], categories=['a', 'b'])
mask = np.array([True, False])
with pytest.raises(ValueError):
target[mask] = other[mask]

@pytest.mark.parametrize('other', [
pd.Categorical(['b', 'a']),
pd.Categorical(['b', 'a'], categories=['b', 'a'], ordered=True),
pd.Categorical(['b', 'a'], categories=['a', 'b', 'c'], ordered=True),
])
def test_setitem_same_ordered_rasies(self, other):
# Gh-24142
target = pd.Categorical(['a', 'b'], categories=['a', 'b'],
ordered=True)
mask = np.array([True, False])

with pytest.raises(ValueError):
target[mask] = other[mask]


class TestCategoricalIndexing(object):

Expand Down Expand Up @@ -122,6 +162,60 @@ def test_get_indexer_non_unique(self, idx_values, key_values, key_class):
tm.assert_numpy_array_equal(expected, result)
tm.assert_numpy_array_equal(exp_miss, res_miss)

def test_where_unobserved_nan(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is where all of the where tests are?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There weren't any previously (we used to fall back to object).

ser = pd.Series(pd.Categorical(['a', 'b']))
result = ser.where([True, False])
expected = pd.Series(pd.Categorical(['a', None],
categories=['a', 'b']))
tm.assert_series_equal(result, expected)

# all NA
ser = pd.Series(pd.Categorical(['a', 'b']))
result = ser.where([False, False])
expected = pd.Series(pd.Categorical([None, None],
categories=['a', 'b']))
tm.assert_series_equal(result, expected)

def test_where_unobserved_categories(self):
ser = pd.Series(
Categorical(['a', 'b', 'c'], categories=['d', 'c', 'b', 'a'])
)
result = ser.where([True, True, False], other='b')
expected = pd.Series(
Categorical(['a', 'b', 'b'], categories=ser.cat.categories)
)
tm.assert_series_equal(result, expected)

def test_where_other_categorical(self):
ser = pd.Series(
Categorical(['a', 'b', 'c'], categories=['d', 'c', 'b', 'a'])
)
other = Categorical(['b', 'c', 'a'], categories=['a', 'c', 'b', 'd'])
result = ser.where([True, False, True], other)
expected = pd.Series(Categorical(['a', 'c', 'c'], dtype=ser.dtype))
tm.assert_series_equal(result, expected)

def test_where_warns(self):
ser = pd.Series(Categorical(['a', 'b', 'c']))
with tm.assert_produces_warning(FutureWarning):
result = ser.where([True, False, True], 'd')

expected = pd.Series(np.array(['a', 'd', 'c'], dtype='object'))
tm.assert_series_equal(result, expected)

def test_where_ordered_differs_rasies(self):
ser = pd.Series(
Categorical(['a', 'b', 'c'], categories=['d', 'c', 'b', 'a'],
ordered=True)
)
other = Categorical(['b', 'c', 'a'], categories=['a', 'c', 'b', 'd'],
ordered=True)
with tm.assert_produces_warning(FutureWarning):
result = ser.where([True, False, True], other)

expected = pd.Series(np.array(['a', 'c', 'c'], dtype=object))
tm.assert_series_equal(result, expected)


@pytest.mark.parametrize("index", [True, False])
def test_mask_with_boolean(index):
Expand Down
14 changes: 13 additions & 1 deletion pandas/tests/arrays/interval/test_interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import numpy as np
import pytest

from pandas import Index, IntervalIndex, date_range, timedelta_range
import pandas as pd
from pandas import Index, Interval, IntervalIndex, date_range, timedelta_range
from pandas.core.arrays import IntervalArray
import pandas.util.testing as tm

Expand Down Expand Up @@ -50,6 +51,17 @@ def test_set_closed(self, closed, new_closed):
expected = IntervalArray.from_breaks(range(10), closed=new_closed)
tm.assert_extension_array_equal(result, expected)

@pytest.mark.parametrize('other', [
Interval(0, 1, closed='right'),
IntervalArray.from_breaks([1, 2, 3, 4], closed='right'),
])
def test_where_raises(self, other):
ser = pd.Series(IntervalArray.from_breaks([1, 2, 3, 4],
closed='left'))
match = "'value.closed' is 'right', expected 'left'."
with pytest.raises(ValueError, match=match):
ser.where([True, False, True], other=other)


class TestSetitem(object):

Expand Down
4 changes: 2 additions & 2 deletions pandas/tests/arrays/sparse/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,10 +357,10 @@ def setitem():
def setslice():
self.arr[1:5] = 2

with pytest.raises(TypeError, match="item assignment"):
with pytest.raises(TypeError, match="assignment via setitem"):
setitem()

with pytest.raises(TypeError, match="item assignment"):
with pytest.raises(TypeError, match="assignment via setitem"):
setslice()

def test_constructor_from_too_large_array(self):
Expand Down
Loading