Skip to content

Commit 2070bb8

Browse files
authored
REF: move SelectN from core.algorithms (#51460)
1 parent 41d937d commit 2070bb8

File tree

4 files changed

+268
-232
lines changed

4 files changed

+268
-232
lines changed

pandas/core/algorithms.py

-226
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,8 @@
88
from textwrap import dedent
99
from typing import (
1010
TYPE_CHECKING,
11-
Hashable,
1211
Literal,
13-
Sequence,
1412
cast,
15-
final,
1613
)
1714
import warnings
1815

@@ -29,7 +26,6 @@
2926
ArrayLike,
3027
AxisInt,
3128
DtypeObj,
32-
IndexLabel,
3329
TakeIndexer,
3430
npt,
3531
)
@@ -97,7 +93,6 @@
9793

9894
from pandas import (
9995
Categorical,
100-
DataFrame,
10196
Index,
10297
Series,
10398
)
@@ -1167,227 +1162,6 @@ def checked_add_with_arr(
11671162
return result
11681163

11691164

1170-
# --------------- #
1171-
# select n #
1172-
# --------------- #
1173-
1174-
1175-
class SelectN:
1176-
def __init__(self, obj, n: int, keep: str) -> None:
1177-
self.obj = obj
1178-
self.n = n
1179-
self.keep = keep
1180-
1181-
if self.keep not in ("first", "last", "all"):
1182-
raise ValueError('keep must be either "first", "last" or "all"')
1183-
1184-
def compute(self, method: str) -> DataFrame | Series:
1185-
raise NotImplementedError
1186-
1187-
@final
1188-
def nlargest(self):
1189-
return self.compute("nlargest")
1190-
1191-
@final
1192-
def nsmallest(self):
1193-
return self.compute("nsmallest")
1194-
1195-
@final
1196-
@staticmethod
1197-
def is_valid_dtype_n_method(dtype: DtypeObj) -> bool:
1198-
"""
1199-
Helper function to determine if dtype is valid for
1200-
nsmallest/nlargest methods
1201-
"""
1202-
return (
1203-
not is_complex_dtype(dtype)
1204-
if is_numeric_dtype(dtype)
1205-
else needs_i8_conversion(dtype)
1206-
)
1207-
1208-
1209-
class SelectNSeries(SelectN):
1210-
"""
1211-
Implement n largest/smallest for Series
1212-
1213-
Parameters
1214-
----------
1215-
obj : Series
1216-
n : int
1217-
keep : {'first', 'last'}, default 'first'
1218-
1219-
Returns
1220-
-------
1221-
nordered : Series
1222-
"""
1223-
1224-
def compute(self, method: str) -> Series:
1225-
from pandas.core.reshape.concat import concat
1226-
1227-
n = self.n
1228-
dtype = self.obj.dtype
1229-
if not self.is_valid_dtype_n_method(dtype):
1230-
raise TypeError(f"Cannot use method '{method}' with dtype {dtype}")
1231-
1232-
if n <= 0:
1233-
return self.obj[[]]
1234-
1235-
dropped = self.obj.dropna()
1236-
nan_index = self.obj.drop(dropped.index)
1237-
1238-
# slow method
1239-
if n >= len(self.obj):
1240-
ascending = method == "nsmallest"
1241-
return self.obj.sort_values(ascending=ascending).head(n)
1242-
1243-
# fast method
1244-
new_dtype = dropped.dtype
1245-
arr = _ensure_data(dropped.values)
1246-
if method == "nlargest":
1247-
arr = -arr
1248-
if is_integer_dtype(new_dtype):
1249-
# GH 21426: ensure reverse ordering at boundaries
1250-
arr -= 1
1251-
1252-
elif is_bool_dtype(new_dtype):
1253-
# GH 26154: ensure False is smaller than True
1254-
arr = 1 - (-arr)
1255-
1256-
if self.keep == "last":
1257-
arr = arr[::-1]
1258-
1259-
nbase = n
1260-
narr = len(arr)
1261-
n = min(n, narr)
1262-
1263-
# arr passed into kth_smallest must be contiguous. We copy
1264-
# here because kth_smallest will modify its input
1265-
kth_val = algos.kth_smallest(arr.copy(order="C"), n - 1)
1266-
(ns,) = np.nonzero(arr <= kth_val)
1267-
inds = ns[arr[ns].argsort(kind="mergesort")]
1268-
1269-
if self.keep != "all":
1270-
inds = inds[:n]
1271-
findex = nbase
1272-
else:
1273-
if len(inds) < nbase <= len(nan_index) + len(inds):
1274-
findex = len(nan_index) + len(inds)
1275-
else:
1276-
findex = len(inds)
1277-
1278-
if self.keep == "last":
1279-
# reverse indices
1280-
inds = narr - 1 - inds
1281-
1282-
return concat([dropped.iloc[inds], nan_index]).iloc[:findex]
1283-
1284-
1285-
class SelectNFrame(SelectN):
1286-
"""
1287-
Implement n largest/smallest for DataFrame
1288-
1289-
Parameters
1290-
----------
1291-
obj : DataFrame
1292-
n : int
1293-
keep : {'first', 'last'}, default 'first'
1294-
columns : list or str
1295-
1296-
Returns
1297-
-------
1298-
nordered : DataFrame
1299-
"""
1300-
1301-
def __init__(self, obj: DataFrame, n: int, keep: str, columns: IndexLabel) -> None:
1302-
super().__init__(obj, n, keep)
1303-
if not is_list_like(columns) or isinstance(columns, tuple):
1304-
columns = [columns]
1305-
1306-
columns = cast(Sequence[Hashable], columns)
1307-
columns = list(columns)
1308-
self.columns = columns
1309-
1310-
def compute(self, method: str) -> DataFrame:
1311-
from pandas.core.api import Index
1312-
1313-
n = self.n
1314-
frame = self.obj
1315-
columns = self.columns
1316-
1317-
for column in columns:
1318-
dtype = frame[column].dtype
1319-
if not self.is_valid_dtype_n_method(dtype):
1320-
raise TypeError(
1321-
f"Column {repr(column)} has dtype {dtype}, "
1322-
f"cannot use method {repr(method)} with this dtype"
1323-
)
1324-
1325-
def get_indexer(current_indexer, other_indexer):
1326-
"""
1327-
Helper function to concat `current_indexer` and `other_indexer`
1328-
depending on `method`
1329-
"""
1330-
if method == "nsmallest":
1331-
return current_indexer.append(other_indexer)
1332-
else:
1333-
return other_indexer.append(current_indexer)
1334-
1335-
# Below we save and reset the index in case index contains duplicates
1336-
original_index = frame.index
1337-
cur_frame = frame = frame.reset_index(drop=True)
1338-
cur_n = n
1339-
indexer = Index([], dtype=np.int64)
1340-
1341-
for i, column in enumerate(columns):
1342-
# For each column we apply method to cur_frame[column].
1343-
# If it's the last column or if we have the number of
1344-
# results desired we are done.
1345-
# Otherwise there are duplicates of the largest/smallest
1346-
# value and we need to look at the rest of the columns
1347-
# to determine which of the rows with the largest/smallest
1348-
# value in the column to keep.
1349-
series = cur_frame[column]
1350-
is_last_column = len(columns) - 1 == i
1351-
values = getattr(series, method)(
1352-
cur_n, keep=self.keep if is_last_column else "all"
1353-
)
1354-
1355-
if is_last_column or len(values) <= cur_n:
1356-
indexer = get_indexer(indexer, values.index)
1357-
break
1358-
1359-
# Now find all values which are equal to
1360-
# the (nsmallest: largest)/(nlargest: smallest)
1361-
# from our series.
1362-
border_value = values == values[values.index[-1]]
1363-
1364-
# Some of these values are among the top-n
1365-
# some aren't.
1366-
unsafe_values = values[border_value]
1367-
1368-
# These values are definitely among the top-n
1369-
safe_values = values[~border_value]
1370-
indexer = get_indexer(indexer, safe_values.index)
1371-
1372-
# Go on and separate the unsafe_values on the remaining
1373-
# columns.
1374-
cur_frame = cur_frame.loc[unsafe_values.index]
1375-
cur_n = n - len(indexer)
1376-
1377-
frame = frame.take(indexer)
1378-
1379-
# Restore the index on frame
1380-
frame.index = original_index.take(indexer)
1381-
1382-
# If there is only one column, the frame is already sorted.
1383-
if len(columns) == 1:
1384-
return frame
1385-
1386-
ascending = method == "nsmallest"
1387-
1388-
return frame.sort_values(columns, ascending=ascending, kind="mergesort")
1389-
1390-
13911165
# ---- #
13921166
# take #
13931167
# ---- #

pandas/core/frame.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@
213213
to_arrays,
214214
treat_as_nested,
215215
)
216+
from pandas.core.methods import selectn
216217
from pandas.core.reshape.melt import melt
217218
from pandas.core.series import Series
218219
from pandas.core.shared_docs import _shared_docs
@@ -7175,7 +7176,7 @@ def nlargest(self, n: int, columns: IndexLabel, keep: str = "first") -> DataFram
71757176
Italy 59000000 1937894 IT
71767177
Brunei 434000 12128 BN
71777178
"""
7178-
return algorithms.SelectNFrame(self, n=n, keep=keep, columns=columns).nlargest()
7179+
return selectn.SelectNFrame(self, n=n, keep=keep, columns=columns).nlargest()
71797180

71807181
def nsmallest(self, n: int, columns: IndexLabel, keep: str = "first") -> DataFrame:
71817182
"""
@@ -7273,9 +7274,7 @@ def nsmallest(self, n: int, columns: IndexLabel, keep: str = "first") -> DataFra
72737274
Anguilla 11300 311 AI
72747275
Nauru 337000 182 NR
72757276
"""
7276-
return algorithms.SelectNFrame(
7277-
self, n=n, keep=keep, columns=columns
7278-
).nsmallest()
7277+
return selectn.SelectNFrame(self, n=n, keep=keep, columns=columns).nsmallest()
72797278

72807279
@doc(
72817280
Series.swaplevel,

0 commit comments

Comments
 (0)