-
-
Notifications
You must be signed in to change notification settings - Fork 18.4k
ENH: mask support for hastable functions for indexing #48396
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a7f1abf
b216f6c
f276f6f
82f8a40
f2c271b
b4b263e
69c6407
d08a7ae
44d0797
7bb6914
4670624
3861add
4642852
0a0ad44
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -396,23 +396,32 @@ dtypes = [('Complex128', 'complex128', 'khcomplex128_t', 'to_khcomplex128_t'), | |
|
||
cdef class {{name}}HashTable(HashTable): | ||
|
||
def __cinit__(self, int64_t size_hint=1): | ||
def __cinit__(self, int64_t size_hint=1, bint uses_mask=False): | ||
self.table = kh_init_{{dtype}}() | ||
size_hint = min(kh_needed_n_buckets(size_hint), SIZE_HINT_LIMIT) | ||
kh_resize_{{dtype}}(self.table, size_hint) | ||
|
||
self.uses_mask = uses_mask | ||
self.na_position = -1 | ||
|
||
def __len__(self) -> int: | ||
return self.table.size | ||
return self.table.size + (0 if self.na_position == -1 else 1) | ||
|
||
def __dealloc__(self): | ||
if self.table is not NULL: | ||
kh_destroy_{{dtype}}(self.table) | ||
self.table = NULL | ||
|
||
def __contains__(self, object key) -> bool: | ||
# The caller is responsible to check for compatible NA values in case | ||
# of masked arrays. | ||
cdef: | ||
khiter_t k | ||
{{c_type}} ckey | ||
|
||
if self.uses_mask and checknull(key): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does it matter what type of null we're looking at? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Depends if we want to support multiple nulls in masked arrays. Currently it does not:
Both match pd.NA. Same for Float64 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. None and np.nan we usually treat as interchangeable with pd.NA. here im thinking more of pd.NaT There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah got you. This is interesting currently:
matches both NAs, but
raises. Any suggestions here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is_valid_na_for_dtype exists for this purpose. For other dtypes we do this in FooIndex.get_loc. Might require some gymnastics in this case There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In general I would be ok with pushing responsibility to the caller. Have to do this for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. definitely reasonable, pls be sure to document this assumption There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added a comment. I'll probably add some documentation in general when we are through with a MaskedEngine |
||
return -1 != self.na_position | ||
|
||
ckey = {{to_c_type}}(key) | ||
k = kh_get_{{dtype}}(self.table, ckey) | ||
return k != self.table.n_buckets | ||
|
@@ -435,30 +444,73 @@ cdef class {{name}}HashTable(HashTable): | |
} | ||
|
||
cpdef get_item(self, {{dtype}}_t val): | ||
"""Extracts the position of val from the hashtable. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can the return be typed as int Py_ssize_t or something? same question for get_na below |
||
|
||
Parameters | ||
---------- | ||
val : Scalar | ||
The value that is looked up in the hashtable | ||
|
||
Returns | ||
------- | ||
The position of the requested integer. | ||
""" | ||
|
||
# Used in core.sorting, IndexEngine.get_loc | ||
# Caller is responsible for checking for pd.NA | ||
cdef: | ||
khiter_t k | ||
{{c_type}} cval | ||
|
||
cval = {{to_c_type}}(val) | ||
k = kh_get_{{dtype}}(self.table, cval) | ||
if k != self.table.n_buckets: | ||
return self.table.vals[k] | ||
else: | ||
raise KeyError(val) | ||
|
||
cpdef get_na(self): | ||
"""Extracts the position of na_value from the hashtable. | ||
|
||
Returns | ||
------- | ||
The position of the last na value. | ||
""" | ||
|
||
if not self.uses_mask: | ||
raise NotImplementedError | ||
|
||
if self.na_position == -1: | ||
raise KeyError("NA") | ||
return self.na_position | ||
|
||
cpdef set_item(self, {{dtype}}_t key, Py_ssize_t val): | ||
# Used in libjoin | ||
# Caller is responsible for checking for pd.NA | ||
cdef: | ||
khiter_t k | ||
int ret = 0 | ||
{{c_type}} ckey | ||
|
||
ckey = {{to_c_type}}(key) | ||
k = kh_put_{{dtype}}(self.table, ckey, &ret) | ||
if kh_exist_{{dtype}}(self.table, k): | ||
self.table.vals[k] = val | ||
else: | ||
raise KeyError(key) | ||
|
||
cpdef set_na(self, Py_ssize_t val): | ||
# Caller is responsible for checking for pd.NA | ||
cdef: | ||
khiter_t k | ||
int ret = 0 | ||
{{c_type}} ckey | ||
|
||
if not self.uses_mask: | ||
raise NotImplementedError | ||
|
||
self.na_position = val | ||
|
||
{{if dtype == "int64" }} | ||
# We only use this for int64, can reduce build size and make .pyi | ||
# more accurate by only implementing it for int64 | ||
|
@@ -480,22 +532,36 @@ cdef class {{name}}HashTable(HashTable): | |
{{endif}} | ||
|
||
@cython.boundscheck(False) | ||
def map_locations(self, const {{dtype}}_t[:] values) -> None: | ||
def map_locations(self, const {{dtype}}_t[:] values, const uint8_t[:] mask = None) -> None: | ||
# Used in libindex, safe_sort | ||
cdef: | ||
Py_ssize_t i, n = len(values) | ||
int ret = 0 | ||
{{c_type}} val | ||
khiter_t k | ||
int8_t na_position = self.na_position | ||
|
||
if self.uses_mask and mask is None: | ||
raise NotImplementedError # pragma: no cover | ||
|
||
with nogil: | ||
for i in range(n): | ||
val= {{to_c_type}}(values[i]) | ||
k = kh_put_{{dtype}}(self.table, val, &ret) | ||
self.table.vals[k] = i | ||
if self.uses_mask: | ||
for i in range(n): | ||
if mask[i]: | ||
na_position = i | ||
else: | ||
val= {{to_c_type}}(values[i]) | ||
k = kh_put_{{dtype}}(self.table, val, &ret) | ||
self.table.vals[k] = i | ||
else: | ||
for i in range(n): | ||
val= {{to_c_type}}(values[i]) | ||
k = kh_put_{{dtype}}(self.table, val, &ret) | ||
self.table.vals[k] = i | ||
self.na_position = na_position | ||
|
||
@cython.boundscheck(False) | ||
def lookup(self, const {{dtype}}_t[:] values) -> ndarray: | ||
def lookup(self, const {{dtype}}_t[:] values, const uint8_t[:] mask = None) -> ndarray: | ||
# -> np.ndarray[np.intp] | ||
# Used in safe_sort, IndexEngine.get_indexer | ||
cdef: | ||
|
@@ -504,15 +570,22 @@ cdef class {{name}}HashTable(HashTable): | |
{{c_type}} val | ||
khiter_t k | ||
intp_t[::1] locs = np.empty(n, dtype=np.intp) | ||
int8_t na_position = self.na_position | ||
|
||
if self.uses_mask and mask is None: | ||
raise NotImplementedError # pragma: no cover | ||
|
||
with nogil: | ||
for i in range(n): | ||
val = {{to_c_type}}(values[i]) | ||
k = kh_get_{{dtype}}(self.table, val) | ||
if k != self.table.n_buckets: | ||
locs[i] = self.table.vals[k] | ||
if self.uses_mask and mask[i]: | ||
locs[i] = na_position | ||
else: | ||
locs[i] = -1 | ||
val = {{to_c_type}}(values[i]) | ||
k = kh_get_{{dtype}}(self.table, val) | ||
if k != self.table.n_buckets: | ||
locs[i] = self.table.vals[k] | ||
else: | ||
locs[i] = -1 | ||
|
||
return np.asarray(locs) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this also be reflected in the pyi file?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, added