Skip to content

Commit 655d9f4

Browse files
authored
ENH: Support mask in unique (#48109)
1 parent b3f8ab4 commit 655d9f4

File tree

5 files changed

+111
-9
lines changed

5 files changed

+111
-9
lines changed

asv_bench/benchmarks/hash_functions.py

+15
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,21 @@ def time_unique(self, exponent):
3939
pd.unique(self.a2)
4040

4141

42+
class Unique:
43+
params = ["Int64", "Float64"]
44+
param_names = ["dtype"]
45+
46+
def setup(self, dtype):
47+
self.ser = pd.Series(([1, pd.NA, 2] + list(range(100_000))) * 3, dtype=dtype)
48+
self.ser_unique = pd.Series(list(range(300_000)) + [pd.NA], dtype=dtype)
49+
50+
def time_unique_with_duplicates(self, exponent):
51+
pd.unique(self.ser)
52+
53+
def time_unique(self, exponent):
54+
pd.unique(self.ser_unique)
55+
56+
4257
class NumericSeriesIndexing:
4358

4459
params = [

pandas/_libs/hashtable_class_helper.pxi.in

+63-6
Original file line numberDiff line numberDiff line change
@@ -521,7 +521,7 @@ cdef class {{name}}HashTable(HashTable):
521521
def _unique(self, const {{dtype}}_t[:] values, {{name}}Vector uniques,
522522
Py_ssize_t count_prior=0, Py_ssize_t na_sentinel=-1,
523523
object na_value=None, bint ignore_na=False,
524-
object mask=None, bint return_inverse=False):
524+
object mask=None, bint return_inverse=False, bint use_result_mask=False):
525525
"""
526526
Calculate unique values and labels (no sorting!)
527527

@@ -551,13 +551,18 @@ cdef class {{name}}HashTable(HashTable):
551551
return_inverse : bool, default False
552552
Whether the mapping of the original array values to their location
553553
in the vector of uniques should be returned.
554+
use_result_mask: bool, default False
555+
Whether to create a result mask for the unique values. Not supported
556+
with return_inverse=True.
554557

555558
Returns
556559
-------
557560
uniques : ndarray[{{dtype}}]
558561
Unique values of input, not sorted
559562
labels : ndarray[intp_t] (if return_inverse=True)
560563
The labels from values to uniques
564+
result_mask: ndarray[bool], if use_result_mask is true
565+
The mask for the result values.
561566
"""
562567
cdef:
563568
Py_ssize_t i, idx, count = count_prior, n = len(values)
@@ -566,14 +571,24 @@ cdef class {{name}}HashTable(HashTable):
566571
{{c_type}} val, na_value2
567572
khiter_t k
568573
{{name}}VectorData *ud
569-
bint use_na_value, use_mask
574+
UInt8Vector result_mask
575+
UInt8VectorData *rmd
576+
bint use_na_value, use_mask, seen_na = False
570577
uint8_t[:] mask_values
571578

572579
if return_inverse:
573580
labels = np.empty(n, dtype=np.intp)
574581
ud = uniques.data
575582
use_na_value = na_value is not None
576583
use_mask = mask is not None
584+
if not use_mask and use_result_mask:
585+
raise NotImplementedError # pragma: no cover
586+
587+
if use_result_mask and return_inverse:
588+
raise NotImplementedError # pragma: no cover
589+
590+
result_mask = UInt8Vector()
591+
rmd = result_mask.data
577592

578593
if use_mask:
579594
mask_values = mask.view("uint8")
@@ -605,6 +620,27 @@ cdef class {{name}}HashTable(HashTable):
605620
# and replace the corresponding label with na_sentinel
606621
labels[i] = na_sentinel
607622
continue
623+
elif not ignore_na and use_result_mask:
624+
if mask_values[i]:
625+
if seen_na:
626+
continue
627+
628+
seen_na = True
629+
if needs_resize(ud):
630+
with gil:
631+
if uniques.external_view_exists:
632+
raise ValueError("external reference to "
633+
"uniques held, but "
634+
"Vector.resize() needed")
635+
uniques.resize()
636+
if result_mask.external_view_exists:
637+
raise ValueError("external reference to "
638+
"result_mask held, but "
639+
"Vector.resize() needed")
640+
result_mask.resize()
641+
append_data_{{dtype}}(ud, val)
642+
append_data_uint8(rmd, 1)
643+
continue
608644

609645
k = kh_get_{{dtype}}(self.table, val)
610646

@@ -619,7 +655,16 @@ cdef class {{name}}HashTable(HashTable):
619655
"uniques held, but "
620656
"Vector.resize() needed")
621657
uniques.resize()
658+
if use_result_mask:
659+
if result_mask.external_view_exists:
660+
raise ValueError("external reference to "
661+
"result_mask held, but "
662+
"Vector.resize() needed")
663+
result_mask.resize()
622664
append_data_{{dtype}}(ud, val)
665+
if use_result_mask:
666+
append_data_uint8(rmd, 0)
667+
623668
if return_inverse:
624669
self.table.vals[k] = count
625670
labels[i] = count
@@ -632,9 +677,11 @@ cdef class {{name}}HashTable(HashTable):
632677

633678
if return_inverse:
634679
return uniques.to_array(), labels.base # .base -> underlying ndarray
680+
if use_result_mask:
681+
return uniques.to_array(), result_mask.to_array()
635682
return uniques.to_array()
636683

637-
def unique(self, const {{dtype}}_t[:] values, bint return_inverse=False):
684+
def unique(self, const {{dtype}}_t[:] values, bint return_inverse=False, object mask=None):
638685
"""
639686
Calculate unique values and labels (no sorting!)
640687

@@ -645,17 +692,23 @@ cdef class {{name}}HashTable(HashTable):
645692
return_inverse : bool, default False
646693
Whether the mapping of the original array values to their location
647694
in the vector of uniques should be returned.
695+
mask : ndarray[bool], optional
696+
If not None, the mask is used as indicator for missing values
697+
(True = missing, False = valid) instead of `na_value` or
648698

649699
Returns
650700
-------
651701
uniques : ndarray[{{dtype}}]
652702
Unique values of input, not sorted
653703
labels : ndarray[intp_t] (if return_inverse)
654704
The labels from values to uniques
705+
result_mask: ndarray[bool], if mask is given as input
706+
The mask for the result values.
655707
"""
656708
uniques = {{name}}Vector()
709+
use_result_mask = True if mask is not None else False
657710
return self._unique(values, uniques, ignore_na=False,
658-
return_inverse=return_inverse)
711+
return_inverse=return_inverse, mask=mask, use_result_mask=use_result_mask)
659712

660713
def factorize(self, const {{dtype}}_t[:] values, Py_ssize_t na_sentinel=-1,
661714
object na_value=None, object mask=None):
@@ -1013,7 +1066,7 @@ cdef class StringHashTable(HashTable):
10131066
return uniques.to_array(), labels.base # .base -> underlying ndarray
10141067
return uniques.to_array()
10151068

1016-
def unique(self, ndarray[object] values, bint return_inverse=False):
1069+
def unique(self, ndarray[object] values, bint return_inverse=False, object mask=None):
10171070
"""
10181071
Calculate unique values and labels (no sorting!)
10191072

@@ -1024,6 +1077,8 @@ cdef class StringHashTable(HashTable):
10241077
return_inverse : bool, default False
10251078
Whether the mapping of the original array values to their location
10261079
in the vector of uniques should be returned.
1080+
mask : ndarray[bool], optional
1081+
Not yet implemented for StringHashTable
10271082

10281083
Returns
10291084
-------
@@ -1266,7 +1321,7 @@ cdef class PyObjectHashTable(HashTable):
12661321
return uniques.to_array(), labels.base # .base -> underlying ndarray
12671322
return uniques.to_array()
12681323

1269-
def unique(self, ndarray[object] values, bint return_inverse=False):
1324+
def unique(self, ndarray[object] values, bint return_inverse=False, object mask=None):
12701325
"""
12711326
Calculate unique values and labels (no sorting!)
12721327

@@ -1277,6 +1332,8 @@ cdef class PyObjectHashTable(HashTable):
12771332
return_inverse : bool, default False
12781333
Whether the mapping of the original array values to their location
12791334
in the vector of uniques should be returned.
1335+
mask : ndarray[bool], optional
1336+
Not yet implemented for PyObjectHashTable
12801337

12811338
Returns
12821339
-------

pandas/core/algorithms.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,11 @@ def unique(values):
404404
>>> pd.unique([("a", "b"), ("b", "a"), ("a", "c"), ("b", "a")])
405405
array([('a', 'b'), ('b', 'a'), ('a', 'c')], dtype=object)
406406
"""
407+
return unique_with_mask(values)
408+
409+
410+
def unique_with_mask(values, mask: npt.NDArray[np.bool_] | None = None):
411+
"""See algorithms.unique for docs. Takes a mask for masked arrays."""
407412
values = _ensure_arraylike(values)
408413

409414
if is_extension_array_dtype(values.dtype):
@@ -414,9 +419,16 @@ def unique(values):
414419
htable, values = _get_hashtable_algo(values)
415420

416421
table = htable(len(values))
417-
uniques = table.unique(values)
418-
uniques = _reconstruct_data(uniques, original.dtype, original)
419-
return uniques
422+
if mask is None:
423+
uniques = table.unique(values)
424+
uniques = _reconstruct_data(uniques, original.dtype, original)
425+
return uniques
426+
427+
else:
428+
uniques, mask = table.unique(values, mask=mask)
429+
uniques = _reconstruct_data(uniques, original.dtype, original)
430+
assert mask is not None # for mypy
431+
return uniques, mask.astype("bool")
420432

421433

422434
unique1d = unique

pandas/core/arrays/masked.py

+11
Original file line numberDiff line numberDiff line change
@@ -851,6 +851,17 @@ def copy(self: BaseMaskedArrayT) -> BaseMaskedArrayT:
851851
mask = mask.copy()
852852
return type(self)(data, mask, copy=False)
853853

854+
def unique(self: BaseMaskedArrayT) -> BaseMaskedArrayT:
855+
"""
856+
Compute the BaseMaskedArray of unique values.
857+
858+
Returns
859+
-------
860+
uniques : BaseMaskedArray
861+
"""
862+
uniques, mask = algos.unique_with_mask(self._data, self._mask)
863+
return type(self)(uniques, mask, copy=False)
864+
854865
@doc(ExtensionArray.searchsorted)
855866
def searchsorted(
856867
self,

pandas/tests/test_algos.py

+7
Original file line numberDiff line numberDiff line change
@@ -834,6 +834,13 @@ def test_do_not_mangle_na_values(self, unique_nulls_fixture, unique_nulls_fixtur
834834
assert a[0] is unique_nulls_fixture
835835
assert a[1] is unique_nulls_fixture2
836836

837+
def test_unique_masked(self, any_numeric_ea_dtype):
838+
# GH#48019
839+
ser = Series([1, pd.NA, 2] * 3, dtype=any_numeric_ea_dtype)
840+
result = pd.unique(ser)
841+
expected = pd.array([1, pd.NA, 2], dtype=any_numeric_ea_dtype)
842+
tm.assert_extension_array_equal(result, expected)
843+
837844

838845
class TestIsin:
839846
def test_invalid(self):

0 commit comments

Comments
 (0)