Skip to content

Commit 03aa68d

Browse files
committed
BUG: nlargest/nsmallest gave wrong result (#22752)
When asking for the n largest/smallest rows in a dataframe nlargest/nsmallest sometimes failed to differentiate the correct result based on the latter columns.
1 parent 1a12c41 commit 03aa68d

File tree

3 files changed

+48
-15
lines changed

3 files changed

+48
-15
lines changed

doc/source/whatsnew/v0.24.0.txt

+1
Original file line numberDiff line numberDiff line change
@@ -804,6 +804,7 @@ Other
804804
- :meth:`~pandas.io.formats.style.Styler.background_gradient` now takes a ``text_color_threshold`` parameter to automatically lighten the text color based on the luminance of the background color. This improves readability with dark background colors without the need to limit the background colormap range. (:issue:`21258`)
805805
- Require at least 0.28.2 version of ``cython`` to support read-only memoryviews (:issue:`21688`)
806806
- :meth:`~pandas.io.formats.style.Styler.background_gradient` now also supports tablewise application (in addition to rowwise and columnwise) with ``axis=None`` (:issue:`15204`)
807+
- :meth:`DataFrame.nlargest` and :meth:`DataFrame.nsmallest` now returns the correct n values when keep != 'all' also when tied on the first columns (:issue:`22752`)
807808
- :meth:`~pandas.io.formats.style.Styler.bar` now also supports tablewise application (in addition to rowwise and columnwise) with ``axis=None`` and setting clipping range with ``vmin`` and ``vmax`` (:issue:`21548` and :issue:`21526`). ``NaN`` values are also handled properly.
808809
- Logical operations ``&, |, ^`` between :class:`Series` and :class:`Index` will no longer raise ``ValueError`` (:issue:`22092`)
809810
-

pandas/core/algorithms.py

+24-15
Original file line numberDiff line numberDiff line change
@@ -1221,27 +1221,36 @@ def get_indexer(current_indexer, other_indexer):
12211221
# and break
12221222
# Otherwise we must save the index of the non duplicated values
12231223
# and set the next cur_frame to cur_frame filtered on all
1224-
# duplcicated values (#GH15297)
1224+
# duplicated values (#GH15297)
12251225
series = cur_frame[column]
1226-
values = getattr(series, method)(cur_n, keep=self.keep)
12271226
is_last_column = len(columns) - 1 == i
1228-
if is_last_column or values.nunique() == series.isin(values).sum():
1227+
values = getattr(series, method)(
1228+
cur_n,
1229+
keep=self.keep if is_last_column else 'all')
12291230

1230-
# Last column in columns or values are unique in
1231-
# series => values
1232-
# is all that matters
1231+
# Are we at the last column or have values got less
1232+
# or equal number of items to what we need.
1233+
if is_last_column or len(values) <= cur_n:
12331234
indexer = get_indexer(indexer, values.index)
12341235
break
12351236

1236-
duplicated_filter = series.duplicated(keep=False)
1237-
duplicated = values[duplicated_filter]
1238-
non_duplicated = values[~duplicated_filter]
1239-
indexer = get_indexer(indexer, non_duplicated.index)
1240-
1241-
# Must set cur frame to include all duplicated values
1242-
# to consider for the next column, we also can reduce
1243-
# cur_n by the current length of the indexer
1244-
cur_frame = cur_frame[series.isin(duplicated)]
1237+
# Find the border value. (Everything above/below that are
1238+
# good to go)
1239+
if method == 'nlargest':
1240+
border_value = values.min()
1241+
else:
1242+
border_value = values.max()
1243+
1244+
border_values = values[values == border_value]
1245+
# All other values than the border value
1246+
# is larger/smaller than the border value. So
1247+
# they should go.
1248+
safe_values = values[values != border_value]
1249+
indexer = get_indexer(indexer, safe_values.index)
1250+
1251+
# Now we have to break the tie among the
1252+
# border value rows.
1253+
cur_frame = cur_frame[border_values]
12451254
cur_n = n - len(indexer)
12461255

12471256
frame = frame.take(indexer)

pandas/tests/series/test_analytics.py

+23
Original file line numberDiff line numberDiff line change
@@ -1996,6 +1996,29 @@ def test_duplicate_keep_all_ties(self):
19961996
expected = Series([6, 7, 7, 7, 7], index=[7, 3, 4, 5, 6])
19971997
assert_series_equal(result, expected)
19981998

1999+
@pytest.mark.parametrize('method,expected', [
2000+
('nlargest',
2001+
pd.DataFrame({'a': [2, 2, 2, 1], 'b': [3, 2, 1, 3]},
2002+
index=[2, 1, 0, 3])),
2003+
('nsmallest',
2004+
pd.DataFrame({'a': [2, 1, 1, 1], 'b': [1, 3, 2, 1]},
2005+
index=[0, 3, 4, 5]))])
2006+
def test_duplicates_on_starter_columns(self, method, expected):
2007+
# regression test for #22752
2008+
2009+
df = pd.DataFrame({
2010+
'a': [2, 2, 2, 1, 1, 1],
2011+
'b': [1, 2, 3, 3, 2, 1]
2012+
})
2013+
2014+
result = getattr(df, method)(
2015+
4, columns=['a', 'b']
2016+
).sort_values(
2017+
['a', 'b'], ascending=False
2018+
)
2019+
2020+
assert_frame_equal(result, expected)
2021+
19992022

20002023
class TestCategoricalSeriesAnalytics(object):
20012024

0 commit comments

Comments
 (0)