Skip to content

Commit c112252

Browse files
Roger Thomasjreback
Roger Thomas
authored andcommitted
BUG: Fix nsmallest/nlargest With Identical Values
closes pandas-dev#15297 Author: Roger Thomas <[email protected]> Closes pandas-dev#15299 from RogerThomas/fix_nsmallest_nlargest_with_n_identical_values and squashes the following commits: d3964f8 [Roger Thomas] Fix nsmallest/nlargest With Identical Values
1 parent a0b089e commit c112252

File tree

3 files changed

+200
-75
lines changed

3 files changed

+200
-75
lines changed

doc/source/whatsnew/v0.20.0.txt

+1
Original file line numberDiff line numberDiff line change
@@ -1097,6 +1097,7 @@ Reshaping
10971097
- Bug in ``pd.pivot_table()`` where no error was raised when values argument was not in the columns (:issue:`14938`)
10981098
- Bug in ``pd.concat()`` in which concatting with an empty dataframe with ``join='inner'`` was being improperly handled (:issue:`15328`)
10991099
- Bug with ``sort=True`` in ``DataFrame.join`` and ``pd.merge`` when joining on indexes (:issue:`15582`)
1100+
- Bug in ``DataFrame.nsmallest`` and ``DataFrame.nlargest`` where identical values resulted in duplicated rows (:issue:`15297`)
11001101

11011102
Numeric
11021103
^^^^^^^

pandas/core/algorithms.py

+68-7
Original file line numberDiff line numberDiff line change
@@ -931,6 +931,15 @@ def select_n_slow(dropped, n, keep, method):
931931
_select_methods = {'nsmallest': nsmallest, 'nlargest': nlargest}
932932

933933

934+
def _is_valid_dtype_n_method(dtype):
935+
"""
936+
Helper function to determine if dtype is valid for
937+
nsmallest/nlargest methods
938+
"""
939+
return ((is_numeric_dtype(dtype) and not is_complex_dtype(dtype)) or
940+
needs_i8_conversion(dtype))
941+
942+
934943
def select_n_series(series, n, keep, method):
935944
"""Implement n largest/smallest for pandas Series
936945
@@ -946,8 +955,7 @@ def select_n_series(series, n, keep, method):
946955
nordered : Series
947956
"""
948957
dtype = series.dtype
949-
if not ((is_numeric_dtype(dtype) and not is_complex_dtype(dtype)) or
950-
needs_i8_conversion(dtype)):
958+
if not _is_valid_dtype_n_method(dtype):
951959
raise TypeError("Cannot use method '{method}' with "
952960
"dtype {dtype}".format(method=method, dtype=dtype))
953961

@@ -981,14 +989,67 @@ def select_n_frame(frame, columns, n, method, keep):
981989
-------
982990
nordered : DataFrame
983991
"""
984-
from pandas.core.series import Series
992+
from pandas import Int64Index
985993
if not is_list_like(columns):
986994
columns = [columns]
987995
columns = list(columns)
988-
ser = getattr(frame[columns[0]], method)(n, keep=keep)
989-
if isinstance(ser, Series):
990-
ser = ser.to_frame()
991-
return ser.merge(frame, on=columns[0], left_index=True)[frame.columns]
996+
for column in columns:
997+
dtype = frame[column].dtype
998+
if not _is_valid_dtype_n_method(dtype):
999+
raise TypeError((
1000+
"Column {column!r} has dtype {dtype}, cannot use method "
1001+
"{method!r} with this dtype"
1002+
).format(column=column, dtype=dtype, method=method))
1003+
1004+
def get_indexer(current_indexer, other_indexer):
1005+
"""Helper function to concat `current_indexer` and `other_indexer`
1006+
depending on `method`
1007+
"""
1008+
if method == 'nsmallest':
1009+
return current_indexer.append(other_indexer)
1010+
else:
1011+
return other_indexer.append(current_indexer)
1012+
1013+
# Below we save and reset the index in case index contains duplicates
1014+
original_index = frame.index
1015+
cur_frame = frame = frame.reset_index(drop=True)
1016+
cur_n = n
1017+
indexer = Int64Index([])
1018+
1019+
for i, column in enumerate(columns):
1020+
1021+
# For each column we apply method to cur_frame[column]. If it is the
1022+
# last column in columns, or if the values returned are unique in
1023+
# frame[column] we save this index and break
1024+
# Otherwise we must save the index of the non duplicated values
1025+
# and set the next cur_frame to cur_frame filtered on all duplcicated
1026+
# values (#GH15297)
1027+
series = cur_frame[column]
1028+
values = getattr(series, method)(cur_n, keep=keep)
1029+
is_last_column = len(columns) - 1 == i
1030+
if is_last_column or values.nunique() == series.isin(values).sum():
1031+
1032+
# Last column in columns or values are unique in series => values
1033+
# is all that matters
1034+
indexer = get_indexer(indexer, values.index)
1035+
break
1036+
1037+
duplicated_filter = series.duplicated(keep=False)
1038+
duplicated = values[duplicated_filter]
1039+
non_duplicated = values[~duplicated_filter]
1040+
indexer = get_indexer(indexer, non_duplicated.index)
1041+
1042+
# Must set cur frame to include all duplicated values to consider for
1043+
# the next column, we also can reduce cur_n by the current length of
1044+
# the indexer
1045+
cur_frame = cur_frame[series.isin(duplicated)]
1046+
cur_n = n - len(indexer)
1047+
1048+
frame = frame.take(indexer)
1049+
1050+
# Restore the index on frame
1051+
frame.index = original_index.take(indexer)
1052+
return frame
9921053

9931054

9941055
def _finalize_nsmallest(arr, kth_val, n, keep, narr):

pandas/tests/frame/test_analytics.py

+131-68
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@
77
import sys
88
import pytest
99

10+
from string import ascii_lowercase
1011
from numpy import nan
1112
from numpy.random import randn
1213
import numpy as np
1314

14-
from pandas.compat import lrange
15+
from pandas.compat import lrange, product
1516
from pandas import (compat, isnull, notnull, DataFrame, Series,
1617
MultiIndex, date_range, Timestamp)
1718
import pandas as pd
@@ -1120,73 +1121,6 @@ def __nonzero__(self):
11201121
self.assertTrue(r1.all())
11211122

11221123
# ----------------------------------------------------------------------
1123-
# Top / bottom
1124-
1125-
def test_nlargest(self):
1126-
# GH10393
1127-
from string import ascii_lowercase
1128-
df = pd.DataFrame({'a': np.random.permutation(10),
1129-
'b': list(ascii_lowercase[:10])})
1130-
result = df.nlargest(5, 'a')
1131-
expected = df.sort_values('a', ascending=False).head(5)
1132-
tm.assert_frame_equal(result, expected)
1133-
1134-
def test_nlargest_multiple_columns(self):
1135-
from string import ascii_lowercase
1136-
df = pd.DataFrame({'a': np.random.permutation(10),
1137-
'b': list(ascii_lowercase[:10]),
1138-
'c': np.random.permutation(10).astype('float64')})
1139-
result = df.nlargest(5, ['a', 'b'])
1140-
expected = df.sort_values(['a', 'b'], ascending=False).head(5)
1141-
tm.assert_frame_equal(result, expected)
1142-
1143-
def test_nsmallest(self):
1144-
from string import ascii_lowercase
1145-
df = pd.DataFrame({'a': np.random.permutation(10),
1146-
'b': list(ascii_lowercase[:10])})
1147-
result = df.nsmallest(5, 'a')
1148-
expected = df.sort_values('a').head(5)
1149-
tm.assert_frame_equal(result, expected)
1150-
1151-
def test_nsmallest_multiple_columns(self):
1152-
from string import ascii_lowercase
1153-
df = pd.DataFrame({'a': np.random.permutation(10),
1154-
'b': list(ascii_lowercase[:10]),
1155-
'c': np.random.permutation(10).astype('float64')})
1156-
result = df.nsmallest(5, ['a', 'c'])
1157-
expected = df.sort_values(['a', 'c']).head(5)
1158-
tm.assert_frame_equal(result, expected)
1159-
1160-
def test_nsmallest_nlargest_duplicate_index(self):
1161-
# GH 13412
1162-
df = pd.DataFrame({'a': [1, 2, 3, 4],
1163-
'b': [4, 3, 2, 1],
1164-
'c': [0, 1, 2, 3]},
1165-
index=[0, 0, 1, 1])
1166-
result = df.nsmallest(4, 'a')
1167-
expected = df.sort_values('a').head(4)
1168-
tm.assert_frame_equal(result, expected)
1169-
1170-
result = df.nlargest(4, 'a')
1171-
expected = df.sort_values('a', ascending=False).head(4)
1172-
tm.assert_frame_equal(result, expected)
1173-
1174-
result = df.nsmallest(4, ['a', 'c'])
1175-
expected = df.sort_values(['a', 'c']).head(4)
1176-
tm.assert_frame_equal(result, expected)
1177-
1178-
result = df.nsmallest(4, ['c', 'a'])
1179-
expected = df.sort_values(['c', 'a']).head(4)
1180-
tm.assert_frame_equal(result, expected)
1181-
1182-
result = df.nlargest(4, ['a', 'c'])
1183-
expected = df.sort_values(['a', 'c'], ascending=False).head(4)
1184-
tm.assert_frame_equal(result, expected)
1185-
1186-
result = df.nlargest(4, ['c', 'a'])
1187-
expected = df.sort_values(['c', 'a'], ascending=False).head(4)
1188-
tm.assert_frame_equal(result, expected)
1189-
# ----------------------------------------------------------------------
11901124
# Isin
11911125

11921126
def test_isin(self):
@@ -1965,3 +1899,132 @@ def test_dot(self):
19651899

19661900
with tm.assertRaisesRegexp(ValueError, 'aligned'):
19671901
df.dot(df2)
1902+
1903+
1904+
@pytest.fixture
1905+
def df_duplicates():
1906+
return pd.DataFrame({'a': [1, 2, 3, 4, 4],
1907+
'b': [1, 1, 1, 1, 1],
1908+
'c': [0, 1, 2, 5, 4]},
1909+
index=[0, 0, 1, 1, 1])
1910+
1911+
1912+
@pytest.fixture
1913+
def df_strings():
1914+
return pd.DataFrame({'a': np.random.permutation(10),
1915+
'b': list(ascii_lowercase[:10]),
1916+
'c': np.random.permutation(10).astype('float64')})
1917+
1918+
1919+
@pytest.fixture
1920+
def df_main_dtypes():
1921+
return pd.DataFrame(
1922+
{'group': [1, 1, 2],
1923+
'int': [1, 2, 3],
1924+
'float': [4., 5., 6.],
1925+
'string': list('abc'),
1926+
'category_string': pd.Series(list('abc')).astype('category'),
1927+
'category_int': [7, 8, 9],
1928+
'datetime': pd.date_range('20130101', periods=3),
1929+
'datetimetz': pd.date_range('20130101',
1930+
periods=3,
1931+
tz='US/Eastern'),
1932+
'timedelta': pd.timedelta_range('1 s', periods=3, freq='s')},
1933+
columns=['group', 'int', 'float', 'string',
1934+
'category_string', 'category_int',
1935+
'datetime', 'datetimetz',
1936+
'timedelta'])
1937+
1938+
1939+
class TestNLargestNSmallest(object):
1940+
1941+
dtype_error_msg_template = ("Column {column!r} has dtype {dtype}, cannot "
1942+
"use method {method!r} with this dtype")
1943+
1944+
# ----------------------------------------------------------------------
1945+
# Top / bottom
1946+
@pytest.mark.parametrize(
1947+
'method, n, order',
1948+
product(['nsmallest', 'nlargest'], range(1, 11),
1949+
[['a'],
1950+
['c'],
1951+
['a', 'b'],
1952+
['a', 'c'],
1953+
['b', 'a'],
1954+
['b', 'c'],
1955+
['a', 'b', 'c'],
1956+
['c', 'a', 'b'],
1957+
['c', 'b', 'a'],
1958+
['b', 'c', 'a'],
1959+
['b', 'a', 'c'],
1960+
1961+
# dups!
1962+
['b', 'c', 'c'],
1963+
1964+
]))
1965+
def test_n(self, df_strings, method, n, order):
1966+
# GH10393
1967+
df = df_strings
1968+
if 'b' in order:
1969+
1970+
error_msg = self.dtype_error_msg_template.format(
1971+
column='b', method=method, dtype='object')
1972+
with tm.assertRaisesRegexp(TypeError, error_msg):
1973+
getattr(df, method)(n, order)
1974+
else:
1975+
ascending = method == 'nsmallest'
1976+
result = getattr(df, method)(n, order)
1977+
expected = df.sort_values(order, ascending=ascending).head(n)
1978+
tm.assert_frame_equal(result, expected)
1979+
1980+
@pytest.mark.parametrize(
1981+
'method, columns',
1982+
product(['nsmallest', 'nlargest'],
1983+
product(['group'], ['category_string', 'string'])
1984+
))
1985+
def test_n_error(self, df_main_dtypes, method, columns):
1986+
df = df_main_dtypes
1987+
error_msg = self.dtype_error_msg_template.format(
1988+
column=columns[1], method=method, dtype=df[columns[1]].dtype)
1989+
with tm.assertRaisesRegexp(TypeError, error_msg):
1990+
getattr(df, method)(2, columns)
1991+
1992+
def test_n_all_dtypes(self, df_main_dtypes):
1993+
df = df_main_dtypes
1994+
df.nsmallest(2, list(set(df) - {'category_string', 'string'}))
1995+
df.nlargest(2, list(set(df) - {'category_string', 'string'}))
1996+
1997+
def test_n_identical_values(self):
1998+
# GH15297
1999+
df = pd.DataFrame({'a': [1] * 5, 'b': [1, 2, 3, 4, 5]})
2000+
2001+
result = df.nlargest(3, 'a')
2002+
expected = pd.DataFrame(
2003+
{'a': [1] * 3, 'b': [1, 2, 3]}, index=[0, 1, 2]
2004+
)
2005+
tm.assert_frame_equal(result, expected)
2006+
2007+
result = df.nsmallest(3, 'a')
2008+
expected = pd.DataFrame({'a': [1] * 3, 'b': [1, 2, 3]})
2009+
tm.assert_frame_equal(result, expected)
2010+
2011+
@pytest.mark.parametrize(
2012+
'n, order',
2013+
product([1, 2, 3, 4, 5],
2014+
[['a', 'b', 'c'],
2015+
['c', 'b', 'a'],
2016+
['a'],
2017+
['b'],
2018+
['a', 'b'],
2019+
['c', 'b']]))
2020+
def test_n_duplicate_index(self, df_duplicates, n, order):
2021+
# GH 13412
2022+
2023+
df = df_duplicates
2024+
result = df.nsmallest(n, order)
2025+
expected = df.sort_values(order).head(n)
2026+
tm.assert_frame_equal(result, expected)
2027+
2028+
result = df.nlargest(n, order)
2029+
expected = df.sort_values(order, ascending=False).head(n)
2030+
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)