@@ -866,38 +866,46 @@ cdef class Int64Factorizer:
866
866
self .count = len (self .uniques)
867
867
return labels
868
868
869
+ ctypedef fused kh_scalar64:
870
+ kh_int64_t
871
+ kh_float64_t
872
+
869
873
@ cython.boundscheck (False )
870
- cdef build_count_table_scalar64(sixty_four_bit_scalar[:] values, void * table, bint dropna):
874
+ cdef build_count_table_scalar64(sixty_four_bit_scalar[:] values,
875
+ kh_scalar64 * table, bint dropna):
871
876
cdef:
872
877
khiter_t k
873
878
Py_ssize_t i, n = len (values)
874
879
sixty_four_bit_scalar val
875
880
int ret = 0
876
881
877
- with nogil :
878
- if sixty_four_bit_scalar is float64_t :
879
- kh_resize_float64(< kh_float64_t * > table, n)
882
+ if sixty_four_bit_scalar is float64_t and kh_scalar64 is kh_float64_t :
883
+ with nogil :
884
+ kh_resize_float64(table, n)
880
885
881
886
for i in range (n):
882
887
val = values[i]
883
888
if val == val or not dropna:
884
- k = kh_get_float64(< kh_float64_t * > table, val)
885
- if k != ( < kh_float64_t * > table) .n_buckets:
886
- ( < kh_float64_t * > table) .vals[k] += 1
889
+ k = kh_get_float64(table, val)
890
+ if k != table.n_buckets:
891
+ table.vals[k] += 1
887
892
else :
888
- k = kh_put_float64(< kh_float64_t* > table, val, & ret)
889
- (< kh_float64_t* > table).vals[k] = 1
890
- elif sixty_four_bit_scalar is int64_t:
891
- kh_resize_int64(< kh_int64_t* > table, n)
893
+ k = kh_put_float64(table, val, & ret)
894
+ table.vals[k] = 1
895
+ elif sixty_four_bit_scalar is int64_t and kh_scalar64 is kh_int64_t:
896
+ with nogil:
897
+ kh_resize_int64(table, n)
892
898
893
899
for i in range (n):
894
900
val = values[i]
895
- k = kh_get_int64(< kh_int64_t * > table, val)
896
- if k != ( < kh_int64_t * > table) .n_buckets:
897
- ( < kh_int64_t * > table) .vals[k] += 1
901
+ k = kh_get_int64(table, val)
902
+ if k != table.n_buckets:
903
+ table.vals[k] += 1
898
904
else :
899
- k = kh_put_int64(< kh_int64_t* > table, val, & ret)
900
- (< kh_int64_t* > table).vals[k] = 1
905
+ k = kh_put_int64(table, val, & ret)
906
+ table.vals[k] = 1
907
+ else :
908
+ raise ValueError (" Table type must match scalar type." )
901
909
902
910
903
911
0 commit comments