Skip to content

ENH: mask support for hastable functions for indexing #48396

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Oct 20, 2022
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions pandas/_libs/hashtable.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -41,75 +41,123 @@ cdef class HashTable:

cdef class UInt64HashTable(HashTable):
cdef kh_uint64_t *table
cdef int64_t na_position
cdef bint uses_mask

cpdef get_item(self, uint64_t val)
cpdef set_item(self, uint64_t key, Py_ssize_t val)
cpdef get_na(self)
cpdef set_na(self, Py_ssize_t val)

cdef class Int64HashTable(HashTable):
cdef kh_int64_t *table
cdef int64_t na_position
cdef bint uses_mask

cpdef get_item(self, int64_t val)
cpdef set_item(self, int64_t key, Py_ssize_t val)
cpdef get_na(self)
cpdef set_na(self, Py_ssize_t val)

cdef class UInt32HashTable(HashTable):
cdef kh_uint32_t *table
cdef int64_t na_position
cdef bint uses_mask

cpdef get_item(self, uint32_t val)
cpdef set_item(self, uint32_t key, Py_ssize_t val)
cpdef get_na(self)
cpdef set_na(self, Py_ssize_t val)

cdef class Int32HashTable(HashTable):
cdef kh_int32_t *table
cdef int64_t na_position
cdef bint uses_mask

cpdef get_item(self, int32_t val)
cpdef set_item(self, int32_t key, Py_ssize_t val)
cpdef get_na(self)
cpdef set_na(self, Py_ssize_t val)

cdef class UInt16HashTable(HashTable):
cdef kh_uint16_t *table
cdef int64_t na_position
cdef bint uses_mask

cpdef get_item(self, uint16_t val)
cpdef set_item(self, uint16_t key, Py_ssize_t val)
cpdef get_na(self)
cpdef set_na(self, Py_ssize_t val)

cdef class Int16HashTable(HashTable):
cdef kh_int16_t *table
cdef int64_t na_position
cdef bint uses_mask

cpdef get_item(self, int16_t val)
cpdef set_item(self, int16_t key, Py_ssize_t val)
cpdef get_na(self)
cpdef set_na(self, Py_ssize_t val)

cdef class UInt8HashTable(HashTable):
cdef kh_uint8_t *table
cdef int64_t na_position
cdef bint uses_mask

cpdef get_item(self, uint8_t val)
cpdef set_item(self, uint8_t key, Py_ssize_t val)
cpdef get_na(self)
cpdef set_na(self, Py_ssize_t val)

cdef class Int8HashTable(HashTable):
cdef kh_int8_t *table
cdef int64_t na_position
cdef bint uses_mask

cpdef get_item(self, int8_t val)
cpdef set_item(self, int8_t key, Py_ssize_t val)
cpdef get_na(self)
cpdef set_na(self, Py_ssize_t val)

cdef class Float64HashTable(HashTable):
cdef kh_float64_t *table
cdef int64_t na_position
cdef bint uses_mask

cpdef get_item(self, float64_t val)
cpdef set_item(self, float64_t key, Py_ssize_t val)
cpdef get_na(self)
cpdef set_na(self, Py_ssize_t val)

cdef class Float32HashTable(HashTable):
cdef kh_float32_t *table
cdef int64_t na_position
cdef bint uses_mask

cpdef get_item(self, float32_t val)
cpdef set_item(self, float32_t key, Py_ssize_t val)
cpdef get_na(self)
cpdef set_na(self, Py_ssize_t val)

cdef class Complex64HashTable(HashTable):
cdef kh_complex64_t *table
cdef int64_t na_position
cdef bint uses_mask

cpdef get_item(self, complex64_t val)
cpdef set_item(self, complex64_t key, Py_ssize_t val)
cpdef get_na(self)
cpdef set_na(self, Py_ssize_t val)

cdef class Complex128HashTable(HashTable):
cdef kh_complex128_t *table
cdef int64_t na_position
cdef bint uses_mask

cpdef get_item(self, complex128_t val)
cpdef set_item(self, complex128_t key, Py_ssize_t val)
cpdef get_na(self)
cpdef set_na(self, Py_ssize_t val)

cdef class PyObjectHashTable(HashTable):
cdef kh_pymap_t *table
Expand Down
8 changes: 5 additions & 3 deletions pandas/_libs/hashtable.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -110,16 +110,18 @@ class ObjectVector:

class HashTable:
# NB: The base HashTable class does _not_ actually have these methods;
# we are putting the here for the sake of mypy to avoid
# we are putting them here for the sake of mypy to avoid
# reproducing them in each subclass below.
def __init__(self, size_hint: int = ...) -> None: ...
def __init__(self, size_hint: int = ..., uses_mask: bool = ...) -> None: ...
def __len__(self) -> int: ...
def __contains__(self, key: Hashable) -> bool: ...
def sizeof(self, deep: bool = ...) -> int: ...
def get_state(self) -> dict[str, int]: ...
# TODO: `item` type is subclass-specific
def get_item(self, item): ... # TODO: return type?
def set_item(self, item) -> None: ...
def set_item(self, item, val) -> None: ...
def get_na(self): ... # TODO: return type?
def set_na(self, val) -> None: ...
def map_locations(
self,
values: np.ndarray, # np.ndarray[subclass-specific]
Expand Down
99 changes: 86 additions & 13 deletions pandas/_libs/hashtable_class_helper.pxi.in
Original file line number Diff line number Diff line change
Expand Up @@ -396,23 +396,32 @@ dtypes = [('Complex128', 'complex128', 'khcomplex128_t', 'to_khcomplex128_t'),

cdef class {{name}}HashTable(HashTable):

def __cinit__(self, int64_t size_hint=1):
def __cinit__(self, int64_t size_hint=1, bint uses_mask=False):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this also be reflected in the pyi file?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, added

self.table = kh_init_{{dtype}}()
size_hint = min(kh_needed_n_buckets(size_hint), SIZE_HINT_LIMIT)
kh_resize_{{dtype}}(self.table, size_hint)

self.uses_mask = uses_mask
self.na_position = -1

def __len__(self) -> int:
return self.table.size
return self.table.size + (0 if self.na_position == -1 else 1)

def __dealloc__(self):
if self.table is not NULL:
kh_destroy_{{dtype}}(self.table)
self.table = NULL

def __contains__(self, object key) -> bool:
# The caller is responsible to check for compatible NA values in case
# of masked arrays.
cdef:
khiter_t k
{{c_type}} ckey

if self.uses_mask and checknull(key):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does it matter what type of null we're looking at?

Copy link
Member Author

@phofl phofl Sep 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Depends if we want to support multiple nulls in masked arrays. Currently it does not:

idx = Index([1, 2, pd.NA, pd.NA], dtype="Int64")
idx.get_loc(None)
idx.get_loc(np.nan)

Both match pd.NA. Same for Float64

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

None and np.nan we usually treat as interchangeable with pd.NA. here im thinking more of pd.NaT

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah got you. This is interesting currently:

idx.get_loc(pd.NaT)

matches both NAs, but

Index([1, 2, pd.NA, pd.NA, pd.NaT], dtype="Int64")

raises. Any suggestions here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_valid_na_for_dtype exists for this purpose. For other dtypes we do this in FooIndex.get_loc. Might require some gymnastics in this case

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general I would be ok with pushing responsibility to the caller. Have to do this for set_item and get_item anyway

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

definitely reasonable, pls be sure to document this assumption

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a comment. I'll probably add some documentation in general when we are through with a MaskedEngine

return -1 != self.na_position

ckey = {{to_c_type}}(key)
k = kh_get_{{dtype}}(self.table, ckey)
return k != self.table.n_buckets
Expand All @@ -435,30 +444,73 @@ cdef class {{name}}HashTable(HashTable):
}

cpdef get_item(self, {{dtype}}_t val):
"""Extracts the position of val from the hashtable.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can the return be typed as int Py_ssize_t or something? same question for get_na below


Parameters
----------
val : Scalar
The value that is looked up in the hashtable

Returns
-------
The position of the requested integer.
"""

# Used in core.sorting, IndexEngine.get_loc
# Caller is responsible for checking for pd.NA
cdef:
khiter_t k
{{c_type}} cval

cval = {{to_c_type}}(val)
k = kh_get_{{dtype}}(self.table, cval)
if k != self.table.n_buckets:
return self.table.vals[k]
else:
raise KeyError(val)

cpdef get_na(self):
"""Extracts the position of na_value from the hashtable.

Returns
-------
The position of the last na value.
"""

if not self.uses_mask:
raise NotImplementedError

if self.na_position == -1:
raise KeyError("NA")
return self.na_position

cpdef set_item(self, {{dtype}}_t key, Py_ssize_t val):
# Used in libjoin
# Caller is responsible for checking for pd.NA
cdef:
khiter_t k
int ret = 0
{{c_type}} ckey

ckey = {{to_c_type}}(key)
k = kh_put_{{dtype}}(self.table, ckey, &ret)
if kh_exist_{{dtype}}(self.table, k):
self.table.vals[k] = val
else:
raise KeyError(key)

cpdef set_na(self, Py_ssize_t val):
# Caller is responsible for checking for pd.NA
cdef:
khiter_t k
int ret = 0
{{c_type}} ckey

if not self.uses_mask:
raise NotImplementedError

self.na_position = val

{{if dtype == "int64" }}
# We only use this for int64, can reduce build size and make .pyi
# more accurate by only implementing it for int64
Expand All @@ -480,22 +532,36 @@ cdef class {{name}}HashTable(HashTable):
{{endif}}

@cython.boundscheck(False)
def map_locations(self, const {{dtype}}_t[:] values) -> None:
def map_locations(self, const {{dtype}}_t[:] values, const uint8_t[:] mask = None) -> None:
# Used in libindex, safe_sort
cdef:
Py_ssize_t i, n = len(values)
int ret = 0
{{c_type}} val
khiter_t k
int8_t na_position = self.na_position

if self.uses_mask and mask is None:
raise NotImplementedError # pragma: no cover

with nogil:
for i in range(n):
val= {{to_c_type}}(values[i])
k = kh_put_{{dtype}}(self.table, val, &ret)
self.table.vals[k] = i
if self.uses_mask:
for i in range(n):
if mask[i]:
na_position = i
else:
val= {{to_c_type}}(values[i])
k = kh_put_{{dtype}}(self.table, val, &ret)
self.table.vals[k] = i
else:
for i in range(n):
val= {{to_c_type}}(values[i])
k = kh_put_{{dtype}}(self.table, val, &ret)
self.table.vals[k] = i
self.na_position = na_position

@cython.boundscheck(False)
def lookup(self, const {{dtype}}_t[:] values) -> ndarray:
def lookup(self, const {{dtype}}_t[:] values, const uint8_t[:] mask = None) -> ndarray:
# -> np.ndarray[np.intp]
# Used in safe_sort, IndexEngine.get_indexer
cdef:
Expand All @@ -504,15 +570,22 @@ cdef class {{name}}HashTable(HashTable):
{{c_type}} val
khiter_t k
intp_t[::1] locs = np.empty(n, dtype=np.intp)
int8_t na_position = self.na_position

if self.uses_mask and mask is None:
raise NotImplementedError # pragma: no cover

with nogil:
for i in range(n):
val = {{to_c_type}}(values[i])
k = kh_get_{{dtype}}(self.table, val)
if k != self.table.n_buckets:
locs[i] = self.table.vals[k]
if self.uses_mask and mask[i]:
locs[i] = na_position
else:
locs[i] = -1
val = {{to_c_type}}(values[i])
k = kh_get_{{dtype}}(self.table, val)
if k != self.table.n_buckets:
locs[i] = self.table.vals[k]
else:
locs[i] = -1

return np.asarray(locs)

Expand Down
Loading