Skip to content

Commit 8d3b883

Browse files
committed
ENH: Allow keep='all' for nlargest/nsmallest
Closes pandas-devgh-16818. Closes pandas-devgh-18656.
1 parent 45e55af commit 8d3b883

File tree

5 files changed

+115
-21
lines changed

5 files changed

+115
-21
lines changed

doc/source/whatsnew/v0.24.0.txt

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ Other Enhancements
2424
<https://pandas-gbq.readthedocs.io/en/latest/changelog.html#changelog-0-5-0>`__.
2525
(:issue:`21627`)
2626
- New method :meth:`HDFStore.walk` will recursively walk the group hierarchy of an HDF5 file (:issue:`10932`)
27+
- :func:`Series` / :func:`DataFrame` methods :func:`nlargest` / :func:`nsmallest` now accept the value 'all' for the `keep` argument. This keeps all ties for the nth largest/smallest value (:issue:`16818`)
2728
-
2829

2930
.. _whatsnew_0240.api_breaking:

pandas/core/algorithms.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -1076,8 +1076,8 @@ def __init__(self, obj, n, keep):
10761076
self.n = n
10771077
self.keep = keep
10781078

1079-
if self.keep not in ('first', 'last'):
1080-
raise ValueError('keep must be either "first", "last"')
1079+
if self.keep not in ('first', 'last', 'all'):
1080+
raise ValueError('keep must be either "first", "last" or "all"')
10811081

10821082
def nlargest(self):
10831083
return self.compute('nlargest')
@@ -1148,7 +1148,11 @@ def compute(self, method):
11481148

11491149
kth_val = algos.kth_smallest(arr.copy(), n - 1)
11501150
ns, = np.nonzero(arr <= kth_val)
1151-
inds = ns[arr[ns].argsort(kind='mergesort')][:n]
1151+
inds = ns[arr[ns].argsort(kind='mergesort')]
1152+
1153+
if self.keep != 'all':
1154+
inds = inds[:n]
1155+
11521156
if self.keep == 'last':
11531157
# reverse indices
11541158
inds = narr - 1 - inds

pandas/core/frame.py

+79-18
Original file line numberDiff line numberDiff line change
@@ -4558,11 +4558,15 @@ def nlargest(self, n, columns, keep='first'):
45584558
Number of rows to return.
45594559
columns : label or list of labels
45604560
Column label(s) to order by.
4561-
keep : {'first', 'last'}, default 'first'
4561+
keep : {'first', 'last', 'all'}, default 'first'
45624562
Where there are duplicate values:
45634563
45644564
- `first` : prioritize the first occurrence(s)
45654565
- `last` : prioritize the last occurrence(s)
4566+
- ``all`` : do not drop any duplicates, even it means
4567+
selecting more than `n` items.
4568+
4569+
.. versionadded:: 0.24.0
45664570
45674571
Returns
45684572
-------
@@ -4585,42 +4589,51 @@ def nlargest(self, n, columns, keep='first'):
45854589
45864590
Examples
45874591
--------
4588-
>>> df = pd.DataFrame({'a': [1, 10, 8, 10, -1],
4589-
... 'b': list('abdce'),
4590-
... 'c': [1.0, 2.0, np.nan, 3.0, 4.0]})
4592+
>>> df = pd.DataFrame({'a': [1, 10, 8, 11, 8, 2],
4593+
... 'b': list('abdcef'),
4594+
... 'c': [1.0, 2.0, np.nan, 3.0, 4.0, 9.0]})
45914595
>>> df
45924596
a b c
45934597
0 1 a 1.0
45944598
1 10 b 2.0
45954599
2 8 d NaN
4596-
3 10 c 3.0
4597-
4 -1 e 4.0
4600+
3 11 c 3.0
4601+
4 8 e 4.0
4602+
5 2 f 9.0
45984603
45994604
In the following example, we will use ``nlargest`` to select the three
46004605
rows having the largest values in column "a".
46014606
46024607
>>> df.nlargest(3, 'a')
46034608
a b c
4609+
3 11 c 3.0
46044610
1 10 b 2.0
4605-
3 10 c 3.0
46064611
2 8 d NaN
46074612
46084613
When using ``keep='last'``, ties are resolved in reverse order:
46094614
46104615
>>> df.nlargest(3, 'a', keep='last')
46114616
a b c
4612-
3 10 c 3.0
4617+
3 11 c 3.0
4618+
1 10 b 2.0
4619+
4 8 e 4.0
4620+
4621+
When using ``keep='all'``, all duplicate items are maintained
4622+
>>> df.nlargest(3, 'a', keep='all')
4623+
a b c
4624+
3 11 c 3.0
46134625
1 10 b 2.0
46144626
2 8 d NaN
4627+
4 8 e 4.0
46154628
46164629
To order by the largest values in column "a" and then "c", we can
46174630
specify multiple columns like in the next example.
46184631
46194632
>>> df.nlargest(3, ['a', 'c'])
46204633
a b c
4621-
3 10 c 3.0
4634+
4 8 e 4.0
4635+
3 11 c 3.0
46224636
1 10 b 2.0
4623-
2 8 d NaN
46244637
46254638
Attempting to use ``nlargest`` on non-numeric dtypes will raise a
46264639
``TypeError``:
@@ -4644,25 +4657,73 @@ def nsmallest(self, n, columns, keep='first'):
46444657
Number of items to retrieve
46454658
columns : list or str
46464659
Column name or names to order by
4647-
keep : {'first', 'last'}, default 'first'
4660+
keep : {'first', 'last', 'all'}, default 'first'
46484661
Where there are duplicate values:
46494662
- ``first`` : take the first occurrence.
46504663
- ``last`` : take the last occurrence.
4664+
- ``all`` : do not drop any duplicates, even it means
4665+
selecting more than `n` items.
4666+
4667+
.. versionadded:: 0.24.0
46514668
46524669
Returns
46534670
-------
46544671
DataFrame
46554672
46564673
Examples
46574674
--------
4658-
>>> df = pd.DataFrame({'a': [1, 10, 8, 11, -1],
4659-
... 'b': list('abdce'),
4660-
... 'c': [1.0, 2.0, np.nan, 3.0, 4.0]})
4675+
>>> df = pd.DataFrame({'a': [1, 10, 8, 11, 8, 2],
4676+
... 'b': list('abdcef'),
4677+
... 'c': [1.0, 2.0, np.nan, 3.0, 4.0, 9.0]})
4678+
>>> df
4679+
a b c
4680+
0 1 a 1.0
4681+
1 10 b 2.0
4682+
2 8 d NaN
4683+
3 11 c 3.0
4684+
4 8 e 4.0
4685+
5 2 f 9.0
4686+
4687+
In the following example, we will use ``nsmallest`` to select the
4688+
three rows having the smallest values in column "a".
4689+
46614690
>>> df.nsmallest(3, 'a')
4662-
a b c
4663-
4 -1 e 4
4664-
0 1 a 1
4665-
2 8 d NaN
4691+
a b c
4692+
0 1 a 1.0
4693+
5 2 f 9.0
4694+
2 8 d NaN
4695+
4696+
When using ``keep='last'``, ties are resolved in reverse order:
4697+
4698+
>>> df.nsmallest(3, 'a', keep='last')
4699+
a b c
4700+
0 1 a 1.0
4701+
5 2 f 9.0
4702+
4 8 e 4.0
4703+
4704+
When using ``keep='all'``, all duplicate items are maintained
4705+
>>> df.nsmallest(3, 'a', keep='all')
4706+
a b c
4707+
0 1 a 1.0
4708+
5 2 f 9.0
4709+
2 8 d NaN
4710+
4 8 e 4.0
4711+
4712+
To order by the largest values in column "a" and then "c", we can
4713+
specify multiple columns like in the next example.
4714+
4715+
>>> df.nsmallest(3, ['a', 'c'])
4716+
a b c
4717+
0 1 a 1.0
4718+
5 2 f 9.0
4719+
4 8 e 4.0
4720+
4721+
Attempting to use ``nsmallest`` on non-numeric dtypes will raise a
4722+
``TypeError``:
4723+
4724+
>>> df.nsmallest(3, 'b')
4725+
Traceback (most recent call last):
4726+
TypeError: Column 'b' has dtype object, cannot use method 'nsmallest'
46664727
"""
46674728
return algorithms.SelectNFrame(self,
46684729
n=n,

pandas/tests/frame/test_analytics.py

+16
Original file line numberDiff line numberDiff line change
@@ -2461,6 +2461,22 @@ def test_n_duplicate_index(self, df_duplicates, n, order):
24612461
expected = df.sort_values(order, ascending=False).head(n)
24622462
tm.assert_frame_equal(result, expected)
24632463

2464+
def test_keep_all_ties(self):
2465+
# see gh-16818
2466+
df = pd.DataFrame({'a': [5, 4, 4, 2, 3, 3, 3, 3],
2467+
'b': [10, 9, 8, 7, 5, 50, 10, 20]})
2468+
result = df.nlargest(4, 'a', keep='all')
2469+
expected = pd.DataFrame({'a': {0: 5, 1: 4, 2: 4, 4: 3,
2470+
5: 3, 6: 3, 7: 3},
2471+
'b': {0: 10, 1: 9, 2: 8, 4: 5,
2472+
5: 50, 6: 10, 7: 20}})
2473+
tm.assert_frame_equal(result, expected)
2474+
2475+
result = df.nsmallest(2, 'a', keep='all')
2476+
expected = pd.DataFrame({'a': {3: 2, 4: 3, 5: 3, 6: 3, 7: 3},
2477+
'b': {3: 7, 4: 5, 5: 50, 6: 10, 7: 20}})
2478+
tm.assert_frame_equal(result, expected)
2479+
24642480
def test_series_broadcasting(self):
24652481
# smoke test for numpy warnings
24662482
# GH 16378, GH 16306

pandas/tests/series/test_analytics.py

+12
Original file line numberDiff line numberDiff line change
@@ -2082,6 +2082,18 @@ def test_boundary_datetimelike(self, nselect_method, dtype):
20822082
vals = [min_val + 1, min_val + 2, max_val - 1, max_val, min_val]
20832083
assert_check_nselect_boundary(vals, dtype, nselect_method)
20842084

2085+
def test_keep_all_ties(self):
2086+
# see gh-16818
2087+
s = Series([10, 9, 8, 7, 7, 7, 7, 6])
2088+
result = s.nlargest(4, keep='all')
2089+
expected = Series([10, 9, 8, 7, 7, 7, 7])
2090+
print(result, expected)
2091+
assert_series_equal(result, expected)
2092+
2093+
result = s.nsmallest(2, keep='all')
2094+
expected = Series([6, 7, 7, 7, 7], index=[7, 3, 4, 5, 6])
2095+
assert_series_equal(result, expected)
2096+
20852097

20862098
class TestCategoricalSeriesAnalytics(object):
20872099

0 commit comments

Comments
 (0)