Skip to content

Commit 4f44f84

Browse files
TomAugspurgervictor
authored and
victor
committed
is_bool_dtype for ExtensionArrays (pandas-dev#22667)
Closes pandas-dev#22665 Closes pandas-dev#22326
1 parent 9cd137a commit 4f44f84

File tree

10 files changed

+276
-8
lines changed

10 files changed

+276
-8
lines changed

doc/source/whatsnew/v0.24.0.txt

+3-1
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,7 @@ ExtensionType Changes
485485
- ``ExtensionArray`` has gained the abstract methods ``.dropna()`` (:issue:`21185`)
486486
- ``ExtensionDtype`` has gained the ability to instantiate from string dtypes, e.g. ``decimal`` would instantiate a registered ``DecimalDtype``; furthermore
487487
the ``ExtensionDtype`` has gained the method ``construct_array_type`` (:issue:`21185`)
488+
- An ``ExtensionArray`` with a boolean dtype now works correctly as a boolean indexer. :meth:`pandas.api.types.is_bool_dtype` now properly considers them boolean (:issue:`22326`)
488489
- Added ``ExtensionDtype._is_numeric`` for controlling whether an extension dtype is considered numeric (:issue:`22290`).
489490
- The ``ExtensionArray`` constructor, ``_from_sequence`` now take the keyword arg ``copy=False`` (:issue:`21185`)
490491
- Bug in :meth:`Series.get` for ``Series`` using ``ExtensionArray`` and integer index (:issue:`21257`)
@@ -616,7 +617,8 @@ Categorical
616617
^^^^^^^^^^^
617618

618619
- Bug in :meth:`Categorical.from_codes` where ``NaN`` values in ``codes`` were silently converted to ``0`` (:issue:`21767`). In the future this will raise a ``ValueError``. Also changes the behavior of ``.from_codes([1.1, 2.0])``.
619-
- Constructing a :class:`pd.CategoricalIndex` with empty values and boolean categories was raising a ``ValueError`` after a change to dtype coercion (:issue:`22702`).
620+
- Bug when indexing with a boolean-valued ``Categorical``. Now a boolean-valued ``Categorical`` is treated as a boolean mask (:issue:`22665`)
621+
- Constructing a :class:`CategoricalIndex` with empty values and boolean categories was raising a ``ValueError`` after a change to dtype coercion (:issue:`22702`).
620622

621623
Datetimelike
622624
^^^^^^^^^^^^

pandas/core/common.py

+35-5
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
from pandas import compat
1616
from pandas.compat import iteritems, PY36, OrderedDict
1717
from pandas.core.dtypes.generic import ABCSeries, ABCIndex, ABCIndexClass
18-
from pandas.core.dtypes.common import is_integer
18+
from pandas.core.dtypes.common import (
19+
is_integer, is_bool_dtype, is_extension_array_dtype, is_array_like
20+
)
1921
from pandas.core.dtypes.inference import _iterable_not_string
2022
from pandas.core.dtypes.missing import isna, isnull, notnull # noqa
2123
from pandas.core.dtypes.cast import construct_1d_object_array_from_listlike
@@ -100,17 +102,45 @@ def maybe_box_datetimelike(value):
100102

101103

102104
def is_bool_indexer(key):
103-
if isinstance(key, (ABCSeries, np.ndarray, ABCIndex)):
105+
# type: (Any) -> bool
106+
"""
107+
Check whether `key` is a valid boolean indexer.
108+
109+
Parameters
110+
----------
111+
key : Any
112+
Only list-likes may be considered boolean indexers.
113+
All other types are not considered a boolean indexer.
114+
For array-like input, boolean ndarrays or ExtensionArrays
115+
with ``_is_boolean`` set are considered boolean indexers.
116+
117+
Returns
118+
-------
119+
bool
120+
121+
Raises
122+
------
123+
ValueError
124+
When the array is an object-dtype ndarray or ExtensionArray
125+
and contains missing values.
126+
"""
127+
na_msg = 'cannot index with vector containing NA / NaN values'
128+
if (isinstance(key, (ABCSeries, np.ndarray, ABCIndex)) or
129+
(is_array_like(key) and is_extension_array_dtype(key.dtype))):
104130
if key.dtype == np.object_:
105131
key = np.asarray(values_from_object(key))
106132

107133
if not lib.is_bool_array(key):
108134
if isna(key).any():
109-
raise ValueError('cannot index with vector containing '
110-
'NA / NaN values')
135+
raise ValueError(na_msg)
111136
return False
112137
return True
113-
elif key.dtype == np.bool_:
138+
elif is_bool_dtype(key.dtype):
139+
# an ndarray with bool-dtype by definition has no missing values.
140+
# So we only need to check for NAs in ExtensionArrays
141+
if is_extension_array_dtype(key.dtype):
142+
if np.any(key.isna()):
143+
raise ValueError(na_msg)
114144
return True
115145
elif isinstance(key, list):
116146
try:

pandas/core/dtypes/base.py

+20
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,25 @@ def _is_numeric(self):
106106
"""
107107
return False
108108

109+
@property
110+
def _is_boolean(self):
111+
# type: () -> bool
112+
"""
113+
Whether this dtype should be considered boolean.
114+
115+
By default, ExtensionDtypes are assumed to be non-numeric.
116+
Setting this to True will affect the behavior of several places,
117+
e.g.
118+
119+
* is_bool
120+
* boolean indexing
121+
122+
Returns
123+
-------
124+
bool
125+
"""
126+
return False
127+
109128

110129
class ExtensionDtype(_DtypeOpsMixin):
111130
"""A custom data type, to be paired with an ExtensionArray.
@@ -125,6 +144,7 @@ class ExtensionDtype(_DtypeOpsMixin):
125144
pandas operations
126145
127146
* _is_numeric
147+
* _is_boolean
128148
129149
Optionally one can override construct_array_type for construction
130150
with the name of this dtype via the Registry. See

pandas/core/dtypes/common.py

+17
Original file line numberDiff line numberDiff line change
@@ -1619,6 +1619,11 @@ def is_bool_dtype(arr_or_dtype):
16191619
-------
16201620
boolean : Whether or not the array or dtype is of a boolean dtype.
16211621
1622+
Notes
1623+
-----
1624+
An ExtensionArray is considered boolean when the ``_is_boolean``
1625+
attribute is set to True.
1626+
16221627
Examples
16231628
--------
16241629
>>> is_bool_dtype(str)
@@ -1635,6 +1640,8 @@ def is_bool_dtype(arr_or_dtype):
16351640
False
16361641
>>> is_bool_dtype(np.array([True, False]))
16371642
True
1643+
>>> is_bool_dtype(pd.Categorical([True, False]))
1644+
True
16381645
"""
16391646

16401647
if arr_or_dtype is None:
@@ -1645,6 +1652,13 @@ def is_bool_dtype(arr_or_dtype):
16451652
# this isn't even a dtype
16461653
return False
16471654

1655+
if isinstance(arr_or_dtype, (ABCCategorical, ABCCategoricalIndex)):
1656+
arr_or_dtype = arr_or_dtype.dtype
1657+
1658+
if isinstance(arr_or_dtype, CategoricalDtype):
1659+
arr_or_dtype = arr_or_dtype.categories
1660+
# now we use the special definition for Index
1661+
16481662
if isinstance(arr_or_dtype, ABCIndexClass):
16491663

16501664
# TODO(jreback)
@@ -1653,6 +1667,9 @@ def is_bool_dtype(arr_or_dtype):
16531667
# guess this
16541668
return (arr_or_dtype.is_object and
16551669
arr_or_dtype.inferred_type == 'boolean')
1670+
elif is_extension_array_dtype(arr_or_dtype):
1671+
dtype = getattr(arr_or_dtype, 'dtype', arr_or_dtype)
1672+
return dtype._is_boolean
16561673

16571674
return issubclass(tipo, np.bool_)
16581675

pandas/core/dtypes/dtypes.py

+6
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,12 @@ def ordered(self):
462462
"""Whether the categories have an ordered relationship"""
463463
return self._ordered
464464

465+
@property
466+
def _is_boolean(self):
467+
from pandas.core.dtypes.common import is_bool_dtype
468+
469+
return is_bool_dtype(self.categories)
470+
465471

466472
class DatetimeTZDtypeType(type):
467473
"""

pandas/tests/arrays/categorical/test_indexing.py

+26-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
import numpy as np
66

77
import pandas.util.testing as tm
8-
from pandas import Categorical, Index, CategoricalIndex, PeriodIndex
8+
from pandas import Categorical, Index, CategoricalIndex, PeriodIndex, Series
9+
import pandas.core.common as com
910
from pandas.tests.arrays.categorical.common import TestCategorical
1011

1112

@@ -121,3 +122,27 @@ def test_get_indexer_non_unique(self, idx_values, key_values, key_class):
121122

122123
tm.assert_numpy_array_equal(expected, result)
123124
tm.assert_numpy_array_equal(exp_miss, res_miss)
125+
126+
127+
@pytest.mark.parametrize("index", [True, False])
128+
def test_mask_with_boolean(index):
129+
s = Series(range(3))
130+
idx = Categorical([True, False, True])
131+
if index:
132+
idx = CategoricalIndex(idx)
133+
134+
assert com.is_bool_indexer(idx)
135+
result = s[idx]
136+
expected = s[idx.astype('object')]
137+
tm.assert_series_equal(result, expected)
138+
139+
140+
@pytest.mark.parametrize("index", [True, False])
141+
def test_mask_with_boolean_raises(index):
142+
s = Series(range(3))
143+
idx = Categorical([True, False, None])
144+
if index:
145+
idx = CategoricalIndex(idx)
146+
147+
with tm.assert_raises_regex(ValueError, 'NA / NaN'):
148+
s[idx]

pandas/tests/dtypes/test_dtypes.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
is_dtype_equal, is_datetime64_ns_dtype,
1818
is_datetime64_dtype, is_interval_dtype,
1919
is_datetime64_any_dtype, is_string_dtype,
20-
_coerce_to_dtype)
20+
_coerce_to_dtype, is_bool_dtype)
2121
import pandas.util.testing as tm
2222

2323

@@ -126,6 +126,18 @@ def test_tuple_categories(self):
126126
result = CategoricalDtype(categories)
127127
assert all(result.categories == categories)
128128

129+
@pytest.mark.parametrize("categories, expected", [
130+
([True, False], True),
131+
([True, False, None], True),
132+
([True, False, "a", "b'"], False),
133+
([0, 1], False),
134+
])
135+
def test_is_boolean(self, categories, expected):
136+
cat = Categorical(categories)
137+
assert cat.dtype._is_boolean is expected
138+
assert is_bool_dtype(cat) is expected
139+
assert is_bool_dtype(cat.dtype) is expected
140+
129141

130142
class TestDatetimeTZDtype(Base):
131143

pandas/tests/extension/arrow/__init__.py

Whitespace-only changes.

pandas/tests/extension/arrow/bool.py

+108
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
"""Rudimentary Apache Arrow-backed ExtensionArray.
2+
3+
At the moment, just a boolean array / type is implemented.
4+
Eventually, we'll want to parametrize the type and support
5+
multiple dtypes. Not all methods are implemented yet, and the
6+
current implementation is not efficient.
7+
"""
8+
import copy
9+
import itertools
10+
11+
import numpy as np
12+
import pyarrow as pa
13+
import pandas as pd
14+
from pandas.api.extensions import (
15+
ExtensionDtype, ExtensionArray, take, register_extension_dtype
16+
)
17+
18+
19+
@register_extension_dtype
20+
class ArrowBoolDtype(ExtensionDtype):
21+
22+
type = np.bool_
23+
kind = 'b'
24+
name = 'arrow_bool'
25+
na_value = pa.NULL
26+
27+
@classmethod
28+
def construct_from_string(cls, string):
29+
if string == cls.name:
30+
return cls()
31+
else:
32+
raise TypeError("Cannot construct a '{}' from "
33+
"'{}'".format(cls, string))
34+
35+
@classmethod
36+
def construct_array_type(cls):
37+
return ArrowBoolArray
38+
39+
def _is_boolean(self):
40+
return True
41+
42+
43+
class ArrowBoolArray(ExtensionArray):
44+
def __init__(self, values):
45+
if not isinstance(values, pa.ChunkedArray):
46+
raise ValueError
47+
48+
assert values.type == pa.bool_()
49+
self._data = values
50+
self._dtype = ArrowBoolDtype()
51+
52+
def __repr__(self):
53+
return "ArrowBoolArray({})".format(repr(self._data))
54+
55+
@classmethod
56+
def from_scalars(cls, values):
57+
arr = pa.chunked_array([pa.array(np.asarray(values))])
58+
return cls(arr)
59+
60+
@classmethod
61+
def from_array(cls, arr):
62+
assert isinstance(arr, pa.Array)
63+
return cls(pa.chunked_array([arr]))
64+
65+
@classmethod
66+
def _from_sequence(cls, scalars, dtype=None, copy=False):
67+
return cls.from_scalars(scalars)
68+
69+
def __getitem__(self, item):
70+
return self._data.to_pandas()[item]
71+
72+
def __len__(self):
73+
return len(self._data)
74+
75+
@property
76+
def dtype(self):
77+
return self._dtype
78+
79+
@property
80+
def nbytes(self):
81+
return sum(x.size for chunk in self._data.chunks
82+
for x in chunk.buffers()
83+
if x is not None)
84+
85+
def isna(self):
86+
return pd.isna(self._data.to_pandas())
87+
88+
def take(self, indices, allow_fill=False, fill_value=None):
89+
data = self._data.to_pandas()
90+
91+
if allow_fill and fill_value is None:
92+
fill_value = self.dtype.na_value
93+
94+
result = take(data, indices, fill_value=fill_value,
95+
allow_fill=allow_fill)
96+
return self._from_sequence(result, dtype=self.dtype)
97+
98+
def copy(self, deep=False):
99+
if deep:
100+
return copy.deepcopy(self._data)
101+
else:
102+
return copy.copy(self._data)
103+
104+
def _concat_same_type(cls, to_concat):
105+
chunks = list(itertools.chain.from_iterable(x._data.chunks
106+
for x in to_concat))
107+
arr = pa.chunked_array(chunks)
108+
return cls(arr)
+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import numpy as np
2+
import pytest
3+
import pandas as pd
4+
import pandas.util.testing as tm
5+
from pandas.tests.extension import base
6+
7+
pytest.importorskip('pyarrow', minversion="0.10.0")
8+
9+
from .bool import ArrowBoolDtype, ArrowBoolArray
10+
11+
12+
@pytest.fixture
13+
def dtype():
14+
return ArrowBoolDtype()
15+
16+
17+
@pytest.fixture
18+
def data():
19+
return ArrowBoolArray.from_scalars(np.random.randint(0, 2, size=100,
20+
dtype=bool))
21+
22+
23+
class BaseArrowTests(object):
24+
pass
25+
26+
27+
class TestDtype(BaseArrowTests, base.BaseDtypeTests):
28+
def test_array_type_with_arg(self, data, dtype):
29+
pytest.skip("GH-22666")
30+
31+
32+
class TestInterface(BaseArrowTests, base.BaseInterfaceTests):
33+
def test_repr(self, data):
34+
raise pytest.skip("TODO")
35+
36+
37+
class TestConstructors(BaseArrowTests, base.BaseConstructorsTests):
38+
def test_from_dtype(self, data):
39+
pytest.skip("GH-22666")
40+
41+
42+
def test_is_bool_dtype(data):
43+
assert pd.api.types.is_bool_dtype(data)
44+
assert pd.core.common.is_bool_indexer(data)
45+
s = pd.Series(range(len(data)))
46+
result = s[data]
47+
expected = s[np.asarray(data)]
48+
tm.assert_series_equal(result, expected)

0 commit comments

Comments
 (0)