@@ -396,23 +396,32 @@ dtypes = [('Complex128', 'complex128', 'khcomplex128_t', 'to_khcomplex128_t'),
396
396
397
397
cdef class {{name}}HashTable(HashTable):
398
398
399
- def __cinit__(self, int64_t size_hint=1):
399
+ def __cinit__(self, int64_t size_hint=1, bint uses_mask=False ):
400
400
self.table = kh_init_{{dtype}}()
401
401
size_hint = min(kh_needed_n_buckets(size_hint), SIZE_HINT_LIMIT)
402
402
kh_resize_{{dtype}}(self.table, size_hint)
403
403
404
+ self.uses_mask = uses_mask
405
+ self.na_position = -1
406
+
404
407
def __len__(self) -> int:
405
- return self.table.size
408
+ return self.table.size + (0 if self.na_position == -1 else 1)
406
409
407
410
def __dealloc__(self):
408
411
if self.table is not NULL:
409
412
kh_destroy_{{dtype}}(self.table)
410
413
self.table = NULL
411
414
412
415
def __contains__(self, object key) -> bool:
416
+ # The caller is responsible to check for compatible NA values in case
417
+ # of masked arrays.
413
418
cdef:
414
419
khiter_t k
415
420
{{c_type}} ckey
421
+
422
+ if self.uses_mask and checknull(key):
423
+ return -1 != self.na_position
424
+
416
425
ckey = {{to_c_type}}(key)
417
426
k = kh_get_{{dtype}}(self.table, ckey)
418
427
return k != self.table.n_buckets
@@ -435,30 +444,73 @@ cdef class {{name}}HashTable(HashTable):
435
444
}
436
445
437
446
cpdef get_item(self, {{dtype}}_t val):
447
+ """Extracts the position of val from the hashtable.
448
+
449
+ Parameters
450
+ ----------
451
+ val : Scalar
452
+ The value that is looked up in the hashtable
453
+
454
+ Returns
455
+ -------
456
+ The position of the requested integer.
457
+ """
458
+
438
459
# Used in core.sorting, IndexEngine.get_loc
460
+ # Caller is responsible for checking for pd.NA
439
461
cdef:
440
462
khiter_t k
441
463
{{c_type}} cval
464
+
442
465
cval = {{to_c_type}}(val)
443
466
k = kh_get_{{dtype}}(self.table, cval)
444
467
if k != self.table.n_buckets:
445
468
return self.table.vals[k]
446
469
else:
447
470
raise KeyError(val)
448
471
472
+ cpdef get_na(self):
473
+ """Extracts the position of na_value from the hashtable.
474
+
475
+ Returns
476
+ -------
477
+ The position of the last na value.
478
+ """
479
+
480
+ if not self.uses_mask:
481
+ raise NotImplementedError
482
+
483
+ if self.na_position == -1:
484
+ raise KeyError("NA")
485
+ return self.na_position
486
+
449
487
cpdef set_item(self, {{dtype}}_t key, Py_ssize_t val):
450
488
# Used in libjoin
489
+ # Caller is responsible for checking for pd.NA
451
490
cdef:
452
491
khiter_t k
453
492
int ret = 0
454
493
{{c_type}} ckey
494
+
455
495
ckey = {{to_c_type}}(key)
456
496
k = kh_put_{{dtype}}(self.table, ckey, &ret)
457
497
if kh_exist_{{dtype}}(self.table, k):
458
498
self.table.vals[k] = val
459
499
else:
460
500
raise KeyError(key)
461
501
502
+ cpdef set_na(self, Py_ssize_t val):
503
+ # Caller is responsible for checking for pd.NA
504
+ cdef:
505
+ khiter_t k
506
+ int ret = 0
507
+ {{c_type}} ckey
508
+
509
+ if not self.uses_mask:
510
+ raise NotImplementedError
511
+
512
+ self.na_position = val
513
+
462
514
{{if dtype == "int64" }}
463
515
# We only use this for int64, can reduce build size and make .pyi
464
516
# more accurate by only implementing it for int64
@@ -480,22 +532,36 @@ cdef class {{name}}HashTable(HashTable):
480
532
{{endif}}
481
533
482
534
@cython.boundscheck(False)
483
- def map_locations(self, const {{dtype}}_t[:] values) -> None:
535
+ def map_locations(self, const {{dtype}}_t[:] values, const uint8_t[:] mask = None ) -> None:
484
536
# Used in libindex, safe_sort
485
537
cdef:
486
538
Py_ssize_t i, n = len(values)
487
539
int ret = 0
488
540
{{c_type}} val
489
541
khiter_t k
542
+ int8_t na_position = self.na_position
543
+
544
+ if self.uses_mask and mask is None:
545
+ raise NotImplementedError # pragma: no cover
490
546
491
547
with nogil:
492
- for i in range(n):
493
- val= {{to_c_type}}(values[i])
494
- k = kh_put_{{dtype}}(self.table, val, &ret)
495
- self.table.vals[k] = i
548
+ if self.uses_mask:
549
+ for i in range(n):
550
+ if mask[i]:
551
+ na_position = i
552
+ else:
553
+ val= {{to_c_type}}(values[i])
554
+ k = kh_put_{{dtype}}(self.table, val, &ret)
555
+ self.table.vals[k] = i
556
+ else:
557
+ for i in range(n):
558
+ val= {{to_c_type}}(values[i])
559
+ k = kh_put_{{dtype}}(self.table, val, &ret)
560
+ self.table.vals[k] = i
561
+ self.na_position = na_position
496
562
497
563
@cython.boundscheck(False)
498
- def lookup(self, const {{dtype}}_t[:] values) -> ndarray:
564
+ def lookup(self, const {{dtype}}_t[:] values, const uint8_t[:] mask = None ) -> ndarray:
499
565
# -> np.ndarray[np.intp]
500
566
# Used in safe_sort, IndexEngine.get_indexer
501
567
cdef:
@@ -504,15 +570,22 @@ cdef class {{name}}HashTable(HashTable):
504
570
{{c_type}} val
505
571
khiter_t k
506
572
intp_t[::1] locs = np.empty(n, dtype=np.intp)
573
+ int8_t na_position = self.na_position
574
+
575
+ if self.uses_mask and mask is None:
576
+ raise NotImplementedError # pragma: no cover
507
577
508
578
with nogil:
509
579
for i in range(n):
510
- val = {{to_c_type}}(values[i])
511
- k = kh_get_{{dtype}}(self.table, val)
512
- if k != self.table.n_buckets:
513
- locs[i] = self.table.vals[k]
580
+ if self.uses_mask and mask[i]:
581
+ locs[i] = na_position
514
582
else:
515
- locs[i] = -1
583
+ val = {{to_c_type}}(values[i])
584
+ k = kh_get_{{dtype}}(self.table, val)
585
+ if k != self.table.n_buckets:
586
+ locs[i] = self.table.vals[k]
587
+ else:
588
+ locs[i] = -1
516
589
517
590
return np.asarray(locs)
518
591
0 commit comments