diff --git a/pandas/_libs/hashtable.pxd b/pandas/_libs/hashtable.pxd index 51ec4ba43159c..0499eabf708af 100644 --- a/pandas/_libs/hashtable.pxd +++ b/pandas/_libs/hashtable.pxd @@ -36,8 +36,8 @@ cdef class PyObjectHashTable(HashTable): cdef class StringHashTable(HashTable): cdef kh_str_t *table - cpdef get_item(self, object val) - cpdef set_item(self, object key, Py_ssize_t val) + cpdef get_item(self, str val) + cpdef set_item(self, str key, Py_ssize_t val) cdef struct Int64VectorData: int64_t *data diff --git a/pandas/_libs/hashtable_class_helper.pxi.in b/pandas/_libs/hashtable_class_helper.pxi.in index b207fcb66948d..7d57c67e70b58 100644 --- a/pandas/_libs/hashtable_class_helper.pxi.in +++ b/pandas/_libs/hashtable_class_helper.pxi.in @@ -599,7 +599,7 @@ cdef class StringHashTable(HashTable): sizeof(Py_ssize_t) + # vals sizeof(uint32_t)) # flags - cpdef get_item(self, object val): + cpdef get_item(self, str val): cdef: khiter_t k const char *v @@ -611,16 +611,16 @@ cdef class StringHashTable(HashTable): else: raise KeyError(val) - cpdef set_item(self, object key, Py_ssize_t val): + cpdef set_item(self, str key, Py_ssize_t val): cdef: khiter_t k int ret = 0 const char *v - v = get_c_string(val) + v = get_c_string(key) k = kh_put_str(self.table, v, &ret) - self.table.keys[k] = key + self.table.keys[k] = v if kh_exist_str(self.table, k): self.table.vals[k] = val else: @@ -784,7 +784,7 @@ cdef class StringHashTable(HashTable): labels[i] = na_sentinel else: # if ignore_na is False, we also stringify NaN/None/etc. - v = get_c_string(val) + v = get_c_string(val) vecs[i] = v # compute diff --git a/pandas/_libs/tslibs/util.pxd b/pandas/_libs/tslibs/util.pxd index 63cbd36f9cd1d..936532a81c6d6 100644 --- a/pandas/_libs/tslibs/util.pxd +++ b/pandas/_libs/tslibs/util.pxd @@ -218,7 +218,7 @@ cdef inline bint is_nan(object val): return is_complex_object(val) and val != val -cdef inline const char* get_c_string_buf_and_size(object py_string, +cdef inline const char* get_c_string_buf_and_size(str py_string, Py_ssize_t *length): """ Extract internal char* buffer of unicode or bytes object `py_string` with @@ -231,7 +231,7 @@ cdef inline const char* get_c_string_buf_and_size(object py_string, Parameters ---------- - py_string : object + py_string : str length : Py_ssize_t* Returns @@ -241,12 +241,9 @@ cdef inline const char* get_c_string_buf_and_size(object py_string, cdef: const char *buf - if PyUnicode_Check(py_string): - buf = PyUnicode_AsUTF8AndSize(py_string, length) - else: - PyBytes_AsStringAndSize(py_string, &buf, length) + buf = PyUnicode_AsUTF8AndSize(py_string, length) return buf -cdef inline const char* get_c_string(object py_string): +cdef inline const char* get_c_string(str py_string): return get_c_string_buf_and_size(py_string, NULL) diff --git a/pandas/tests/test_algos.py b/pandas/tests/test_algos.py index e0e4beffe113a..82f647c9385b2 100644 --- a/pandas/tests/test_algos.py +++ b/pandas/tests/test_algos.py @@ -1402,6 +1402,19 @@ class TestGroupVarFloat32(GroupVarTestMixin): class TestHashTable: + def test_string_hashtable_set_item_signature(self): + # GH#30419 fix typing in StringHashTable.set_item to prevent segfault + tbl = ht.StringHashTable() + + tbl.set_item("key", 1) + assert tbl.get_item("key") == 1 + + with pytest.raises(TypeError, match="'key' has incorrect type"): + # key arg typed as string, not object + tbl.set_item(4, 6) + with pytest.raises(TypeError, match="'val' has incorrect type"): + tbl.get_item(4) + def test_lookup_nan(self, writable): xs = np.array([2.718, 3.14, np.nan, -7, 5, 2, 3]) # GH 21688 ensure we can deal with readonly memory views