@@ -944,46 +944,63 @@ def select_n_frame(frame, columns, n, method, keep):
944
944
-------
945
945
nordered : DataFrame
946
946
"""
947
+ from pandas import Int64Index
947
948
if not is_list_like (columns ):
948
949
columns = [columns ]
949
950
columns = list (columns )
950
- ascending = method == 'nsmallest'
951
-
952
- # Below we save and reset the index
953
- # in case index contains duplicates
951
+ for column in columns :
952
+ dtype = frame [column ].dtype
953
+ if not issubclass (dtype .type , (np .integer , np .floating , np .datetime64 ,
954
+ np .timedelta64 )):
955
+ msg = (
956
+ "{column!r} has dtype: {dtype}, cannot use method {method!r} "
957
+ "with this dtype"
958
+ ).format (column = column , dtype = dtype , method = method )
959
+ raise TypeError (msg )
960
+
961
+ # Below we save and reset the index in case index contains duplicates
954
962
original_index = frame .index
955
- frame = frame .reset_index (drop = True )
956
-
963
+ cur_frame = frame = frame .reset_index (drop = True )
964
+ cur_n = n
965
+ indexer = Int64Index ([])
957
966
for i , column in enumerate (columns ):
958
967
959
- # For each column in columns we peform ``method`` on this frame
960
- # To guard against the possibility column has duplicate values that
961
- # must be considered for futher columns (# GH15297) we filter using
962
- # frame[isin] on the values returned by ``method``. If there are no
963
- # duplicated values, we simply take the values returned by
964
- # ``method``, otherwise we sort the isin filtered frame and continue
965
- series = frame [column ]
966
- values = getattr (series , method )(n , keep = keep )
967
- indexer = values .index
968
- if i + 1 == len (columns ):
969
-
970
- # This is the last column => duplicates here don't matter
971
- frame = frame .take (indexer )
968
+ # For each column we apply method to cur_frame[column]. If it is the
969
+ # last column in columns, or if the values returned are unique in
970
+ # frame[column] we save this index and break
971
+ # Otherwise we must save the index of the non duplicated values
972
+ # and set the next cur_frame to cur_frame filtered on all duplcicated
973
+ # values (#GH15297)
974
+ series = cur_frame [column ]
975
+ values = getattr (series , method )(cur_n , keep = keep )
976
+ is_last_column = len (columns ) - 1 == i
977
+ if is_last_column or len (values .unique ()) == sum (series .isin (values )):
978
+
979
+ # Last column in columns or values are unique in series => values
980
+ # is all that matters
981
+ if method == 'nsmallest' :
982
+ indexer = indexer .append (values .index )
983
+ else :
984
+ indexer = values .index .append (indexer )
985
+ break
986
+ duplicated_filter = series .duplicated (keep = False )
987
+ non_duplicated = values [~ duplicated_filter ]
988
+ duplicated = values [duplicated_filter ]
989
+ if method == 'nsmallest' :
990
+ indexer = indexer .append (non_duplicated .index )
972
991
else :
973
- filtered_frame = frame [series .isin (values )]
974
- if len (filtered_frame ) == len (values ):
992
+ indexer = non_duplicated .index .append (indexer )
975
993
976
- # Values are unique in series => take and break
977
- frame = frame .take (indexer )
978
- break
994
+ # Must set cur frame to include all duplicated values to consider for
995
+ # the next column, we also can reduce cur_n by the current length of
996
+ # the indexer
997
+ cur_frame = cur_frame [series .isin (duplicated )]
998
+ cur_n = n - len (indexer )
979
999
980
- # Values are not unique in series => sort and continue
981
- frame = filtered_frame .sort_values (
982
- column , ascending = ascending
983
- )
1000
+ frame = frame .take (indexer )
984
1001
985
- # Below we set the index of the returning frame to the original index
986
- frame .index = original_index .take (frame . index )
1002
+ # Restore the index on frame
1003
+ frame .index = original_index .take (indexer )
987
1004
return frame
988
1005
989
1006
0 commit comments