Skip to content

Commit f28a443

Browse files
Combined build_count_table_int64 and build_count_table_float64 into a
single function using fused types.
1 parent 009f4df commit f28a443

File tree

1 file changed

+27
-34
lines changed

1 file changed

+27
-34
lines changed

pandas/hashtable.pyx

+27-34
Original file line numberDiff line numberDiff line change
@@ -867,25 +867,38 @@ cdef class Int64Factorizer:
867867
return labels
868868

869869
@cython.boundscheck(False)
870-
cdef build_count_table_float64(float64_t[:] values, kh_float64_t *table, bint dropna):
870+
cdef build_count_table_scalar64(sixty_four_bit_scalar[:] values, void *table, bint dropna):
871871
cdef:
872872
khiter_t k
873873
Py_ssize_t i, n = len(values)
874-
float64_t val
874+
sixty_four_bit_scalar val
875875
int ret = 0
876876

877877
with nogil:
878-
kh_resize_float64(table, n)
878+
if sixty_four_bit_scalar is float64_t:
879+
kh_resize_float64(<kh_float64_t*>table, n)
879880

880-
for i in range(n):
881-
val = values[i]
882-
if val == val or not dropna:
883-
k = kh_get_float64(table, val)
884-
if k != table.n_buckets:
885-
table.vals[k] += 1
881+
for i in range(n):
882+
val = values[i]
883+
if val == val or not dropna:
884+
k = kh_get_float64(<kh_float64_t*>table, val)
885+
if k != (<kh_float64_t*>table).n_buckets:
886+
(<kh_float64_t*>table).vals[k] += 1
887+
else:
888+
k = kh_put_float64(<kh_float64_t*>table, val, &ret)
889+
(<kh_float64_t*>table).vals[k] = 1
890+
elif sixty_four_bit_scalar is int64_t:
891+
kh_resize_int64(<kh_int64_t*>table, n)
892+
893+
for i in range(n):
894+
val = values[i]
895+
k = kh_get_int64(<kh_int64_t*>table, val)
896+
if k != (<kh_int64_t*>table).n_buckets:
897+
(<kh_int64_t*>table).vals[k] += 1
886898
else:
887-
k = kh_put_float64(table, val, &ret)
888-
table.vals[k] = 1
899+
k = kh_put_int64(<kh_int64_t*>table, val, &ret)
900+
(<kh_int64_t*>table).vals[k] = 1
901+
889902

890903

891904
@cython.boundscheck(False)
@@ -902,7 +915,7 @@ cpdef value_count_scalar64(sixty_four_bit_scalar[:] values, bint dropna):
902915

903916
if sixty_four_bit_scalar is float64_t:
904917
ftable = kh_init_float64()
905-
build_count_table_float64(values, ftable, dropna)
918+
build_count_table_scalar64(values, ftable, dropna)
906919

907920
result_keys = np.empty(ftable.n_occupied, dtype=np.float64)
908921
result_counts = np.zeros(ftable.n_occupied, dtype=np.int64)
@@ -917,7 +930,7 @@ cpdef value_count_scalar64(sixty_four_bit_scalar[:] values, bint dropna):
917930

918931
elif sixty_four_bit_scalar is int64_t:
919932
itable = kh_init_int64()
920-
build_count_table_int64(values, itable)
933+
build_count_table_scalar64(values, itable, dropna)
921934

922935
result_keys = np.empty(itable.n_occupied, dtype=np.int64)
923936
result_counts = np.zeros(itable.n_occupied, dtype=np.int64)
@@ -932,26 +945,6 @@ cpdef value_count_scalar64(sixty_four_bit_scalar[:] values, bint dropna):
932945

933946
return np.asarray(result_keys), np.asarray(result_counts)
934947

935-
@cython.boundscheck(False)
936-
cdef build_count_table_int64(int64_t[:] values, kh_int64_t *table):
937-
cdef:
938-
khiter_t k
939-
Py_ssize_t i, n = len(values)
940-
int64_t val
941-
int ret = 0
942-
943-
with nogil:
944-
kh_resize_int64(table, n)
945-
946-
for i in range(n):
947-
val = values[i]
948-
k = kh_get_int64(table, val)
949-
if k != table.n_buckets:
950-
table.vals[k] += 1
951-
else:
952-
k = kh_put_int64(table, val, &ret)
953-
table.vals[k] = 1
954-
955948

956949
cdef build_count_table_object(ndarray[object] values,
957950
ndarray[uint8_t, cast=True] mask,
@@ -1040,7 +1033,7 @@ def mode_int64(int64_t[:] values):
10401033

10411034
table = kh_init_int64()
10421035

1043-
build_count_table_int64(values, table)
1036+
build_count_table_scalar64(values, table, 0)
10441037

10451038
modes = np.empty(table.n_buckets, dtype=np.int64)
10461039

0 commit comments

Comments
 (0)