From a7f1abf568f0362bdd47fb7bb53818c0b039345a Mon Sep 17 00:00:00 2001 From: Patrick Hoefler Date: Mon, 5 Sep 2022 13:14:23 +0200 Subject: [PATCH 1/7] ENH: mask support for hastable functions for indexing --- pandas/_libs/hashtable.pxd | 72 ++++++++++++------ pandas/_libs/hashtable.pyi | 4 +- pandas/_libs/hashtable_class_helper.pxi.in | 85 +++++++++++++++++----- pandas/tests/libs/test_hashtable.py | 77 ++++++++++++++++++++ 4 files changed, 192 insertions(+), 46 deletions(-) diff --git a/pandas/_libs/hashtable.pxd b/pandas/_libs/hashtable.pxd index 80d7ab58dc559..0d7f88d38a827 100644 --- a/pandas/_libs/hashtable.pxd +++ b/pandas/_libs/hashtable.pxd @@ -41,75 +41,99 @@ cdef class HashTable: cdef class UInt64HashTable(HashTable): cdef kh_uint64_t *table + cdef int64_t na_position + cdef bint uses_mask - cpdef get_item(self, uint64_t val) - cpdef set_item(self, uint64_t key, Py_ssize_t val) + cpdef get_item(self, uint64_t val, bint na_value=*) + cpdef set_item(self, uint64_t key, Py_ssize_t val, bint na_value=*) cdef class Int64HashTable(HashTable): cdef kh_int64_t *table + cdef int64_t na_position + cdef bint uses_mask - cpdef get_item(self, int64_t val) - cpdef set_item(self, int64_t key, Py_ssize_t val) + cpdef get_item(self, int64_t val, bint na_value=*) + cpdef set_item(self, int64_t key, Py_ssize_t val, bint na_value=*) cdef class UInt32HashTable(HashTable): cdef kh_uint32_t *table + cdef int64_t na_position + cdef bint uses_mask - cpdef get_item(self, uint32_t val) - cpdef set_item(self, uint32_t key, Py_ssize_t val) + cpdef get_item(self, uint32_t val, bint na_value=*) + cpdef set_item(self, uint32_t key, Py_ssize_t val, bint na_value=*) cdef class Int32HashTable(HashTable): cdef kh_int32_t *table + cdef int64_t na_position + cdef bint uses_mask - cpdef get_item(self, int32_t val) - cpdef set_item(self, int32_t key, Py_ssize_t val) + cpdef get_item(self, int32_t val, bint na_value=*) + cpdef set_item(self, int32_t key, Py_ssize_t val, bint na_value=*) cdef class UInt16HashTable(HashTable): cdef kh_uint16_t *table + cdef int64_t na_position + cdef bint uses_mask - cpdef get_item(self, uint16_t val) - cpdef set_item(self, uint16_t key, Py_ssize_t val) + cpdef get_item(self, uint16_t val, bint na_value=*) + cpdef set_item(self, uint16_t key, Py_ssize_t val, bint na_value=*) cdef class Int16HashTable(HashTable): cdef kh_int16_t *table + cdef int64_t na_position + cdef bint uses_mask - cpdef get_item(self, int16_t val) - cpdef set_item(self, int16_t key, Py_ssize_t val) + cpdef get_item(self, int16_t val, bint na_value=*) + cpdef set_item(self, int16_t key, Py_ssize_t val, bint na_value=*) cdef class UInt8HashTable(HashTable): cdef kh_uint8_t *table + cdef int64_t na_position + cdef bint uses_mask - cpdef get_item(self, uint8_t val) - cpdef set_item(self, uint8_t key, Py_ssize_t val) + cpdef get_item(self, uint8_t val, bint na_value=*) + cpdef set_item(self, uint8_t key, Py_ssize_t val, bint na_value=*) cdef class Int8HashTable(HashTable): cdef kh_int8_t *table + cdef int64_t na_position + cdef bint uses_mask - cpdef get_item(self, int8_t val) - cpdef set_item(self, int8_t key, Py_ssize_t val) + cpdef get_item(self, int8_t val, bint na_value=*) + cpdef set_item(self, int8_t key, Py_ssize_t val, bint na_value=*) cdef class Float64HashTable(HashTable): cdef kh_float64_t *table + cdef int64_t na_position + cdef bint uses_mask - cpdef get_item(self, float64_t val) - cpdef set_item(self, float64_t key, Py_ssize_t val) + cpdef get_item(self, float64_t val, bint na_value=*) + cpdef set_item(self, float64_t key, Py_ssize_t val, bint na_value=*) cdef class Float32HashTable(HashTable): cdef kh_float32_t *table + cdef int64_t na_position + cdef bint uses_mask - cpdef get_item(self, float32_t val) - cpdef set_item(self, float32_t key, Py_ssize_t val) + cpdef get_item(self, float32_t val, bint na_value=*) + cpdef set_item(self, float32_t key, Py_ssize_t val, bint na_value=*) cdef class Complex64HashTable(HashTable): cdef kh_complex64_t *table + cdef int64_t na_position + cdef bint uses_mask - cpdef get_item(self, complex64_t val) - cpdef set_item(self, complex64_t key, Py_ssize_t val) + cpdef get_item(self, complex64_t val, bint na_value=*) + cpdef set_item(self, complex64_t key, Py_ssize_t val, bint na_value=*) cdef class Complex128HashTable(HashTable): cdef kh_complex128_t *table + cdef int64_t na_position + cdef bint uses_mask - cpdef get_item(self, complex128_t val) - cpdef set_item(self, complex128_t key, Py_ssize_t val) + cpdef get_item(self, complex128_t val, bint na_value=*) + cpdef set_item(self, complex128_t key, Py_ssize_t val, bint na_value=*) cdef class PyObjectHashTable(HashTable): cdef kh_pymap_t *table diff --git a/pandas/_libs/hashtable.pyi b/pandas/_libs/hashtable.pyi index 8500fdf2f602e..1e415edf1b9ad 100644 --- a/pandas/_libs/hashtable.pyi +++ b/pandas/_libs/hashtable.pyi @@ -118,8 +118,8 @@ class HashTable: def sizeof(self, deep: bool = ...) -> int: ... def get_state(self) -> dict[str, int]: ... # TODO: `item` type is subclass-specific - def get_item(self, item): ... # TODO: return type? - def set_item(self, item) -> None: ... + def get_item(self, item, na_value): ... # TODO: return type? + def set_item(self, item, val, na_value) -> None: ... def map_locations( self, values: np.ndarray, # np.ndarray[subclass-specific] diff --git a/pandas/_libs/hashtable_class_helper.pxi.in b/pandas/_libs/hashtable_class_helper.pxi.in index 54260a9a90964..ec7b994b0b9e3 100644 --- a/pandas/_libs/hashtable_class_helper.pxi.in +++ b/pandas/_libs/hashtable_class_helper.pxi.in @@ -396,13 +396,16 @@ dtypes = [('Complex128', 'complex128', 'khcomplex128_t', 'to_khcomplex128_t'), cdef class {{name}}HashTable(HashTable): - def __cinit__(self, int64_t size_hint=1): + def __cinit__(self, int64_t size_hint=1, bint uses_mask=False): self.table = kh_init_{{dtype}}() size_hint = min(kh_needed_n_buckets(size_hint), SIZE_HINT_LIMIT) kh_resize_{{dtype}}(self.table, size_hint) + self.uses_mask = uses_mask + self.na_position = -1 + def __len__(self) -> int: - return self.table.size + return self.table.size + (0 if self.na_position == -1 else 1) def __dealloc__(self): if self.table is not NULL: @@ -413,6 +416,10 @@ cdef class {{name}}HashTable(HashTable): cdef: khiter_t k {{c_type}} ckey + + if self.uses_mask and checknull(key): + return -1 != self.na_position + ckey = {{to_c_type}}(key) k = kh_get_{{dtype}}(self.table, ckey) return k != self.table.n_buckets @@ -434,11 +441,20 @@ cdef class {{name}}HashTable(HashTable): 'upper_bound' : self.table.upper_bound, } - cpdef get_item(self, {{dtype}}_t val): + cpdef get_item(self, {{dtype}}_t val, bint na_value = False): # Used in core.sorting, IndexEngine.get_loc cdef: khiter_t k {{c_type}} cval + + if not self.uses_mask and na_value: + raise NotImplementedError + + if na_value: + if self.na_position == -1: + raise KeyError("NA") + return self.na_position + cval = {{to_c_type}}(val) k = kh_get_{{dtype}}(self.table, cval) if k != self.table.n_buckets: @@ -446,18 +462,26 @@ cdef class {{name}}HashTable(HashTable): else: raise KeyError(val) - cpdef set_item(self, {{dtype}}_t key, Py_ssize_t val): + cpdef set_item(self, {{dtype}}_t key, Py_ssize_t val, bint na_value = False): # Used in libjoin cdef: khiter_t k int ret = 0 {{c_type}} ckey - ckey = {{to_c_type}}(key) - k = kh_put_{{dtype}}(self.table, ckey, &ret) - if kh_exist_{{dtype}}(self.table, k): - self.table.vals[k] = val + + if na_value and not self.uses_mask: + raise NotImplementedError + + if na_value: + self.na_position = val + else: - raise KeyError(key) + ckey = {{to_c_type}}(key) + k = kh_put_{{dtype}}(self.table, ckey, &ret) + if kh_exist_{{dtype}}(self.table, k): + self.table.vals[k] = val + else: + raise KeyError(key) {{if dtype == "int64" }} # We only use this for int64, can reduce build size and make .pyi @@ -480,22 +504,36 @@ cdef class {{name}}HashTable(HashTable): {{endif}} @cython.boundscheck(False) - def map_locations(self, const {{dtype}}_t[:] values) -> None: + def map_locations(self, const {{dtype}}_t[:] values, const uint8_t[:] mask = None) -> None: # Used in libindex, safe_sort cdef: Py_ssize_t i, n = len(values) int ret = 0 {{c_type}} val khiter_t k + int8_t na_position = self.na_position + + if self.uses_mask and mask is None: + raise NotImplementedError # pragma: no cover with nogil: - for i in range(n): - val= {{to_c_type}}(values[i]) - k = kh_put_{{dtype}}(self.table, val, &ret) - self.table.vals[k] = i + if self.uses_mask: + for i in range(n): + if mask[i]: + na_position = i + else: + val= {{to_c_type}}(values[i]) + k = kh_put_{{dtype}}(self.table, val, &ret) + self.table.vals[k] = i + else: + for i in range(n): + val= {{to_c_type}}(values[i]) + k = kh_put_{{dtype}}(self.table, val, &ret) + self.table.vals[k] = i + self.na_position = na_position @cython.boundscheck(False) - def lookup(self, const {{dtype}}_t[:] values) -> ndarray: + def lookup(self, const {{dtype}}_t[:] values, const uint8_t[:] mask = None) -> ndarray: # -> np.ndarray[np.intp] # Used in safe_sort, IndexEngine.get_indexer cdef: @@ -504,15 +542,22 @@ cdef class {{name}}HashTable(HashTable): {{c_type}} val khiter_t k intp_t[::1] locs = np.empty(n, dtype=np.intp) + int8_t na_position = self.na_position + + if self.uses_mask and mask is None: + raise NotImplementedError # pragma: no cover with nogil: for i in range(n): - val = {{to_c_type}}(values[i]) - k = kh_get_{{dtype}}(self.table, val) - if k != self.table.n_buckets: - locs[i] = self.table.vals[k] + if self.uses_mask and mask[i]: + locs[i] = na_position else: - locs[i] = -1 + val = {{to_c_type}}(values[i]) + k = kh_get_{{dtype}}(self.table, val) + if k != self.table.n_buckets: + locs[i] = self.table.vals[k] + else: + locs[i] = -1 return np.asarray(locs) diff --git a/pandas/tests/libs/test_hashtable.py b/pandas/tests/libs/test_hashtable.py index 0a3a10315b5fd..c098a39b52ca4 100644 --- a/pandas/tests/libs/test_hashtable.py +++ b/pandas/tests/libs/test_hashtable.py @@ -1,4 +1,5 @@ from contextlib import contextmanager +import re import struct import tracemalloc @@ -75,6 +76,48 @@ def test_get_set_contains_len(self, table_type, dtype): assert table.get_item(index + 1) == 41 assert index + 2 not in table + table.set_item(index + 1, 21) + assert index in table + assert index + 1 in table + assert len(table) == 2 + assert table.get_item(index) == 21 + assert table.get_item(index + 1) == 21 + + with pytest.raises(KeyError, match=str(index + 2)): + table.get_item(index + 2) + + def test_get_set_contains_len_mask(self, table_type, dtype): + if table_type == ht.PyObjectHashTable: + pytest.skip("Mask not supporter for object") + index = 5 + table = table_type(55, uses_mask=True) + assert len(table) == 0 + assert index not in table + + table.set_item(index, 42) + assert len(table) == 1 + assert index in table + assert table.get_item(index) == 42 + with pytest.raises(KeyError, match="NA"): + table.get_item(0, na_value=True) + + table.set_item(index + 1, 41) + table.set_item(0, 41, na_value=True) + assert index in table + assert index + 1 in table + assert len(table) == 3 + assert table.get_item(index) == 42 + assert table.get_item(index + 1) == 41 + assert table.get_item(1, na_value=True) == 41 + + table.set_item(0, 21, na_value=True) + assert index in table + assert index + 1 in table + assert len(table) == 3 + assert table.get_item(index + 1) == 41 + assert table.get_item(1, na_value=True) == 21 + assert index + 2 not in table + with pytest.raises(KeyError, match=str(index + 2)): table.get_item(index + 2) @@ -100,6 +143,22 @@ def test_map_locations(self, table_type, dtype, writable): for i in range(N): assert table.get_item(keys[i]) == i + def test_map_locations_mask(self, table_type, dtype, writable): + if table_type == ht.PyObjectHashTable: + pytest.skip("Mask not supporter for object") + N = 3 + table = table_type(uses_mask=True) + keys = (np.arange(N) + N).astype(dtype) + keys.flags.writeable = writable + table.map_locations(keys, np.array([False, False, True])) + for i in range(N - 1): + assert table.get_item(keys[i]) == i + + with pytest.raises(KeyError, match=re.escape(str(keys[N - 1]))): + table.get_item(keys[N - 1]) + + assert table.get_item(keys[N - 1], na_value=True) == 2 + def test_lookup(self, table_type, dtype, writable): N = 3 table = table_type() @@ -122,6 +181,24 @@ def test_lookup_wrong(self, table_type, dtype): result = table.lookup(wrong_keys) assert np.all(result == -1) + def test_lookup_mask(self, table_type, dtype, writable): + if table_type == ht.PyObjectHashTable: + pytest.skip("Mask not supporter for object") + N = 3 + table = table_type(uses_mask=True) + keys = (np.arange(N) + N).astype(dtype) + mask = np.array([False, True, False]) + keys.flags.writeable = writable + table.map_locations(keys, mask) + result = table.lookup(keys, mask) + expected = np.arange(N) + tm.assert_numpy_array_equal(result.astype(np.int64), expected.astype(np.int64)) + + result = table.lookup(np.array([1 + N]).astype(dtype), np.array([False])) + tm.assert_numpy_array_equal( + result.astype(np.int64), np.array([-1], dtype=np.int64) + ) + def test_unique(self, table_type, dtype, writable): if dtype in (np.int8, np.uint8): N = 88 From b216f6cb73c30071380a7fa0736155f815e4aa65 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler Date: Mon, 5 Sep 2022 16:27:21 +0200 Subject: [PATCH 2/7] Fix mypy --- pandas/_libs/hashtable.pyi | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pandas/_libs/hashtable.pyi b/pandas/_libs/hashtable.pyi index 1e415edf1b9ad..fd61347f39345 100644 --- a/pandas/_libs/hashtable.pyi +++ b/pandas/_libs/hashtable.pyi @@ -110,7 +110,7 @@ class ObjectVector: class HashTable: # NB: The base HashTable class does _not_ actually have these methods; - # we are putting the here for the sake of mypy to avoid + # we are putting them here for the sake of mypy to avoid # reproducing them in each subclass below. def __init__(self, size_hint: int = ...): ... def __len__(self) -> int: ... @@ -118,8 +118,8 @@ class HashTable: def sizeof(self, deep: bool = ...) -> int: ... def get_state(self) -> dict[str, int]: ... # TODO: `item` type is subclass-specific - def get_item(self, item, na_value): ... # TODO: return type? - def set_item(self, item, val, na_value) -> None: ... + def get_item(self, item, na_value: bool = ...): ... # TODO: return type? + def set_item(self, item, val, na_value: bool = ...) -> None: ... def map_locations( self, values: np.ndarray, # np.ndarray[subclass-specific] From f276f6f9f8cede68ee83e75e821a1ae6d7604767 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler Date: Mon, 5 Sep 2022 16:30:16 +0200 Subject: [PATCH 3/7] Adjust test --- pandas/_libs/hashtable_class_helper.pxi.in | 2 ++ pandas/tests/libs/test_hashtable.py | 1 + 2 files changed, 3 insertions(+) diff --git a/pandas/_libs/hashtable_class_helper.pxi.in b/pandas/_libs/hashtable_class_helper.pxi.in index ec7b994b0b9e3..089a86258c30b 100644 --- a/pandas/_libs/hashtable_class_helper.pxi.in +++ b/pandas/_libs/hashtable_class_helper.pxi.in @@ -443,6 +443,7 @@ cdef class {{name}}HashTable(HashTable): cpdef get_item(self, {{dtype}}_t val, bint na_value = False): # Used in core.sorting, IndexEngine.get_loc + # Caller is responsible for checking for pd.NA cdef: khiter_t k {{c_type}} cval @@ -464,6 +465,7 @@ cdef class {{name}}HashTable(HashTable): cpdef set_item(self, {{dtype}}_t key, Py_ssize_t val, bint na_value = False): # Used in libjoin + # Caller is responsible for checking for pd.NA cdef: khiter_t k int ret = 0 diff --git a/pandas/tests/libs/test_hashtable.py b/pandas/tests/libs/test_hashtable.py index c098a39b52ca4..a1e9497d82ff7 100644 --- a/pandas/tests/libs/test_hashtable.py +++ b/pandas/tests/libs/test_hashtable.py @@ -103,6 +103,7 @@ def test_get_set_contains_len_mask(self, table_type, dtype): table.set_item(index + 1, 41) table.set_item(0, 41, na_value=True) + assert pd.NA in table assert index in table assert index + 1 in table assert len(table) == 3 From f2c271bce1c8e93dcbb8140c89764c96c1abe5d2 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler Date: Fri, 9 Sep 2022 21:45:36 +0200 Subject: [PATCH 4/7] Add comment --- pandas/_libs/hashtable_class_helper.pxi.in | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pandas/_libs/hashtable_class_helper.pxi.in b/pandas/_libs/hashtable_class_helper.pxi.in index 089a86258c30b..33b7ee9b30475 100644 --- a/pandas/_libs/hashtable_class_helper.pxi.in +++ b/pandas/_libs/hashtable_class_helper.pxi.in @@ -413,6 +413,8 @@ cdef class {{name}}HashTable(HashTable): self.table = NULL def __contains__(self, object key) -> bool: + # The caller is responsible to check for compatible NA values in case + # of masked arrays. cdef: khiter_t k {{c_type}} ckey From d08a7ae4362ad93866c8a7a7d4220ffe9ef2c902 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler Date: Thu, 29 Sep 2022 10:26:29 +0200 Subject: [PATCH 5/7] Add docstring --- pandas/_libs/hashtable.pyi | 2 +- pandas/_libs/hashtable_class_helper.pxi.in | 14 ++++++++++++++ pandas/tests/libs/test_hashtable.py | 6 +++--- 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/pandas/_libs/hashtable.pyi b/pandas/_libs/hashtable.pyi index 23a557560f4bc..f40e9c93fd004 100644 --- a/pandas/_libs/hashtable.pyi +++ b/pandas/_libs/hashtable.pyi @@ -112,7 +112,7 @@ class HashTable: # NB: The base HashTable class does _not_ actually have these methods; # we are putting them here for the sake of mypy to avoid # reproducing them in each subclass below. - def __init__(self, size_hint: int = ...) -> None: ... + def __init__(self, size_hint: int = ..., uses_mask: bool = ...) -> None: ... def __len__(self) -> int: ... def __contains__(self, key: Hashable) -> bool: ... def sizeof(self, deep: bool = ...) -> int: ... diff --git a/pandas/_libs/hashtable_class_helper.pxi.in b/pandas/_libs/hashtable_class_helper.pxi.in index 33b7ee9b30475..618417c4965c8 100644 --- a/pandas/_libs/hashtable_class_helper.pxi.in +++ b/pandas/_libs/hashtable_class_helper.pxi.in @@ -444,6 +444,20 @@ cdef class {{name}}HashTable(HashTable): } cpdef get_item(self, {{dtype}}_t val, bint na_value = False): + """Extracts the position of val or na_value from the hashtable. + + Parameters + ---------- + val : Scalar + The value that is looked up in the hashtable + na_value : bool, default False + Returns posiition of the NA value if using masked dtypes. + + Returns + ------- + The position of the requested integer. + """ + # Used in core.sorting, IndexEngine.get_loc # Caller is responsible for checking for pd.NA cdef: diff --git a/pandas/tests/libs/test_hashtable.py b/pandas/tests/libs/test_hashtable.py index a652b5add0ddf..c3a288e6fd040 100644 --- a/pandas/tests/libs/test_hashtable.py +++ b/pandas/tests/libs/test_hashtable.py @@ -89,7 +89,7 @@ def test_get_set_contains_len(self, table_type, dtype): def test_get_set_contains_len_mask(self, table_type, dtype): if table_type == ht.PyObjectHashTable: - pytest.skip("Mask not supporter for object") + pytest.skip("Mask not supported for object") index = 5 table = table_type(55, uses_mask=True) assert len(table) == 0 @@ -147,7 +147,7 @@ def test_map_locations(self, table_type, dtype, writable): def test_map_locations_mask(self, table_type, dtype, writable): if table_type == ht.PyObjectHashTable: - pytest.skip("Mask not supporter for object") + pytest.skip("Mask not supported for object") N = 3 table = table_type(uses_mask=True) keys = (np.arange(N) + N).astype(dtype) @@ -185,7 +185,7 @@ def test_lookup_wrong(self, table_type, dtype): def test_lookup_mask(self, table_type, dtype, writable): if table_type == ht.PyObjectHashTable: - pytest.skip("Mask not supporter for object") + pytest.skip("Mask not supported for object") N = 3 table = table_type(uses_mask=True) keys = (np.arange(N) + N).astype(dtype) From 4670624bd5f245c39f09637f9d65318c39abba5c Mon Sep 17 00:00:00 2001 From: Patrick Hoefler Date: Fri, 14 Oct 2022 16:07:43 +0200 Subject: [PATCH 6/7] Refactor into own functions --- pandas/_libs/hashtable.pxd | 72 ++++++++++++++-------- pandas/_libs/hashtable.pyi | 6 +- pandas/_libs/hashtable_class_helper.pxi.in | 58 +++++++++-------- pandas/tests/libs/test_hashtable.py | 12 ++-- 4 files changed, 92 insertions(+), 56 deletions(-) diff --git a/pandas/_libs/hashtable.pxd b/pandas/_libs/hashtable.pxd index 0d7f88d38a827..b32bd4880588d 100644 --- a/pandas/_libs/hashtable.pxd +++ b/pandas/_libs/hashtable.pxd @@ -44,96 +44,120 @@ cdef class UInt64HashTable(HashTable): cdef int64_t na_position cdef bint uses_mask - cpdef get_item(self, uint64_t val, bint na_value=*) - cpdef set_item(self, uint64_t key, Py_ssize_t val, bint na_value=*) + cpdef get_item(self, uint64_t val) + cpdef set_item(self, uint64_t key, Py_ssize_t val) + cpdef get_na(self) + cpdef set_na(self, Py_ssize_t val) cdef class Int64HashTable(HashTable): cdef kh_int64_t *table cdef int64_t na_position cdef bint uses_mask - cpdef get_item(self, int64_t val, bint na_value=*) - cpdef set_item(self, int64_t key, Py_ssize_t val, bint na_value=*) + cpdef get_item(self, int64_t val) + cpdef set_item(self, int64_t key, Py_ssize_t val) + cpdef get_na(self) + cpdef set_na(self, Py_ssize_t val) cdef class UInt32HashTable(HashTable): cdef kh_uint32_t *table cdef int64_t na_position cdef bint uses_mask - cpdef get_item(self, uint32_t val, bint na_value=*) - cpdef set_item(self, uint32_t key, Py_ssize_t val, bint na_value=*) + cpdef get_item(self, uint32_t val) + cpdef set_item(self, uint32_t key, Py_ssize_t val) + cpdef get_na(self) + cpdef set_na(self, Py_ssize_t val) cdef class Int32HashTable(HashTable): cdef kh_int32_t *table cdef int64_t na_position cdef bint uses_mask - cpdef get_item(self, int32_t val, bint na_value=*) - cpdef set_item(self, int32_t key, Py_ssize_t val, bint na_value=*) + cpdef get_item(self, int32_t val) + cpdef set_item(self, int32_t key, Py_ssize_t val) + cpdef get_na(self) + cpdef set_na(self, Py_ssize_t val) cdef class UInt16HashTable(HashTable): cdef kh_uint16_t *table cdef int64_t na_position cdef bint uses_mask - cpdef get_item(self, uint16_t val, bint na_value=*) - cpdef set_item(self, uint16_t key, Py_ssize_t val, bint na_value=*) + cpdef get_item(self, uint16_t val) + cpdef set_item(self, uint16_t key, Py_ssize_t val) + cpdef get_na(self) + cpdef set_na(self, Py_ssize_t val) cdef class Int16HashTable(HashTable): cdef kh_int16_t *table cdef int64_t na_position cdef bint uses_mask - cpdef get_item(self, int16_t val, bint na_value=*) - cpdef set_item(self, int16_t key, Py_ssize_t val, bint na_value=*) + cpdef get_item(self, int16_t val) + cpdef set_item(self, int16_t key, Py_ssize_t val) + cpdef get_na(self) + cpdef set_na(self, Py_ssize_t val) cdef class UInt8HashTable(HashTable): cdef kh_uint8_t *table cdef int64_t na_position cdef bint uses_mask - cpdef get_item(self, uint8_t val, bint na_value=*) - cpdef set_item(self, uint8_t key, Py_ssize_t val, bint na_value=*) + cpdef get_item(self, uint8_t val) + cpdef set_item(self, uint8_t key, Py_ssize_t val) + cpdef get_na(self) + cpdef set_na(self, Py_ssize_t val) cdef class Int8HashTable(HashTable): cdef kh_int8_t *table cdef int64_t na_position cdef bint uses_mask - cpdef get_item(self, int8_t val, bint na_value=*) - cpdef set_item(self, int8_t key, Py_ssize_t val, bint na_value=*) + cpdef get_item(self, int8_t val) + cpdef set_item(self, int8_t key, Py_ssize_t val) + cpdef get_na(self) + cpdef set_na(self, Py_ssize_t val) cdef class Float64HashTable(HashTable): cdef kh_float64_t *table cdef int64_t na_position cdef bint uses_mask - cpdef get_item(self, float64_t val, bint na_value=*) - cpdef set_item(self, float64_t key, Py_ssize_t val, bint na_value=*) + cpdef get_item(self, float64_t val) + cpdef set_item(self, float64_t key, Py_ssize_t val) + cpdef get_na(self) + cpdef set_na(self, Py_ssize_t val) cdef class Float32HashTable(HashTable): cdef kh_float32_t *table cdef int64_t na_position cdef bint uses_mask - cpdef get_item(self, float32_t val, bint na_value=*) - cpdef set_item(self, float32_t key, Py_ssize_t val, bint na_value=*) + cpdef get_item(self, float32_t val) + cpdef set_item(self, float32_t key, Py_ssize_t val) + cpdef get_na(self) + cpdef set_na(self, Py_ssize_t val) cdef class Complex64HashTable(HashTable): cdef kh_complex64_t *table cdef int64_t na_position cdef bint uses_mask - cpdef get_item(self, complex64_t val, bint na_value=*) - cpdef set_item(self, complex64_t key, Py_ssize_t val, bint na_value=*) + cpdef get_item(self, complex64_t val) + cpdef set_item(self, complex64_t key, Py_ssize_t val) + cpdef get_na(self) + cpdef set_na(self, Py_ssize_t val) cdef class Complex128HashTable(HashTable): cdef kh_complex128_t *table cdef int64_t na_position cdef bint uses_mask - cpdef get_item(self, complex128_t val, bint na_value=*) - cpdef set_item(self, complex128_t key, Py_ssize_t val, bint na_value=*) + cpdef get_item(self, complex128_t val) + cpdef set_item(self, complex128_t key, Py_ssize_t val) + cpdef get_na(self) + cpdef set_na(self, Py_ssize_t val) cdef class PyObjectHashTable(HashTable): cdef kh_pymap_t *table diff --git a/pandas/_libs/hashtable.pyi b/pandas/_libs/hashtable.pyi index f40e9c93fd004..e60ccdb29c6b2 100644 --- a/pandas/_libs/hashtable.pyi +++ b/pandas/_libs/hashtable.pyi @@ -118,8 +118,10 @@ class HashTable: def sizeof(self, deep: bool = ...) -> int: ... def get_state(self) -> dict[str, int]: ... # TODO: `item` type is subclass-specific - def get_item(self, item, na_value: bool = ...): ... # TODO: return type? - def set_item(self, item, val, na_value: bool = ...) -> None: ... + def get_item(self, item): ... # TODO: return type? + def set_item(self, item, val) -> None: ... + def get_na(self): ... # TODO: return type? + def set_na(self, val) -> None: ... def map_locations( self, values: np.ndarray, # np.ndarray[subclass-specific] diff --git a/pandas/_libs/hashtable_class_helper.pxi.in b/pandas/_libs/hashtable_class_helper.pxi.in index 618417c4965c8..c6d8783d6f115 100644 --- a/pandas/_libs/hashtable_class_helper.pxi.in +++ b/pandas/_libs/hashtable_class_helper.pxi.in @@ -443,15 +443,13 @@ cdef class {{name}}HashTable(HashTable): 'upper_bound' : self.table.upper_bound, } - cpdef get_item(self, {{dtype}}_t val, bint na_value = False): - """Extracts the position of val or na_value from the hashtable. + cpdef get_item(self, {{dtype}}_t val): + """Extracts the position of val from the hashtable. Parameters ---------- val : Scalar The value that is looked up in the hashtable - na_value : bool, default False - Returns posiition of the NA value if using masked dtypes. Returns ------- @@ -464,14 +462,6 @@ cdef class {{name}}HashTable(HashTable): khiter_t k {{c_type}} cval - if not self.uses_mask and na_value: - raise NotImplementedError - - if na_value: - if self.na_position == -1: - raise KeyError("NA") - return self.na_position - cval = {{to_c_type}}(val) k = kh_get_{{dtype}}(self.table, cval) if k != self.table.n_buckets: @@ -479,7 +469,22 @@ cdef class {{name}}HashTable(HashTable): else: raise KeyError(val) - cpdef set_item(self, {{dtype}}_t key, Py_ssize_t val, bint na_value = False): + cpdef get_na(self): + """Extracts the position of na_value from the hashtable. + + Returns + ------- + The position of the last na value. + """ + + if not self.uses_mask: + raise NotImplementedError + + if self.na_position == -1: + raise KeyError("NA") + return self.na_position + + cpdef set_item(self, {{dtype}}_t key, Py_ssize_t val): # Used in libjoin # Caller is responsible for checking for pd.NA cdef: @@ -487,19 +492,24 @@ cdef class {{name}}HashTable(HashTable): int ret = 0 {{c_type}} ckey - if na_value and not self.uses_mask: - raise NotImplementedError + ckey = {{to_c_type}}(key) + k = kh_put_{{dtype}}(self.table, ckey, &ret) + if kh_exist_{{dtype}}(self.table, k): + self.table.vals[k] = val + else: + raise KeyError(key) + + cpdef set_na(self, Py_ssize_t val): + # Caller is responsible for checking for pd.NA + cdef: + khiter_t k + int ret = 0 + {{c_type}} ckey - if na_value: - self.na_position = val + if not self.uses_mask: + raise NotImplementedError - else: - ckey = {{to_c_type}}(key) - k = kh_put_{{dtype}}(self.table, ckey, &ret) - if kh_exist_{{dtype}}(self.table, k): - self.table.vals[k] = val - else: - raise KeyError(key) + self.na_position = val {{if dtype == "int64" }} # We only use this for int64, can reduce build size and make .pyi diff --git a/pandas/tests/libs/test_hashtable.py b/pandas/tests/libs/test_hashtable.py index c3a288e6fd040..d9d281a0759da 100644 --- a/pandas/tests/libs/test_hashtable.py +++ b/pandas/tests/libs/test_hashtable.py @@ -100,24 +100,24 @@ def test_get_set_contains_len_mask(self, table_type, dtype): assert index in table assert table.get_item(index) == 42 with pytest.raises(KeyError, match="NA"): - table.get_item(0, na_value=True) + table.get_na() table.set_item(index + 1, 41) - table.set_item(0, 41, na_value=True) + table.set_na(41) assert pd.NA in table assert index in table assert index + 1 in table assert len(table) == 3 assert table.get_item(index) == 42 assert table.get_item(index + 1) == 41 - assert table.get_item(1, na_value=True) == 41 + assert table.get_na() == 41 - table.set_item(0, 21, na_value=True) + table.set_na(21) assert index in table assert index + 1 in table assert len(table) == 3 assert table.get_item(index + 1) == 41 - assert table.get_item(1, na_value=True) == 21 + assert table.get_na() == 21 assert index + 2 not in table with pytest.raises(KeyError, match=str(index + 2)): @@ -159,7 +159,7 @@ def test_map_locations_mask(self, table_type, dtype, writable): with pytest.raises(KeyError, match=re.escape(str(keys[N - 1]))): table.get_item(keys[N - 1]) - assert table.get_item(keys[N - 1], na_value=True) == 2 + assert table.get_na() == 2 def test_lookup(self, table_type, dtype, writable): N = 3 From 4642852a227a6f732de9189ce5b006aee43e5c58 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler Date: Mon, 17 Oct 2022 22:38:41 +0100 Subject: [PATCH 7/7] Fix typing --- scripts/run_stubtest.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/scripts/run_stubtest.py b/scripts/run_stubtest.py index d90f8575234e8..db7a327f231b5 100644 --- a/scripts/run_stubtest.py +++ b/scripts/run_stubtest.py @@ -36,10 +36,12 @@ "pandas._libs.hashtable.HashTable.factorize", "pandas._libs.hashtable.HashTable.get_item", "pandas._libs.hashtable.HashTable.get_labels", + "pandas._libs.hashtable.HashTable.get_na", "pandas._libs.hashtable.HashTable.get_state", "pandas._libs.hashtable.HashTable.lookup", "pandas._libs.hashtable.HashTable.map_locations", "pandas._libs.hashtable.HashTable.set_item", + "pandas._libs.hashtable.HashTable.set_na", "pandas._libs.hashtable.HashTable.sizeof", "pandas._libs.hashtable.HashTable.unique", # stubtest might be too sensitive