Skip to content

Commit a451108

Browse files
author
Roger Thomas
committed
Fix nsmallest/nlargest With Identical Values
1 parent de589c2 commit a451108

File tree

4 files changed

+58
-23
lines changed

4 files changed

+58
-23
lines changed

doc/source/whatsnew/v0.20.0.txt

+1
Original file line numberDiff line numberDiff line change
@@ -1057,3 +1057,4 @@ Bug Fixes
10571057
- Bug in ``pd.melt()`` where passing a tuple value for ``value_vars`` caused a ``TypeError`` (:issue:`15348`)
10581058
- Bug in ``.eval()`` which caused multiline evals to fail with local variables not on the first line (:issue:`15342`)
10591059
- Bug in ``pd.read_msgpack()`` which did not allow to load dataframe with an index of type ``CategoricalIndex`` (:issue:`15487`)
1060+
- Bug in ``DataFrame.nsmallest`` and ``DataFrame.nlargest`` where identical values resulted in duplicated rows (:issue:`15297`)

pandas/core/algorithms.py

+19-6
Original file line numberDiff line numberDiff line change
@@ -944,14 +944,27 @@ def select_n_frame(frame, columns, n, method, keep):
944944
-------
945945
nordered : DataFrame
946946
"""
947-
from pandas.core.series import Series
948947
if not is_list_like(columns):
949948
columns = [columns]
950-
columns = list(columns)
951-
ser = getattr(frame[columns[0]], method)(n, keep=keep)
952-
if isinstance(ser, Series):
953-
ser = ser.to_frame()
954-
return ser.merge(frame, on=columns[0], left_index=True)[frame.columns]
949+
else:
950+
columns = list(columns)
951+
reverse = method == 'nlargest'
952+
for i, column in enumerate(columns):
953+
series = frame[column]
954+
if reverse:
955+
inds = series.argsort()[::-1][:n]
956+
else:
957+
inds = series.argsort()[:n]
958+
values = series.take(inds)
959+
if i != len(columns) - 1 and values.duplicated().any():
960+
# This series has duplicate values => we must consider all rows in
961+
# frame that match `values`
962+
# The first condition is for the last column. In this case we don't
963+
# care if there are duplicates => no need to do the check
964+
frame = frame[series.isin(values)]
965+
else:
966+
break
967+
return frame.take(inds)
955968

956969

957970
def _finalize_nsmallest(arr, kth_val, n, keep, narr):

pandas/tests/frame/test_analytics.py

+28-17
Original file line numberDiff line numberDiff line change
@@ -1140,6 +1140,18 @@ def test_nlargest_multiple_columns(self):
11401140
expected = df.sort_values(['a', 'b'], ascending=False).head(5)
11411141
tm.assert_frame_equal(result, expected)
11421142

1143+
def test_nlargest_nsmallest_identical_values(self):
1144+
# GH15297
1145+
df = pd.DataFrame({'a': [1] * 5, 'b': [1, 2, 3, 4, 5]})
1146+
1147+
result = df.nlargest(3, 'a')
1148+
expected = pd.DataFrame({'a': [1] * 3, 'b': [5, 4, 3]}, index=[4, 3, 2])
1149+
tm.assert_frame_equal(result, expected)
1150+
1151+
result = df.nsmallest(3, 'a')
1152+
expected = pd.DataFrame({'a': [1] * 3, 'b': [1, 2, 3]})
1153+
tm.assert_frame_equal(result, expected)
1154+
11431155
def test_nsmallest(self):
11441156
from string import ascii_lowercase
11451157
df = pd.DataFrame({'a': np.random.permutation(10),
@@ -1159,33 +1171,32 @@ def test_nsmallest_multiple_columns(self):
11591171

11601172
def test_nsmallest_nlargest_duplicate_index(self):
11611173
# GH 13412
1162-
df = pd.DataFrame({'a': [1, 2, 3, 4],
1163-
'b': [4, 3, 2, 1],
1164-
'c': [0, 1, 2, 3]},
1165-
index=[0, 0, 1, 1])
1166-
result = df.nsmallest(4, 'a')
1167-
expected = df.sort_values('a').head(4)
1168-
tm.assert_frame_equal(result, expected)
1174+
df = pd.DataFrame({'a': [1, 2, 3, 3, 3],
1175+
'b': [1, 1, 1, 1, 1],
1176+
'c': [0, 1, 2, 5, 4]},
1177+
index=[0, 0, 1, 1, 1])
11691178

1170-
result = df.nlargest(4, 'a')
1171-
expected = df.sort_values('a', ascending=False).head(4)
1179+
result = df.nsmallest(4, ['a', 'b', 'c'])
1180+
expected = df.sort_values(['a', 'b', 'c']).head(4)
11721181
tm.assert_frame_equal(result, expected)
11731182

1174-
result = df.nsmallest(4, ['a', 'c'])
1175-
expected = df.sort_values(['a', 'c']).head(4)
1183+
result = df.nlargest(4, ['a', 'b', 'c'])
1184+
expected = df.sort_values(['a', 'b', 'c'], ascending=False).head(4)
11761185
tm.assert_frame_equal(result, expected)
11771186

1178-
result = df.nsmallest(4, ['c', 'a'])
1179-
expected = df.sort_values(['c', 'a']).head(4)
1187+
result = df.nlargest(4, ['c', 'b', 'a'])
1188+
expected = df.sort_values(['c', 'b', 'a'], ascending=False).head(4)
11801189
tm.assert_frame_equal(result, expected)
11811190

1182-
result = df.nlargest(4, ['a', 'c'])
1183-
expected = df.sort_values(['a', 'c'], ascending=False).head(4)
1191+
result = df.nsmallest(4, ['c', 'b', 'a'])
1192+
expected = df.sort_values(['c', 'b', 'a']).head(4)
11841193
tm.assert_frame_equal(result, expected)
11851194

1186-
result = df.nlargest(4, ['c', 'a'])
1187-
expected = df.sort_values(['c', 'a'], ascending=False).head(4)
1195+
# Test all duplicates still returns df of size n
1196+
result = df.nsmallest(2, 'b')
1197+
expected = df.sort_values('b').head(2)
11881198
tm.assert_frame_equal(result, expected)
1199+
11891200
# ----------------------------------------------------------------------
11901201
# Isin
11911202

pandas/tests/series/test_analytics.py

+10
Original file line numberDiff line numberDiff line change
@@ -1455,6 +1455,16 @@ def test_nsmallest_nlargest(self):
14551455
expected = s.sort_values().head(3)
14561456
assert_series_equal(result, expected)
14571457

1458+
# GH 15297
1459+
s = Series([1] * 5, index=[1, 2, 3, 4, 5])
1460+
expected = Series([1] * 3, index=[1, 2, 3])
1461+
1462+
result = s.nsmallest(3)
1463+
assert_series_equal(result, expected)
1464+
1465+
result = s.nlargest(3)
1466+
assert_series_equal(result, expected)
1467+
14581468
def test_sort_index_level(self):
14591469
mi = MultiIndex.from_tuples([[1, 1, 3], [1, 1, 1]], names=list('ABC'))
14601470
s = Series([1, 2], mi)

0 commit comments

Comments
 (0)