Skip to content

Commit 94044c8

Browse files
authored
ENH: Support mask in duplicated algorithm (#48150)
* ENH: Support mask in duplicated algorithm * Fix mypy * Add tests * Improve test * Fix bug
1 parent 8f21b97 commit 94044c8

File tree

5 files changed

+114
-31
lines changed

5 files changed

+114
-31
lines changed

asv_bench/benchmarks/algorithms.py

+23
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,29 @@ def time_duplicated(self, unique, keep, dtype):
9595
self.idx.duplicated(keep=keep)
9696

9797

98+
class DuplicatedMaskedArray:
99+
100+
params = [
101+
[True, False],
102+
["first", "last", False],
103+
["Int64", "Float64"],
104+
]
105+
param_names = ["unique", "keep", "dtype"]
106+
107+
def setup(self, unique, keep, dtype):
108+
N = 10**5
109+
data = pd.Series(np.arange(N), dtype=dtype)
110+
data[list(range(1, N, 100))] = pd.NA
111+
if not unique:
112+
data = data.repeat(5)
113+
self.ser = data
114+
# cache is_unique
115+
self.ser.is_unique
116+
117+
def time_duplicated(self, unique, keep, dtype):
118+
self.ser.duplicated(keep=keep)
119+
120+
98121
class Hashing:
99122
def setup_cache(self):
100123
N = 10**5

pandas/_libs/hashtable.pyi

+1
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ class IntpHashTable(HashTable): ...
183183
def duplicated(
184184
values: np.ndarray,
185185
keep: Literal["last", "first", False] = ...,
186+
mask: npt.NDArray[np.bool_] | None = ...,
186187
) -> npt.NDArray[np.bool_]: ...
187188
def mode(
188189
values: np.ndarray, dropna: bool, mask: npt.NDArray[np.bool_] | None = ...

pandas/_libs/hashtable_func_helper.pxi.in

+60-31
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,9 @@ cdef value_count_{{dtype}}(const {{dtype}}_t[:] values, bint dropna, const uint8
118118
@cython.wraparound(False)
119119
@cython.boundscheck(False)
120120
{{if dtype == 'object'}}
121-
cdef duplicated_{{dtype}}(ndarray[{{dtype}}] values, object keep='first'):
121+
cdef duplicated_{{dtype}}(ndarray[{{dtype}}] values, object keep='first', const uint8_t[:] mask=None):
122122
{{else}}
123-
cdef duplicated_{{dtype}}(const {{dtype}}_t[:] values, object keep='first'):
123+
cdef duplicated_{{dtype}}(const {{dtype}}_t[:] values, object keep='first', const uint8_t[:] mask=None):
124124
{{endif}}
125125
cdef:
126126
int ret = 0
@@ -129,10 +129,12 @@ cdef duplicated_{{dtype}}(const {{dtype}}_t[:] values, object keep='first'):
129129
{{else}}
130130
PyObject* value
131131
{{endif}}
132-
Py_ssize_t i, n = len(values)
132+
Py_ssize_t i, n = len(values), first_na = -1
133133
khiter_t k
134134
kh_{{ttype}}_t *table = kh_init_{{ttype}}()
135135
ndarray[uint8_t, ndim=1, cast=True] out = np.empty(n, dtype='bool')
136+
bint seen_na = False, uses_mask = mask is not None
137+
bint seen_multiple_na = False
136138

137139
kh_resize_{{ttype}}(table, min(kh_needed_n_buckets(n), SIZE_HINT_LIMIT))
138140

@@ -147,9 +149,16 @@ cdef duplicated_{{dtype}}(const {{dtype}}_t[:] values, object keep='first'):
147149
{{endif}}
148150
for i in range(n - 1, -1, -1):
149151
# equivalent: range(n)[::-1], which cython doesn't like in nogil
150-
value = {{to_c_type}}(values[i])
151-
kh_put_{{ttype}}(table, value, &ret)
152-
out[i] = ret == 0
152+
if uses_mask and mask[i]:
153+
if seen_na:
154+
out[i] = True
155+
else:
156+
out[i] = False
157+
seen_na = True
158+
else:
159+
value = {{to_c_type}}(values[i])
160+
kh_put_{{ttype}}(table, value, &ret)
161+
out[i] = ret == 0
153162

154163
elif keep == 'first':
155164
{{if dtype == 'object'}}
@@ -158,9 +167,16 @@ cdef duplicated_{{dtype}}(const {{dtype}}_t[:] values, object keep='first'):
158167
with nogil:
159168
{{endif}}
160169
for i in range(n):
161-
value = {{to_c_type}}(values[i])
162-
kh_put_{{ttype}}(table, value, &ret)
163-
out[i] = ret == 0
170+
if uses_mask and mask[i]:
171+
if seen_na:
172+
out[i] = True
173+
else:
174+
out[i] = False
175+
seen_na = True
176+
else:
177+
value = {{to_c_type}}(values[i])
178+
kh_put_{{ttype}}(table, value, &ret)
179+
out[i] = ret == 0
164180

165181
else:
166182
{{if dtype == 'object'}}
@@ -169,15 +185,28 @@ cdef duplicated_{{dtype}}(const {{dtype}}_t[:] values, object keep='first'):
169185
with nogil:
170186
{{endif}}
171187
for i in range(n):
172-
value = {{to_c_type}}(values[i])
173-
k = kh_get_{{ttype}}(table, value)
174-
if k != table.n_buckets:
175-
out[table.vals[k]] = 1
176-
out[i] = 1
188+
if uses_mask and mask[i]:
189+
if not seen_na:
190+
first_na = i
191+
seen_na = True
192+
out[i] = 0
193+
elif not seen_multiple_na:
194+
out[i] = 1
195+
out[first_na] = 1
196+
seen_multiple_na = True
197+
else:
198+
out[i] = 1
199+
177200
else:
178-
k = kh_put_{{ttype}}(table, value, &ret)
179-
table.vals[k] = i
180-
out[i] = 0
201+
value = {{to_c_type}}(values[i])
202+
k = kh_get_{{ttype}}(table, value)
203+
if k != table.n_buckets:
204+
out[table.vals[k]] = 1
205+
out[i] = 1
206+
else:
207+
k = kh_put_{{ttype}}(table, value, &ret)
208+
table.vals[k] = i
209+
out[i] = 0
181210

182211
kh_destroy_{{ttype}}(table)
183212
return out
@@ -301,37 +330,37 @@ cpdef value_count(ndarray[htfunc_t] values, bint dropna, const uint8_t[:] mask=N
301330
raise TypeError(values.dtype)
302331

303332

304-
cpdef duplicated(ndarray[htfunc_t] values, object keep="first"):
333+
cpdef duplicated(ndarray[htfunc_t] values, object keep="first", const uint8_t[:] mask=None):
305334
if htfunc_t is object:
306-
return duplicated_object(values, keep)
335+
return duplicated_object(values, keep, mask=mask)
307336

308337
elif htfunc_t is int8_t:
309-
return duplicated_int8(values, keep)
338+
return duplicated_int8(values, keep, mask=mask)
310339
elif htfunc_t is int16_t:
311-
return duplicated_int16(values, keep)
340+
return duplicated_int16(values, keep, mask=mask)
312341
elif htfunc_t is int32_t:
313-
return duplicated_int32(values, keep)
342+
return duplicated_int32(values, keep, mask=mask)
314343
elif htfunc_t is int64_t:
315-
return duplicated_int64(values, keep)
344+
return duplicated_int64(values, keep, mask=mask)
316345

317346
elif htfunc_t is uint8_t:
318-
return duplicated_uint8(values, keep)
347+
return duplicated_uint8(values, keep, mask=mask)
319348
elif htfunc_t is uint16_t:
320-
return duplicated_uint16(values, keep)
349+
return duplicated_uint16(values, keep, mask=mask)
321350
elif htfunc_t is uint32_t:
322-
return duplicated_uint32(values, keep)
351+
return duplicated_uint32(values, keep, mask=mask)
323352
elif htfunc_t is uint64_t:
324-
return duplicated_uint64(values, keep)
353+
return duplicated_uint64(values, keep, mask=mask)
325354

326355
elif htfunc_t is float64_t:
327-
return duplicated_float64(values, keep)
356+
return duplicated_float64(values, keep, mask=mask)
328357
elif htfunc_t is float32_t:
329-
return duplicated_float32(values, keep)
358+
return duplicated_float32(values, keep, mask=mask)
330359

331360
elif htfunc_t is complex128_t:
332-
return duplicated_complex128(values, keep)
361+
return duplicated_complex128(values, keep, mask=mask)
333362
elif htfunc_t is complex64_t:
334-
return duplicated_complex64(values, keep)
363+
return duplicated_complex64(values, keep, mask=mask)
335364

336365
else:
337366
raise TypeError(values.dtype)

pandas/core/algorithms.py

+4
Original file line numberDiff line numberDiff line change
@@ -1043,6 +1043,10 @@ def duplicated(
10431043
-------
10441044
duplicated : ndarray[bool]
10451045
"""
1046+
if hasattr(values, "dtype") and isinstance(values.dtype, BaseMaskedDtype):
1047+
values = cast("BaseMaskedArray", values)
1048+
return htable.duplicated(values._data, keep=keep, mask=values._mask)
1049+
10461050
values = _ensure_data(values)
10471051
return htable.duplicated(values, keep=keep)
10481052

pandas/tests/series/methods/test_duplicated.py

+26
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import pytest
33

44
from pandas import (
5+
NA,
56
Categorical,
67
Series,
78
)
@@ -50,3 +51,28 @@ def test_duplicated_categorical_bool_na(nulls_fixture):
5051
result = ser.duplicated()
5152
expected = Series([False, False, True, True, False])
5253
tm.assert_series_equal(result, expected)
54+
55+
56+
@pytest.mark.parametrize(
57+
"keep, vals",
58+
[
59+
("last", [True, True, False]),
60+
("first", [False, True, True]),
61+
(False, [True, True, True]),
62+
],
63+
)
64+
def test_duplicated_mask(keep, vals):
65+
# GH#48150
66+
ser = Series([1, 2, NA, NA, NA], dtype="Int64")
67+
result = ser.duplicated(keep=keep)
68+
expected = Series([False, False] + vals)
69+
tm.assert_series_equal(result, expected)
70+
71+
72+
@pytest.mark.parametrize("keep", ["last", "first", False])
73+
def test_duplicated_mask_no_duplicated_na(keep):
74+
# GH#48150
75+
ser = Series([1, 2, NA], dtype="Int64")
76+
result = ser.duplicated(keep=keep)
77+
expected = Series([False, False, False])
78+
tm.assert_series_equal(result, expected)

0 commit comments

Comments
 (0)