Skip to content

Commit e745be0

Browse files
jbrockmendeljreback
authored andcommitted
BUG: strengthen typing in get_c_string, fix StringHashTable segfault (#30419)
1 parent ccbe7be commit e745be0

File tree

4 files changed

+24
-14
lines changed

4 files changed

+24
-14
lines changed

pandas/_libs/hashtable.pxd

+2-2
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ cdef class PyObjectHashTable(HashTable):
3636
cdef class StringHashTable(HashTable):
3737
cdef kh_str_t *table
3838

39-
cpdef get_item(self, object val)
40-
cpdef set_item(self, object key, Py_ssize_t val)
39+
cpdef get_item(self, str val)
40+
cpdef set_item(self, str key, Py_ssize_t val)
4141

4242
cdef struct Int64VectorData:
4343
int64_t *data

pandas/_libs/hashtable_class_helper.pxi.in

+5-5
Original file line numberDiff line numberDiff line change
@@ -599,7 +599,7 @@ cdef class StringHashTable(HashTable):
599599
sizeof(Py_ssize_t) + # vals
600600
sizeof(uint32_t)) # flags
601601

602-
cpdef get_item(self, object val):
602+
cpdef get_item(self, str val):
603603
cdef:
604604
khiter_t k
605605
const char *v
@@ -611,16 +611,16 @@ cdef class StringHashTable(HashTable):
611611
else:
612612
raise KeyError(val)
613613

614-
cpdef set_item(self, object key, Py_ssize_t val):
614+
cpdef set_item(self, str key, Py_ssize_t val):
615615
cdef:
616616
khiter_t k
617617
int ret = 0
618618
const char *v
619619

620-
v = get_c_string(val)
620+
v = get_c_string(key)
621621

622622
k = kh_put_str(self.table, v, &ret)
623-
self.table.keys[k] = key
623+
self.table.keys[k] = v
624624
if kh_exist_str(self.table, k):
625625
self.table.vals[k] = val
626626
else:
@@ -784,7 +784,7 @@ cdef class StringHashTable(HashTable):
784784
labels[i] = na_sentinel
785785
else:
786786
# if ignore_na is False, we also stringify NaN/None/etc.
787-
v = get_c_string(val)
787+
v = get_c_string(<str>val)
788788
vecs[i] = v
789789

790790
# compute

pandas/_libs/tslibs/util.pxd

+4-7
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ cdef inline bint is_nan(object val):
218218
return is_complex_object(val) and val != val
219219

220220

221-
cdef inline const char* get_c_string_buf_and_size(object py_string,
221+
cdef inline const char* get_c_string_buf_and_size(str py_string,
222222
Py_ssize_t *length):
223223
"""
224224
Extract internal char* buffer of unicode or bytes object `py_string` with
@@ -231,7 +231,7 @@ cdef inline const char* get_c_string_buf_and_size(object py_string,
231231
232232
Parameters
233233
----------
234-
py_string : object
234+
py_string : str
235235
length : Py_ssize_t*
236236
237237
Returns
@@ -241,12 +241,9 @@ cdef inline const char* get_c_string_buf_and_size(object py_string,
241241
cdef:
242242
const char *buf
243243

244-
if PyUnicode_Check(py_string):
245-
buf = PyUnicode_AsUTF8AndSize(py_string, length)
246-
else:
247-
PyBytes_AsStringAndSize(py_string, <char**>&buf, length)
244+
buf = PyUnicode_AsUTF8AndSize(py_string, length)
248245
return buf
249246

250247

251-
cdef inline const char* get_c_string(object py_string):
248+
cdef inline const char* get_c_string(str py_string):
252249
return get_c_string_buf_and_size(py_string, NULL)

pandas/tests/test_algos.py

+13
Original file line numberDiff line numberDiff line change
@@ -1402,6 +1402,19 @@ class TestGroupVarFloat32(GroupVarTestMixin):
14021402

14031403

14041404
class TestHashTable:
1405+
def test_string_hashtable_set_item_signature(self):
1406+
# GH#30419 fix typing in StringHashTable.set_item to prevent segfault
1407+
tbl = ht.StringHashTable()
1408+
1409+
tbl.set_item("key", 1)
1410+
assert tbl.get_item("key") == 1
1411+
1412+
with pytest.raises(TypeError, match="'key' has incorrect type"):
1413+
# key arg typed as string, not object
1414+
tbl.set_item(4, 6)
1415+
with pytest.raises(TypeError, match="'val' has incorrect type"):
1416+
tbl.get_item(4)
1417+
14051418
def test_lookup_nan(self, writable):
14061419
xs = np.array([2.718, 3.14, np.nan, -7, 5, 2, 3])
14071420
# GH 21688 ensure we can deal with readonly memory views

0 commit comments

Comments
 (0)