Skip to content

Commit f271445

Browse files
authored
ENH: mask support for hastable functions for indexing (#48396)
* ENH: mask support for hastable functions for indexing * Fix mypy * Adjust test * Add comment * Add docstring * Refactor into own functions * Fix typing
1 parent 0e88f4f commit f271445

File tree

5 files changed

+219
-16
lines changed

5 files changed

+219
-16
lines changed

pandas/_libs/hashtable.pxd

+48
Original file line numberDiff line numberDiff line change
@@ -41,75 +41,123 @@ cdef class HashTable:
4141

4242
cdef class UInt64HashTable(HashTable):
4343
cdef kh_uint64_t *table
44+
cdef int64_t na_position
45+
cdef bint uses_mask
4446

4547
cpdef get_item(self, uint64_t val)
4648
cpdef set_item(self, uint64_t key, Py_ssize_t val)
49+
cpdef get_na(self)
50+
cpdef set_na(self, Py_ssize_t val)
4751

4852
cdef class Int64HashTable(HashTable):
4953
cdef kh_int64_t *table
54+
cdef int64_t na_position
55+
cdef bint uses_mask
5056

5157
cpdef get_item(self, int64_t val)
5258
cpdef set_item(self, int64_t key, Py_ssize_t val)
59+
cpdef get_na(self)
60+
cpdef set_na(self, Py_ssize_t val)
5361

5462
cdef class UInt32HashTable(HashTable):
5563
cdef kh_uint32_t *table
64+
cdef int64_t na_position
65+
cdef bint uses_mask
5666

5767
cpdef get_item(self, uint32_t val)
5868
cpdef set_item(self, uint32_t key, Py_ssize_t val)
69+
cpdef get_na(self)
70+
cpdef set_na(self, Py_ssize_t val)
5971

6072
cdef class Int32HashTable(HashTable):
6173
cdef kh_int32_t *table
74+
cdef int64_t na_position
75+
cdef bint uses_mask
6276

6377
cpdef get_item(self, int32_t val)
6478
cpdef set_item(self, int32_t key, Py_ssize_t val)
79+
cpdef get_na(self)
80+
cpdef set_na(self, Py_ssize_t val)
6581

6682
cdef class UInt16HashTable(HashTable):
6783
cdef kh_uint16_t *table
84+
cdef int64_t na_position
85+
cdef bint uses_mask
6886

6987
cpdef get_item(self, uint16_t val)
7088
cpdef set_item(self, uint16_t key, Py_ssize_t val)
89+
cpdef get_na(self)
90+
cpdef set_na(self, Py_ssize_t val)
7191

7292
cdef class Int16HashTable(HashTable):
7393
cdef kh_int16_t *table
94+
cdef int64_t na_position
95+
cdef bint uses_mask
7496

7597
cpdef get_item(self, int16_t val)
7698
cpdef set_item(self, int16_t key, Py_ssize_t val)
99+
cpdef get_na(self)
100+
cpdef set_na(self, Py_ssize_t val)
77101

78102
cdef class UInt8HashTable(HashTable):
79103
cdef kh_uint8_t *table
104+
cdef int64_t na_position
105+
cdef bint uses_mask
80106

81107
cpdef get_item(self, uint8_t val)
82108
cpdef set_item(self, uint8_t key, Py_ssize_t val)
109+
cpdef get_na(self)
110+
cpdef set_na(self, Py_ssize_t val)
83111

84112
cdef class Int8HashTable(HashTable):
85113
cdef kh_int8_t *table
114+
cdef int64_t na_position
115+
cdef bint uses_mask
86116

87117
cpdef get_item(self, int8_t val)
88118
cpdef set_item(self, int8_t key, Py_ssize_t val)
119+
cpdef get_na(self)
120+
cpdef set_na(self, Py_ssize_t val)
89121

90122
cdef class Float64HashTable(HashTable):
91123
cdef kh_float64_t *table
124+
cdef int64_t na_position
125+
cdef bint uses_mask
92126

93127
cpdef get_item(self, float64_t val)
94128
cpdef set_item(self, float64_t key, Py_ssize_t val)
129+
cpdef get_na(self)
130+
cpdef set_na(self, Py_ssize_t val)
95131

96132
cdef class Float32HashTable(HashTable):
97133
cdef kh_float32_t *table
134+
cdef int64_t na_position
135+
cdef bint uses_mask
98136

99137
cpdef get_item(self, float32_t val)
100138
cpdef set_item(self, float32_t key, Py_ssize_t val)
139+
cpdef get_na(self)
140+
cpdef set_na(self, Py_ssize_t val)
101141

102142
cdef class Complex64HashTable(HashTable):
103143
cdef kh_complex64_t *table
144+
cdef int64_t na_position
145+
cdef bint uses_mask
104146

105147
cpdef get_item(self, complex64_t val)
106148
cpdef set_item(self, complex64_t key, Py_ssize_t val)
149+
cpdef get_na(self)
150+
cpdef set_na(self, Py_ssize_t val)
107151

108152
cdef class Complex128HashTable(HashTable):
109153
cdef kh_complex128_t *table
154+
cdef int64_t na_position
155+
cdef bint uses_mask
110156

111157
cpdef get_item(self, complex128_t val)
112158
cpdef set_item(self, complex128_t key, Py_ssize_t val)
159+
cpdef get_na(self)
160+
cpdef set_na(self, Py_ssize_t val)
113161

114162
cdef class PyObjectHashTable(HashTable):
115163
cdef kh_pymap_t *table

pandas/_libs/hashtable.pyi

+5-3
Original file line numberDiff line numberDiff line change
@@ -110,16 +110,18 @@ class ObjectVector:
110110

111111
class HashTable:
112112
# NB: The base HashTable class does _not_ actually have these methods;
113-
# we are putting the here for the sake of mypy to avoid
113+
# we are putting them here for the sake of mypy to avoid
114114
# reproducing them in each subclass below.
115-
def __init__(self, size_hint: int = ...) -> None: ...
115+
def __init__(self, size_hint: int = ..., uses_mask: bool = ...) -> None: ...
116116
def __len__(self) -> int: ...
117117
def __contains__(self, key: Hashable) -> bool: ...
118118
def sizeof(self, deep: bool = ...) -> int: ...
119119
def get_state(self) -> dict[str, int]: ...
120120
# TODO: `item` type is subclass-specific
121121
def get_item(self, item): ... # TODO: return type?
122-
def set_item(self, item) -> None: ...
122+
def set_item(self, item, val) -> None: ...
123+
def get_na(self): ... # TODO: return type?
124+
def set_na(self, val) -> None: ...
123125
def map_locations(
124126
self,
125127
values: np.ndarray, # np.ndarray[subclass-specific]

pandas/_libs/hashtable_class_helper.pxi.in

+86-13
Original file line numberDiff line numberDiff line change
@@ -396,23 +396,32 @@ dtypes = [('Complex128', 'complex128', 'khcomplex128_t', 'to_khcomplex128_t'),
396396

397397
cdef class {{name}}HashTable(HashTable):
398398

399-
def __cinit__(self, int64_t size_hint=1):
399+
def __cinit__(self, int64_t size_hint=1, bint uses_mask=False):
400400
self.table = kh_init_{{dtype}}()
401401
size_hint = min(kh_needed_n_buckets(size_hint), SIZE_HINT_LIMIT)
402402
kh_resize_{{dtype}}(self.table, size_hint)
403403

404+
self.uses_mask = uses_mask
405+
self.na_position = -1
406+
404407
def __len__(self) -> int:
405-
return self.table.size
408+
return self.table.size + (0 if self.na_position == -1 else 1)
406409

407410
def __dealloc__(self):
408411
if self.table is not NULL:
409412
kh_destroy_{{dtype}}(self.table)
410413
self.table = NULL
411414

412415
def __contains__(self, object key) -> bool:
416+
# The caller is responsible to check for compatible NA values in case
417+
# of masked arrays.
413418
cdef:
414419
khiter_t k
415420
{{c_type}} ckey
421+
422+
if self.uses_mask and checknull(key):
423+
return -1 != self.na_position
424+
416425
ckey = {{to_c_type}}(key)
417426
k = kh_get_{{dtype}}(self.table, ckey)
418427
return k != self.table.n_buckets
@@ -435,30 +444,73 @@ cdef class {{name}}HashTable(HashTable):
435444
}
436445

437446
cpdef get_item(self, {{dtype}}_t val):
447+
"""Extracts the position of val from the hashtable.
448+
449+
Parameters
450+
----------
451+
val : Scalar
452+
The value that is looked up in the hashtable
453+
454+
Returns
455+
-------
456+
The position of the requested integer.
457+
"""
458+
438459
# Used in core.sorting, IndexEngine.get_loc
460+
# Caller is responsible for checking for pd.NA
439461
cdef:
440462
khiter_t k
441463
{{c_type}} cval
464+
442465
cval = {{to_c_type}}(val)
443466
k = kh_get_{{dtype}}(self.table, cval)
444467
if k != self.table.n_buckets:
445468
return self.table.vals[k]
446469
else:
447470
raise KeyError(val)
448471

472+
cpdef get_na(self):
473+
"""Extracts the position of na_value from the hashtable.
474+
475+
Returns
476+
-------
477+
The position of the last na value.
478+
"""
479+
480+
if not self.uses_mask:
481+
raise NotImplementedError
482+
483+
if self.na_position == -1:
484+
raise KeyError("NA")
485+
return self.na_position
486+
449487
cpdef set_item(self, {{dtype}}_t key, Py_ssize_t val):
450488
# Used in libjoin
489+
# Caller is responsible for checking for pd.NA
451490
cdef:
452491
khiter_t k
453492
int ret = 0
454493
{{c_type}} ckey
494+
455495
ckey = {{to_c_type}}(key)
456496
k = kh_put_{{dtype}}(self.table, ckey, &ret)
457497
if kh_exist_{{dtype}}(self.table, k):
458498
self.table.vals[k] = val
459499
else:
460500
raise KeyError(key)
461501

502+
cpdef set_na(self, Py_ssize_t val):
503+
# Caller is responsible for checking for pd.NA
504+
cdef:
505+
khiter_t k
506+
int ret = 0
507+
{{c_type}} ckey
508+
509+
if not self.uses_mask:
510+
raise NotImplementedError
511+
512+
self.na_position = val
513+
462514
{{if dtype == "int64" }}
463515
# We only use this for int64, can reduce build size and make .pyi
464516
# more accurate by only implementing it for int64
@@ -480,22 +532,36 @@ cdef class {{name}}HashTable(HashTable):
480532
{{endif}}
481533

482534
@cython.boundscheck(False)
483-
def map_locations(self, const {{dtype}}_t[:] values) -> None:
535+
def map_locations(self, const {{dtype}}_t[:] values, const uint8_t[:] mask = None) -> None:
484536
# Used in libindex, safe_sort
485537
cdef:
486538
Py_ssize_t i, n = len(values)
487539
int ret = 0
488540
{{c_type}} val
489541
khiter_t k
542+
int8_t na_position = self.na_position
543+
544+
if self.uses_mask and mask is None:
545+
raise NotImplementedError # pragma: no cover
490546

491547
with nogil:
492-
for i in range(n):
493-
val= {{to_c_type}}(values[i])
494-
k = kh_put_{{dtype}}(self.table, val, &ret)
495-
self.table.vals[k] = i
548+
if self.uses_mask:
549+
for i in range(n):
550+
if mask[i]:
551+
na_position = i
552+
else:
553+
val= {{to_c_type}}(values[i])
554+
k = kh_put_{{dtype}}(self.table, val, &ret)
555+
self.table.vals[k] = i
556+
else:
557+
for i in range(n):
558+
val= {{to_c_type}}(values[i])
559+
k = kh_put_{{dtype}}(self.table, val, &ret)
560+
self.table.vals[k] = i
561+
self.na_position = na_position
496562

497563
@cython.boundscheck(False)
498-
def lookup(self, const {{dtype}}_t[:] values) -> ndarray:
564+
def lookup(self, const {{dtype}}_t[:] values, const uint8_t[:] mask = None) -> ndarray:
499565
# -> np.ndarray[np.intp]
500566
# Used in safe_sort, IndexEngine.get_indexer
501567
cdef:
@@ -504,15 +570,22 @@ cdef class {{name}}HashTable(HashTable):
504570
{{c_type}} val
505571
khiter_t k
506572
intp_t[::1] locs = np.empty(n, dtype=np.intp)
573+
int8_t na_position = self.na_position
574+
575+
if self.uses_mask and mask is None:
576+
raise NotImplementedError # pragma: no cover
507577

508578
with nogil:
509579
for i in range(n):
510-
val = {{to_c_type}}(values[i])
511-
k = kh_get_{{dtype}}(self.table, val)
512-
if k != self.table.n_buckets:
513-
locs[i] = self.table.vals[k]
580+
if self.uses_mask and mask[i]:
581+
locs[i] = na_position
514582
else:
515-
locs[i] = -1
583+
val = {{to_c_type}}(values[i])
584+
k = kh_get_{{dtype}}(self.table, val)
585+
if k != self.table.n_buckets:
586+
locs[i] = self.table.vals[k]
587+
else:
588+
locs[i] = -1
516589

517590
return np.asarray(locs)
518591

0 commit comments

Comments
 (0)