Skip to content

Commit d08a60e

Browse files
author
Roger Thomas
committed
Update Algorithm and add error checking
1 parent 1c89a5b commit d08a60e

File tree

2 files changed

+68
-36
lines changed

2 files changed

+68
-36
lines changed

pandas/core/algorithms.py

+47-30
Original file line numberDiff line numberDiff line change
@@ -944,46 +944,63 @@ def select_n_frame(frame, columns, n, method, keep):
944944
-------
945945
nordered : DataFrame
946946
"""
947+
from pandas import Int64Index
947948
if not is_list_like(columns):
948949
columns = [columns]
949950
columns = list(columns)
950-
ascending = method == 'nsmallest'
951-
952-
# Below we save and reset the index
953-
# in case index contains duplicates
951+
for column in columns:
952+
dtype = frame[column].dtype
953+
if not issubclass(dtype.type, (np.integer, np.floating, np.datetime64,
954+
np.timedelta64)):
955+
msg = (
956+
"{column!r} has dtype: {dtype}, cannot use method {method!r} "
957+
"with this dtype"
958+
).format(column=column, dtype=dtype, method=method)
959+
raise TypeError(msg)
960+
961+
# Below we save and reset the index in case index contains duplicates
954962
original_index = frame.index
955-
frame = frame.reset_index(drop=True)
956-
963+
cur_frame = frame = frame.reset_index(drop=True)
964+
cur_n = n
965+
indexer = Int64Index([])
957966
for i, column in enumerate(columns):
958967

959-
# For each column in columns we peform ``method`` on this frame
960-
# To guard against the possibility column has duplicate values that
961-
# must be considered for futher columns (# GH15297) we filter using
962-
# frame[isin] on the values returned by ``method``. If there are no
963-
# duplicated values, we simply take the values returned by
964-
# ``method``, otherwise we sort the isin filtered frame and continue
965-
series = frame[column]
966-
values = getattr(series, method)(n, keep=keep)
967-
indexer = values.index
968-
if i + 1 == len(columns):
969-
970-
# This is the last column => duplicates here don't matter
971-
frame = frame.take(indexer)
968+
# For each column we apply method to cur_frame[column]. If it is the
969+
# last column in columns, or if the values returned are unique in
970+
# frame[column] we save this index and break
971+
# Otherwise we must save the index of the non duplicated values
972+
# and set the next cur_frame to cur_frame filtered on all duplcicated
973+
# values (#GH15297)
974+
series = cur_frame[column]
975+
values = getattr(series, method)(cur_n, keep=keep)
976+
is_last_column = len(columns) - 1 == i
977+
if is_last_column or len(values.unique()) == sum(series.isin(values)):
978+
979+
# Last column in columns or values are unique in series => values
980+
# is all that matters
981+
if method == 'nsmallest':
982+
indexer = indexer.append(values.index)
983+
else:
984+
indexer = values.index.append(indexer)
985+
break
986+
duplicated_filter = series.duplicated(keep=False)
987+
non_duplicated = values[~duplicated_filter]
988+
duplicated = values[duplicated_filter]
989+
if method == 'nsmallest':
990+
indexer = indexer.append(non_duplicated.index)
972991
else:
973-
filtered_frame = frame[series.isin(values)]
974-
if len(filtered_frame) == len(values):
992+
indexer = non_duplicated.index.append(indexer)
975993

976-
# Values are unique in series => take and break
977-
frame = frame.take(indexer)
978-
break
994+
# Must set cur frame to include all duplicated values to consider for
995+
# the next column, we also can reduce cur_n by the current length of
996+
# the indexer
997+
cur_frame = cur_frame[series.isin(duplicated)]
998+
cur_n = n - len(indexer)
979999

980-
# Values are not unique in series => sort and continue
981-
frame = filtered_frame.sort_values(
982-
column, ascending=ascending
983-
)
1000+
frame = frame.take(indexer)
9841001

985-
# Below we set the index of the returning frame to the original index
986-
frame.index = original_index.take(frame.index)
1002+
# Restore the index on frame
1003+
frame.index = original_index.take(indexer)
9871004
return frame
9881005

9891006

pandas/tests/frame/test_analytics.py

+21-6
Original file line numberDiff line numberDiff line change
@@ -1942,13 +1942,28 @@ class TestNLargestNSmallest(object):
19421942
def test_n(self, df_strings, n, order):
19431943
# GH10393
19441944
df = df_strings
1945-
result = df.nsmallest(n, order)
1946-
expected = df.sort_values(order).head(n)
1947-
tm.assert_frame_equal(result, expected)
19481945

1949-
result = df.nlargest(n, order)
1950-
expected = df.sort_values(order, ascending=False).head(n)
1951-
tm.assert_frame_equal(result, expected)
1946+
error_msg = (
1947+
"'b' has dtype: object, cannot use method 'nsmallest' "
1948+
"with this dtype"
1949+
)
1950+
if 'b' in order:
1951+
with pytest.raises(TypeError) as exception:
1952+
df.nsmallest(n, order)
1953+
assert exception.value, error_msg
1954+
else:
1955+
result = df.nsmallest(n, order)
1956+
expected = df.sort_values(order).head(n)
1957+
tm.assert_frame_equal(result, expected)
1958+
1959+
if 'b' in order:
1960+
with pytest.raises(TypeError) as exception:
1961+
df.nsmallest(n, order)
1962+
assert exception.value, error_msg
1963+
else:
1964+
result = df.nlargest(n, order)
1965+
expected = df.sort_values(order, ascending=False).head(n)
1966+
tm.assert_frame_equal(result, expected)
19521967

19531968
def test_n_error(self, df_strings):
19541969
# b alone raises a TypeError

0 commit comments

Comments
 (0)