Skip to content

Commit c09c5a2

Browse files
committed
BUG: nlargest/nsmallest gave wrong result (#22752)
When asking for the n largest/smallest rows in a dataframe nlargest/nsmallest sometimes failed to differentiate the correct result based on the latter columns.
1 parent 1a12c41 commit c09c5a2

File tree

3 files changed

+44
-11
lines changed

3 files changed

+44
-11
lines changed

doc/source/whatsnew/v0.24.0.txt

+1
Original file line numberDiff line numberDiff line change
@@ -804,6 +804,7 @@ Other
804804
- :meth:`~pandas.io.formats.style.Styler.background_gradient` now takes a ``text_color_threshold`` parameter to automatically lighten the text color based on the luminance of the background color. This improves readability with dark background colors without the need to limit the background colormap range. (:issue:`21258`)
805805
- Require at least 0.28.2 version of ``cython`` to support read-only memoryviews (:issue:`21688`)
806806
- :meth:`~pandas.io.formats.style.Styler.background_gradient` now also supports tablewise application (in addition to rowwise and columnwise) with ``axis=None`` (:issue:`15204`)
807+
- :meth:`DataFrame.nlargest` and :meth:`DataFrame.nsmallest` now returns the correct n values when keep != 'all' also when tied on the first columns (:issue:`22752`)
807808
- :meth:`~pandas.io.formats.style.Styler.bar` now also supports tablewise application (in addition to rowwise and columnwise) with ``axis=None`` and setting clipping range with ``vmin`` and ``vmax`` (:issue:`21548` and :issue:`21526`). ``NaN`` values are also handled properly.
808809
- Logical operations ``&, |, ^`` between :class:`Series` and :class:`Index` will no longer raise ``ValueError`` (:issue:`22092`)
809810
-

pandas/core/algorithms.py

+20-11
Original file line numberDiff line numberDiff line change
@@ -1221,27 +1221,36 @@ def get_indexer(current_indexer, other_indexer):
12211221
# and break
12221222
# Otherwise we must save the index of the non duplicated values
12231223
# and set the next cur_frame to cur_frame filtered on all
1224-
# duplcicated values (#GH15297)
1224+
# duplicated values (#GH15297)
12251225
series = cur_frame[column]
12261226
values = getattr(series, method)(cur_n, keep=self.keep)
12271227
is_last_column = len(columns) - 1 == i
1228-
if is_last_column or values.nunique() == series.isin(values).sum():
12291228

1230-
# Last column in columns or values are unique in
1231-
# series => values
1232-
# is all that matters
1229+
if is_last_column:
1230+
# Last column reached if two values are identical
1231+
# we can't separate them or
12331232
indexer = get_indexer(indexer, values.index)
12341233
break
12351234

1236-
duplicated_filter = series.duplicated(keep=False)
1237-
duplicated = values[duplicated_filter]
1238-
non_duplicated = values[~duplicated_filter]
1239-
indexer = get_indexer(indexer, non_duplicated.index)
1235+
# Find the value at the border
1236+
if method == 'nlargest':
1237+
border_value = values.min()
1238+
else:
1239+
border_value = values.max()
1240+
1241+
border_values = series == border_value
1242+
if border_values.count() == 1:
1243+
# The largest/smallest value is unique
1244+
# we are done.
1245+
indexer = get_indexer(indexer, values.index)
1246+
break
12401247

1241-
# Must set cur frame to include all duplicated values
1248+
safe_values = values[values != border_value]
1249+
indexer = get_indexer(indexer, safe_values.index)
1250+
# Must set cur frame to include all border values
12421251
# to consider for the next column, we also can reduce
12431252
# cur_n by the current length of the indexer
1244-
cur_frame = cur_frame[series.isin(duplicated)]
1253+
cur_frame = cur_frame[border_values]
12451254
cur_n = n - len(indexer)
12461255

12471256
frame = frame.take(indexer)

pandas/tests/series/test_analytics.py

+23
Original file line numberDiff line numberDiff line change
@@ -1996,6 +1996,29 @@ def test_duplicate_keep_all_ties(self):
19961996
expected = Series([6, 7, 7, 7, 7], index=[7, 3, 4, 5, 6])
19971997
assert_series_equal(result, expected)
19981998

1999+
@pytest.mark.parametrize('method,expected', [
2000+
('nlargest',
2001+
pd.DataFrame({'a': [2, 2, 2, 1], 'b': [3, 2, 1, 3]},
2002+
index=[2, 1, 0, 3])),
2003+
('nsmallest',
2004+
pd.DataFrame({'a': [2, 1, 1, 1], 'b': [1, 3, 2, 1]},
2005+
index=[0, 3, 4, 5]))])
2006+
def test_duplicates_on_starter_columns(self, method, expected):
2007+
# regression test for #22752
2008+
2009+
df = pd.DataFrame({
2010+
'a': [2, 2, 2, 1, 1, 1],
2011+
'b': [1, 2, 3, 3, 2, 1]
2012+
})
2013+
2014+
result = getattr(df, method)(
2015+
4, columns=['a', 'b']
2016+
).sort_values(
2017+
['a', 'b'], ascending=False
2018+
)
2019+
2020+
assert_frame_equal(result, expected)
2021+
19992022

20002023
class TestCategoricalSeriesAnalytics(object):
20012024

0 commit comments

Comments
 (0)