Skip to content

Commit 0801b8c

Browse files
authored
ENH: Allow keep='all' for nlargest/nsmallest (#21650)
Closes gh-16818. Closes gh-18656.
1 parent e0f978d commit 0801b8c

File tree

7 files changed

+120
-23
lines changed

7 files changed

+120
-23
lines changed

asv_bench/benchmarks/frame_methods.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -501,7 +501,7 @@ def time_info(self):
501501
class NSort(object):
502502

503503
goal_time = 0.2
504-
params = ['first', 'last']
504+
params = ['first', 'last', 'all']
505505
param_names = ['keep']
506506

507507
def setup(self, keep):

asv_bench/benchmarks/series_methods.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def time_isin(self, dtypes):
4141
class NSort(object):
4242

4343
goal_time = 0.2
44-
params = ['last', 'first']
44+
params = ['first', 'last', 'all']
4545
param_names = ['keep']
4646

4747
def setup(self, keep):

doc/source/whatsnew/v0.24.0.txt

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ Other Enhancements
2424
<https://pandas-gbq.readthedocs.io/en/latest/changelog.html#changelog-0-5-0>`__.
2525
(:issue:`21627`)
2626
- New method :meth:`HDFStore.walk` will recursively walk the group hierarchy of an HDF5 file (:issue:`10932`)
27+
- :meth:`Series.nlargest`, :meth:`Series.nsmallest`, :meth:`DataFrame.nlargest`, and :meth:`DataFrame.nsmallest` now accept the value ``"all"`` for the ``keep` argument. This keeps all ties for the nth largest/smallest value (:issue:`16818`)
2728
-
2829

2930
.. _whatsnew_0240.api_breaking:

pandas/core/algorithms.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -1076,8 +1076,8 @@ def __init__(self, obj, n, keep):
10761076
self.n = n
10771077
self.keep = keep
10781078

1079-
if self.keep not in ('first', 'last'):
1080-
raise ValueError('keep must be either "first", "last"')
1079+
if self.keep not in ('first', 'last', 'all'):
1080+
raise ValueError('keep must be either "first", "last" or "all"')
10811081

10821082
def nlargest(self):
10831083
return self.compute('nlargest')
@@ -1148,7 +1148,11 @@ def compute(self, method):
11481148

11491149
kth_val = algos.kth_smallest(arr.copy(), n - 1)
11501150
ns, = np.nonzero(arr <= kth_val)
1151-
inds = ns[arr[ns].argsort(kind='mergesort')][:n]
1151+
inds = ns[arr[ns].argsort(kind='mergesort')]
1152+
1153+
if self.keep != 'all':
1154+
inds = inds[:n]
1155+
11521156
if self.keep == 'last':
11531157
# reverse indices
11541158
inds = narr - 1 - inds

pandas/core/frame.py

+83-18
Original file line numberDiff line numberDiff line change
@@ -4559,11 +4559,15 @@ def nlargest(self, n, columns, keep='first'):
45594559
Number of rows to return.
45604560
columns : label or list of labels
45614561
Column label(s) to order by.
4562-
keep : {'first', 'last'}, default 'first'
4562+
keep : {'first', 'last', 'all'}, default 'first'
45634563
Where there are duplicate values:
45644564
45654565
- `first` : prioritize the first occurrence(s)
45664566
- `last` : prioritize the last occurrence(s)
4567+
- ``all`` : do not drop any duplicates, even it means
4568+
selecting more than `n` items.
4569+
4570+
.. versionadded:: 0.24.0
45674571
45684572
Returns
45694573
-------
@@ -4586,47 +4590,58 @@ def nlargest(self, n, columns, keep='first'):
45864590
45874591
Examples
45884592
--------
4589-
>>> df = pd.DataFrame({'a': [1, 10, 8, 10, -1],
4590-
... 'b': list('abdce'),
4591-
... 'c': [1.0, 2.0, np.nan, 3.0, 4.0]})
4593+
>>> df = pd.DataFrame({'a': [1, 10, 8, 11, 8, 2],
4594+
... 'b': list('abdcef'),
4595+
... 'c': [1.0, 2.0, np.nan, 3.0, 4.0, 9.0]})
45924596
>>> df
45934597
a b c
45944598
0 1 a 1.0
45954599
1 10 b 2.0
45964600
2 8 d NaN
4597-
3 10 c 3.0
4598-
4 -1 e 4.0
4601+
3 11 c 3.0
4602+
4 8 e 4.0
4603+
5 2 f 9.0
45994604
46004605
In the following example, we will use ``nlargest`` to select the three
46014606
rows having the largest values in column "a".
46024607
46034608
>>> df.nlargest(3, 'a')
46044609
a b c
4610+
3 11 c 3.0
46054611
1 10 b 2.0
4606-
3 10 c 3.0
46074612
2 8 d NaN
46084613
46094614
When using ``keep='last'``, ties are resolved in reverse order:
46104615
46114616
>>> df.nlargest(3, 'a', keep='last')
46124617
a b c
4613-
3 10 c 3.0
4618+
3 11 c 3.0
4619+
1 10 b 2.0
4620+
4 8 e 4.0
4621+
4622+
When using ``keep='all'``, all duplicate items are maintained:
4623+
4624+
>>> df.nlargest(3, 'a', keep='all')
4625+
a b c
4626+
3 11 c 3.0
46144627
1 10 b 2.0
46154628
2 8 d NaN
4629+
4 8 e 4.0
46164630
46174631
To order by the largest values in column "a" and then "c", we can
46184632
specify multiple columns like in the next example.
46194633
46204634
>>> df.nlargest(3, ['a', 'c'])
46214635
a b c
4622-
3 10 c 3.0
4636+
4 8 e 4.0
4637+
3 11 c 3.0
46234638
1 10 b 2.0
4624-
2 8 d NaN
46254639
46264640
Attempting to use ``nlargest`` on non-numeric dtypes will raise a
46274641
``TypeError``:
46284642
46294643
>>> df.nlargest(3, 'b')
4644+
46304645
Traceback (most recent call last):
46314646
TypeError: Column 'b' has dtype object, cannot use method 'nlargest'
46324647
"""
@@ -4645,25 +4660,75 @@ def nsmallest(self, n, columns, keep='first'):
46454660
Number of items to retrieve
46464661
columns : list or str
46474662
Column name or names to order by
4648-
keep : {'first', 'last'}, default 'first'
4663+
keep : {'first', 'last', 'all'}, default 'first'
46494664
Where there are duplicate values:
46504665
- ``first`` : take the first occurrence.
46514666
- ``last`` : take the last occurrence.
4667+
- ``all`` : do not drop any duplicates, even it means
4668+
selecting more than `n` items.
4669+
4670+
.. versionadded:: 0.24.0
46524671
46534672
Returns
46544673
-------
46554674
DataFrame
46564675
46574676
Examples
46584677
--------
4659-
>>> df = pd.DataFrame({'a': [1, 10, 8, 11, -1],
4660-
... 'b': list('abdce'),
4661-
... 'c': [1.0, 2.0, np.nan, 3.0, 4.0]})
4678+
>>> df = pd.DataFrame({'a': [1, 10, 8, 11, 8, 2],
4679+
... 'b': list('abdcef'),
4680+
... 'c': [1.0, 2.0, np.nan, 3.0, 4.0, 9.0]})
4681+
>>> df
4682+
a b c
4683+
0 1 a 1.0
4684+
1 10 b 2.0
4685+
2 8 d NaN
4686+
3 11 c 3.0
4687+
4 8 e 4.0
4688+
5 2 f 9.0
4689+
4690+
In the following example, we will use ``nsmallest`` to select the
4691+
three rows having the smallest values in column "a".
4692+
46624693
>>> df.nsmallest(3, 'a')
4663-
a b c
4664-
4 -1 e 4
4665-
0 1 a 1
4666-
2 8 d NaN
4694+
a b c
4695+
0 1 a 1.0
4696+
5 2 f 9.0
4697+
2 8 d NaN
4698+
4699+
When using ``keep='last'``, ties are resolved in reverse order:
4700+
4701+
>>> df.nsmallest(3, 'a', keep='last')
4702+
a b c
4703+
0 1 a 1.0
4704+
5 2 f 9.0
4705+
4 8 e 4.0
4706+
4707+
When using ``keep='all'``, all duplicate items are maintained:
4708+
4709+
>>> df.nsmallest(3, 'a', keep='all')
4710+
a b c
4711+
0 1 a 1.0
4712+
5 2 f 9.0
4713+
2 8 d NaN
4714+
4 8 e 4.0
4715+
4716+
To order by the largest values in column "a" and then "c", we can
4717+
specify multiple columns like in the next example.
4718+
4719+
>>> df.nsmallest(3, ['a', 'c'])
4720+
a b c
4721+
0 1 a 1.0
4722+
5 2 f 9.0
4723+
4 8 e 4.0
4724+
4725+
Attempting to use ``nsmallest`` on non-numeric dtypes will raise a
4726+
``TypeError``:
4727+
4728+
>>> df.nsmallest(3, 'b')
4729+
4730+
Traceback (most recent call last):
4731+
TypeError: Column 'b' has dtype object, cannot use method 'nsmallest'
46674732
"""
46684733
return algorithms.SelectNFrame(self,
46694734
n=n,

pandas/tests/frame/test_analytics.py

+16
Original file line numberDiff line numberDiff line change
@@ -2461,6 +2461,22 @@ def test_n_duplicate_index(self, df_duplicates, n, order):
24612461
expected = df.sort_values(order, ascending=False).head(n)
24622462
tm.assert_frame_equal(result, expected)
24632463

2464+
def test_duplicate_keep_all_ties(self):
2465+
# see gh-16818
2466+
df = pd.DataFrame({'a': [5, 4, 4, 2, 3, 3, 3, 3],
2467+
'b': [10, 9, 8, 7, 5, 50, 10, 20]})
2468+
result = df.nlargest(4, 'a', keep='all')
2469+
expected = pd.DataFrame({'a': {0: 5, 1: 4, 2: 4, 4: 3,
2470+
5: 3, 6: 3, 7: 3},
2471+
'b': {0: 10, 1: 9, 2: 8, 4: 5,
2472+
5: 50, 6: 10, 7: 20}})
2473+
tm.assert_frame_equal(result, expected)
2474+
2475+
result = df.nsmallest(2, 'a', keep='all')
2476+
expected = pd.DataFrame({'a': {3: 2, 4: 3, 5: 3, 6: 3, 7: 3},
2477+
'b': {3: 7, 4: 5, 5: 50, 6: 10, 7: 20}})
2478+
tm.assert_frame_equal(result, expected)
2479+
24642480
def test_series_broadcasting(self):
24652481
# smoke test for numpy warnings
24662482
# GH 16378, GH 16306

pandas/tests/series/test_analytics.py

+11
Original file line numberDiff line numberDiff line change
@@ -2082,6 +2082,17 @@ def test_boundary_datetimelike(self, nselect_method, dtype):
20822082
vals = [min_val + 1, min_val + 2, max_val - 1, max_val, min_val]
20832083
assert_check_nselect_boundary(vals, dtype, nselect_method)
20842084

2085+
def test_duplicate_keep_all_ties(self):
2086+
# see gh-16818
2087+
s = Series([10, 9, 8, 7, 7, 7, 7, 6])
2088+
result = s.nlargest(4, keep='all')
2089+
expected = Series([10, 9, 8, 7, 7, 7, 7])
2090+
assert_series_equal(result, expected)
2091+
2092+
result = s.nsmallest(2, keep='all')
2093+
expected = Series([6, 7, 7, 7, 7], index=[7, 3, 4, 5, 6])
2094+
assert_series_equal(result, expected)
2095+
20852096

20862097
class TestCategoricalSeriesAnalytics(object):
20872098

0 commit comments

Comments
 (0)