@@ -36,6 +36,7 @@ from pandas._libs import (
36
36
)
37
37
from pandas._libs.missing import (
38
38
checknull,
39
+ is_matching_na,
39
40
isnaobj,
40
41
)
41
42
@@ -270,6 +271,25 @@ cdef class IndexEngine:
270
271
self._ensure_mapping_populated()
271
272
return self.mapping.lookup(values )
272
273
274
+ def get_stargets(self , ndarray targets ) -> set:
275
+ return set(targets )
276
+
277
+ def convert_val_if_nan(self , object val ) -> object:
278
+ # unable to utilize val if nan when updating
279
+ # hashable data structures (ie. sets , dict )
280
+ if checknull(val ):
281
+ return - 1
282
+ else :
283
+ return val
284
+
285
+ def should_update_d (self , object target , object val ) -> bool:
286
+ # -1 in targets could be either -1 or nan
287
+ # ensures values in d[-1] to be included only once
288
+ if target == val or is_matching_na(target , val ):
289
+ return True
290
+
291
+ return False
292
+
273
293
def get_indexer_non_unique (self , ndarray targets ):
274
294
"""
275
295
Return an indexer suitable for taking from a non unique index
@@ -293,11 +313,7 @@ cdef class IndexEngine:
293
313
294
314
self ._ensure_mapping_populated()
295
315
values = np.array(self ._get_index_values(), copy = False )
296
- values_mask = isnaobj(values)
297
- targets_mask = isnaobj(targets)
298
- stargets = set (targets)
299
- if - 1 not in stargets and targets_mask.any():
300
- stargets.add(- 1 )
316
+ stargets = self .get_stargets(targets)
301
317
n = len (values)
302
318
n_t = len (targets)
303
319
if n > 10 _000:
@@ -328,22 +344,15 @@ cdef class IndexEngine:
328
344
if stargets:
329
345
# otherwise, map by iterating through all items in the index
330
346
for i in range (n):
331
- if values_mask[i]:
332
- val = - 1
333
- else :
334
- val = values[i]
347
+ val = self .convert_val_if_nan(values[i])
335
348
336
349
if val in stargets:
337
350
if val not in d:
338
351
d[val] = []
339
352
d[val].append(i)
340
353
341
354
for i in range (n_t):
342
- nan_target = targets_mask[i]
343
- if nan_target:
344
- val = - 1
345
- else :
346
- val = targets[i]
355
+ val = self .convert_val_if_nan(targets[i])
347
356
348
357
# found
349
358
if val in d:
@@ -354,21 +363,7 @@ cdef class IndexEngine:
354
363
n_alloc += 10 _000
355
364
result = np.resize(result, n_alloc)
356
365
357
- # -1 in targets could be either -1 or nan
358
- # ensures values in d[-1] to be included only once
359
- if val == - 1 :
360
- nan_val = values_mask[j]
361
- # nan
362
- if nan_target:
363
- if nan_val:
364
- result[count] = j
365
- count += 1
366
- # -1
367
- else :
368
- if not nan_val:
369
- result[count] = j
370
- count += 1
371
- else :
366
+ if self .should_update_d(targets[i], values[j]):
372
367
result[count] = j
373
368
count += 1
374
369
@@ -419,6 +414,13 @@ cdef class ObjectEngine(IndexEngine):
419
414
cdef _make_hash_table(self , Py_ssize_t n):
420
415
return _hash.PyObjectHashTable(n)
421
416
417
+ def get_stargets (self , ndarray targets ) -> set:
418
+ stargets = set (targets)
419
+ # account for NA-like targets
420
+ if -1 not in stargets and isnaobj(targets ).any():
421
+ stargets.add(- 1 )
422
+
423
+ return stargets
422
424
423
425
cdef class DatetimeEngine(Int64Engine):
424
426
@@ -490,6 +492,14 @@ cdef class DatetimeEngine(Int64Engine):
490
492
except KeyError :
491
493
raise KeyError (val)
492
494
495
+ def get_stargets (self , ndarray targets ) -> set:
496
+ stargets = set (targets)
497
+ # account for NaTs
498
+ if -1 not in stargets and isnaobj(targets ).any():
499
+ stargets.add(- 1 )
500
+
501
+ return stargets
502
+
493
503
def get_indexer_non_unique (self , ndarray targets ):
494
504
# we may get datetime64[ns] or timedelta64[ns], cast these to int64
495
505
return super ().get_indexer_non_unique(targets.view(" i8" ))
0 commit comments