Skip to content

Commit 39cb4ca

Browse files
committed
ENH: #59237
1 parent 642d244 commit 39cb4ca

File tree

2 files changed

+50
-6
lines changed

2 files changed

+50
-6
lines changed

pandas/core/frame.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
Iterator,
2121
Mapping,
2222
Sequence,
23+
Set as AbstractSet,
2324
)
2425
import functools
2526
from io import StringIO
@@ -6534,7 +6535,7 @@ def dropna(
65346535
@overload
65356536
def drop_duplicates(
65366537
self,
6537-
subset: Hashable | Sequence[Hashable] | None = ...,
6538+
subset: Hashable | Sequence[Hashable] | AbstractSet | None = ...,
65386539
*,
65396540
keep: DropKeep = ...,
65406541
inplace: Literal[True],
@@ -6544,7 +6545,7 @@ def drop_duplicates(
65446545
@overload
65456546
def drop_duplicates(
65466547
self,
6547-
subset: Hashable | Sequence[Hashable] | None = ...,
6548+
subset: Hashable | Sequence[Hashable] | AbstractSet | None = ...,
65486549
*,
65496550
keep: DropKeep = ...,
65506551
inplace: Literal[False] = ...,
@@ -6554,7 +6555,7 @@ def drop_duplicates(
65546555
@overload
65556556
def drop_duplicates(
65566557
self,
6557-
subset: Hashable | Sequence[Hashable] | None = ...,
6558+
subset: Hashable | Sequence[Hashable] | AbstractSet | None = ...,
65586559
*,
65596560
keep: DropKeep = ...,
65606561
inplace: bool = ...,
@@ -6563,7 +6564,7 @@ def drop_duplicates(
65636564

65646565
def drop_duplicates(
65656566
self,
6566-
subset: Hashable | Sequence[Hashable] | None = None,
6567+
subset: Hashable | Sequence[Hashable] | AbstractSet | None = None,
65676568
*,
65686569
keep: DropKeep = "first",
65696570
inplace: bool = False,
@@ -6667,7 +6668,7 @@ def drop_duplicates(
66676668

66686669
def duplicated(
66696670
self,
6670-
subset: Hashable | Sequence[Hashable] | None = None,
6671+
subset: Hashable | Sequence[Hashable] | AbstractSet | None = None,
66716672
keep: DropKeep = "first",
66726673
) -> Series:
66736674
"""
@@ -6792,8 +6793,13 @@ def f(vals) -> tuple[np.ndarray, int]:
67926793
raise KeyError(Index(diff))
67936794

67946795
if len(subset) == 1 and self.columns.is_unique:
6796+
# GH#59237 adding support for single element sets
6797+
if isinstance(subset, set):
6798+
elem = subset.pop()
6799+
else:
6800+
elem = subset[0]
67956801
# GH#45236 This is faster than get_group_index below
6796-
result = self[subset[0]].duplicated(keep)
6802+
result = self[elem].duplicated(keep)
67976803
result.name = None
67986804
else:
67996805
vals = (col.values for name, col in self.items() if name in subset)

pandas/tests/frame/methods/test_drop_duplicates.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -476,3 +476,41 @@ def test_drop_duplicates_non_boolean_ignore_index(arg):
476476
msg = '^For argument "ignore_index" expected type bool, received type .*.$'
477477
with pytest.raises(ValueError, match=msg):
478478
df.drop_duplicates(ignore_index=arg)
479+
480+
481+
def test_drop_duplicates_set():
482+
# GH#59237
483+
df = DataFrame(
484+
{
485+
"AAA": ["foo", "bar", "foo", "bar", "foo", "bar", "bar", "foo"],
486+
"B": ["one", "one", "two", "two", "two", "two", "one", "two"],
487+
"C": [1, 1, 2, 2, 2, 2, 1, 2],
488+
"D": range(8),
489+
}
490+
)
491+
# single column
492+
result = df.drop_duplicates({"AAA"})
493+
expected = df[:2]
494+
tm.assert_frame_equal(result, expected)
495+
496+
result = df.drop_duplicates({"AAA"}, keep="last")
497+
expected = df.loc[[6, 7]]
498+
tm.assert_frame_equal(result, expected)
499+
500+
result = df.drop_duplicates({"AAA"}, keep=False)
501+
expected = df.loc[[]]
502+
tm.assert_frame_equal(result, expected)
503+
assert len(result) == 0
504+
505+
# multi column
506+
expected = df.loc[[0, 1, 2, 3]]
507+
result = df.drop_duplicates({"AAA", "B"})
508+
tm.assert_frame_equal(result, expected)
509+
510+
result = df.drop_duplicates({"AAA", "B"}, keep="last")
511+
expected = df.loc[[0, 5, 6, 7]]
512+
tm.assert_frame_equal(result, expected)
513+
514+
result = df.drop_duplicates({"AAA", "B"}, keep=False)
515+
expected = df.loc[[0]]
516+
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)