@@ -47,6 +47,40 @@ cdef inline bint is_definitely_invalid_key(object val):
47
47
return False
48
48
49
49
50
+ cdef ndarray _get_bool_indexer(ndarray values, object val):
51
+ """
52
+ Return a ndarray[bool] of locations where val matches self.values.
53
+
54
+ If val is not NA, this is equivalent to `self.values == val`
55
+ """
56
+ # Caller is responsible for ensuring _check_type has already been called
57
+ cdef:
58
+ ndarray[uint8_t, ndim= 1 , cast= True ] indexer
59
+ Py_ssize_t i
60
+ object item
61
+
62
+ if values.descr.type_num == cnp.NPY_OBJECT:
63
+ # i.e. values.dtype == object
64
+ if not checknull(val):
65
+ indexer = values == val
66
+
67
+ else :
68
+ # We need to check for _matching_ NA values
69
+ indexer = np.empty(len (values), dtype = np.uint8)
70
+
71
+ for i in range (len (values)):
72
+ item = values[i]
73
+ indexer[i] = is_matching_na(item, val)
74
+
75
+ else :
76
+ if util.is_nan(val):
77
+ indexer = np.isnan(values)
78
+ else :
79
+ indexer = values == val
80
+
81
+ return indexer.view(bool )
82
+
83
+
50
84
# Don't populate hash tables in monotonic indexes larger than this
51
85
_SIZE_CUTOFF = 1 _000_000
52
86
@@ -83,12 +117,13 @@ cdef class IndexEngine:
83
117
if is_definitely_invalid_key(val):
84
118
raise TypeError (f" '{val}' is an invalid key" )
85
119
120
+ self ._check_type(val)
121
+
86
122
if self .over_size_threshold and self .is_monotonic_increasing:
87
123
if not self .is_unique:
88
124
return self ._get_loc_duplicates(val)
89
125
values = self .values
90
126
91
- self ._check_type(val)
92
127
loc = self ._searchsorted_left(val)
93
128
if loc >= len (values):
94
129
raise KeyError (val)
@@ -100,8 +135,6 @@ cdef class IndexEngine:
100
135
if not self .unique:
101
136
return self ._get_loc_duplicates(val)
102
137
103
- self ._check_type(val)
104
-
105
138
try :
106
139
return self .mapping.get_item(val)
107
140
except (TypeError , ValueError , OverflowError ):
@@ -148,17 +181,9 @@ cdef class IndexEngine:
148
181
cdef:
149
182
ndarray[uint8_t, ndim= 1 , cast= True ] indexer
150
183
151
- indexer = self . _get_bool_indexer(val)
184
+ indexer = _get_bool_indexer(self .values, val)
152
185
return self ._unpack_bool_indexer(indexer, val)
153
186
154
- cdef ndarray _get_bool_indexer(self , object val):
155
- """
156
- Return a ndarray[bool] of locations where val matches self.values.
157
-
158
- If val is not NA, this is equivalent to `self.values == val`
159
- """
160
- raise NotImplementedError (" Implemented by subclasses" )
161
-
162
187
cdef _unpack_bool_indexer(self ,
163
188
ndarray[uint8_t, ndim= 1 , cast= True ] indexer,
164
189
object val):
@@ -253,16 +278,13 @@ cdef class IndexEngine:
253
278
254
279
values = self .values
255
280
self .mapping = self ._make_hash_table(len (values))
256
- self ._call_map_locations (values)
281
+ self .mapping.map_locations (values)
257
282
258
283
if len (self .mapping) == len (values):
259
284
self .unique = 1
260
285
261
286
self .need_unique_check = 0
262
287
263
- cdef void _call_map_locations(self , ndarray values):
264
- self .mapping.map_locations(values)
265
-
266
288
def clear_mapping (self ):
267
289
self .mapping = None
268
290
self .need_monotonic_check = 1
@@ -430,25 +452,6 @@ cdef class ObjectEngine(IndexEngine):
430
452
raise KeyError (val) from err
431
453
return loc
432
454
433
- cdef ndarray _get_bool_indexer(self , object val):
434
- # We need to check for equality and for matching NAs
435
- cdef:
436
- ndarray values = self .values
437
-
438
- if not checknull(val):
439
- return values == val
440
-
441
- cdef:
442
- ndarray[uint8_t] result = np.empty(len (values), dtype = np.uint8)
443
- Py_ssize_t i
444
- object item
445
-
446
- for i in range (len (values)):
447
- item = values[i]
448
- result[i] = is_matching_na(item, val)
449
-
450
- return result.view(bool )
451
-
452
455
453
456
cdef class DatetimeEngine(Int64Engine):
454
457
0 commit comments