Skip to content

Commit c86ce45

Browse files
author
Roger Thomas
committed
Fix Some Stuff
1 parent 7f8cd04 commit c86ce45

File tree

1 file changed

+15
-19
lines changed

1 file changed

+15
-19
lines changed

pandas/core/algorithms.py

+15-19
Original file line numberDiff line numberDiff line change
@@ -948,21 +948,22 @@ def select_n_frame(frame, columns, n, method, keep):
948948
if not is_list_like(columns):
949949
columns = [columns]
950950
columns = list(columns)
951-
for column in columns:
952-
dtype = frame[column].dtype
953-
if not issubclass(dtype.type, (np.integer, np.floating, np.datetime64,
954-
np.timedelta64)):
955-
msg = (
956-
"{column!r} has dtype: {dtype}, cannot use method {method!r} "
957-
"with this dtype"
958-
).format(column=column, dtype=dtype, method=method)
959-
raise TypeError(msg)
951+
952+
def get_indexer(current_indexer, other_indexer):
953+
"""Helper function to concat `current_indexer` and `other_indexer`
954+
depending on `method`
955+
"""
956+
if method == 'nsmallest':
957+
return current_indexer.append(other_indexer)
958+
else:
959+
return other_indexer.append(current_indexer)
960960

961961
# Below we save and reset the index in case index contains duplicates
962962
original_index = frame.index
963963
cur_frame = frame = frame.reset_index(drop=True)
964964
cur_n = n
965965
indexer = Int64Index([])
966+
966967
for i, column in enumerate(columns):
967968

968969
# For each column we apply method to cur_frame[column]. If it is the
@@ -974,22 +975,17 @@ def select_n_frame(frame, columns, n, method, keep):
974975
series = cur_frame[column]
975976
values = getattr(series, method)(cur_n, keep=keep)
976977
is_last_column = len(columns) - 1 == i
977-
if is_last_column or len(values.unique()) == sum(series.isin(values)):
978+
if is_last_column or values.nunique() == series.isin(values).sum():
978979

979980
# Last column in columns or values are unique in series => values
980981
# is all that matters
981-
if method == 'nsmallest':
982-
indexer = indexer.append(values.index)
983-
else:
984-
indexer = values.index.append(indexer)
982+
indexer = get_indexer(indexer, values.index)
985983
break
984+
986985
duplicated_filter = series.duplicated(keep=False)
987-
non_duplicated = values[~duplicated_filter]
988986
duplicated = values[duplicated_filter]
989-
if method == 'nsmallest':
990-
indexer = indexer.append(non_duplicated.index)
991-
else:
992-
indexer = non_duplicated.index.append(indexer)
987+
non_duplicated = values[~duplicated_filter]
988+
indexer = get_indexer(indexer, non_duplicated.index)
993989

994990
# Must set cur frame to include all duplicated values to consider for
995991
# the next column, we also can reduce cur_n by the current length of

0 commit comments

Comments
 (0)