Skip to content

Commit 9f346e7

Browse files
realeadluckyvs1
authored andcommitted
PERF: Introducing hash tables for complex64 and complex128 (pandas-dev#38179)
1 parent 32435f2 commit 9f346e7

8 files changed

+276
-84
lines changed

pandas/_libs/hashtable.pxd

+18
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
from numpy cimport intp_t, ndarray
22

33
from pandas._libs.khash cimport (
4+
complex64_t,
5+
complex128_t,
46
float32_t,
57
float64_t,
68
int8_t,
79
int16_t,
810
int32_t,
911
int64_t,
12+
kh_complex64_t,
13+
kh_complex128_t,
1014
kh_float32_t,
1115
kh_float64_t,
1216
kh_int8_t,
@@ -19,6 +23,8 @@ from pandas._libs.khash cimport (
1923
kh_uint16_t,
2024
kh_uint32_t,
2125
kh_uint64_t,
26+
khcomplex64_t,
27+
khcomplex128_t,
2228
uint8_t,
2329
uint16_t,
2430
uint32_t,
@@ -90,6 +96,18 @@ cdef class Float32HashTable(HashTable):
9096
cpdef get_item(self, float32_t val)
9197
cpdef set_item(self, float32_t key, Py_ssize_t val)
9298

99+
cdef class Complex64HashTable(HashTable):
100+
cdef kh_complex64_t *table
101+
102+
cpdef get_item(self, complex64_t val)
103+
cpdef set_item(self, complex64_t key, Py_ssize_t val)
104+
105+
cdef class Complex128HashTable(HashTable):
106+
cdef kh_complex128_t *table
107+
108+
cpdef get_item(self, complex128_t val)
109+
cpdef set_item(self, complex128_t key, Py_ssize_t val)
110+
93111
cdef class PyObjectHashTable(HashTable):
94112
cdef kh_pymap_t *table
95113

pandas/_libs/hashtable.pyx

+11-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,17 @@ cnp.import_array()
1313

1414

1515
from pandas._libs cimport util
16-
from pandas._libs.khash cimport KHASH_TRACE_DOMAIN, kh_str_t, khiter_t
16+
from pandas._libs.khash cimport (
17+
KHASH_TRACE_DOMAIN,
18+
are_equivalent_float32_t,
19+
are_equivalent_float64_t,
20+
are_equivalent_khcomplex64_t,
21+
are_equivalent_khcomplex128_t,
22+
kh_str_t,
23+
khcomplex64_t,
24+
khcomplex128_t,
25+
khiter_t,
26+
)
1727
from pandas._libs.missing cimport checknull
1828

1929

pandas/_libs/hashtable_class_helper.pxi.in

+115-40
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,73 @@ WARNING: DO NOT edit .pxi FILE directly, .pxi is generated from .pxi.in
88
{{py:
99

1010
# name
11-
cimported_types = ['float32',
11+
complex_types = ['complex64',
12+
'complex128']
13+
}}
14+
15+
{{for name in complex_types}}
16+
cdef kh{{name}}_t to_kh{{name}}_t({{name}}_t val) nogil:
17+
cdef kh{{name}}_t res
18+
res.real = val.real
19+
res.imag = val.imag
20+
return res
21+
22+
23+
cdef {{name}}_t to_{{name}}(kh{{name}}_t val) nogil:
24+
cdef {{name}}_t res
25+
res.real = val.real
26+
res.imag = val.imag
27+
return res
28+
29+
{{endfor}}
30+
31+
32+
{{py:
33+
34+
35+
# name
36+
c_types = ['khcomplex128_t',
37+
'khcomplex64_t',
38+
'float64_t',
39+
'float32_t',
40+
'int64_t',
41+
'int32_t',
42+
'int16_t',
43+
'int8_t',
44+
'uint64_t',
45+
'uint32_t',
46+
'uint16_t',
47+
'uint8_t']
48+
}}
49+
50+
{{for c_type in c_types}}
51+
52+
cdef bint is_nan_{{c_type}}({{c_type}} val) nogil:
53+
{{if c_type in {'khcomplex128_t', 'khcomplex64_t'} }}
54+
return val.real != val.real or val.imag != val.imag
55+
{{elif c_type in {'float64_t', 'float32_t'} }}
56+
return val != val
57+
{{else}}
58+
return False
59+
{{endif}}
60+
61+
62+
{{if c_type in {'khcomplex128_t', 'khcomplex64_t', 'float64_t', 'float32_t'} }}
63+
# are_equivalent_{{c_type}} is cimported via khash.pxd
64+
{{else}}
65+
cdef bint are_equivalent_{{c_type}}({{c_type}} val1, {{c_type}} val2) nogil:
66+
return val1 == val2
67+
{{endif}}
68+
69+
{{endfor}}
70+
71+
72+
{{py:
73+
74+
# name
75+
cimported_types = ['complex64',
76+
'complex128',
77+
'float32',
1278
'float64',
1379
'int8',
1480
'int16',
@@ -32,6 +98,7 @@ from pandas._libs.khash cimport (
3298
kh_put_{{name}},
3399
kh_resize_{{name}},
34100
)
101+
35102
{{endfor}}
36103

37104
# ----------------------------------------------------------------------
@@ -48,7 +115,9 @@ from pandas._libs.missing cimport C_NA
48115
# but is included for completeness (rather ObjectVector is used
49116
# for uniques in hashtables)
50117

51-
dtypes = [('Float64', 'float64', 'float64_t'),
118+
dtypes = [('Complex128', 'complex128', 'khcomplex128_t'),
119+
('Complex64', 'complex64', 'khcomplex64_t'),
120+
('Float64', 'float64', 'float64_t'),
52121
('Float32', 'float32', 'float32_t'),
53122
('Int64', 'int64', 'int64_t'),
54123
('Int32', 'int32', 'int32_t'),
@@ -94,6 +163,8 @@ ctypedef fused vector_data:
94163
UInt8VectorData
95164
Float64VectorData
96165
Float32VectorData
166+
Complex128VectorData
167+
Complex64VectorData
97168
StringVectorData
98169

99170
cdef inline bint needs_resize(vector_data *data) nogil:
@@ -106,7 +177,9 @@ cdef inline bint needs_resize(vector_data *data) nogil:
106177
{{py:
107178

108179
# name, dtype, c_type
109-
dtypes = [('Float64', 'float64', 'float64_t'),
180+
dtypes = [('Complex128', 'complex128', 'khcomplex128_t'),
181+
('Complex64', 'complex64', 'khcomplex64_t'),
182+
('Float64', 'float64', 'float64_t'),
110183
('UInt64', 'uint64', 'uint64_t'),
111184
('Int64', 'int64', 'int64_t'),
112185
('Float32', 'float32', 'float32_t'),
@@ -303,22 +376,24 @@ cdef class HashTable:
303376

304377
{{py:
305378

306-
# name, dtype, float_group
307-
dtypes = [('Float64', 'float64', True),
308-
('UInt64', 'uint64', False),
309-
('Int64', 'int64', False),
310-
('Float32', 'float32', True),
311-
('UInt32', 'uint32', False),
312-
('Int32', 'int32', False),
313-
('UInt16', 'uint16', False),
314-
('Int16', 'int16', False),
315-
('UInt8', 'uint8', False),
316-
('Int8', 'int8', False)]
379+
# name, dtype, c_type, to_c_type
380+
dtypes = [('Complex128', 'complex128', 'khcomplex128_t', 'to_khcomplex128_t'),
381+
('Float64', 'float64', 'float64_t', ''),
382+
('UInt64', 'uint64', 'uint64_t', ''),
383+
('Int64', 'int64', 'int64_t', ''),
384+
('Complex64', 'complex64', 'khcomplex64_t', 'to_khcomplex64_t'),
385+
('Float32', 'float32', 'float32_t', ''),
386+
('UInt32', 'uint32', 'uint32_t', ''),
387+
('Int32', 'int32', 'int32_t', ''),
388+
('UInt16', 'uint16', 'uint16_t', ''),
389+
('Int16', 'int16', 'int16_t', ''),
390+
('UInt8', 'uint8', 'uint8_t', ''),
391+
('Int8', 'int8', 'int8_t', '')]
317392

318393
}}
319394

320395

321-
{{for name, dtype, float_group in dtypes}}
396+
{{for name, dtype, c_type, to_c_type in dtypes}}
322397

323398
cdef class {{name}}HashTable(HashTable):
324399

@@ -339,7 +414,9 @@ cdef class {{name}}HashTable(HashTable):
339414
def __contains__(self, object key):
340415
cdef:
341416
khiter_t k
342-
k = kh_get_{{dtype}}(self.table, key)
417+
{{c_type}} ckey
418+
ckey = {{to_c_type}}(key)
419+
k = kh_get_{{dtype}}(self.table, ckey)
343420
return k != self.table.n_buckets
344421

345422
def sizeof(self, deep=False):
@@ -353,7 +430,9 @@ cdef class {{name}}HashTable(HashTable):
353430
cpdef get_item(self, {{dtype}}_t val):
354431
cdef:
355432
khiter_t k
356-
k = kh_get_{{dtype}}(self.table, val)
433+
{{c_type}} cval
434+
cval = {{to_c_type}}(val)
435+
k = kh_get_{{dtype}}(self.table, cval)
357436
if k != self.table.n_buckets:
358437
return self.table.vals[k]
359438
else:
@@ -363,9 +442,9 @@ cdef class {{name}}HashTable(HashTable):
363442
cdef:
364443
khiter_t k
365444
int ret = 0
366-
367-
k = kh_put_{{dtype}}(self.table, key, &ret)
368-
self.table.keys[k] = key
445+
{{c_type}} ckey
446+
ckey = {{to_c_type}}(key)
447+
k = kh_put_{{dtype}}(self.table, ckey, &ret)
369448
if kh_exist_{{dtype}}(self.table, k):
370449
self.table.vals[k] = val
371450
else:
@@ -376,12 +455,12 @@ cdef class {{name}}HashTable(HashTable):
376455
cdef:
377456
Py_ssize_t i, n = len(values)
378457
int ret = 0
379-
{{dtype}}_t key
458+
{{c_type}} key
380459
khiter_t k
381460

382461
with nogil:
383462
for i in range(n):
384-
key = keys[i]
463+
key = {{to_c_type}}(keys[i])
385464
k = kh_put_{{dtype}}(self.table, key, &ret)
386465
self.table.vals[k] = <Py_ssize_t>values[i]
387466

@@ -390,12 +469,12 @@ cdef class {{name}}HashTable(HashTable):
390469
cdef:
391470
Py_ssize_t i, n = len(values)
392471
int ret = 0
393-
{{dtype}}_t val
472+
{{c_type}} val
394473
khiter_t k
395474

396475
with nogil:
397476
for i in range(n):
398-
val = values[i]
477+
val= {{to_c_type}}(values[i])
399478
k = kh_put_{{dtype}}(self.table, val, &ret)
400479
self.table.vals[k] = i
401480

@@ -404,13 +483,13 @@ cdef class {{name}}HashTable(HashTable):
404483
cdef:
405484
Py_ssize_t i, n = len(values)
406485
int ret = 0
407-
{{dtype}}_t val
486+
{{c_type}} val
408487
khiter_t k
409488
intp_t[:] locs = np.empty(n, dtype=np.intp)
410489

411490
with nogil:
412491
for i in range(n):
413-
val = values[i]
492+
val = {{to_c_type}}(values[i])
414493
k = kh_get_{{dtype}}(self.table, val)
415494
if k != self.table.n_buckets:
416495
locs[i] = self.table.vals[k]
@@ -466,7 +545,7 @@ cdef class {{name}}HashTable(HashTable):
466545
Py_ssize_t i, idx, count = count_prior, n = len(values)
467546
int64_t[:] labels
468547
int ret = 0
469-
{{dtype}}_t val, na_value2
548+
{{c_type}} val, na_value2
470549
khiter_t k
471550
{{name}}VectorData *ud
472551
bint use_na_value, use_mask
@@ -487,23 +566,21 @@ cdef class {{name}}HashTable(HashTable):
487566
# We use None, to make it optional, which requires `object` type
488567
# for the parameter. To please the compiler, we use na_value2,
489568
# which is only used if it's *specified*.
490-
na_value2 = <{{dtype}}_t>na_value
569+
na_value2 = {{to_c_type}}(na_value)
491570
else:
492-
na_value2 = 0
571+
na_value2 = {{to_c_type}}(0)
493572

494573
with nogil:
495574
for i in range(n):
496-
val = values[i]
575+
val = {{to_c_type}}(values[i])
497576

498577
if ignore_na and use_mask:
499578
if mask_values[i]:
500579
labels[i] = na_sentinel
501580
continue
502581
elif ignore_na and (
503-
{{if not name.lower().startswith(("uint", "int"))}}
504-
val != val or
505-
{{endif}}
506-
(use_na_value and val == na_value2)
582+
is_nan_{{c_type}}(val) or
583+
(use_na_value and are_equivalent_{{c_type}}(val, na_value2))
507584
):
508585
# if missing values do not count as unique values (i.e. if
509586
# ignore_na is True), skip the hashtable entry for them,
@@ -606,14 +683,15 @@ cdef class {{name}}HashTable(HashTable):
606683
ignore_na=True, return_inverse=True)
607684
return labels
608685

686+
{{if dtype == 'int64'}}
609687
@cython.boundscheck(False)
610688
def get_labels_groupby(self, const {{dtype}}_t[:] values):
611689
cdef:
612690
Py_ssize_t i, n = len(values)
613691
intp_t[:] labels
614692
Py_ssize_t idx, count = 0
615693
int ret = 0
616-
{{dtype}}_t val
694+
{{c_type}} val
617695
khiter_t k
618696
{{name}}Vector uniques = {{name}}Vector()
619697
{{name}}VectorData *ud
@@ -623,14 +701,12 @@ cdef class {{name}}HashTable(HashTable):
623701

624702
with nogil:
625703
for i in range(n):
626-
val = values[i]
704+
val = {{to_c_type}}(values[i])
627705

628706
# specific for groupby
629-
{{if dtype != 'uint64'}}
630707
if val < 0:
631708
labels[i] = -1
632709
continue
633-
{{endif}}
634710

635711
k = kh_get_{{dtype}}(self.table, val)
636712
if k != self.table.n_buckets:
@@ -650,6 +726,7 @@ cdef class {{name}}HashTable(HashTable):
650726
arr_uniques = uniques.to_array()
651727

652728
return np.asarray(labels), arr_uniques
729+
{{endif}}
653730

654731
{{endfor}}
655732

@@ -698,7 +775,6 @@ cdef class StringHashTable(HashTable):
698775
v = get_c_string(key)
699776

700777
k = kh_put_str(self.table, v, &ret)
701-
self.table.keys[k] = v
702778
if kh_exist_str(self.table, k):
703779
self.table.vals[k] = val
704780
else:
@@ -1022,7 +1098,6 @@ cdef class PyObjectHashTable(HashTable):
10221098
hash(key)
10231099

10241100
k = kh_put_pymap(self.table, <PyObject*>key, &ret)
1025-
# self.table.keys[k] = key
10261101
if kh_exist_pymap(self.table, k):
10271102
self.table.vals[k] = val
10281103
else:

0 commit comments

Comments
 (0)