Skip to content

Commit ee6b0a0

Browse files
authored
ENH: Index[bool] (#45061)
1 parent 601170d commit ee6b0a0

23 files changed

+137
-56
lines changed

doc/source/whatsnew/v1.5.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ Other enhancements
3737
- :meth:`to_numeric` now preserves float64 arrays when downcasting would generate values not representable in float32 (:issue:`43693`)
3838
- :meth:`Series.reset_index` and :meth:`DataFrame.reset_index` now support the argument ``allow_duplicates`` (:issue:`44410`)
3939
- :meth:`.GroupBy.min` and :meth:`.GroupBy.max` now supports `Numba <https://numba.pydata.org/>`_ execution with the ``engine`` keyword (:issue:`45428`)
40+
- Implemented a ``bool``-dtype :class:`Index`, passing a bool-dtype arraylike to ``pd.Index`` will now retain ``bool`` dtype instead of casting to ``object`` (:issue:`45061`)
4041
-
4142

4243
.. ---------------------------------------------------------------------------

pandas/_libs/index.pyi

+1
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ class ObjectEngine(IndexEngine): ...
4242
class DatetimeEngine(Int64Engine): ...
4343
class TimedeltaEngine(DatetimeEngine): ...
4444
class PeriodEngine(Int64Engine): ...
45+
class BoolEngine(UInt8Engine): ...
4546

4647
class BaseMultiIndexCodesEngine:
4748
levels: list[np.ndarray]

pandas/_libs/index.pyx

+7
Original file line numberDiff line numberDiff line change
@@ -802,6 +802,13 @@ cdef class BaseMultiIndexCodesEngine:
802802
include "index_class_helper.pxi"
803803

804804

805+
cdef class BoolEngine(UInt8Engine):
806+
cdef _check_type(self, object val):
807+
if not util.is_bool_object(val):
808+
raise KeyError(val)
809+
return <uint8_t>val
810+
811+
805812
@cython.internal
806813
@cython.freelist(32)
807814
cdef class SharedEngine:

pandas/conftest.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -555,7 +555,8 @@ def _create_mi_with_dt64tz_level():
555555
"num_uint8": tm.makeNumericIndex(100, dtype="uint8"),
556556
"num_float64": tm.makeNumericIndex(100, dtype="float64"),
557557
"num_float32": tm.makeNumericIndex(100, dtype="float32"),
558-
"bool": tm.makeBoolIndex(10),
558+
"bool-object": tm.makeBoolIndex(10).astype(object),
559+
"bool-dtype": Index(np.random.randn(10) < 0),
559560
"categorical": tm.makeCategoricalIndex(100),
560561
"interval": tm.makeIntervalIndex(100),
561562
"empty": Index([]),
@@ -630,7 +631,7 @@ def index_flat_unique(request):
630631
key
631632
for key in indices_dict
632633
if not (
633-
key in ["int", "uint", "range", "empty", "repeats"]
634+
key in ["int", "uint", "range", "empty", "repeats", "bool-dtype"]
634635
or key.startswith("num_")
635636
)
636637
and not isinstance(indices_dict[key], MultiIndex)

pandas/core/algorithms.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -220,9 +220,6 @@ def _reconstruct_data(
220220
elif is_bool_dtype(dtype):
221221
values = values.astype(dtype, copy=False)
222222

223-
# we only support object dtypes bool Index
224-
if isinstance(original, ABCIndex):
225-
values = values.astype(object, copy=False)
226223
elif dtype is not None:
227224
if is_datetime64_dtype(dtype):
228225
dtype = np.dtype("datetime64[ns]")
@@ -830,7 +827,10 @@ def value_counts(
830827
-------
831828
Series
832829
"""
833-
from pandas.core.series import Series
830+
from pandas import (
831+
Index,
832+
Series,
833+
)
834834

835835
name = getattr(values, "name", None)
836836

@@ -868,7 +868,13 @@ def value_counts(
868868
else:
869869
keys, counts = value_counts_arraylike(values, dropna)
870870

871-
result = Series(counts, index=keys, name=name)
871+
# For backwards compatibility, we let Index do its normal type
872+
# inference, _except_ for if if infers from object to bool.
873+
idx = Index._with_infer(keys)
874+
if idx.dtype == bool and keys.dtype == object:
875+
idx = idx.astype(object)
876+
877+
result = Series(counts, index=idx, name=name)
872878

873879
if sort:
874880
result = result.sort_values(ascending=ascending)

pandas/core/indexes/base.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,10 @@ def __new__(
505505
if data.dtype.kind in ["i", "u", "f"]:
506506
# maybe coerce to a sub-class
507507
arr = data
508+
elif data.dtype.kind == "b":
509+
# No special subclass, and Index._ensure_array won't do this
510+
# for us.
511+
arr = np.asarray(data)
508512
else:
509513
arr = com.asarray_tuplesafe(data, dtype=_dtype_obj)
510514

@@ -702,7 +706,7 @@ def _with_infer(cls, *args, **kwargs):
702706
# "Union[ExtensionArray, ndarray[Any, Any]]"; expected
703707
# "ndarray[Any, Any]"
704708
values = lib.maybe_convert_objects(result._values) # type: ignore[arg-type]
705-
if values.dtype.kind in ["i", "u", "f"]:
709+
if values.dtype.kind in ["i", "u", "f", "b"]:
706710
return Index(values, name=result.name)
707711

708712
return result
@@ -872,9 +876,12 @@ def _engine(
872876
):
873877
return libindex.ExtensionEngine(target_values)
874878

879+
target_values = cast(np.ndarray, target_values)
875880
# to avoid a reference cycle, bind `target_values` to a local variable, so
876881
# `self` is not passed into the lambda.
877-
target_values = cast(np.ndarray, target_values)
882+
if target_values.dtype == bool:
883+
return libindex.BoolEngine(target_values)
884+
878885
# error: Argument 1 to "ExtensionEngine" has incompatible type
879886
# "ndarray[Any, Any]"; expected "ExtensionArray"
880887
return self._engine_type(target_values) # type:ignore[arg-type]
@@ -2680,7 +2687,6 @@ def _is_all_dates(self) -> bool:
26802687
"""
26812688
Whether or not the index values only consist of dates.
26822689
"""
2683-
26842690
if needs_i8_conversion(self.dtype):
26852691
return True
26862692
elif self.dtype != _dtype_obj:
@@ -7302,7 +7308,7 @@ def _maybe_cast_data_without_dtype(
73027308
FutureWarning,
73037309
stacklevel=3,
73047310
)
7305-
if result.dtype.kind in ["b", "c"]:
7311+
if result.dtype.kind in ["c"]:
73067312
return subarr
73077313
result = ensure_wrapped_if_datetimelike(result)
73087314
return result

pandas/core/tools/datetimes.py

+2
Original file line numberDiff line numberDiff line change
@@ -1076,6 +1076,8 @@ def to_datetime(
10761076
result = convert_listlike(arg, format)
10771077
else:
10781078
result = convert_listlike(np.array([arg]), format)[0]
1079+
if isinstance(arg, bool) and isinstance(result, np.bool_):
1080+
result = bool(result) # TODO: avoid this kludge.
10791081

10801082
# error: Incompatible return value type (got "Union[Timestamp, NaTType,
10811083
# Series, Index]", expected "Union[DatetimeIndex, Series, float, str,

pandas/core/util/hashing.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ def _hash_ndarray(
319319

320320
# First, turn whatever array this is into unsigned 64-bit ints, if we can
321321
# manage it.
322-
elif isinstance(dtype, bool):
322+
elif dtype == bool:
323323
vals = vals.astype("u8")
324324
elif issubclass(dtype.type, (np.datetime64, np.timedelta64)):
325325
vals = vals.view("i8").astype("u8", copy=False)

pandas/tests/base/test_value_counts.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ def test_value_counts_with_nan(dropna, index_or_series):
284284
obj = klass(values)
285285
res = obj.value_counts(dropna=dropna)
286286
if dropna is True:
287-
expected = Series([1], index=[True])
287+
expected = Series([1], index=Index([True], dtype=obj.dtype))
288288
else:
289289
expected = Series([1, 1, 1], index=[True, pd.NA, np.nan])
290290
tm.assert_series_equal(res, expected)

pandas/tests/indexes/common.py

+5
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,8 @@ def test_ensure_copied_data(self, index):
216216
# RangeIndex cannot be initialized from data
217217
# MultiIndex and CategoricalIndex are tested separately
218218
return
219+
elif index.dtype == object and index.inferred_type == "boolean":
220+
init_kwargs["dtype"] = index.dtype
219221

220222
index_type = type(index)
221223
result = index_type(index.values, copy=True, **init_kwargs)
@@ -522,6 +524,9 @@ def test_fillna(self, index):
522524
# GH 11343
523525
if len(index) == 0:
524526
return
527+
elif index.dtype == bool:
528+
# can't hold NAs
529+
return
525530
elif isinstance(index, NumericIndex) and is_integer_dtype(index.dtype):
526531
return
527532
elif isinstance(index, MultiIndex):

pandas/tests/indexes/multi/test_indexing.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -621,13 +621,22 @@ def test_get_loc_implicit_cast(self, level, dtypes):
621621
idx = MultiIndex.from_product(levels)
622622
assert idx.get_loc(tuple(key)) == 3
623623

624-
def test_get_loc_cast_bool(self):
625-
# GH 19086 : int is casted to bool, but not vice-versa
626-
levels = [[False, True], np.arange(2, dtype="int64")]
624+
@pytest.mark.parametrize("dtype", [bool, object])
625+
def test_get_loc_cast_bool(self, dtype):
626+
# GH 19086 : int is casted to bool, but not vice-versa (for object dtype)
627+
# With bool dtype, we don't cast in either direction.
628+
levels = [Index([False, True], dtype=dtype), np.arange(2, dtype="int64")]
627629
idx = MultiIndex.from_product(levels)
628630

629-
assert idx.get_loc((0, 1)) == 1
630-
assert idx.get_loc((1, 0)) == 2
631+
if dtype is bool:
632+
with pytest.raises(KeyError, match=r"^\(0, 1\)$"):
633+
assert idx.get_loc((0, 1)) == 1
634+
with pytest.raises(KeyError, match=r"^\(1, 0\)$"):
635+
assert idx.get_loc((1, 0)) == 2
636+
else:
637+
# We use python object comparisons, which treat 0 == False and 1 == True
638+
assert idx.get_loc((0, 1)) == 1
639+
assert idx.get_loc((1, 0)) == 2
631640

632641
with pytest.raises(KeyError, match=r"^\(False, True\)$"):
633642
idx.get_loc((False, True))

pandas/tests/indexes/test_any_index.py

+4
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,10 @@ def test_mutability(index):
4949
def test_map_identity_mapping(index):
5050
# GH#12766
5151
result = index.map(lambda x: x)
52+
if index.dtype == object and result.dtype == bool:
53+
assert (index == result).all()
54+
# TODO: could work that into the 'exact="equiv"'?
55+
return # FIXME: doesn't belong in this file anymore!
5256
tm.assert_index_equal(result, index, exact="equiv")
5357

5458

pandas/tests/indexes/test_base.py

+18-9
Original file line numberDiff line numberDiff line change
@@ -321,15 +321,21 @@ def test_view_with_args(self, index):
321321
"unicode",
322322
"string",
323323
pytest.param("categorical", marks=pytest.mark.xfail(reason="gh-25464")),
324-
"bool",
324+
"bool-object",
325+
"bool-dtype",
325326
"empty",
326327
],
327328
indirect=True,
328329
)
329330
def test_view_with_args_object_array_raises(self, index):
330-
msg = "Cannot change data-type for object array"
331-
with pytest.raises(TypeError, match=msg):
332-
index.view("i8")
331+
if index.dtype == bool:
332+
msg = "When changing to a larger dtype"
333+
with pytest.raises(ValueError, match=msg):
334+
index.view("i8")
335+
else:
336+
msg = "Cannot change data-type for object array"
337+
with pytest.raises(TypeError, match=msg):
338+
index.view("i8")
333339

334340
@pytest.mark.parametrize("index", ["int", "range"], indirect=True)
335341
def test_astype(self, index):
@@ -397,9 +403,9 @@ def test_is_(self):
397403

398404
def test_asof_numeric_vs_bool_raises(self):
399405
left = Index([1, 2, 3])
400-
right = Index([True, False])
406+
right = Index([True, False], dtype=object)
401407

402-
msg = "Cannot compare dtypes int64 and object"
408+
msg = "Cannot compare dtypes int64 and bool"
403409
with pytest.raises(TypeError, match=msg):
404410
left.asof(right[0])
405411
# TODO: should right.asof(left[0]) also raise?
@@ -591,7 +597,8 @@ def test_append_empty_preserve_name(self, name, expected):
591597
"index, expected",
592598
[
593599
("string", False),
594-
("bool", False),
600+
("bool-object", False),
601+
("bool-dtype", False),
595602
("categorical", False),
596603
("int", True),
597604
("datetime", False),
@@ -606,7 +613,8 @@ def test_is_numeric(self, index, expected):
606613
"index, expected",
607614
[
608615
("string", True),
609-
("bool", True),
616+
("bool-object", True),
617+
("bool-dtype", False),
610618
("categorical", False),
611619
("int", False),
612620
("datetime", False),
@@ -621,7 +629,8 @@ def test_is_object(self, index, expected):
621629
"index, expected",
622630
[
623631
("string", False),
624-
("bool", False),
632+
("bool-object", False),
633+
("bool-dtype", False),
625634
("categorical", False),
626635
("int", False),
627636
("datetime", True),

pandas/tests/indexes/test_common.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,9 @@ def test_constructor_non_hashable_name(self, index_flat):
8888

8989
def test_constructor_unwraps_index(self, index_flat):
9090
a = index_flat
91-
b = type(a)(a)
91+
# Passing dtype is necessary for Index([True, False], dtype=object)
92+
# case.
93+
b = type(a)(a, dtype=a.dtype)
9294
tm.assert_equal(a._data, b._data)
9395

9496
def test_to_flat_index(self, index_flat):
@@ -426,6 +428,9 @@ def test_hasnans_isnans(self, index_flat):
426428
return
427429
elif isinstance(index, NumericIndex) and is_integer_dtype(index.dtype):
428430
return
431+
elif index.dtype == bool:
432+
# values[1] = np.nan below casts to True!
433+
return
429434

430435
values[1] = np.nan
431436

pandas/tests/indexes/test_index_new.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def test_constructor_dtypes_to_object(self, cast_index, vals):
7474
index = Index(vals)
7575

7676
assert type(index) is Index
77-
assert index.dtype == object
77+
assert index.dtype == bool
7878

7979
def test_constructor_categorical_to_object(self):
8080
# GH#32167 Categorical data and dtype=object should return object-dtype

pandas/tests/indexes/test_numpy_compat.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -68,16 +68,18 @@ def test_numpy_ufuncs_basic(index, func):
6868
with tm.external_error_raised((TypeError, AttributeError)):
6969
with np.errstate(all="ignore"):
7070
func(index)
71-
elif isinstance(index, NumericIndex) or (
72-
not isinstance(index.dtype, np.dtype) and index.dtype._is_numeric
71+
elif (
72+
isinstance(index, NumericIndex)
73+
or (not isinstance(index.dtype, np.dtype) and index.dtype._is_numeric)
74+
or index.dtype == bool
7375
):
7476
# coerces to float (e.g. np.sin)
7577
with np.errstate(all="ignore"):
7678
result = func(index)
7779
exp = Index(func(index.values), name=index.name)
7880

7981
tm.assert_index_equal(result, exp)
80-
if type(index) is not Index:
82+
if type(index) is not Index or index.dtype == bool:
8183
# i.e NumericIndex
8284
assert isinstance(result, Float64Index)
8385
else:
@@ -117,8 +119,10 @@ def test_numpy_ufuncs_other(index, func):
117119
with tm.external_error_raised(TypeError):
118120
func(index)
119121

120-
elif isinstance(index, NumericIndex) or (
121-
not isinstance(index.dtype, np.dtype) and index.dtype._is_numeric
122+
elif (
123+
isinstance(index, NumericIndex)
124+
or (not isinstance(index.dtype, np.dtype) and index.dtype._is_numeric)
125+
or index.dtype == bool
122126
):
123127
# Results in bool array
124128
result = func(index)

pandas/tests/indexes/test_setops.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import pytest
1010

1111
from pandas.core.dtypes.cast import find_common_type
12-
from pandas.core.dtypes.common import is_dtype_equal
1312

1413
from pandas import (
1514
CategoricalIndex,
@@ -55,14 +54,20 @@ def test_union_different_types(index_flat, index_flat2, request):
5554

5655
if (
5756
not idx1.is_unique
57+
and not idx2.is_unique
58+
and not idx2.is_monotonic_decreasing
5859
and idx1.dtype.kind == "i"
59-
and is_dtype_equal(idx2.dtype, "boolean")
60+
and idx2.dtype.kind == "b"
6061
) or (
6162
not idx2.is_unique
63+
and not idx1.is_unique
64+
and not idx1.is_monotonic_decreasing
6265
and idx2.dtype.kind == "i"
63-
and is_dtype_equal(idx1.dtype, "boolean")
66+
and idx1.dtype.kind == "b"
6467
):
65-
mark = pytest.mark.xfail(reason="GH#44000 True==1", raises=ValueError)
68+
mark = pytest.mark.xfail(
69+
reason="GH#44000 True==1", raises=ValueError, strict=False
70+
)
6671
request.node.add_marker(mark)
6772

6873
common_dtype = find_common_type([idx1.dtype, idx2.dtype])
@@ -231,7 +236,11 @@ def test_union_base(self, index):
231236
def test_difference_base(self, sort, index):
232237
first = index[2:]
233238
second = index[:4]
234-
if isinstance(index, CategoricalIndex) or index.is_boolean():
239+
if index.is_boolean():
240+
# i think (TODO: be sure) there assumptions baked in about
241+
# the index fixture that don't hold here?
242+
answer = set(first).difference(set(second))
243+
elif isinstance(index, CategoricalIndex):
235244
answer = []
236245
else:
237246
answer = index[4:]

0 commit comments

Comments
 (0)