Skip to content

Commit 4955d67

Browse files
author
Roger Thomas
committed
Fix nsmallest/nlargest With Identical Values
1 parent de589c2 commit 4955d67

File tree

4 files changed

+60
-23
lines changed

4 files changed

+60
-23
lines changed

doc/source/whatsnew/v0.20.0.txt

+1
Original file line numberDiff line numberDiff line change
@@ -1057,3 +1057,4 @@ Bug Fixes
10571057
- Bug in ``pd.melt()`` where passing a tuple value for ``value_vars`` caused a ``TypeError`` (:issue:`15348`)
10581058
- Bug in ``.eval()`` which caused multiline evals to fail with local variables not on the first line (:issue:`15342`)
10591059
- Bug in ``pd.read_msgpack()`` which did not allow to load dataframe with an index of type ``CategoricalIndex`` (:issue:`15487`)
1060+
- Bug in ``DataFrame.nsmallest`` and ``DataFrame.nlargest`` where identical values resulted in duplicated rows (:issue:`15297`)

pandas/core/algorithms.py

+19-6
Original file line numberDiff line numberDiff line change
@@ -944,14 +944,27 @@ def select_n_frame(frame, columns, n, method, keep):
944944
-------
945945
nordered : DataFrame
946946
"""
947-
from pandas.core.series import Series
948947
if not is_list_like(columns):
949948
columns = [columns]
950-
columns = list(columns)
951-
ser = getattr(frame[columns[0]], method)(n, keep=keep)
952-
if isinstance(ser, Series):
953-
ser = ser.to_frame()
954-
return ser.merge(frame, on=columns[0], left_index=True)[frame.columns]
949+
else:
950+
columns = list(columns)
951+
reverse = method == 'nlargest'
952+
for i, column in enumerate(columns):
953+
series = frame[column]
954+
if reverse:
955+
inds = series.argsort()[::-1][:n]
956+
else:
957+
inds = series.argsort()[:n]
958+
values = series.iloc[inds]
959+
if i != len(columns) - 1 and values.duplicated().any():
960+
# This series has duplicate values => we must consider all rows in
961+
# frame that match `values`
962+
# The first condition is for the last column. In this case we don't
963+
# care if there are duplicates => no need to do the check
964+
frame = frame[series.isin(values)]
965+
else:
966+
break
967+
return frame.take(inds)
955968

956969

957970
def _finalize_nsmallest(arr, kth_val, n, keep, narr):

pandas/tests/frame/test_analytics.py

+30-17
Original file line numberDiff line numberDiff line change
@@ -1140,6 +1140,20 @@ def test_nlargest_multiple_columns(self):
11401140
expected = df.sort_values(['a', 'b'], ascending=False).head(5)
11411141
tm.assert_frame_equal(result, expected)
11421142

1143+
def test_nlargest_nsmallest_identical_values(self):
1144+
# GH15297
1145+
df = pd.DataFrame({'a': [1] * 5, 'b': [1, 2, 3, 4, 5]})
1146+
1147+
result = df.nlargest(3, 'a')
1148+
expected = pd.DataFrame(
1149+
{'a': [1] * 3, 'b': [5, 4, 3]}, index=[4, 3, 2]
1150+
)
1151+
tm.assert_frame_equal(result, expected)
1152+
1153+
result = df.nsmallest(3, 'a')
1154+
expected = pd.DataFrame({'a': [1] * 3, 'b': [1, 2, 3]})
1155+
tm.assert_frame_equal(result, expected)
1156+
11431157
def test_nsmallest(self):
11441158
from string import ascii_lowercase
11451159
df = pd.DataFrame({'a': np.random.permutation(10),
@@ -1159,33 +1173,32 @@ def test_nsmallest_multiple_columns(self):
11591173

11601174
def test_nsmallest_nlargest_duplicate_index(self):
11611175
# GH 13412
1162-
df = pd.DataFrame({'a': [1, 2, 3, 4],
1163-
'b': [4, 3, 2, 1],
1164-
'c': [0, 1, 2, 3]},
1165-
index=[0, 0, 1, 1])
1166-
result = df.nsmallest(4, 'a')
1167-
expected = df.sort_values('a').head(4)
1168-
tm.assert_frame_equal(result, expected)
1176+
df = pd.DataFrame({'a': [1, 2, 3, 3, 3],
1177+
'b': [1, 1, 1, 1, 1],
1178+
'c': [0, 1, 2, 5, 4]},
1179+
index=[0, 0, 1, 1, 1])
11691180

1170-
result = df.nlargest(4, 'a')
1171-
expected = df.sort_values('a', ascending=False).head(4)
1181+
result = df.nsmallest(4, ['a', 'b', 'c'])
1182+
expected = df.sort_values(['a', 'b', 'c']).head(4)
11721183
tm.assert_frame_equal(result, expected)
11731184

1174-
result = df.nsmallest(4, ['a', 'c'])
1175-
expected = df.sort_values(['a', 'c']).head(4)
1185+
result = df.nlargest(4, ['a', 'b', 'c'])
1186+
expected = df.sort_values(['a', 'b', 'c'], ascending=False).head(4)
11761187
tm.assert_frame_equal(result, expected)
11771188

1178-
result = df.nsmallest(4, ['c', 'a'])
1179-
expected = df.sort_values(['c', 'a']).head(4)
1189+
result = df.nlargest(4, ['c', 'b', 'a'])
1190+
expected = df.sort_values(['c', 'b', 'a'], ascending=False).head(4)
11801191
tm.assert_frame_equal(result, expected)
11811192

1182-
result = df.nlargest(4, ['a', 'c'])
1183-
expected = df.sort_values(['a', 'c'], ascending=False).head(4)
1193+
result = df.nsmallest(4, ['c', 'b', 'a'])
1194+
expected = df.sort_values(['c', 'b', 'a']).head(4)
11841195
tm.assert_frame_equal(result, expected)
11851196

1186-
result = df.nlargest(4, ['c', 'a'])
1187-
expected = df.sort_values(['c', 'a'], ascending=False).head(4)
1197+
# Test all duplicates still returns df of size n
1198+
result = df.nsmallest(2, 'b')
1199+
expected = df.sort_values('b').head(2)
11881200
tm.assert_frame_equal(result, expected)
1201+
11891202
# ----------------------------------------------------------------------
11901203
# Isin
11911204

pandas/tests/series/test_analytics.py

+10
Original file line numberDiff line numberDiff line change
@@ -1455,6 +1455,16 @@ def test_nsmallest_nlargest(self):
14551455
expected = s.sort_values().head(3)
14561456
assert_series_equal(result, expected)
14571457

1458+
# GH 15297
1459+
s = Series([1] * 5, index=[1, 2, 3, 4, 5])
1460+
expected = Series([1] * 3, index=[1, 2, 3])
1461+
1462+
result = s.nsmallest(3)
1463+
assert_series_equal(result, expected)
1464+
1465+
result = s.nlargest(3)
1466+
assert_series_equal(result, expected)
1467+
14581468
def test_sort_index_level(self):
14591469
mi = MultiIndex.from_tuples([[1, 1, 3], [1, 1, 1]], names=list('ABC'))
14601470
s = Series([1, 2], mi)

0 commit comments

Comments
 (0)