Skip to content

Commit 2d27815

Browse files
authored
REF: share _get_bool_indexer (#43803)
1 parent 2e5ed86 commit 2d27815

File tree

2 files changed

+38
-58
lines changed

2 files changed

+38
-58
lines changed

pandas/_libs/index.pyx

+38-35
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,40 @@ cdef inline bint is_definitely_invalid_key(object val):
4747
return False
4848

4949

50+
cdef ndarray _get_bool_indexer(ndarray values, object val):
51+
"""
52+
Return a ndarray[bool] of locations where val matches self.values.
53+
54+
If val is not NA, this is equivalent to `self.values == val`
55+
"""
56+
# Caller is responsible for ensuring _check_type has already been called
57+
cdef:
58+
ndarray[uint8_t, ndim=1, cast=True] indexer
59+
Py_ssize_t i
60+
object item
61+
62+
if values.descr.type_num == cnp.NPY_OBJECT:
63+
# i.e. values.dtype == object
64+
if not checknull(val):
65+
indexer = values == val
66+
67+
else:
68+
# We need to check for _matching_ NA values
69+
indexer = np.empty(len(values), dtype=np.uint8)
70+
71+
for i in range(len(values)):
72+
item = values[i]
73+
indexer[i] = is_matching_na(item, val)
74+
75+
else:
76+
if util.is_nan(val):
77+
indexer = np.isnan(values)
78+
else:
79+
indexer = values == val
80+
81+
return indexer.view(bool)
82+
83+
5084
# Don't populate hash tables in monotonic indexes larger than this
5185
_SIZE_CUTOFF = 1_000_000
5286

@@ -83,12 +117,13 @@ cdef class IndexEngine:
83117
if is_definitely_invalid_key(val):
84118
raise TypeError(f"'{val}' is an invalid key")
85119

120+
self._check_type(val)
121+
86122
if self.over_size_threshold and self.is_monotonic_increasing:
87123
if not self.is_unique:
88124
return self._get_loc_duplicates(val)
89125
values = self.values
90126

91-
self._check_type(val)
92127
loc = self._searchsorted_left(val)
93128
if loc >= len(values):
94129
raise KeyError(val)
@@ -100,8 +135,6 @@ cdef class IndexEngine:
100135
if not self.unique:
101136
return self._get_loc_duplicates(val)
102137

103-
self._check_type(val)
104-
105138
try:
106139
return self.mapping.get_item(val)
107140
except (TypeError, ValueError, OverflowError):
@@ -148,17 +181,9 @@ cdef class IndexEngine:
148181
cdef:
149182
ndarray[uint8_t, ndim=1, cast=True] indexer
150183

151-
indexer = self._get_bool_indexer(val)
184+
indexer = _get_bool_indexer(self.values, val)
152185
return self._unpack_bool_indexer(indexer, val)
153186

154-
cdef ndarray _get_bool_indexer(self, object val):
155-
"""
156-
Return a ndarray[bool] of locations where val matches self.values.
157-
158-
If val is not NA, this is equivalent to `self.values == val`
159-
"""
160-
raise NotImplementedError("Implemented by subclasses")
161-
162187
cdef _unpack_bool_indexer(self,
163188
ndarray[uint8_t, ndim=1, cast=True] indexer,
164189
object val):
@@ -253,16 +278,13 @@ cdef class IndexEngine:
253278

254279
values = self.values
255280
self.mapping = self._make_hash_table(len(values))
256-
self._call_map_locations(values)
281+
self.mapping.map_locations(values)
257282

258283
if len(self.mapping) == len(values):
259284
self.unique = 1
260285

261286
self.need_unique_check = 0
262287

263-
cdef void _call_map_locations(self, ndarray values):
264-
self.mapping.map_locations(values)
265-
266288
def clear_mapping(self):
267289
self.mapping = None
268290
self.need_monotonic_check = 1
@@ -430,25 +452,6 @@ cdef class ObjectEngine(IndexEngine):
430452
raise KeyError(val) from err
431453
return loc
432454

433-
cdef ndarray _get_bool_indexer(self, object val):
434-
# We need to check for equality and for matching NAs
435-
cdef:
436-
ndarray values = self.values
437-
438-
if not checknull(val):
439-
return values == val
440-
441-
cdef:
442-
ndarray[uint8_t] result = np.empty(len(values), dtype=np.uint8)
443-
Py_ssize_t i
444-
object item
445-
446-
for i in range(len(values)):
447-
item = values[i]
448-
result[i] = is_matching_na(item, val)
449-
450-
return result.view(bool)
451-
452455

453456
cdef class DatetimeEngine(Int64Engine):
454457

pandas/_libs/index_class_helper.pxi.in

-23
Original file line numberDiff line numberDiff line change
@@ -42,28 +42,5 @@ cdef class {{name}}Engine(IndexEngine):
4242
raise KeyError(val)
4343
{{endif}}
4444

45-
cdef void _call_map_locations(self, ndarray[{{dtype}}_t] values):
46-
self.mapping.map_locations(values)
47-
48-
cdef ndarray _get_bool_indexer(self, object val):
49-
cdef:
50-
ndarray[uint8_t, ndim=1, cast=True] indexer
51-
ndarray[{{dtype}}_t, ndim=1] values
52-
53-
self._check_type(val)
54-
55-
values = self.values
56-
57-
{{if name in {'Float64', 'Float32'} }}
58-
if util.is_nan(val):
59-
indexer = np.isnan(values)
60-
else:
61-
indexer = values == val
62-
{{else}}
63-
indexer = values == val
64-
{{endif}}
65-
66-
return indexer
67-
6845

6946
{{endfor}}

0 commit comments

Comments
 (0)