Skip to content

Commit 1a8043f

Browse files
author
Roger Thomas
committed
Fix nsmallest/nlargest With Identical Values
1 parent de589c2 commit 1a8043f

File tree

4 files changed

+86
-20
lines changed

4 files changed

+86
-20
lines changed

doc/source/whatsnew/v0.20.0.txt

+1
Original file line numberDiff line numberDiff line change
@@ -1057,3 +1057,4 @@ Bug Fixes
10571057
- Bug in ``pd.melt()`` where passing a tuple value for ``value_vars`` caused a ``TypeError`` (:issue:`15348`)
10581058
- Bug in ``.eval()`` which caused multiline evals to fail with local variables not on the first line (:issue:`15342`)
10591059
- Bug in ``pd.read_msgpack()`` which did not allow to load dataframe with an index of type ``CategoricalIndex`` (:issue:`15487`)
1060+
- Bug in ``DataFrame.nsmallest`` and ``DataFrame.nlargest`` where identical values resulted in duplicated rows (:issue:`15297`)

pandas/core/algorithms.py

+30-4
Original file line numberDiff line numberDiff line change
@@ -948,10 +948,36 @@ def select_n_frame(frame, columns, n, method, keep):
948948
if not is_list_like(columns):
949949
columns = [columns]
950950
columns = list(columns)
951-
ser = getattr(frame[columns[0]], method)(n, keep=keep)
952-
if isinstance(ser, Series):
953-
ser = ser.to_frame()
954-
return ser.merge(frame, on=columns[0], left_index=True)[frame.columns]
951+
ascending = method == 'nsmallest'
952+
index_is_unique = frame.index.is_unique
953+
if not index_is_unique:
954+
# If index not unique we must reset index to allow re-indexing below
955+
# We must save frame's index to tmp
956+
tmp = Series(np.arange(len(frame)), index=frame.index)
957+
frame = frame.reset_index(drop=True)
958+
for i, column in enumerate(columns):
959+
series = frame[column]
960+
values = getattr(series, method)(n, keep=keep)
961+
if i + 1 == len(columns):
962+
frame = frame.reindex(values.index)
963+
else:
964+
filtered_frame = frame[series.isin(values)]
965+
if len(filtered_frame) == len(values):
966+
# Values are unique in series => reindex and break
967+
frame = frame.reindex(values.index)
968+
break
969+
frame = filtered_frame.sort_values(
970+
column, ascending=ascending
971+
)
972+
if not index_is_unique:
973+
# This below line of code is a little obfuscated. We are setting the
974+
# index of the frame back to it's original index using saved original
975+
# index stored in tmp. Because we reset the index on frame (above)
976+
# frame's index is now purely a unique integer index (as is tmp) =>
977+
# to restore the index to frame we can index tmp's index with frame's
978+
# index...
979+
frame.index = tmp.index[frame.index]
980+
return frame
955981

956982

957983
def _finalize_nsmallest(arr, kth_val, n, keep, narr):

pandas/tests/frame/test_analytics.py

+38-16
Original file line numberDiff line numberDiff line change
@@ -1140,6 +1140,20 @@ def test_nlargest_multiple_columns(self):
11401140
expected = df.sort_values(['a', 'b'], ascending=False).head(5)
11411141
tm.assert_frame_equal(result, expected)
11421142

1143+
def test_nlargest_nsmallest_identical_values(self):
1144+
# GH15297
1145+
df = pd.DataFrame({'a': [1] * 5, 'b': [1, 2, 3, 4, 5]})
1146+
1147+
result = df.nlargest(3, 'a')
1148+
expected = pd.DataFrame(
1149+
{'a': [1] * 3, 'b': [1, 2, 3]}, index=[0, 1, 2]
1150+
)
1151+
tm.assert_frame_equal(result, expected)
1152+
1153+
result = df.nsmallest(3, 'a')
1154+
expected = pd.DataFrame({'a': [1] * 3, 'b': [1, 2, 3]})
1155+
tm.assert_frame_equal(result, expected)
1156+
11431157
def test_nsmallest(self):
11441158
from string import ascii_lowercase
11451159
df = pd.DataFrame({'a': np.random.permutation(10),
@@ -1159,33 +1173,41 @@ def test_nsmallest_multiple_columns(self):
11591173

11601174
def test_nsmallest_nlargest_duplicate_index(self):
11611175
# GH 13412
1162-
df = pd.DataFrame({'a': [1, 2, 3, 4],
1163-
'b': [4, 3, 2, 1],
1164-
'c': [0, 1, 2, 3]},
1165-
index=[0, 0, 1, 1])
1166-
result = df.nsmallest(4, 'a')
1167-
expected = df.sort_values('a').head(4)
1176+
df = pd.DataFrame({'a': [1, 2, 3, 4, 4],
1177+
'b': [1, 1, 1, 1, 1],
1178+
'c': [0, 1, 2, 5, 4]},
1179+
index=[0, 0, 1, 1, 1])
1180+
1181+
result = df.nsmallest(4, ['a', 'b', 'c'])
1182+
expected = df.sort_values(['a', 'b', 'c']).head(4)
11681183
tm.assert_frame_equal(result, expected)
11691184

1170-
result = df.nlargest(4, 'a')
1171-
expected = df.sort_values('a', ascending=False).head(4)
1185+
result = df.nlargest(4, ['a', 'b', 'c'])
1186+
expected = df.sort_values(['a', 'b', 'c'], ascending=False).head(4)
11721187
tm.assert_frame_equal(result, expected)
11731188

1174-
result = df.nsmallest(4, ['a', 'c'])
1175-
expected = df.sort_values(['a', 'c']).head(4)
1189+
result = df.nlargest(4, ['c', 'b', 'a'])
1190+
expected = df.sort_values(['c', 'b', 'a'], ascending=False).head(4)
11761191
tm.assert_frame_equal(result, expected)
11771192

1178-
result = df.nsmallest(4, ['c', 'a'])
1179-
expected = df.sort_values(['c', 'a']).head(4)
1193+
result = df.nsmallest(4, ['c', 'b', 'a'])
1194+
expected = df.sort_values(['c', 'b', 'a']).head(4)
11801195
tm.assert_frame_equal(result, expected)
11811196

1182-
result = df.nlargest(4, ['a', 'c'])
1183-
expected = df.sort_values(['a', 'c'], ascending=False).head(4)
1197+
# Test all duplicates still returns df of size n
1198+
result = df.nsmallest(2, 'b')
1199+
expected = df.sort_values('b').head(2)
11841200
tm.assert_frame_equal(result, expected)
11851201

1186-
result = df.nlargest(4, ['c', 'a'])
1187-
expected = df.sort_values(['c', 'a'], ascending=False).head(4)
1202+
def test_nsmallest_nlargest_duplicate_multi_index(self):
1203+
df = pd.DataFrame({'a': [1, 2, 3, 3, 3],
1204+
'b': [1, 1, 1, 1, 1],
1205+
'c': [0, 1, 2, 5, 4]},
1206+
index=[[0, 0, 0, 0, 0], [1, 1, 1, 1, 1]])
1207+
result = df.nsmallest(4, ['a', 'b', 'c'])
1208+
expected = df.sort_values(['a', 'b', 'c']).head(4)
11881209
tm.assert_frame_equal(result, expected)
1210+
11891211
# ----------------------------------------------------------------------
11901212
# Isin
11911213

pandas/tests/series/test_analytics.py

+17
Original file line numberDiff line numberDiff line change
@@ -1455,6 +1455,23 @@ def test_nsmallest_nlargest(self):
14551455
expected = s.sort_values().head(3)
14561456
assert_series_equal(result, expected)
14571457

1458+
# GH 15297
1459+
s = Series([1] * 5, index=[1, 2, 3, 4, 5])
1460+
expected_first = Series([1] * 3, index=[1, 2, 3])
1461+
expected_last = Series([1] * 3, index=[5, 4, 3])
1462+
1463+
result = s.nsmallest(3)
1464+
assert_series_equal(result, expected_first)
1465+
1466+
result = s.nsmallest(3, keep='last')
1467+
assert_series_equal(result, expected_last)
1468+
1469+
result = s.nlargest(3)
1470+
assert_series_equal(result, expected_first)
1471+
1472+
result = s.nlargest(3, keep='last')
1473+
assert_series_equal(result, expected_last)
1474+
14581475
def test_sort_index_level(self):
14591476
mi = MultiIndex.from_tuples([[1, 1, 3], [1, 1, 1]], names=list('ABC'))
14601477
s = Series([1, 2], mi)

0 commit comments

Comments
 (0)