@@ -859,24 +859,50 @@ cdef class MultiIndexHashTable(HashTable):
859
859
sizeof(size_t) + # vals
860
860
sizeof(uint32_t)) # flags
861
861
862
+ def _check_for_collisions(self, int64_t[:] locs, object mi):
863
+ # validate that the locs map to the actual values
864
+ # provided in the mi
865
+ # we can only check if we *don't* have any missing values
866
+ # :<
867
+ cdef:
868
+ ndarray[int64_t] alocs
869
+
870
+ alocs = np.asarray(locs)
871
+ if (alocs!=-1).all():
872
+
873
+ result = self.mi.take(locs)
874
+ if not result.equals(mi):
875
+ raise ValueError("hash collision alert")
876
+
862
877
def __contains__(self, object key):
863
878
cdef:
864
879
khiter_t k
865
880
uint64_t value
866
881
867
882
value = self.mi._hashed_indexing_key(key)
868
883
k = kh_get_uint64(self.table, value)
869
- return k != self.table.n_buckets
884
+ if k != self.table.n_buckets:
885
+ loc = self.table.vals[k]
886
+ locs = np.array([loc], dtype=np.int64)
887
+ self._check_for_collisions(locs, key)
888
+ return True
889
+
890
+ return False
870
891
871
892
cpdef get_item(self, object key):
872
893
cdef:
873
894
khiter_t k
874
895
uint64_t value
896
+ int64_t[:] locs
897
+ Py_ssize_t loc
875
898
876
899
value = self.mi._hashed_indexing_key(key)
877
900
k = kh_get_uint64(self.table, value)
878
901
if k != self.table.n_buckets:
879
- return self.table.vals[k]
902
+ loc = self.table.vals[k]
903
+ locs = np.array([loc], dtype=np.int64)
904
+ self._check_for_collisions(locs, key)
905
+ return loc
880
906
else:
881
907
raise KeyError(key)
882
908
@@ -927,6 +953,7 @@ cdef class MultiIndexHashTable(HashTable):
927
953
else:
928
954
locs[i] = -1
929
955
956
+ self._check_for_collisions(locs, mi)
930
957
return np.asarray(locs)
931
958
932
959
def unique(self, object mi):
0 commit comments