Skip to content

Commit ba85304

Browse files
CLN: Further simplified build_count_table_scalar64.
1 parent b9b54ae commit ba85304

File tree

1 file changed

+24
-16
lines changed

1 file changed

+24
-16
lines changed

pandas/hashtable.pyx

+24-16
Original file line numberDiff line numberDiff line change
@@ -866,38 +866,46 @@ cdef class Int64Factorizer:
866866
self.count = len(self.uniques)
867867
return labels
868868

869+
ctypedef fused kh_scalar64:
870+
kh_int64_t
871+
kh_float64_t
872+
869873
@cython.boundscheck(False)
870-
cdef build_count_table_scalar64(sixty_four_bit_scalar[:] values, void *table, bint dropna):
874+
cdef build_count_table_scalar64(sixty_four_bit_scalar[:] values,
875+
kh_scalar64 *table, bint dropna):
871876
cdef:
872877
khiter_t k
873878
Py_ssize_t i, n = len(values)
874879
sixty_four_bit_scalar val
875880
int ret = 0
876881

877-
with nogil:
878-
if sixty_four_bit_scalar is float64_t:
879-
kh_resize_float64(<kh_float64_t*>table, n)
882+
if sixty_four_bit_scalar is float64_t and kh_scalar64 is kh_float64_t:
883+
with nogil:
884+
kh_resize_float64(table, n)
880885

881886
for i in range(n):
882887
val = values[i]
883888
if val == val or not dropna:
884-
k = kh_get_float64(<kh_float64_t*>table, val)
885-
if k != (<kh_float64_t*>table).n_buckets:
886-
(<kh_float64_t*>table).vals[k] += 1
889+
k = kh_get_float64(table, val)
890+
if k != table.n_buckets:
891+
table.vals[k] += 1
887892
else:
888-
k = kh_put_float64(<kh_float64_t*>table, val, &ret)
889-
(<kh_float64_t*>table).vals[k] = 1
890-
elif sixty_four_bit_scalar is int64_t:
891-
kh_resize_int64(<kh_int64_t*>table, n)
893+
k = kh_put_float64(table, val, &ret)
894+
table.vals[k] = 1
895+
elif sixty_four_bit_scalar is int64_t and kh_scalar64 is kh_int64_t:
896+
with nogil:
897+
kh_resize_int64(table, n)
892898

893899
for i in range(n):
894900
val = values[i]
895-
k = kh_get_int64(<kh_int64_t*>table, val)
896-
if k != (<kh_int64_t*>table).n_buckets:
897-
(<kh_int64_t*>table).vals[k] += 1
901+
k = kh_get_int64(table, val)
902+
if k != table.n_buckets:
903+
table.vals[k] += 1
898904
else:
899-
k = kh_put_int64(<kh_int64_t*>table, val, &ret)
900-
(<kh_int64_t*>table).vals[k] = 1
905+
k = kh_put_int64(table, val, &ret)
906+
table.vals[k] = 1
907+
else:
908+
raise ValueError("Table type must match scalar type.")
901909

902910

903911

0 commit comments

Comments
 (0)