Skip to content

Commit 9a2cb69

Browse files
committed
check for hash collisions
1 parent 8fbb323 commit 9a2cb69

File tree

2 files changed

+32
-3
lines changed

2 files changed

+32
-3
lines changed

pandas/indexes/multi.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -2181,7 +2181,9 @@ def equals(self, other):
21812181
return True
21822182

21832183
if not isinstance(other, Index):
2184-
return False
2184+
if not isinstance(other, tuple):
2185+
return False
2186+
other = Index([other])
21852187

21862188
if not isinstance(other, MultiIndex):
21872189
return array_equivalent(self._values,

pandas/src/hashtable_class_helper.pxi.in

+29-2
Original file line numberDiff line numberDiff line change
@@ -859,24 +859,50 @@ cdef class MultiIndexHashTable(HashTable):
859859
sizeof(size_t) + # vals
860860
sizeof(uint32_t)) # flags
861861

862+
def _check_for_collisions(self, int64_t[:] locs, object mi):
863+
# validate that the locs map to the actual values
864+
# provided in the mi
865+
# we can only check if we *don't* have any missing values
866+
# :<
867+
cdef:
868+
ndarray[int64_t] alocs
869+
870+
alocs = np.asarray(locs)
871+
if (alocs!=-1).all():
872+
873+
result = self.mi.take(locs)
874+
if not result.equals(mi):
875+
raise ValueError("hash collision alert")
876+
862877
def __contains__(self, object key):
863878
cdef:
864879
khiter_t k
865880
uint64_t value
866881

867882
value = self.mi._hashed_indexing_key(key)
868883
k = kh_get_uint64(self.table, value)
869-
return k != self.table.n_buckets
884+
if k != self.table.n_buckets:
885+
loc = self.table.vals[k]
886+
locs = np.array([loc], dtype=np.int64)
887+
self._check_for_collisions(locs, key)
888+
return True
889+
890+
return False
870891

871892
cpdef get_item(self, object key):
872893
cdef:
873894
khiter_t k
874895
uint64_t value
896+
int64_t[:] locs
897+
Py_ssize_t loc
875898

876899
value = self.mi._hashed_indexing_key(key)
877900
k = kh_get_uint64(self.table, value)
878901
if k != self.table.n_buckets:
879-
return self.table.vals[k]
902+
loc = self.table.vals[k]
903+
locs = np.array([loc], dtype=np.int64)
904+
self._check_for_collisions(locs, key)
905+
return loc
880906
else:
881907
raise KeyError(key)
882908

@@ -927,6 +953,7 @@ cdef class MultiIndexHashTable(HashTable):
927953
else:
928954
locs[i] = -1
929955

956+
self._check_for_collisions(locs, mi)
930957
return np.asarray(locs)
931958

932959
def unique(self, object mi):

0 commit comments

Comments
 (0)