@@ -859,6 +859,12 @@ def quantile(x, q, interpolation_method='fraction'):
859
859
860
860
values = np .sort (x )
861
861
862
+ def _interpolate (a , b , fraction ):
863
+ """Returns the point at the given fraction between a and b, where
864
+ 'fraction' must be between 0 and 1.
865
+ """
866
+ return a + (b - a ) * fraction
867
+
862
868
def _get_score (at ):
863
869
if len (values ) == 0 :
864
870
return np .nan
@@ -887,179 +893,183 @@ def _get_score(at):
887
893
return algos .arrmap_float64 (q , _get_score )
888
894
889
895
890
- def _interpolate (a , b , fraction ):
891
- """Returns the point at the given fraction between a and b, where
892
- 'fraction' must be between 0 and 1.
893
- """
894
- return a + (b - a ) * fraction
895
-
896
-
897
- def nsmallest (arr , n , keep = 'first' ):
898
- """
899
- Find the indices of the n smallest values of a numpy array.
900
-
901
- Note: Fails silently with NaN.
902
- """
903
- if keep == 'last' :
904
- arr = arr [::- 1 ]
905
-
906
- narr = len (arr )
907
- n = min (n , narr )
908
-
909
- arr = _ensure_data_view (arr )
910
- kth_val = algos .kth_smallest (arr .copy (), n - 1 )
911
- return _finalize_nsmallest (arr , kth_val , n , keep , narr )
896
+ # --------------- #
897
+ # select n #
898
+ # --------------- #
912
899
900
+ class SelectN (object ):
913
901
914
- def nlargest (arr , n , keep = 'first' ):
915
- """
916
- Find the indices of the n largest values of a numpy array.
902
+ def __init__ (self , obj , n , keep ):
903
+ self .obj = obj
904
+ self .n = n
905
+ self .keep = keep
917
906
918
- Note: Fails silently with NaN.
919
- """
920
- arr = _ensure_data_view (arr )
921
- return nsmallest (- arr , n , keep = keep )
907
+ if self .keep not in ('first' , 'last' ):
908
+ raise ValueError ('keep must be either "first", "last"' )
922
909
910
+ def nlargest (self ):
911
+ return self .compute ('nlargest' )
923
912
924
- def select_n_slow (dropped , n , keep , method ):
925
- reverse_it = (keep == 'last' or method == 'nlargest' )
926
- ascending = method == 'nsmallest'
927
- slc = np .s_ [::- 1 ] if reverse_it else np .s_ [:]
928
- return dropped [slc ].sort_values (ascending = ascending ).head (n )
913
+ def nsmallest (self ):
914
+ return self .compute ('nsmallest' )
929
915
930
-
931
- _select_methods = {'nsmallest' : nsmallest , 'nlargest' : nlargest }
916
+ @staticmethod
917
+ def is_valid_dtype_n_method (dtype ):
918
+ """
919
+ Helper function to determine if dtype is valid for
920
+ nsmallest/nlargest methods
921
+ """
922
+ return ((is_numeric_dtype (dtype ) and not is_complex_dtype (dtype )) or
923
+ needs_i8_conversion (dtype ))
932
924
933
925
934
- def _is_valid_dtype_n_method (dtype ):
935
- """
936
- Helper function to determine if dtype is valid for
937
- nsmallest/nlargest methods
926
+ class SelectNSeries (SelectN ):
938
927
"""
939
- return ((is_numeric_dtype (dtype ) and not is_complex_dtype (dtype )) or
940
- needs_i8_conversion (dtype ))
941
-
942
-
943
- def select_n_series (series , n , keep , method ):
944
- """Implement n largest/smallest for pandas Series
928
+ Implement n largest/smallest for Series
945
929
946
930
Parameters
947
931
----------
948
- series : pandas.Series object
932
+ frame : pandas.DataFrame object
949
933
n : int
950
934
keep : {'first', 'last'}, default 'first'
951
- method : str, {'nlargest', 'nsmallest'}
952
935
953
936
Returns
954
937
-------
955
938
nordered : Series
956
939
"""
957
- dtype = series .dtype
958
- if not _is_valid_dtype_n_method (dtype ):
959
- raise TypeError ("Cannot use method '{method}' with "
960
- "dtype {dtype}" .format (method = method , dtype = dtype ))
961
940
962
- if keep not in ('first' , 'last' ):
963
- raise ValueError ('keep must be either "first", "last"' )
941
+ def compute (self , method ):
942
+
943
+ n = self .n
944
+ dtype = self .obj .dtype
945
+ if not self .is_valid_dtype_n_method (dtype ):
946
+ raise TypeError ("Cannot use method '{method}' with "
947
+ "dtype {dtype}" .format (method = method ,
948
+ dtype = dtype ))
949
+
950
+ if n <= 0 :
951
+ return self .obj [[]]
964
952
965
- if n <= 0 :
966
- return series [[]]
953
+ dropped = self .obj .dropna ()
967
954
968
- dropped = series .dropna ()
955
+ # slow method
956
+ if n >= len (self .obj ):
969
957
970
- if n >= len (series ):
971
- return select_n_slow (dropped , n , keep , method )
958
+ reverse_it = (self .keep == 'last' or method == 'nlargest' )
959
+ ascending = method == 'nsmallest'
960
+ slc = np .s_ [::- 1 ] if reverse_it else np .s_ [:]
961
+ return dropped [slc ].sort_values (ascending = ascending ).head (n )
972
962
973
- inds = _select_methods [method ](dropped .values , n , keep )
974
- return dropped .iloc [inds ]
963
+ # fast method
964
+ arr = _ensure_data_view (dropped .values )
965
+ if method == 'nlargest' :
966
+ arr = - arr
975
967
968
+ if self .keep == 'last' :
969
+ arr = arr [::- 1 ]
976
970
977
- def select_n_frame (frame , columns , n , method , keep ):
978
- """Implement n largest/smallest for pandas DataFrame
971
+ narr = len (arr )
972
+ n = min (n , narr )
973
+
974
+ kth_val = algos .kth_smallest (arr .copy (), n - 1 )
975
+ ns , = np .nonzero (arr <= kth_val )
976
+ inds = ns [arr [ns ].argsort (kind = 'mergesort' )][:n ]
977
+ if self .keep == 'last' :
978
+ # reverse indices
979
+ inds = narr - 1 - inds
980
+
981
+ return dropped .iloc [inds ]
982
+
983
+
984
+ class SelectNFrame (SelectN ):
985
+ """
986
+ Implement n largest/smallest for DataFrame
979
987
980
988
Parameters
981
989
----------
982
990
frame : pandas.DataFrame object
983
- columns : list or str
984
991
n : int
985
992
keep : {'first', 'last'}, default 'first'
986
- method : str, {'nlargest', 'nsmallest'}
993
+ columns : list or str
987
994
988
995
Returns
989
996
-------
990
997
nordered : DataFrame
991
998
"""
992
- from pandas import Int64Index
993
- if not is_list_like (columns ):
994
- columns = [columns ]
995
- columns = list (columns )
996
- for column in columns :
997
- dtype = frame [column ].dtype
998
- if not _is_valid_dtype_n_method (dtype ):
999
- raise TypeError ((
1000
- "Column {column!r} has dtype {dtype}, cannot use method "
1001
- "{method!r} with this dtype"
1002
- ).format (column = column , dtype = dtype , method = method ))
1003
-
1004
- def get_indexer (current_indexer , other_indexer ):
1005
- """Helper function to concat `current_indexer` and `other_indexer`
1006
- depending on `method`
1007
- """
1008
- if method == 'nsmallest' :
1009
- return current_indexer .append (other_indexer )
1010
- else :
1011
- return other_indexer .append (current_indexer )
1012
-
1013
- # Below we save and reset the index in case index contains duplicates
1014
- original_index = frame .index
1015
- cur_frame = frame = frame .reset_index (drop = True )
1016
- cur_n = n
1017
- indexer = Int64Index ([])
1018
-
1019
- for i , column in enumerate (columns ):
1020
-
1021
- # For each column we apply method to cur_frame[column]. If it is the
1022
- # last column in columns, or if the values returned are unique in
1023
- # frame[column] we save this index and break
1024
- # Otherwise we must save the index of the non duplicated values
1025
- # and set the next cur_frame to cur_frame filtered on all duplcicated
1026
- # values (#GH15297)
1027
- series = cur_frame [column ]
1028
- values = getattr (series , method )(cur_n , keep = keep )
1029
- is_last_column = len (columns ) - 1 == i
1030
- if is_last_column or values .nunique () == series .isin (values ).sum ():
1031
-
1032
- # Last column in columns or values are unique in series => values
1033
- # is all that matters
1034
- indexer = get_indexer (indexer , values .index )
1035
- break
1036
-
1037
- duplicated_filter = series .duplicated (keep = False )
1038
- duplicated = values [duplicated_filter ]
1039
- non_duplicated = values [~ duplicated_filter ]
1040
- indexer = get_indexer (indexer , non_duplicated .index )
1041
-
1042
- # Must set cur frame to include all duplicated values to consider for
1043
- # the next column, we also can reduce cur_n by the current length of
1044
- # the indexer
1045
- cur_frame = cur_frame [series .isin (duplicated )]
1046
- cur_n = n - len (indexer )
1047
-
1048
- frame = frame .take (indexer )
1049
-
1050
- # Restore the index on frame
1051
- frame .index = original_index .take (indexer )
1052
- return frame
1053
-
1054
-
1055
- def _finalize_nsmallest (arr , kth_val , n , keep , narr ):
1056
- ns , = np .nonzero (arr <= kth_val )
1057
- inds = ns [arr [ns ].argsort (kind = 'mergesort' )][:n ]
1058
- if keep == 'last' :
1059
- # reverse indices
1060
- return narr - 1 - inds
1061
- else :
1062
- return inds
999
+
1000
+ def __init__ (self , obj , n , keep , columns ):
1001
+ super (SelectNFrame , self ).__init__ (obj , n , keep )
1002
+ if not is_list_like (columns ):
1003
+ columns = [columns ]
1004
+ columns = list (columns )
1005
+ self .columns = columns
1006
+
1007
+ def compute (self , method ):
1008
+
1009
+ from pandas import Int64Index
1010
+ n = self .n
1011
+ frame = self .obj
1012
+ columns = self .columns
1013
+
1014
+ for column in columns :
1015
+ dtype = frame [column ].dtype
1016
+ if not self .is_valid_dtype_n_method (dtype ):
1017
+ raise TypeError ((
1018
+ "Column {column!r} has dtype {dtype}, cannot use method "
1019
+ "{method!r} with this dtype"
1020
+ ).format (column = column , dtype = dtype , method = method ))
1021
+
1022
+ def get_indexer (current_indexer , other_indexer ):
1023
+ """Helper function to concat `current_indexer` and `other_indexer`
1024
+ depending on `method`
1025
+ """
1026
+ if method == 'nsmallest' :
1027
+ return current_indexer .append (other_indexer )
1028
+ else :
1029
+ return other_indexer .append (current_indexer )
1030
+
1031
+ # Below we save and reset the index in case index contains duplicates
1032
+ original_index = frame .index
1033
+ cur_frame = frame = frame .reset_index (drop = True )
1034
+ cur_n = n
1035
+ indexer = Int64Index ([])
1036
+
1037
+ for i , column in enumerate (columns ):
1038
+
1039
+ # For each column we apply method to cur_frame[column].
1040
+ # If it is the last column in columns, or if the values
1041
+ # returned are unique in frame[column] we save this index
1042
+ # and break
1043
+ # Otherwise we must save the index of the non duplicated values
1044
+ # and set the next cur_frame to cur_frame filtered on all
1045
+ # duplcicated values (#GH15297)
1046
+ series = cur_frame [column ]
1047
+ values = getattr (series , method )(cur_n , keep = self .keep )
1048
+ is_last_column = len (columns ) - 1 == i
1049
+ if is_last_column or values .nunique () == series .isin (values ).sum ():
1050
+
1051
+ # Last column in columns or values are unique in
1052
+ # series => values
1053
+ # is all that matters
1054
+ indexer = get_indexer (indexer , values .index )
1055
+ break
1056
+
1057
+ duplicated_filter = series .duplicated (keep = False )
1058
+ duplicated = values [duplicated_filter ]
1059
+ non_duplicated = values [~ duplicated_filter ]
1060
+ indexer = get_indexer (indexer , non_duplicated .index )
1061
+
1062
+ # Must set cur frame to include all duplicated values
1063
+ # to consider for the next column, we also can reduce
1064
+ # cur_n by the current length of the indexer
1065
+ cur_frame = cur_frame [series .isin (duplicated )]
1066
+ cur_n = n - len (indexer )
1067
+
1068
+ frame = frame .take (indexer )
1069
+
1070
+ # Restore the index on frame
1071
+ frame .index = original_index .take (indexer )
1072
+ return frame
1063
1073
1064
1074
1065
1075
# ------- #
0 commit comments