Skip to content

ENH: mask support for hastable functions for indexing #48396

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Oct 20, 2022
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 48 additions & 24 deletions pandas/_libs/hashtable.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -41,75 +41,99 @@ cdef class HashTable:

cdef class UInt64HashTable(HashTable):
cdef kh_uint64_t *table
cdef int64_t na_position
cdef bint uses_mask

cpdef get_item(self, uint64_t val)
cpdef set_item(self, uint64_t key, Py_ssize_t val)
cpdef get_item(self, uint64_t val, bint na_value=*)
cpdef set_item(self, uint64_t key, Py_ssize_t val, bint na_value=*)

cdef class Int64HashTable(HashTable):
cdef kh_int64_t *table
cdef int64_t na_position
cdef bint uses_mask

cpdef get_item(self, int64_t val)
cpdef set_item(self, int64_t key, Py_ssize_t val)
cpdef get_item(self, int64_t val, bint na_value=*)
cpdef set_item(self, int64_t key, Py_ssize_t val, bint na_value=*)

cdef class UInt32HashTable(HashTable):
cdef kh_uint32_t *table
cdef int64_t na_position
cdef bint uses_mask

cpdef get_item(self, uint32_t val)
cpdef set_item(self, uint32_t key, Py_ssize_t val)
cpdef get_item(self, uint32_t val, bint na_value=*)
cpdef set_item(self, uint32_t key, Py_ssize_t val, bint na_value=*)

cdef class Int32HashTable(HashTable):
cdef kh_int32_t *table
cdef int64_t na_position
cdef bint uses_mask

cpdef get_item(self, int32_t val)
cpdef set_item(self, int32_t key, Py_ssize_t val)
cpdef get_item(self, int32_t val, bint na_value=*)
cpdef set_item(self, int32_t key, Py_ssize_t val, bint na_value=*)

cdef class UInt16HashTable(HashTable):
cdef kh_uint16_t *table
cdef int64_t na_position
cdef bint uses_mask

cpdef get_item(self, uint16_t val)
cpdef set_item(self, uint16_t key, Py_ssize_t val)
cpdef get_item(self, uint16_t val, bint na_value=*)
cpdef set_item(self, uint16_t key, Py_ssize_t val, bint na_value=*)

cdef class Int16HashTable(HashTable):
cdef kh_int16_t *table
cdef int64_t na_position
cdef bint uses_mask

cpdef get_item(self, int16_t val)
cpdef set_item(self, int16_t key, Py_ssize_t val)
cpdef get_item(self, int16_t val, bint na_value=*)
cpdef set_item(self, int16_t key, Py_ssize_t val, bint na_value=*)

cdef class UInt8HashTable(HashTable):
cdef kh_uint8_t *table
cdef int64_t na_position
cdef bint uses_mask

cpdef get_item(self, uint8_t val)
cpdef set_item(self, uint8_t key, Py_ssize_t val)
cpdef get_item(self, uint8_t val, bint na_value=*)
cpdef set_item(self, uint8_t key, Py_ssize_t val, bint na_value=*)

cdef class Int8HashTable(HashTable):
cdef kh_int8_t *table
cdef int64_t na_position
cdef bint uses_mask

cpdef get_item(self, int8_t val)
cpdef set_item(self, int8_t key, Py_ssize_t val)
cpdef get_item(self, int8_t val, bint na_value=*)
cpdef set_item(self, int8_t key, Py_ssize_t val, bint na_value=*)

cdef class Float64HashTable(HashTable):
cdef kh_float64_t *table
cdef int64_t na_position
cdef bint uses_mask

cpdef get_item(self, float64_t val)
cpdef set_item(self, float64_t key, Py_ssize_t val)
cpdef get_item(self, float64_t val, bint na_value=*)
cpdef set_item(self, float64_t key, Py_ssize_t val, bint na_value=*)

cdef class Float32HashTable(HashTable):
cdef kh_float32_t *table
cdef int64_t na_position
cdef bint uses_mask

cpdef get_item(self, float32_t val)
cpdef set_item(self, float32_t key, Py_ssize_t val)
cpdef get_item(self, float32_t val, bint na_value=*)
cpdef set_item(self, float32_t key, Py_ssize_t val, bint na_value=*)

cdef class Complex64HashTable(HashTable):
cdef kh_complex64_t *table
cdef int64_t na_position
cdef bint uses_mask

cpdef get_item(self, complex64_t val)
cpdef set_item(self, complex64_t key, Py_ssize_t val)
cpdef get_item(self, complex64_t val, bint na_value=*)
cpdef set_item(self, complex64_t key, Py_ssize_t val, bint na_value=*)

cdef class Complex128HashTable(HashTable):
cdef kh_complex128_t *table
cdef int64_t na_position
cdef bint uses_mask

cpdef get_item(self, complex128_t val)
cpdef set_item(self, complex128_t key, Py_ssize_t val)
cpdef get_item(self, complex128_t val, bint na_value=*)
cpdef set_item(self, complex128_t key, Py_ssize_t val, bint na_value=*)

cdef class PyObjectHashTable(HashTable):
cdef kh_pymap_t *table
Expand Down
6 changes: 3 additions & 3 deletions pandas/_libs/hashtable.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -110,16 +110,16 @@ class ObjectVector:

class HashTable:
# NB: The base HashTable class does _not_ actually have these methods;
# we are putting the here for the sake of mypy to avoid
# we are putting them here for the sake of mypy to avoid
# reproducing them in each subclass below.
def __init__(self, size_hint: int = ...) -> None: ...
def __len__(self) -> int: ...
def __contains__(self, key: Hashable) -> bool: ...
def sizeof(self, deep: bool = ...) -> int: ...
def get_state(self) -> dict[str, int]: ...
# TODO: `item` type is subclass-specific
def get_item(self, item): ... # TODO: return type?
def set_item(self, item) -> None: ...
def get_item(self, item, na_value: bool = ...): ... # TODO: return type?
def set_item(self, item, val, na_value: bool = ...) -> None: ...
def map_locations(
self,
values: np.ndarray, # np.ndarray[subclass-specific]
Expand Down
89 changes: 69 additions & 20 deletions pandas/_libs/hashtable_class_helper.pxi.in
Original file line number Diff line number Diff line change
Expand Up @@ -396,23 +396,32 @@ dtypes = [('Complex128', 'complex128', 'khcomplex128_t', 'to_khcomplex128_t'),

cdef class {{name}}HashTable(HashTable):

def __cinit__(self, int64_t size_hint=1):
def __cinit__(self, int64_t size_hint=1, bint uses_mask=False):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this also be reflected in the pyi file?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, added

self.table = kh_init_{{dtype}}()
size_hint = min(kh_needed_n_buckets(size_hint), SIZE_HINT_LIMIT)
kh_resize_{{dtype}}(self.table, size_hint)

self.uses_mask = uses_mask
self.na_position = -1

def __len__(self) -> int:
return self.table.size
return self.table.size + (0 if self.na_position == -1 else 1)

def __dealloc__(self):
if self.table is not NULL:
kh_destroy_{{dtype}}(self.table)
self.table = NULL

def __contains__(self, object key) -> bool:
# The caller is responsible to check for compatible NA values in case
# of masked arrays.
cdef:
khiter_t k
{{c_type}} ckey

if self.uses_mask and checknull(key):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does it matter what type of null we're looking at?

Copy link
Member Author

@phofl phofl Sep 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Depends if we want to support multiple nulls in masked arrays. Currently it does not:

idx = Index([1, 2, pd.NA, pd.NA], dtype="Int64")
idx.get_loc(None)
idx.get_loc(np.nan)

Both match pd.NA. Same for Float64

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

None and np.nan we usually treat as interchangeable with pd.NA. here im thinking more of pd.NaT

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah got you. This is interesting currently:

idx.get_loc(pd.NaT)

matches both NAs, but

Index([1, 2, pd.NA, pd.NA, pd.NaT], dtype="Int64")

raises. Any suggestions here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_valid_na_for_dtype exists for this purpose. For other dtypes we do this in FooIndex.get_loc. Might require some gymnastics in this case

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general I would be ok with pushing responsibility to the caller. Have to do this for set_item and get_item anyway

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

definitely reasonable, pls be sure to document this assumption

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a comment. I'll probably add some documentation in general when we are through with a MaskedEngine

return -1 != self.na_position

ckey = {{to_c_type}}(key)
k = kh_get_{{dtype}}(self.table, ckey)
return k != self.table.n_buckets
Expand All @@ -434,30 +443,49 @@ cdef class {{name}}HashTable(HashTable):
'upper_bound' : self.table.upper_bound,
}

cpdef get_item(self, {{dtype}}_t val):
cpdef get_item(self, {{dtype}}_t val, bint na_value = False):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would something like val_is_na be a more descriptive keyword name? (typically in other places where we have na_value, it is the value itself, not a bool flag)

(while you are at it, a brief docstring could also help to describe the keywords)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

imo val_is_na is a bit confusing, because val is actually a placeholder and therefore garbage at this place. It is more like get_item of na_value

Copy link
Member

@jorisvandenbossche jorisvandenbossche Oct 3, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or is_na_value or get_na_value?
(I mostly would like to avoid using the same keyword name that is used elsewhere for a different meaning)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although with get_na_value, it is giving some duplication (get_item(.., get_na_value=True)

But this is also pointing to that it's not a very clean API, with basically two separate behaviours (value being garbarge if the other keyword is set). It's not too important given this is deep into the internals, but maybe a clearer interface is to just have two separate methods instead of the keyword in the single method: get_item(value) and get_na().
Since the caller needs to handle this specifically in the level up anyway to check for NA and set the keyword, you might as well just call a different method instead of passing a keyword.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactored into own methods, should be significantly easier to understand now

# Used in core.sorting, IndexEngine.get_loc
# Caller is responsible for checking for pd.NA
cdef:
khiter_t k
{{c_type}} cval

if not self.uses_mask and na_value:
raise NotImplementedError

if na_value:
if self.na_position == -1:
raise KeyError("NA")
return self.na_position

cval = {{to_c_type}}(val)
k = kh_get_{{dtype}}(self.table, cval)
if k != self.table.n_buckets:
return self.table.vals[k]
else:
raise KeyError(val)

cpdef set_item(self, {{dtype}}_t key, Py_ssize_t val):
cpdef set_item(self, {{dtype}}_t key, Py_ssize_t val, bint na_value = False):
# Used in libjoin
# Caller is responsible for checking for pd.NA
cdef:
khiter_t k
int ret = 0
{{c_type}} ckey
ckey = {{to_c_type}}(key)
k = kh_put_{{dtype}}(self.table, ckey, &ret)
if kh_exist_{{dtype}}(self.table, k):
self.table.vals[k] = val

if na_value and not self.uses_mask:
raise NotImplementedError

if na_value:
self.na_position = val

else:
raise KeyError(key)
ckey = {{to_c_type}}(key)
k = kh_put_{{dtype}}(self.table, ckey, &ret)
if kh_exist_{{dtype}}(self.table, k):
self.table.vals[k] = val
else:
raise KeyError(key)

{{if dtype == "int64" }}
# We only use this for int64, can reduce build size and make .pyi
Expand All @@ -480,22 +508,36 @@ cdef class {{name}}HashTable(HashTable):
{{endif}}

@cython.boundscheck(False)
def map_locations(self, const {{dtype}}_t[:] values) -> None:
def map_locations(self, const {{dtype}}_t[:] values, const uint8_t[:] mask = None) -> None:
# Used in libindex, safe_sort
cdef:
Py_ssize_t i, n = len(values)
int ret = 0
{{c_type}} val
khiter_t k
int8_t na_position = self.na_position

if self.uses_mask and mask is None:
raise NotImplementedError # pragma: no cover

with nogil:
for i in range(n):
val= {{to_c_type}}(values[i])
k = kh_put_{{dtype}}(self.table, val, &ret)
self.table.vals[k] = i
if self.uses_mask:
for i in range(n):
if mask[i]:
na_position = i
else:
val= {{to_c_type}}(values[i])
k = kh_put_{{dtype}}(self.table, val, &ret)
self.table.vals[k] = i
else:
for i in range(n):
val= {{to_c_type}}(values[i])
k = kh_put_{{dtype}}(self.table, val, &ret)
self.table.vals[k] = i
self.na_position = na_position

@cython.boundscheck(False)
def lookup(self, const {{dtype}}_t[:] values) -> ndarray:
def lookup(self, const {{dtype}}_t[:] values, const uint8_t[:] mask = None) -> ndarray:
# -> np.ndarray[np.intp]
# Used in safe_sort, IndexEngine.get_indexer
cdef:
Expand All @@ -504,15 +546,22 @@ cdef class {{name}}HashTable(HashTable):
{{c_type}} val
khiter_t k
intp_t[::1] locs = np.empty(n, dtype=np.intp)
int8_t na_position = self.na_position

if self.uses_mask and mask is None:
raise NotImplementedError # pragma: no cover

with nogil:
for i in range(n):
val = {{to_c_type}}(values[i])
k = kh_get_{{dtype}}(self.table, val)
if k != self.table.n_buckets:
locs[i] = self.table.vals[k]
if self.uses_mask and mask[i]:
locs[i] = na_position
else:
locs[i] = -1
val = {{to_c_type}}(values[i])
k = kh_get_{{dtype}}(self.table, val)
if k != self.table.n_buckets:
locs[i] = self.table.vals[k]
else:
locs[i] = -1

return np.asarray(locs)

Expand Down
Loading