@@ -948,21 +948,22 @@ def select_n_frame(frame, columns, n, method, keep):
948
948
if not is_list_like (columns ):
949
949
columns = [columns ]
950
950
columns = list (columns )
951
- for column in columns :
952
- dtype = frame [ column ]. dtype
953
- if not issubclass ( dtype . type , ( np . integer , np . floating , np . datetime64 ,
954
- np . timedelta64 )):
955
- msg = (
956
- "{column!r} has dtype: {dtype}, cannot use method {method!r} "
957
- "with this dtype"
958
- ). format ( column = column , dtype = dtype , method = method )
959
- raise TypeError ( msg )
951
+
952
+ def get_indexer ( current_indexer , other_indexer ):
953
+ """Helper function to concat `current_indexer` and `other_indexer`
954
+ depending on `method`
955
+ """
956
+ if method == 'nsmallest' :
957
+ return current_indexer . append ( other_indexer )
958
+ else :
959
+ return other_indexer . append ( current_indexer )
960
960
961
961
# Below we save and reset the index in case index contains duplicates
962
962
original_index = frame .index
963
963
cur_frame = frame = frame .reset_index (drop = True )
964
964
cur_n = n
965
965
indexer = Int64Index ([])
966
+
966
967
for i , column in enumerate (columns ):
967
968
968
969
# For each column we apply method to cur_frame[column]. If it is the
@@ -974,22 +975,17 @@ def select_n_frame(frame, columns, n, method, keep):
974
975
series = cur_frame [column ]
975
976
values = getattr (series , method )(cur_n , keep = keep )
976
977
is_last_column = len (columns ) - 1 == i
977
- if is_last_column or len ( values .unique ()) == sum ( series .isin (values )):
978
+ if is_last_column or values .nunique () == series .isin (values ). sum ( ):
978
979
979
980
# Last column in columns or values are unique in series => values
980
981
# is all that matters
981
- if method == 'nsmallest' :
982
- indexer = indexer .append (values .index )
983
- else :
984
- indexer = values .index .append (indexer )
982
+ indexer = get_indexer (indexer , values .index )
985
983
break
984
+
986
985
duplicated_filter = series .duplicated (keep = False )
987
- non_duplicated = values [~ duplicated_filter ]
988
986
duplicated = values [duplicated_filter ]
989
- if method == 'nsmallest' :
990
- indexer = indexer .append (non_duplicated .index )
991
- else :
992
- indexer = non_duplicated .index .append (indexer )
987
+ non_duplicated = values [~ duplicated_filter ]
988
+ indexer = get_indexer (indexer , non_duplicated .index )
993
989
994
990
# Must set cur frame to include all duplicated values to consider for
995
991
# the next column, we also can reduce cur_n by the current length of
0 commit comments