Skip to content

Commit 8464750

Browse files
committed
Unify hashtable.factorize and .unique
1 parent 085a8eb commit 8464750

File tree

1 file changed

+91
-162
lines changed

1 file changed

+91
-162
lines changed

pandas/_libs/hashtable_class_helper.pxi.in

+91-162
Original file line numberDiff line numberDiff line change
@@ -356,64 +356,21 @@ cdef class {{name}}HashTable(HashTable):
356356
return np.asarray(locs)
357357

358358
@cython.boundscheck(False)
359-
def unique(self, const {{dtype}}_t[:] values, bint return_inverse=False):
359+
def _unique(self, const {{dtype}}_t[:] values, {{name}}Vector uniques,
360+
bint ignore_na=False, bint return_inverse=False,
361+
Py_ssize_t count_prior=0, Py_ssize_t na_sentinel=-1,
362+
object na_value=None):
360363
cdef:
361-
Py_ssize_t i, idx, count = 0, n = len(values)
364+
Py_ssize_t i, idx, count = count_prior, n = len(values)
362365
int64_t[:] labels
363366
int ret = 0
364-
{{dtype}}_t val
365-
khiter_t k
366-
{{name}}Vector uniques = {{name}}Vector()
367-
{{name}}VectorData *ud
368-
369-
ud = uniques.data
370-
if return_inverse:
371-
labels = np.empty(n, dtype=np.int64)
372-
373-
with nogil:
374-
for i in range(n):
375-
val = values[i]
376-
k = kh_get_{{dtype}}(self.table, val)
377-
if return_inverse and k != self.table.n_buckets:
378-
# k falls into a previous bucket
379-
idx = self.table.vals[k]
380-
labels[i] = idx
381-
elif k == self.table.n_buckets:
382-
# k hasn't been seen yet
383-
k = kh_put_{{dtype}}(self.table, val, &ret)
384-
if needs_resize(ud):
385-
with gil:
386-
uniques.resize()
387-
append_data_{{dtype}}(ud, val)
388-
if return_inverse:
389-
self.table.vals[k] = count
390-
labels[i] = count
391-
count += 1
392-
393-
if return_inverse:
394-
return uniques.to_array(), np.asarray(labels)
395-
return uniques.to_array()
396-
397-
def factorize(self, {{dtype}}_t[:] values):
398-
uniques = {{name}}Vector()
399-
labels = self.get_labels(values, uniques, 0)
400-
return uniques.to_array(), labels
401-
402-
@cython.boundscheck(False)
403-
def get_labels(self, const {{dtype}}_t[:] values, {{name}}Vector uniques,
404-
Py_ssize_t count_prior=0, Py_ssize_t na_sentinel=-1,
405-
object na_value=None):
406-
cdef:
407-
Py_ssize_t i, n = len(values)
408-
int64_t[:] labels
409-
Py_ssize_t idx, count = count_prior
410-
int ret = 0
411367
{{dtype}}_t val, na_value2
412368
khiter_t k
413369
{{name}}VectorData *ud
414370
bint use_na_value
415371

416-
labels = np.empty(n, dtype=np.int64)
372+
if return_inverse:
373+
labels = np.empty(n, dtype=np.int64)
417374
ud = uniques.data
418375
use_na_value = na_value is not None
419376

@@ -431,21 +388,19 @@ cdef class {{name}}HashTable(HashTable):
431388
for i in range(n):
432389
val = values[i]
433390

434-
if val != val or (use_na_value and val == na_value2):
391+
if ignore_na and (val != val
392+
or (use_na_value and val == na_value2)):
435393
labels[i] = na_sentinel
436394
continue
437395

438396
k = kh_get_{{dtype}}(self.table, val)
439-
440-
if k != self.table.n_buckets:
397+
if return_inverse and k != self.table.n_buckets:
441398
# k falls into a previous bucket
442399
idx = self.table.vals[k]
443400
labels[i] = idx
444-
else:
401+
elif k == self.table.n_buckets:
445402
# k hasn't been seen yet
446403
k = kh_put_{{dtype}}(self.table, val, &ret)
447-
self.table.vals[k] = count
448-
449404
if needs_resize(ud):
450405
with gil:
451406
if uniques.external_view_exists:
@@ -454,10 +409,30 @@ cdef class {{name}}HashTable(HashTable):
454409
"Vector.resize() needed")
455410
uniques.resize()
456411
append_data_{{dtype}}(ud, val)
457-
labels[i] = count
412+
if return_inverse:
413+
self.table.vals[k] = count
414+
labels[i] = count
458415
count += 1
459416

460-
return np.asarray(labels)
417+
if return_inverse:
418+
return uniques.to_array(), np.asarray(labels)
419+
return uniques.to_array()
420+
421+
def unique(self, const {{dtype}}_t[:] values, bint return_inverse=False):
422+
return self._unique(values, uniques={{name}}Vector(), ignore_na=False,
423+
return_inverse=return_inverse)
424+
425+
def factorize(self, {{dtype}}_t[:] values):
426+
return self._unique(values, uniques={{name}}Vector(), ignore_na=True,
427+
return_inverse=True)
428+
429+
def get_labels(self, const {{dtype}}_t[:] values, {{name}}Vector uniques,
430+
Py_ssize_t count_prior=0, Py_ssize_t na_sentinel=-1,
431+
object na_value=None):
432+
_, labels = self._unique(values, uniques, ignore_na=True,
433+
return_inverse=True, count_prior=count_prior,
434+
na_sentinel=na_sentinel, na_value=na_value)
435+
return labels
461436

462437
@cython.boundscheck(False)
463438
def get_labels_groupby(self, const {{dtype}}_t[:] values):
@@ -645,33 +620,45 @@ cdef class StringHashTable(HashTable):
645620
free(vecs)
646621

647622
@cython.boundscheck(False)
648-
def unique(self, ndarray[object] values, bint return_inverse=False):
623+
def _unique(self, ndarray[object] values, ObjectVector uniques,
624+
bint ignore_na=False, bint return_inverse=False,
625+
Py_ssize_t count_prior=0, Py_ssize_t na_sentinel=-1,
626+
object na_value=None):
649627
cdef:
650-
Py_ssize_t i, idx, count = 0, n = len(values)
628+
Py_ssize_t i, idx, count = count_prior, n = len(values)
651629
int64_t[:] labels
652630
int64_t[:] uindexer
653631
int ret = 0
654632
object val
655-
ObjectVector uniques = ObjectVector()
656-
khiter_t k
657633
const char *v
658634
const char **vecs
635+
khiter_t k
636+
bint use_na_value
659637

660638
if return_inverse:
661639
labels = np.zeros(n, dtype=np.int64)
662640
uindexer = np.empty(n, dtype=np.int64)
641+
use_na_value = na_value is not None
663642

664-
# assign pointers
643+
# assign pointers and pre-filter out missing (if ignore_na)
665644
vecs = <const char **> malloc(n * sizeof(char *))
666645
for i in range(n):
667646
val = values[i]
668-
v = util.get_c_string(val)
669-
vecs[i] = v
670647

648+
if not ignore_na or ((PyUnicode_Check(val) or PyString_Check(val))
649+
and not (use_na_value and val == na_value)):
650+
# if ignore_na is False, we also stringify NaN/None/etc.
651+
v = util.get_c_string(val)
652+
vecs[i] = v
653+
else:
654+
labels[i] = na_sentinel
671655

672656
# compute
673657
with nogil:
674658
for i in range(n):
659+
if ignore_na and labels[i] == na_sentinel:
660+
continue
661+
675662
v = vecs[i]
676663
k = kh_get_str(self.table, v)
677664
if return_inverse and k != self.table.n_buckets:
@@ -697,65 +684,21 @@ cdef class StringHashTable(HashTable):
697684
return uniques.to_array(), np.asarray(labels)
698685
return uniques.to_array()
699686

700-
@cython.boundscheck(False)
701-
def get_labels(self, ndarray[object] values, ObjectVector uniques,
702-
Py_ssize_t count_prior=0, int64_t na_sentinel=-1,
703-
object na_value=None):
704-
cdef:
705-
Py_ssize_t i, n = len(values)
706-
int64_t[:] labels
707-
int64_t[:] uindexer
708-
Py_ssize_t idx, count = count_prior
709-
int ret = 0
710-
object val
711-
const char *v
712-
const char **vecs
713-
khiter_t k
714-
bint use_na_value
715-
716-
# these by-definition *must* be strings
717-
labels = np.zeros(n, dtype=np.int64)
718-
uindexer = np.empty(n, dtype=np.int64)
719-
use_na_value = na_value is not None
720-
721-
# pre-filter out missing
722-
# and assign pointers
723-
vecs = <const char **> malloc(n * sizeof(char *))
724-
for i in range(n):
725-
val = values[i]
726-
727-
if ((PyUnicode_Check(val) or PyString_Check(val)) and
728-
not (use_na_value and val == na_value)):
729-
v = util.get_c_string(val)
730-
vecs[i] = v
731-
else:
732-
labels[i] = na_sentinel
733-
734-
# compute
735-
with nogil:
736-
for i in range(n):
737-
if labels[i] == na_sentinel:
738-
continue
739-
740-
v = vecs[i]
741-
k = kh_get_str(self.table, v)
742-
if k != self.table.n_buckets:
743-
idx = self.table.vals[k]
744-
labels[i] = <int64_t>idx
745-
else:
746-
k = kh_put_str(self.table, v, &ret)
747-
self.table.vals[k] = count
748-
uindexer[count] = i
749-
labels[i] = <int64_t>count
750-
count += 1
751-
752-
free(vecs)
687+
def unique(self, ndarray[object] values, bint return_inverse=False):
688+
return self._unique(values, uniques=ObjectVector(), ignore_na=False,
689+
return_inverse=return_inverse)
753690

754-
# uniques
755-
for i in range(count):
756-
uniques.append(values[uindexer[i]])
691+
def factorize(self, ndarray[object] values):
692+
return self._unique(values, uniques=ObjectVector(), ignore_na=True,
693+
return_inverse=True)
757694

758-
return np.asarray(labels)
695+
def get_labels(self, ndarray[object] values, ObjectVector uniques,
696+
Py_ssize_t count_prior=0, Py_ssize_t na_sentinel=-1,
697+
object na_value=None):
698+
_, labels = self._unique(values, uniques, ignore_na=True,
699+
return_inverse=True, count_prior=count_prior,
700+
na_sentinel=na_sentinel, na_value=na_value)
701+
return labels
759702

760703

761704
cdef class PyObjectHashTable(HashTable):
@@ -844,21 +787,31 @@ cdef class PyObjectHashTable(HashTable):
844787
return np.asarray(locs)
845788

846789
@cython.boundscheck(False)
847-
def unique(self, ndarray[object] values, bint return_inverse=False):
790+
def _unique(self, ndarray[object] values, ObjectVector uniques,
791+
bint ignore_na=False, bint return_inverse=False,
792+
Py_ssize_t count_prior=0, Py_ssize_t na_sentinel=-1,
793+
object na_value=None):
848794
cdef:
849-
Py_ssize_t i, idx, count = 0, n = len(values)
795+
Py_ssize_t i, idx, count = count_prior, n = len(values)
850796
int64_t[:] labels
851797
int ret = 0
852798
object val
853799
khiter_t k
854-
ObjectVector uniques = ObjectVector()
800+
bint use_na_value
855801

856802
if return_inverse:
857803
labels = np.empty(n, dtype=np.int64)
804+
use_na_value = na_value is not None
858805

859806
for i in range(n):
860807
val = values[i]
861808
hash(val)
809+
810+
if ignore_na and ((val != val or val is None)
811+
or (use_na_value and val == na_value)):
812+
labels[i] = na_sentinel
813+
continue
814+
862815
k = kh_get_pymap(self.table, <PyObject*>val)
863816
if return_inverse and k != self.table.n_buckets:
864817
# k falls into a previous bucket
@@ -877,42 +830,18 @@ cdef class PyObjectHashTable(HashTable):
877830
return uniques.to_array(), np.asarray(labels)
878831
return uniques.to_array()
879832

880-
@cython.boundscheck(False)
881-
def get_labels(self, ndarray[object] values, ObjectVector uniques,
882-
Py_ssize_t count_prior=0, int64_t na_sentinel=-1,
883-
object na_value=None):
884-
cdef:
885-
Py_ssize_t i, n = len(values)
886-
int64_t[:] labels
887-
Py_ssize_t idx, count = count_prior
888-
int ret = 0
889-
object val
890-
khiter_t k
891-
bint use_na_value
892-
893-
labels = np.empty(n, dtype=np.int64)
894-
use_na_value = na_value is not None
895-
896-
for i in range(n):
897-
val = values[i]
898-
hash(val)
899-
900-
if ((val != val or val is None) or
901-
(use_na_value and val == na_value)):
902-
labels[i] = na_sentinel
903-
continue
833+
def unique(self, ndarray[object] values, bint return_inverse=False):
834+
return self._unique(values, uniques=ObjectVector(), ignore_na=False,
835+
return_inverse=return_inverse)
904836

905-
k = kh_get_pymap(self.table, <PyObject*>val)
906-
if k != self.table.n_buckets:
907-
# k falls into a previous bucket
908-
idx = self.table.vals[k]
909-
labels[i] = idx
910-
else:
911-
# k hasn't been seen yet
912-
k = kh_put_pymap(self.table, <PyObject*>val, &ret)
913-
self.table.vals[k] = count
914-
uniques.append(val)
915-
labels[i] = count
916-
count += 1
837+
def factorize(self, ndarray[object] values):
838+
return self._unique(values, uniques=ObjectVector(), ignore_na=True,
839+
return_inverse=True)
917840

918-
return np.asarray(labels)
841+
def get_labels(self, ndarray[object] values, ObjectVector uniques,
842+
Py_ssize_t count_prior=0, Py_ssize_t na_sentinel=-1,
843+
object na_value=None):
844+
_, labels = self._unique(values, uniques, ignore_na=True,
845+
return_inverse=True, count_prior=count_prior,
846+
na_sentinel=na_sentinel, na_value=na_value)
847+
return labels

0 commit comments

Comments
 (0)