Skip to content

Commit 04029ac

Browse files
committed
BUG: EA-backed boolean indexers
Closes #22665 Closes #22326
1 parent fe35002 commit 04029ac

File tree

4 files changed

+31
-4
lines changed

4 files changed

+31
-4
lines changed

doc/source/whatsnew/v0.24.0.txt

+2
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,7 @@ ExtensionType Changes
484484
- ``ExtensionArray`` has gained the abstract methods ``.dropna()`` (:issue:`21185`)
485485
- ``ExtensionDtype`` has gained the ability to instantiate from string dtypes, e.g. ``decimal`` would instantiate a registered ``DecimalDtype``; furthermore
486486
the ``ExtensionDtype`` has gained the method ``construct_array_type`` (:issue:`21185`)
487+
- An ``ExtensionArray`` with a boolean dtype now works correctly as a boolean indexer. :meth:`pandas.api.types.is_bool_dtype` now properly considers them boolean (:issue:`22326`)
487488
- Added ``ExtensionDtype._is_numeric`` for controlling whether an extension dtype is considered numeric (:issue:`22290`).
488489
- The ``ExtensionArray`` constructor, ``_from_sequence`` now take the keyword arg ``copy=False`` (:issue:`21185`)
489490
- Bug in :meth:`Series.get` for ``Series`` using ``ExtensionArray`` and integer index (:issue:`21257`)
@@ -608,6 +609,7 @@ Categorical
608609
^^^^^^^^^^^
609610

610611
- Bug in :meth:`Categorical.from_codes` where ``NaN`` values in ``codes`` were silently converted to ``0`` (:issue:`21767`). In the future this will raise a ``ValueError``. Also changes the behavior of ``.from_codes([1.1, 2.0])``.
612+
- Bug when indexing with a boolean-valued ``Categorical``. Now categoricals are treated as a boolean mask (:issue:`22665`)
611613

612614
Datetimelike
613615
^^^^^^^^^^^^

pandas/core/common.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
from pandas import compat
1616
from pandas.compat import iteritems, PY36, OrderedDict
1717
from pandas.core.dtypes.generic import ABCSeries, ABCIndex, ABCIndexClass
18-
from pandas.core.dtypes.common import is_integer
18+
from pandas.core.dtypes.common import (
19+
is_integer, is_bool_dtype, is_extension_array_dtype, is_array_like
20+
)
1921
from pandas.core.dtypes.inference import _iterable_not_string
2022
from pandas.core.dtypes.missing import isna, isnull, notnull # noqa
2123
from pandas.core.dtypes.cast import construct_1d_object_array_from_listlike
@@ -100,7 +102,8 @@ def maybe_box_datetimelike(value):
100102

101103

102104
def is_bool_indexer(key):
103-
if isinstance(key, (ABCSeries, np.ndarray, ABCIndex)):
105+
if (isinstance(key, (ABCSeries, np.ndarray, ABCIndex)) or
106+
(is_array_like(key) and is_extension_array_dtype(key.dtype))):
104107
if key.dtype == np.object_:
105108
key = np.asarray(values_from_object(key))
106109

@@ -110,7 +113,7 @@ def is_bool_indexer(key):
110113
'NA / NaN values')
111114
return False
112115
return True
113-
elif key.dtype == np.bool_:
116+
elif is_bool_dtype(key.dtype):
114117
return True
115118
elif isinstance(key, list):
116119
try:

pandas/core/dtypes/common.py

+12
Original file line numberDiff line numberDiff line change
@@ -1608,6 +1608,8 @@ def is_bool_dtype(arr_or_dtype):
16081608
False
16091609
>>> is_bool_dtype(np.array([True, False]))
16101610
True
1611+
>>> is_bool_dtype(pd.Categorical([True, False]))
1612+
True
16111613
"""
16121614

16131615
if arr_or_dtype is None:
@@ -1618,6 +1620,13 @@ def is_bool_dtype(arr_or_dtype):
16181620
# this isn't even a dtype
16191621
return False
16201622

1623+
if isinstance(arr_or_dtype, (ABCCategorical, ABCCategoricalIndex)):
1624+
arr_or_dtype = arr_or_dtype.dtype
1625+
1626+
if isinstance(arr_or_dtype, CategoricalDtype):
1627+
arr_or_dtype = arr_or_dtype.categories
1628+
# now we use the special definition for Index
1629+
16211630
if isinstance(arr_or_dtype, ABCIndexClass):
16221631

16231632
# TODO(jreback)
@@ -1626,6 +1635,9 @@ def is_bool_dtype(arr_or_dtype):
16261635
# guess this
16271636
return (arr_or_dtype.is_object and
16281637
arr_or_dtype.inferred_type == 'boolean')
1638+
elif is_extension_array_dtype(arr_or_dtype):
1639+
dtype = getattr(arr_or_dtype, 'dtype', arr_or_dtype)
1640+
return issubclass(dtype.type, np.bool_)
16291641

16301642
return issubclass(tipo, np.bool_)
16311643

pandas/tests/arrays/categorical/test_indexing.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
import numpy as np
66

77
import pandas.util.testing as tm
8-
from pandas import Categorical, Index, CategoricalIndex, PeriodIndex
8+
from pandas import Categorical, Index, CategoricalIndex, PeriodIndex, Series
9+
from pandas.core.common import is_bool_indexer
910
from pandas.tests.arrays.categorical.common import TestCategorical
1011

1112

@@ -121,3 +122,12 @@ def test_get_indexer_non_unique(self, idx_values, key_values, key_class):
121122

122123
tm.assert_numpy_array_equal(expected, result)
123124
tm.assert_numpy_array_equal(exp_miss, res_miss)
125+
126+
127+
def test_mask_with_boolean():
128+
s = Series(range(3))
129+
idx = CategoricalIndex([True, False, True])
130+
assert is_bool_indexer(idx)
131+
result = s[idx]
132+
expected = s[idx.astype('object')]
133+
tm.assert_series_equal(result, expected)

0 commit comments

Comments
 (0)