Skip to content

Commit 5f772db

Browse files
author
Roger Thomas
committed
Fix nsmallest/nlargest With Identical Values
Remove Add comments
1 parent 1b53d88 commit 5f772db

File tree

4 files changed

+85
-20
lines changed

4 files changed

+85
-20
lines changed

doc/source/whatsnew/v0.20.0.txt

+1
Original file line numberDiff line numberDiff line change
@@ -1031,6 +1031,7 @@ Reshaping
10311031
- Bug in ``pd.pivot_table()`` where no error was raised when values argument was not in the columns (:issue:`14938`)
10321032
- Bug in ``pd.concat()`` in which concatting with an empty dataframe with ``join='inner'`` was being improperly handled (:issue:`15328`)
10331033
- Bug with ``sort=True`` in ``DataFrame.join`` and ``pd.merge`` when joining on indexes (:issue:`15582`)
1034+
- Bug in ``DataFrame.nsmallest`` and ``DataFrame.nlargest`` where identical values resulted in duplicated rows (:issue:`15297`)
10341035

10351036
Numeric
10361037
^^^^^^^

pandas/core/algorithms.py

+29-4
Original file line numberDiff line numberDiff line change
@@ -948,10 +948,35 @@ def select_n_frame(frame, columns, n, method, keep):
948948
if not is_list_like(columns):
949949
columns = [columns]
950950
columns = list(columns)
951-
ser = getattr(frame[columns[0]], method)(n, keep=keep)
952-
if isinstance(ser, Series):
953-
ser = ser.to_frame()
954-
return ser.merge(frame, on=columns[0], left_index=True)[frame.columns]
951+
ascending = method == 'nsmallest'
952+
original_frame, original_index = frame, frame.index
953+
frame.reset_index(drop=True, inplace=True)
954+
for i, column in enumerate(columns):
955+
# For each column in columns we peform ``method`` on this frame
956+
# To guard against the possibility ``method`` column has duplicate
957+
# values that must be considered for futher columns (# GH15297) we
958+
# filter using isin on the values returned by ``method``. If there are
959+
# no duplicated values, we simply reindex like the values returned
960+
# by ``method``, otherwise we sort the frame and continue
961+
series = frame[column]
962+
values = getattr(series, method)(n, keep=keep)
963+
if i + 1 == len(columns):
964+
# This is the last column => duplicates here don't matter
965+
frame = frame.reindex(values.index)
966+
else:
967+
filtered_frame = frame[series.isin(values)]
968+
if len(filtered_frame) == len(values):
969+
# Values are unique in series => reindex and break
970+
frame = frame.reindex(values.index)
971+
break
972+
# Values are not unique in series => sort and continue
973+
frame = filtered_frame.sort_values(
974+
column, ascending=ascending
975+
)
976+
original_frame.index = original_index # Restore the index
977+
# Below we set the index of the returning frame to the original index
978+
frame.index = original_index[frame.index]
979+
return frame
955980

956981

957982
def _finalize_nsmallest(arr, kth_val, n, keep, narr):

pandas/tests/frame/test_analytics.py

+38-16
Original file line numberDiff line numberDiff line change
@@ -1140,6 +1140,20 @@ def test_nlargest_multiple_columns(self):
11401140
expected = df.sort_values(['a', 'b'], ascending=False).head(5)
11411141
tm.assert_frame_equal(result, expected)
11421142

1143+
def test_nlargest_nsmallest_identical_values(self):
1144+
# GH15297
1145+
df = pd.DataFrame({'a': [1] * 5, 'b': [1, 2, 3, 4, 5]})
1146+
1147+
result = df.nlargest(3, 'a')
1148+
expected = pd.DataFrame(
1149+
{'a': [1] * 3, 'b': [1, 2, 3]}, index=[0, 1, 2]
1150+
)
1151+
tm.assert_frame_equal(result, expected)
1152+
1153+
result = df.nsmallest(3, 'a')
1154+
expected = pd.DataFrame({'a': [1] * 3, 'b': [1, 2, 3]})
1155+
tm.assert_frame_equal(result, expected)
1156+
11431157
def test_nsmallest(self):
11441158
from string import ascii_lowercase
11451159
df = pd.DataFrame({'a': np.random.permutation(10),
@@ -1159,33 +1173,41 @@ def test_nsmallest_multiple_columns(self):
11591173

11601174
def test_nsmallest_nlargest_duplicate_index(self):
11611175
# GH 13412
1162-
df = pd.DataFrame({'a': [1, 2, 3, 4],
1163-
'b': [4, 3, 2, 1],
1164-
'c': [0, 1, 2, 3]},
1165-
index=[0, 0, 1, 1])
1166-
result = df.nsmallest(4, 'a')
1167-
expected = df.sort_values('a').head(4)
1176+
df = pd.DataFrame({'a': [1, 2, 3, 4, 4],
1177+
'b': [1, 1, 1, 1, 1],
1178+
'c': [0, 1, 2, 5, 4]},
1179+
index=[0, 0, 1, 1, 1])
1180+
1181+
result = df.nsmallest(4, ['a', 'b', 'c'])
1182+
expected = df.sort_values(['a', 'b', 'c']).head(4)
11681183
tm.assert_frame_equal(result, expected)
11691184

1170-
result = df.nlargest(4, 'a')
1171-
expected = df.sort_values('a', ascending=False).head(4)
1185+
result = df.nlargest(4, ['a', 'b', 'c'])
1186+
expected = df.sort_values(['a', 'b', 'c'], ascending=False).head(4)
11721187
tm.assert_frame_equal(result, expected)
11731188

1174-
result = df.nsmallest(4, ['a', 'c'])
1175-
expected = df.sort_values(['a', 'c']).head(4)
1189+
result = df.nlargest(4, ['c', 'b', 'a'])
1190+
expected = df.sort_values(['c', 'b', 'a'], ascending=False).head(4)
11761191
tm.assert_frame_equal(result, expected)
11771192

1178-
result = df.nsmallest(4, ['c', 'a'])
1179-
expected = df.sort_values(['c', 'a']).head(4)
1193+
result = df.nsmallest(4, ['c', 'b', 'a'])
1194+
expected = df.sort_values(['c', 'b', 'a']).head(4)
11801195
tm.assert_frame_equal(result, expected)
11811196

1182-
result = df.nlargest(4, ['a', 'c'])
1183-
expected = df.sort_values(['a', 'c'], ascending=False).head(4)
1197+
# Test all duplicates still returns df of size n
1198+
result = df.nsmallest(2, 'b')
1199+
expected = df.sort_values('b').head(2)
11841200
tm.assert_frame_equal(result, expected)
11851201

1186-
result = df.nlargest(4, ['c', 'a'])
1187-
expected = df.sort_values(['c', 'a'], ascending=False).head(4)
1202+
def test_nsmallest_nlargest_duplicate_multi_index(self):
1203+
df = pd.DataFrame({'a': [1, 2, 3, 3, 3],
1204+
'b': [1, 1, 1, 1, 1],
1205+
'c': [0, 1, 2, 5, 4]},
1206+
index=[[0, 0, 0, 0, 0], [1, 1, 1, 1, 1]])
1207+
result = df.nsmallest(4, ['a', 'b', 'c'])
1208+
expected = df.sort_values(['a', 'b', 'c']).head(4)
11881209
tm.assert_frame_equal(result, expected)
1210+
11891211
# ----------------------------------------------------------------------
11901212
# Isin
11911213

pandas/tests/series/test_analytics.py

+17
Original file line numberDiff line numberDiff line change
@@ -1455,6 +1455,23 @@ def test_nsmallest_nlargest(self):
14551455
expected = s.sort_values().head(3)
14561456
assert_series_equal(result, expected)
14571457

1458+
# GH 15297
1459+
s = Series([1] * 5, index=[1, 2, 3, 4, 5])
1460+
expected_first = Series([1] * 3, index=[1, 2, 3])
1461+
expected_last = Series([1] * 3, index=[5, 4, 3])
1462+
1463+
result = s.nsmallest(3)
1464+
assert_series_equal(result, expected_first)
1465+
1466+
result = s.nsmallest(3, keep='last')
1467+
assert_series_equal(result, expected_last)
1468+
1469+
result = s.nlargest(3)
1470+
assert_series_equal(result, expected_first)
1471+
1472+
result = s.nlargest(3, keep='last')
1473+
assert_series_equal(result, expected_last)
1474+
14581475
def test_sort_index_level(self):
14591476
mi = MultiIndex.from_tuples([[1, 1, 3], [1, 1, 1]], names=list('ABC'))
14601477
s = Series([1, 2], mi)

0 commit comments

Comments
 (0)