Skip to content

Commit cfe5709

Browse files
committed
CLN: clean up select_n algos
1 parent 4502e82 commit cfe5709

File tree

3 files changed

+156
-142
lines changed

3 files changed

+156
-142
lines changed

pandas/core/algorithms.py

+146-136
Original file line numberDiff line numberDiff line change
@@ -859,6 +859,12 @@ def quantile(x, q, interpolation_method='fraction'):
859859

860860
values = np.sort(x)
861861

862+
def _interpolate(a, b, fraction):
863+
"""Returns the point at the given fraction between a and b, where
864+
'fraction' must be between 0 and 1.
865+
"""
866+
return a + (b - a) * fraction
867+
862868
def _get_score(at):
863869
if len(values) == 0:
864870
return np.nan
@@ -887,179 +893,183 @@ def _get_score(at):
887893
return algos.arrmap_float64(q, _get_score)
888894

889895

890-
def _interpolate(a, b, fraction):
891-
"""Returns the point at the given fraction between a and b, where
892-
'fraction' must be between 0 and 1.
893-
"""
894-
return a + (b - a) * fraction
895-
896-
897-
def nsmallest(arr, n, keep='first'):
898-
"""
899-
Find the indices of the n smallest values of a numpy array.
900-
901-
Note: Fails silently with NaN.
902-
"""
903-
if keep == 'last':
904-
arr = arr[::-1]
905-
906-
narr = len(arr)
907-
n = min(n, narr)
908-
909-
arr = _ensure_data_view(arr)
910-
kth_val = algos.kth_smallest(arr.copy(), n - 1)
911-
return _finalize_nsmallest(arr, kth_val, n, keep, narr)
896+
# --------------- #
897+
# select n #
898+
# --------------- #
912899

900+
class SelectN(object):
913901

914-
def nlargest(arr, n, keep='first'):
915-
"""
916-
Find the indices of the n largest values of a numpy array.
902+
def __init__(self, obj, n, keep):
903+
self.obj = obj
904+
self.n = n
905+
self.keep = keep
917906

918-
Note: Fails silently with NaN.
919-
"""
920-
arr = _ensure_data_view(arr)
921-
return nsmallest(-arr, n, keep=keep)
907+
if self.keep not in ('first', 'last'):
908+
raise ValueError('keep must be either "first", "last"')
922909

910+
def nlargest(self):
911+
return self.compute('nlargest')
923912

924-
def select_n_slow(dropped, n, keep, method):
925-
reverse_it = (keep == 'last' or method == 'nlargest')
926-
ascending = method == 'nsmallest'
927-
slc = np.s_[::-1] if reverse_it else np.s_[:]
928-
return dropped[slc].sort_values(ascending=ascending).head(n)
913+
def nsmallest(self):
914+
return self.compute('nsmallest')
929915

930-
931-
_select_methods = {'nsmallest': nsmallest, 'nlargest': nlargest}
916+
@staticmethod
917+
def is_valid_dtype_n_method(dtype):
918+
"""
919+
Helper function to determine if dtype is valid for
920+
nsmallest/nlargest methods
921+
"""
922+
return ((is_numeric_dtype(dtype) and not is_complex_dtype(dtype)) or
923+
needs_i8_conversion(dtype))
932924

933925

934-
def _is_valid_dtype_n_method(dtype):
935-
"""
936-
Helper function to determine if dtype is valid for
937-
nsmallest/nlargest methods
926+
class SelectNSeries(SelectN):
938927
"""
939-
return ((is_numeric_dtype(dtype) and not is_complex_dtype(dtype)) or
940-
needs_i8_conversion(dtype))
941-
942-
943-
def select_n_series(series, n, keep, method):
944-
"""Implement n largest/smallest for pandas Series
928+
Implement n largest/smallest for Series
945929
946930
Parameters
947931
----------
948-
series : pandas.Series object
932+
frame : pandas.DataFrame object
949933
n : int
950934
keep : {'first', 'last'}, default 'first'
951-
method : str, {'nlargest', 'nsmallest'}
952935
953936
Returns
954937
-------
955938
nordered : Series
956939
"""
957-
dtype = series.dtype
958-
if not _is_valid_dtype_n_method(dtype):
959-
raise TypeError("Cannot use method '{method}' with "
960-
"dtype {dtype}".format(method=method, dtype=dtype))
961940

962-
if keep not in ('first', 'last'):
963-
raise ValueError('keep must be either "first", "last"')
941+
def compute(self, method):
942+
943+
n = self.n
944+
dtype = self.obj.dtype
945+
if not self.is_valid_dtype_n_method(dtype):
946+
raise TypeError("Cannot use method '{method}' with "
947+
"dtype {dtype}".format(method=method,
948+
dtype=dtype))
949+
950+
if n <= 0:
951+
return self.obj[[]]
964952

965-
if n <= 0:
966-
return series[[]]
953+
dropped = self.obj.dropna()
967954

968-
dropped = series.dropna()
955+
# slow method
956+
if n >= len(self.obj):
969957

970-
if n >= len(series):
971-
return select_n_slow(dropped, n, keep, method)
958+
reverse_it = (self.keep == 'last' or method == 'nlargest')
959+
ascending = method == 'nsmallest'
960+
slc = np.s_[::-1] if reverse_it else np.s_[:]
961+
return dropped[slc].sort_values(ascending=ascending).head(n)
972962

973-
inds = _select_methods[method](dropped.values, n, keep)
974-
return dropped.iloc[inds]
963+
# fast method
964+
arr = _ensure_data_view(dropped.values)
965+
if method == 'nlargest':
966+
arr = -arr
975967

968+
if self.keep == 'last':
969+
arr = arr[::-1]
976970

977-
def select_n_frame(frame, columns, n, method, keep):
978-
"""Implement n largest/smallest for pandas DataFrame
971+
narr = len(arr)
972+
n = min(n, narr)
973+
974+
kth_val = algos.kth_smallest(arr.copy(), n - 1)
975+
ns, = np.nonzero(arr <= kth_val)
976+
inds = ns[arr[ns].argsort(kind='mergesort')][:n]
977+
if self.keep == 'last':
978+
# reverse indices
979+
inds = narr - 1 - inds
980+
981+
return dropped.iloc[inds]
982+
983+
984+
class SelectNFrame(SelectN):
985+
"""
986+
Implement n largest/smallest for DataFrame
979987
980988
Parameters
981989
----------
982990
frame : pandas.DataFrame object
983-
columns : list or str
984991
n : int
985992
keep : {'first', 'last'}, default 'first'
986-
method : str, {'nlargest', 'nsmallest'}
993+
columns : list or str
987994
988995
Returns
989996
-------
990997
nordered : DataFrame
991998
"""
992-
from pandas import Int64Index
993-
if not is_list_like(columns):
994-
columns = [columns]
995-
columns = list(columns)
996-
for column in columns:
997-
dtype = frame[column].dtype
998-
if not _is_valid_dtype_n_method(dtype):
999-
raise TypeError((
1000-
"Column {column!r} has dtype {dtype}, cannot use method "
1001-
"{method!r} with this dtype"
1002-
).format(column=column, dtype=dtype, method=method))
1003-
1004-
def get_indexer(current_indexer, other_indexer):
1005-
"""Helper function to concat `current_indexer` and `other_indexer`
1006-
depending on `method`
1007-
"""
1008-
if method == 'nsmallest':
1009-
return current_indexer.append(other_indexer)
1010-
else:
1011-
return other_indexer.append(current_indexer)
1012-
1013-
# Below we save and reset the index in case index contains duplicates
1014-
original_index = frame.index
1015-
cur_frame = frame = frame.reset_index(drop=True)
1016-
cur_n = n
1017-
indexer = Int64Index([])
1018-
1019-
for i, column in enumerate(columns):
1020-
1021-
# For each column we apply method to cur_frame[column]. If it is the
1022-
# last column in columns, or if the values returned are unique in
1023-
# frame[column] we save this index and break
1024-
# Otherwise we must save the index of the non duplicated values
1025-
# and set the next cur_frame to cur_frame filtered on all duplcicated
1026-
# values (#GH15297)
1027-
series = cur_frame[column]
1028-
values = getattr(series, method)(cur_n, keep=keep)
1029-
is_last_column = len(columns) - 1 == i
1030-
if is_last_column or values.nunique() == series.isin(values).sum():
1031-
1032-
# Last column in columns or values are unique in series => values
1033-
# is all that matters
1034-
indexer = get_indexer(indexer, values.index)
1035-
break
1036-
1037-
duplicated_filter = series.duplicated(keep=False)
1038-
duplicated = values[duplicated_filter]
1039-
non_duplicated = values[~duplicated_filter]
1040-
indexer = get_indexer(indexer, non_duplicated.index)
1041-
1042-
# Must set cur frame to include all duplicated values to consider for
1043-
# the next column, we also can reduce cur_n by the current length of
1044-
# the indexer
1045-
cur_frame = cur_frame[series.isin(duplicated)]
1046-
cur_n = n - len(indexer)
1047-
1048-
frame = frame.take(indexer)
1049-
1050-
# Restore the index on frame
1051-
frame.index = original_index.take(indexer)
1052-
return frame
1053-
1054-
1055-
def _finalize_nsmallest(arr, kth_val, n, keep, narr):
1056-
ns, = np.nonzero(arr <= kth_val)
1057-
inds = ns[arr[ns].argsort(kind='mergesort')][:n]
1058-
if keep == 'last':
1059-
# reverse indices
1060-
return narr - 1 - inds
1061-
else:
1062-
return inds
999+
1000+
def __init__(self, obj, n, keep, columns):
1001+
super(SelectNFrame, self).__init__(obj, n, keep)
1002+
if not is_list_like(columns):
1003+
columns = [columns]
1004+
columns = list(columns)
1005+
self.columns = columns
1006+
1007+
def compute(self, method):
1008+
1009+
from pandas import Int64Index
1010+
n = self.n
1011+
frame = self.obj
1012+
columns = self.columns
1013+
1014+
for column in columns:
1015+
dtype = frame[column].dtype
1016+
if not self.is_valid_dtype_n_method(dtype):
1017+
raise TypeError((
1018+
"Column {column!r} has dtype {dtype}, cannot use method "
1019+
"{method!r} with this dtype"
1020+
).format(column=column, dtype=dtype, method=method))
1021+
1022+
def get_indexer(current_indexer, other_indexer):
1023+
"""Helper function to concat `current_indexer` and `other_indexer`
1024+
depending on `method`
1025+
"""
1026+
if method == 'nsmallest':
1027+
return current_indexer.append(other_indexer)
1028+
else:
1029+
return other_indexer.append(current_indexer)
1030+
1031+
# Below we save and reset the index in case index contains duplicates
1032+
original_index = frame.index
1033+
cur_frame = frame = frame.reset_index(drop=True)
1034+
cur_n = n
1035+
indexer = Int64Index([])
1036+
1037+
for i, column in enumerate(columns):
1038+
1039+
# For each column we apply method to cur_frame[column].
1040+
# If it is the last column in columns, or if the values
1041+
# returned are unique in frame[column] we save this index
1042+
# and break
1043+
# Otherwise we must save the index of the non duplicated values
1044+
# and set the next cur_frame to cur_frame filtered on all
1045+
# duplcicated values (#GH15297)
1046+
series = cur_frame[column]
1047+
values = getattr(series, method)(cur_n, keep=self.keep)
1048+
is_last_column = len(columns) - 1 == i
1049+
if is_last_column or values.nunique() == series.isin(values).sum():
1050+
1051+
# Last column in columns or values are unique in
1052+
# series => values
1053+
# is all that matters
1054+
indexer = get_indexer(indexer, values.index)
1055+
break
1056+
1057+
duplicated_filter = series.duplicated(keep=False)
1058+
duplicated = values[duplicated_filter]
1059+
non_duplicated = values[~duplicated_filter]
1060+
indexer = get_indexer(indexer, non_duplicated.index)
1061+
1062+
# Must set cur frame to include all duplicated values
1063+
# to consider for the next column, we also can reduce
1064+
# cur_n by the current length of the indexer
1065+
cur_frame = cur_frame[series.isin(duplicated)]
1066+
cur_n = n - len(indexer)
1067+
1068+
frame = frame.take(indexer)
1069+
1070+
# Restore the index on frame
1071+
frame.index = original_index.take(indexer)
1072+
return frame
10631073

10641074

10651075
# ------- #

pandas/core/frame.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -3441,7 +3441,10 @@ def nlargest(self, n, columns, keep='first'):
34413441
1 10 b 2
34423442
2 8 d NaN
34433443
"""
3444-
return algorithms.select_n_frame(self, columns, n, 'nlargest', keep)
3444+
return algorithms.SelectNFrame(self,
3445+
n=n,
3446+
keep=keep,
3447+
columns=columns).nlargest()
34453448

34463449
def nsmallest(self, n, columns, keep='first'):
34473450
"""Get the rows of a DataFrame sorted by the `n` smallest
@@ -3475,7 +3478,10 @@ def nsmallest(self, n, columns, keep='first'):
34753478
0 1 a 1
34763479
2 8 d NaN
34773480
"""
3478-
return algorithms.select_n_frame(self, columns, n, 'nsmallest', keep)
3481+
return algorithms.SelectNFrame(self,
3482+
n=n,
3483+
keep=keep,
3484+
columns=columns).nsmallest()
34793485

34803486
def swaplevel(self, i=-2, j=-1, axis=0):
34813487
"""

pandas/core/series.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -1856,8 +1856,7 @@ def nlargest(self, n=5, keep='first'):
18561856
121637 4.240952
18571857
dtype: float64
18581858
"""
1859-
return algorithms.select_n_series(self, n=n, keep=keep,
1860-
method='nlargest')
1859+
return algorithms.SelectNSeries(self, n=n, keep=keep).nlargest()
18611860

18621861
def nsmallest(self, n=5, keep='first'):
18631862
"""Return the smallest `n` elements.
@@ -1903,8 +1902,7 @@ def nsmallest(self, n=5, keep='first'):
19031902
359919 -4.331927
19041903
dtype: float64
19051904
"""
1906-
return algorithms.select_n_series(self, n=n, keep=keep,
1907-
method='nsmallest')
1905+
return algorithms.SelectNSeries(self, n=n, keep=keep).nsmallest()
19081906

19091907
def sortlevel(self, level=0, ascending=True, sort_remaining=True):
19101908
"""

0 commit comments

Comments
 (0)