Skip to content

Commit 6afe9ff

Browse files
mattipstangirala
authored andcommitted
COMPAT/TEST test, fix for unsafe Vector.resize(), which allows refche… (pandas-dev#16258)
* COMPAT/TEST test, fix for unsafe Vector.resize(), which allows refcheck=False * COMPAT/TEST improve error msg, document test as per review * COMPAT/TEST unify interfaces as per review
1 parent f1b03f6 commit 6afe9ff

File tree

4 files changed

+80
-24
lines changed

4 files changed

+80
-24
lines changed

pandas/_libs/hashtable.pxd

+1
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ cdef struct Int64VectorData:
5252
cdef class Int64Vector:
5353
cdef Int64VectorData *data
5454
cdef ndarray ao
55+
cdef bint external_view_exists
5556

5657
cdef resize(self)
5758
cpdef to_array(self)

pandas/_libs/hashtable.pyx

+13
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,10 @@ cdef class Factorizer:
6464
>>> factorize(np.array([1,2,np.nan], dtype='O'), na_sentinel=20)
6565
array([ 0, 1, 20])
6666
"""
67+
if self.uniques.external_view_exists:
68+
uniques = ObjectVector()
69+
uniques.extend(self.uniques.to_array())
70+
self.uniques = uniques
6771
labels = self.table.get_labels(values, self.uniques,
6872
self.count, na_sentinel, check_null)
6973
mask = (labels == na_sentinel)
@@ -99,6 +103,15 @@ cdef class Int64Factorizer:
99103

100104
def factorize(self, int64_t[:] values, sort=False,
101105
na_sentinel=-1, check_null=True):
106+
"""
107+
Factorize values with nans replaced by na_sentinel
108+
>>> factorize(np.array([1,2,np.nan], dtype='O'), na_sentinel=20)
109+
array([ 0, 1, 20])
110+
"""
111+
if self.uniques.external_view_exists:
112+
uniques = Int64Vector()
113+
uniques.extend(self.uniques.to_array())
114+
self.uniques = uniques
102115
labels = self.table.get_labels(values, self.uniques,
103116
self.count, na_sentinel,
104117
check_null)

pandas/_libs/hashtable_class_helper.pxi.in

+39-7
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ cdef class {{name}}Vector:
7171

7272
{{if dtype != 'int64'}}
7373
cdef:
74+
bint external_view_exists
7475
{{name}}VectorData *data
7576
ndarray ao
7677
{{endif}}
@@ -80,14 +81,15 @@ cdef class {{name}}Vector:
8081
sizeof({{name}}VectorData))
8182
if not self.data:
8283
raise MemoryError()
84+
self.external_view_exists = False
8385
self.data.n = 0
8486
self.data.m = _INIT_VEC_CAP
8587
self.ao = np.empty(self.data.m, dtype={{idtype}})
8688
self.data.data = <{{arg}}*> self.ao.data
8789

8890
cdef resize(self):
8991
self.data.m = max(self.data.m * 4, _INIT_VEC_CAP)
90-
self.ao.resize(self.data.m)
92+
self.ao.resize(self.data.m, refcheck=False)
9193
self.data.data = <{{arg}}*> self.ao.data
9294

9395
def __dealloc__(self):
@@ -99,13 +101,20 @@ cdef class {{name}}Vector:
99101
return self.data.n
100102

101103
cpdef to_array(self):
102-
self.ao.resize(self.data.n)
103-
self.data.m = self.data.n
104+
if self.data.m != self.data.n:
105+
if self.external_view_exists:
106+
# should never happen
107+
raise ValueError("should have raised on append()")
108+
self.ao.resize(self.data.n, refcheck=False)
109+
self.data.m = self.data.n
110+
self.external_view_exists = True
104111
return self.ao
105112

106113
cdef inline void append(self, {{arg}} x):
107114

108115
if needs_resize(self.data):
116+
if self.external_view_exists:
117+
raise ValueError("external reference but Vector.resize() needed")
109118
self.resize()
110119

111120
append_data_{{dtype}}(self.data, x)
@@ -120,15 +129,19 @@ cdef class StringVector:
120129

121130
cdef:
122131
StringVectorData *data
132+
bint external_view_exists
123133

124134
def __cinit__(self):
125135
self.data = <StringVectorData *>PyMem_Malloc(
126136
sizeof(StringVectorData))
127137
if not self.data:
128138
raise MemoryError()
139+
self.external_view_exists = False
129140
self.data.n = 0
130141
self.data.m = _INIT_VEC_CAP
131142
self.data.data = <char **> malloc(self.data.m * sizeof(char *))
143+
if not self.data.data:
144+
raise MemoryError()
132145

133146
cdef resize(self):
134147
cdef:
@@ -138,9 +151,10 @@ cdef class StringVector:
138151
m = self.data.m
139152
self.data.m = max(self.data.m * 4, _INIT_VEC_CAP)
140153

141-
# TODO: can resize?
142154
orig_data = self.data.data
143155
self.data.data = <char **> malloc(self.data.m * sizeof(char *))
156+
if not self.data.data:
157+
raise MemoryError()
144158
for i in range(m):
145159
self.data.data[i] = orig_data[i]
146160

@@ -164,6 +178,7 @@ cdef class StringVector:
164178
for i in range(self.data.n):
165179
val = self.data.data[i]
166180
ao[i] = val
181+
self.external_view_exists = True
167182
self.data.m = self.data.n
168183
return ao
169184

@@ -174,15 +189,20 @@ cdef class StringVector:
174189

175190
append_data_string(self.data, x)
176191

192+
cdef extend(self, ndarray[:] x):
193+
for i in range(len(x)):
194+
self.append(x[i])
177195

178196
cdef class ObjectVector:
179197

180198
cdef:
181199
PyObject **data
182200
size_t n, m
183201
ndarray ao
202+
bint external_view_exists
184203

185204
def __cinit__(self):
205+
self.external_view_exists = False
186206
self.n = 0
187207
self.m = _INIT_VEC_CAP
188208
self.ao = np.empty(_INIT_VEC_CAP, dtype=object)
@@ -193,19 +213,28 @@ cdef class ObjectVector:
193213

194214
cdef inline append(self, object o):
195215
if self.n == self.m:
216+
if self.external_view_exists:
217+
raise ValueError("external reference but Vector.resize() needed")
196218
self.m = max(self.m * 2, _INIT_VEC_CAP)
197-
self.ao.resize(self.m)
219+
self.ao.resize(self.m, refcheck=False)
198220
self.data = <PyObject**> self.ao.data
199221

200222
Py_INCREF(o)
201223
self.data[self.n] = <PyObject*> o
202224
self.n += 1
203225

204226
def to_array(self):
205-
self.ao.resize(self.n)
206-
self.m = self.n
227+
if self.m != self.n:
228+
if self.external_view_exists:
229+
raise ValueError("should have raised on append()")
230+
self.ao.resize(self.n, refcheck=False)
231+
self.m = self.n
232+
self.external_view_exists = True
207233
return self.ao
208234

235+
cdef extend(self, ndarray[:] x):
236+
for i in range(len(x)):
237+
self.append(x[i])
209238

210239
#----------------------------------------------------------------------
211240
# HashTable
@@ -362,6 +391,9 @@ cdef class {{name}}HashTable(HashTable):
362391

363392
if needs_resize(ud):
364393
with gil:
394+
if uniques.external_view_exists:
395+
raise ValueError("external reference to uniques held, "
396+
"but Vector.resize() needed")
365397
uniques.resize()
366398
append_data_{{dtype}}(ud, val)
367399
labels[i] = count

pandas/tests/test_algos.py

+27-17
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from pandas import compat
1616
from pandas._libs import (groupby as libgroupby, algos as libalgos,
17-
hashtable)
17+
hashtable as ht)
1818
from pandas._libs.hashtable import unique_label_indices
1919
from pandas.compat import lrange, range
2020
import pandas.core.algorithms as algos
@@ -259,7 +259,7 @@ def test_factorize_nan(self):
259259
# rizer.factorize should not raise an exception if na_sentinel indexes
260260
# outside of reverse_indexer
261261
key = np.array([1, 2, 1, np.nan], dtype='O')
262-
rizer = hashtable.Factorizer(len(key))
262+
rizer = ht.Factorizer(len(key))
263263
for na_sentinel in (-1, 20):
264264
ids = rizer.factorize(key, sort=True, na_sentinel=na_sentinel)
265265
expected = np.array([0, 1, 0, na_sentinel], dtype='int32')
@@ -1049,14 +1049,14 @@ class TestHashTable(object):
10491049

10501050
def test_lookup_nan(self):
10511051
xs = np.array([2.718, 3.14, np.nan, -7, 5, 2, 3])
1052-
m = hashtable.Float64HashTable()
1052+
m = ht.Float64HashTable()
10531053
m.map_locations(xs)
10541054
tm.assert_numpy_array_equal(m.lookup(xs), np.arange(len(xs),
10551055
dtype=np.int64))
10561056

10571057
def test_lookup_overflow(self):
10581058
xs = np.array([1, 2, 2**63], dtype=np.uint64)
1059-
m = hashtable.UInt64HashTable()
1059+
m = ht.UInt64HashTable()
10601060
m.map_locations(xs)
10611061
tm.assert_numpy_array_equal(m.lookup(xs), np.arange(len(xs),
10621062
dtype=np.int64))
@@ -1070,25 +1070,35 @@ def test_vector_resize(self):
10701070
# Test for memory errors after internal vector
10711071
# reallocations (pull request #7157)
10721072

1073-
def _test_vector_resize(htable, uniques, dtype, nvals):
1073+
def _test_vector_resize(htable, uniques, dtype, nvals, safely_resizes):
10741074
vals = np.array(np.random.randn(1000), dtype=dtype)
1075-
# get_labels appends to the vector
1075+
# get_labels may append to uniques
10761076
htable.get_labels(vals[:nvals], uniques, 0, -1)
1077-
# to_array resizes the vector
1078-
uniques.to_array()
1079-
htable.get_labels(vals, uniques, 0, -1)
1077+
# to_array() set an external_view_exists flag on uniques.
1078+
tmp = uniques.to_array()
1079+
oldshape = tmp.shape
1080+
# subsequent get_labels() calls can no longer append to it
1081+
# (for all but StringHashTables + ObjectVector)
1082+
if safely_resizes:
1083+
htable.get_labels(vals, uniques, 0, -1)
1084+
else:
1085+
with pytest.raises(ValueError) as excinfo:
1086+
htable.get_labels(vals, uniques, 0, -1)
1087+
assert str(excinfo.value).startswith('external reference')
1088+
uniques.to_array() # should not raise here
1089+
assert tmp.shape == oldshape
10801090

10811091
test_cases = [
1082-
(hashtable.PyObjectHashTable, hashtable.ObjectVector, 'object'),
1083-
(hashtable.StringHashTable, hashtable.ObjectVector, 'object'),
1084-
(hashtable.Float64HashTable, hashtable.Float64Vector, 'float64'),
1085-
(hashtable.Int64HashTable, hashtable.Int64Vector, 'int64'),
1086-
(hashtable.UInt64HashTable, hashtable.UInt64Vector, 'uint64')]
1092+
(ht.PyObjectHashTable, ht.ObjectVector, 'object', False),
1093+
(ht.StringHashTable, ht.ObjectVector, 'object', True),
1094+
(ht.Float64HashTable, ht.Float64Vector, 'float64', False),
1095+
(ht.Int64HashTable, ht.Int64Vector, 'int64', False),
1096+
(ht.UInt64HashTable, ht.UInt64Vector, 'uint64', False)]
10871097

1088-
for (tbl, vect, dtype) in test_cases:
1098+
for (tbl, vect, dtype, safely_resizes) in test_cases:
10891099
# resizing to empty is a special case
1090-
_test_vector_resize(tbl(), vect(), dtype, 0)
1091-
_test_vector_resize(tbl(), vect(), dtype, 10)
1100+
_test_vector_resize(tbl(), vect(), dtype, 0, safely_resizes)
1101+
_test_vector_resize(tbl(), vect(), dtype, 10, safely_resizes)
10921102

10931103

10941104
def test_quantile():

0 commit comments

Comments
 (0)