-
-
Notifications
You must be signed in to change notification settings - Fork 18.4k
ENH: mask support for hastable functions for indexing #48396
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
a7f1abf
b216f6c
f276f6f
82f8a40
f2c271b
b4b263e
69c6407
d08a7ae
44d0797
7bb6914
4670624
3861add
4642852
0a0ad44
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -396,23 +396,32 @@ dtypes = [('Complex128', 'complex128', 'khcomplex128_t', 'to_khcomplex128_t'), | |
|
||
cdef class {{name}}HashTable(HashTable): | ||
|
||
def __cinit__(self, int64_t size_hint=1): | ||
def __cinit__(self, int64_t size_hint=1, bint uses_mask=False): | ||
self.table = kh_init_{{dtype}}() | ||
size_hint = min(kh_needed_n_buckets(size_hint), SIZE_HINT_LIMIT) | ||
kh_resize_{{dtype}}(self.table, size_hint) | ||
|
||
self.uses_mask = uses_mask | ||
self.na_position = -1 | ||
|
||
def __len__(self) -> int: | ||
return self.table.size | ||
return self.table.size + (0 if self.na_position == -1 else 1) | ||
|
||
def __dealloc__(self): | ||
if self.table is not NULL: | ||
kh_destroy_{{dtype}}(self.table) | ||
self.table = NULL | ||
|
||
def __contains__(self, object key) -> bool: | ||
# The caller is responsible to check for compatible NA values in case | ||
# of masked arrays. | ||
cdef: | ||
khiter_t k | ||
{{c_type}} ckey | ||
|
||
if self.uses_mask and checknull(key): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does it matter what type of null we're looking at? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Depends if we want to support multiple nulls in masked arrays. Currently it does not:
Both match pd.NA. Same for Float64 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. None and np.nan we usually treat as interchangeable with pd.NA. here im thinking more of pd.NaT There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah got you. This is interesting currently:
matches both NAs, but
raises. Any suggestions here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is_valid_na_for_dtype exists for this purpose. For other dtypes we do this in FooIndex.get_loc. Might require some gymnastics in this case There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In general I would be ok with pushing responsibility to the caller. Have to do this for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. definitely reasonable, pls be sure to document this assumption There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added a comment. I'll probably add some documentation in general when we are through with a MaskedEngine |
||
return -1 != self.na_position | ||
|
||
ckey = {{to_c_type}}(key) | ||
k = kh_get_{{dtype}}(self.table, ckey) | ||
return k != self.table.n_buckets | ||
|
@@ -434,30 +443,49 @@ cdef class {{name}}HashTable(HashTable): | |
'upper_bound' : self.table.upper_bound, | ||
} | ||
|
||
cpdef get_item(self, {{dtype}}_t val): | ||
cpdef get_item(self, {{dtype}}_t val, bint na_value = False): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would something like (while you are at it, a brief docstring could also help to describe the keywords) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. imo val_is_na is a bit confusing, because There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Or There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Although with But this is also pointing to that it's not a very clean API, with basically two separate behaviours ( There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Refactored into own methods, should be significantly easier to understand now |
||
# Used in core.sorting, IndexEngine.get_loc | ||
# Caller is responsible for checking for pd.NA | ||
cdef: | ||
khiter_t k | ||
{{c_type}} cval | ||
|
||
if not self.uses_mask and na_value: | ||
raise NotImplementedError | ||
|
||
if na_value: | ||
if self.na_position == -1: | ||
raise KeyError("NA") | ||
return self.na_position | ||
|
||
cval = {{to_c_type}}(val) | ||
k = kh_get_{{dtype}}(self.table, cval) | ||
if k != self.table.n_buckets: | ||
return self.table.vals[k] | ||
else: | ||
raise KeyError(val) | ||
|
||
cpdef set_item(self, {{dtype}}_t key, Py_ssize_t val): | ||
cpdef set_item(self, {{dtype}}_t key, Py_ssize_t val, bint na_value = False): | ||
# Used in libjoin | ||
# Caller is responsible for checking for pd.NA | ||
cdef: | ||
khiter_t k | ||
int ret = 0 | ||
{{c_type}} ckey | ||
ckey = {{to_c_type}}(key) | ||
k = kh_put_{{dtype}}(self.table, ckey, &ret) | ||
if kh_exist_{{dtype}}(self.table, k): | ||
self.table.vals[k] = val | ||
|
||
if na_value and not self.uses_mask: | ||
raise NotImplementedError | ||
|
||
if na_value: | ||
self.na_position = val | ||
|
||
else: | ||
raise KeyError(key) | ||
ckey = {{to_c_type}}(key) | ||
k = kh_put_{{dtype}}(self.table, ckey, &ret) | ||
if kh_exist_{{dtype}}(self.table, k): | ||
self.table.vals[k] = val | ||
else: | ||
raise KeyError(key) | ||
|
||
{{if dtype == "int64" }} | ||
# We only use this for int64, can reduce build size and make .pyi | ||
|
@@ -480,22 +508,36 @@ cdef class {{name}}HashTable(HashTable): | |
{{endif}} | ||
|
||
@cython.boundscheck(False) | ||
def map_locations(self, const {{dtype}}_t[:] values) -> None: | ||
def map_locations(self, const {{dtype}}_t[:] values, const uint8_t[:] mask = None) -> None: | ||
# Used in libindex, safe_sort | ||
cdef: | ||
Py_ssize_t i, n = len(values) | ||
int ret = 0 | ||
{{c_type}} val | ||
khiter_t k | ||
int8_t na_position = self.na_position | ||
|
||
if self.uses_mask and mask is None: | ||
raise NotImplementedError # pragma: no cover | ||
|
||
with nogil: | ||
for i in range(n): | ||
val= {{to_c_type}}(values[i]) | ||
k = kh_put_{{dtype}}(self.table, val, &ret) | ||
self.table.vals[k] = i | ||
if self.uses_mask: | ||
for i in range(n): | ||
if mask[i]: | ||
na_position = i | ||
else: | ||
val= {{to_c_type}}(values[i]) | ||
k = kh_put_{{dtype}}(self.table, val, &ret) | ||
self.table.vals[k] = i | ||
else: | ||
for i in range(n): | ||
val= {{to_c_type}}(values[i]) | ||
k = kh_put_{{dtype}}(self.table, val, &ret) | ||
self.table.vals[k] = i | ||
self.na_position = na_position | ||
|
||
@cython.boundscheck(False) | ||
def lookup(self, const {{dtype}}_t[:] values) -> ndarray: | ||
def lookup(self, const {{dtype}}_t[:] values, const uint8_t[:] mask = None) -> ndarray: | ||
# -> np.ndarray[np.intp] | ||
# Used in safe_sort, IndexEngine.get_indexer | ||
cdef: | ||
|
@@ -504,15 +546,22 @@ cdef class {{name}}HashTable(HashTable): | |
{{c_type}} val | ||
khiter_t k | ||
intp_t[::1] locs = np.empty(n, dtype=np.intp) | ||
int8_t na_position = self.na_position | ||
|
||
if self.uses_mask and mask is None: | ||
raise NotImplementedError # pragma: no cover | ||
|
||
with nogil: | ||
for i in range(n): | ||
val = {{to_c_type}}(values[i]) | ||
k = kh_get_{{dtype}}(self.table, val) | ||
if k != self.table.n_buckets: | ||
locs[i] = self.table.vals[k] | ||
if self.uses_mask and mask[i]: | ||
locs[i] = na_position | ||
else: | ||
locs[i] = -1 | ||
val = {{to_c_type}}(values[i]) | ||
k = kh_get_{{dtype}}(self.table, val) | ||
if k != self.table.n_buckets: | ||
locs[i] = self.table.vals[k] | ||
else: | ||
locs[i] = -1 | ||
|
||
return np.asarray(locs) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this also be reflected in the pyi file?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, added