Skip to content

Commit 45ab37e

Browse files
committed
BUG: nlargest/nsmallest gave wrong result (#22752)
When asking for the n largest/smallest rows in a dataframe nlargest/nsmallest sometimes failed to differentiate the correct result based on the latter columns.
1 parent fb784ca commit 45ab37e

File tree

4 files changed

+55
-24
lines changed

4 files changed

+55
-24
lines changed

asv_bench/benchmarks/frame_methods.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -505,14 +505,21 @@ class NSort(object):
505505
param_names = ['keep']
506506

507507
def setup(self, keep):
508-
self.df = DataFrame(np.random.randn(1000, 3), columns=list('ABC'))
508+
self.df = DataFrame(np.random.randn(100000, 3),
509+
columns=list('ABC'))
509510

510-
def time_nlargest(self, keep):
511+
def time_nlargest_one_column(self, keep):
511512
self.df.nlargest(100, 'A', keep=keep)
512513

513-
def time_nsmallest(self, keep):
514+
def time_nlargest_two_columns(self, keep):
515+
self.df.nlargest(100, ['A', 'B'], keep=keep)
516+
517+
def time_nsmallest_one_column(self, keep):
514518
self.df.nsmallest(100, 'A', keep=keep)
515519

520+
def time_nsmallest_two_columns(self, keep):
521+
self.df.nsmallest(100, ['A', 'B'], keep=keep)
522+
516523

517524
class Describe(object):
518525

doc/source/whatsnew/v0.24.0.txt

+1
Original file line numberDiff line numberDiff line change
@@ -814,6 +814,7 @@ Other
814814
- :meth:`~pandas.io.formats.style.Styler.background_gradient` now takes a ``text_color_threshold`` parameter to automatically lighten the text color based on the luminance of the background color. This improves readability with dark background colors without the need to limit the background colormap range. (:issue:`21258`)
815815
- Require at least 0.28.2 version of ``cython`` to support read-only memoryviews (:issue:`21688`)
816816
- :meth:`~pandas.io.formats.style.Styler.background_gradient` now also supports tablewise application (in addition to rowwise and columnwise) with ``axis=None`` (:issue:`15204`)
817+
- :meth:`DataFrame.nlargest` and :meth:`DataFrame.nsmallest` now returns the correct n values when keep != 'all' also when tied on the first columns (:issue:`22752`)
817818
- :meth:`~pandas.io.formats.style.Styler.bar` now also supports tablewise application (in addition to rowwise and columnwise) with ``axis=None`` and setting clipping range with ``vmin`` and ``vmax`` (:issue:`21548` and :issue:`21526`). ``NaN`` values are also handled properly.
818819
- Logical operations ``&, |, ^`` between :class:`Series` and :class:`Index` will no longer raise ``ValueError`` (:issue:`22092`)
819820
-

pandas/core/algorithms.py

+26-21
Original file line numberDiff line numberDiff line change
@@ -1214,41 +1214,46 @@ def get_indexer(current_indexer, other_indexer):
12141214
indexer = Int64Index([])
12151215

12161216
for i, column in enumerate(columns):
1217-
12181217
# For each column we apply method to cur_frame[column].
1219-
# If it is the last column in columns, or if the values
1220-
# returned are unique in frame[column] we save this index
1221-
# and break
1222-
# Otherwise we must save the index of the non duplicated values
1223-
# and set the next cur_frame to cur_frame filtered on all
1224-
# duplcicated values (#GH15297)
1218+
# If it's the last column or if we have the number of
1219+
# results desired we are done.
1220+
# Otherwise there are duplicates of the largest/smallest
1221+
# value and we need to look at the rest of the columns
1222+
# to determine which of the rows with the largest/smallest
1223+
# value in the column to keep.
12251224
series = cur_frame[column]
1226-
values = getattr(series, method)(cur_n, keep=self.keep)
12271225
is_last_column = len(columns) - 1 == i
1228-
if is_last_column or values.nunique() == series.isin(values).sum():
1226+
values = getattr(series, method)(
1227+
cur_n,
1228+
keep=self.keep if is_last_column else 'all')
12291229

1230-
# Last column in columns or values are unique in
1231-
# series => values
1232-
# is all that matters
1230+
if is_last_column or len(values) <= cur_n:
12331231
indexer = get_indexer(indexer, values.index)
12341232
break
12351233

1236-
duplicated_filter = series.duplicated(keep=False)
1237-
duplicated = values[duplicated_filter]
1238-
non_duplicated = values[~duplicated_filter]
1239-
indexer = get_indexer(indexer, non_duplicated.index)
1234+
last_value = values == values[values.index[-1]]
1235+
safe_values = values[~last_value]
1236+
unsafe_values = values[last_value]
12401237

1241-
# Must set cur frame to include all duplicated values
1242-
# to consider for the next column, we also can reduce
1243-
# cur_n by the current length of the indexer
1244-
cur_frame = cur_frame[series.isin(duplicated)]
1238+
indexer = get_indexer(indexer, safe_values.index)
1239+
cur_frame = cur_frame.loc[unsafe_values.index]
12451240
cur_n = n - len(indexer)
12461241

12471242
frame = frame.take(indexer)
12481243

12491244
# Restore the index on frame
12501245
frame.index = original_index.take(indexer)
1251-
return frame
1246+
1247+
# If there is only one column, the frame is already sorted.
1248+
if len(columns) == 1:
1249+
return frame
1250+
1251+
ascending = method == 'nsmallest'
1252+
1253+
return frame.sort_values(
1254+
columns,
1255+
ascending=ascending,
1256+
kind='mergesort')
12521257

12531258

12541259
# ------- ## ---- #

pandas/tests/frame/test_analytics.py

+18
Original file line numberDiff line numberDiff line change
@@ -2095,6 +2095,24 @@ def test_n_all_dtypes(self, df_main_dtypes):
20952095
df.nsmallest(2, list(set(df) - {'category_string', 'string'}))
20962096
df.nlargest(2, list(set(df) - {'category_string', 'string'}))
20972097

2098+
@pytest.mark.parametrize('method,expected', [
2099+
('nlargest',
2100+
pd.DataFrame({'a': [2, 2, 2, 1], 'b': [3, 2, 1, 3]},
2101+
index=[2, 1, 0, 3])),
2102+
('nsmallest',
2103+
pd.DataFrame({'a': [1, 1, 1, 2], 'b': [1, 2, 3, 1]},
2104+
index=[5, 4, 3, 0]))])
2105+
def test_duplicates_on_starter_columns(self, method, expected):
2106+
# regression test for #22752
2107+
2108+
df = pd.DataFrame({
2109+
'a': [2, 2, 2, 1, 1, 1],
2110+
'b': [1, 2, 3, 3, 2, 1]
2111+
})
2112+
2113+
result = getattr(df, method)(4, columns=['a', 'b'])
2114+
tm.assert_frame_equal(result, expected)
2115+
20982116
def test_n_identical_values(self):
20992117
# GH15297
21002118
df = pd.DataFrame({'a': [1] * 5, 'b': [1, 2, 3, 4, 5]})

0 commit comments

Comments
 (0)