|
8 | 8 | from textwrap import dedent
|
9 | 9 | from typing import (
|
10 | 10 | TYPE_CHECKING,
|
11 |
| - Hashable, |
12 | 11 | Literal,
|
13 |
| - Sequence, |
14 | 12 | cast,
|
15 |
| - final, |
16 | 13 | )
|
17 | 14 | import warnings
|
18 | 15 |
|
|
29 | 26 | ArrayLike,
|
30 | 27 | AxisInt,
|
31 | 28 | DtypeObj,
|
32 |
| - IndexLabel, |
33 | 29 | TakeIndexer,
|
34 | 30 | npt,
|
35 | 31 | )
|
|
97 | 93 |
|
98 | 94 | from pandas import (
|
99 | 95 | Categorical,
|
100 |
| - DataFrame, |
101 | 96 | Index,
|
102 | 97 | Series,
|
103 | 98 | )
|
@@ -1167,227 +1162,6 @@ def checked_add_with_arr(
|
1167 | 1162 | return result
|
1168 | 1163 |
|
1169 | 1164 |
|
1170 |
| -# --------------- # |
1171 |
| -# select n # |
1172 |
| -# --------------- # |
1173 |
| - |
1174 |
| - |
1175 |
| -class SelectN: |
1176 |
| - def __init__(self, obj, n: int, keep: str) -> None: |
1177 |
| - self.obj = obj |
1178 |
| - self.n = n |
1179 |
| - self.keep = keep |
1180 |
| - |
1181 |
| - if self.keep not in ("first", "last", "all"): |
1182 |
| - raise ValueError('keep must be either "first", "last" or "all"') |
1183 |
| - |
1184 |
| - def compute(self, method: str) -> DataFrame | Series: |
1185 |
| - raise NotImplementedError |
1186 |
| - |
1187 |
| - @final |
1188 |
| - def nlargest(self): |
1189 |
| - return self.compute("nlargest") |
1190 |
| - |
1191 |
| - @final |
1192 |
| - def nsmallest(self): |
1193 |
| - return self.compute("nsmallest") |
1194 |
| - |
1195 |
| - @final |
1196 |
| - @staticmethod |
1197 |
| - def is_valid_dtype_n_method(dtype: DtypeObj) -> bool: |
1198 |
| - """ |
1199 |
| - Helper function to determine if dtype is valid for |
1200 |
| - nsmallest/nlargest methods |
1201 |
| - """ |
1202 |
| - return ( |
1203 |
| - not is_complex_dtype(dtype) |
1204 |
| - if is_numeric_dtype(dtype) |
1205 |
| - else needs_i8_conversion(dtype) |
1206 |
| - ) |
1207 |
| - |
1208 |
| - |
1209 |
| -class SelectNSeries(SelectN): |
1210 |
| - """ |
1211 |
| - Implement n largest/smallest for Series |
1212 |
| -
|
1213 |
| - Parameters |
1214 |
| - ---------- |
1215 |
| - obj : Series |
1216 |
| - n : int |
1217 |
| - keep : {'first', 'last'}, default 'first' |
1218 |
| -
|
1219 |
| - Returns |
1220 |
| - ------- |
1221 |
| - nordered : Series |
1222 |
| - """ |
1223 |
| - |
1224 |
| - def compute(self, method: str) -> Series: |
1225 |
| - from pandas.core.reshape.concat import concat |
1226 |
| - |
1227 |
| - n = self.n |
1228 |
| - dtype = self.obj.dtype |
1229 |
| - if not self.is_valid_dtype_n_method(dtype): |
1230 |
| - raise TypeError(f"Cannot use method '{method}' with dtype {dtype}") |
1231 |
| - |
1232 |
| - if n <= 0: |
1233 |
| - return self.obj[[]] |
1234 |
| - |
1235 |
| - dropped = self.obj.dropna() |
1236 |
| - nan_index = self.obj.drop(dropped.index) |
1237 |
| - |
1238 |
| - # slow method |
1239 |
| - if n >= len(self.obj): |
1240 |
| - ascending = method == "nsmallest" |
1241 |
| - return self.obj.sort_values(ascending=ascending).head(n) |
1242 |
| - |
1243 |
| - # fast method |
1244 |
| - new_dtype = dropped.dtype |
1245 |
| - arr = _ensure_data(dropped.values) |
1246 |
| - if method == "nlargest": |
1247 |
| - arr = -arr |
1248 |
| - if is_integer_dtype(new_dtype): |
1249 |
| - # GH 21426: ensure reverse ordering at boundaries |
1250 |
| - arr -= 1 |
1251 |
| - |
1252 |
| - elif is_bool_dtype(new_dtype): |
1253 |
| - # GH 26154: ensure False is smaller than True |
1254 |
| - arr = 1 - (-arr) |
1255 |
| - |
1256 |
| - if self.keep == "last": |
1257 |
| - arr = arr[::-1] |
1258 |
| - |
1259 |
| - nbase = n |
1260 |
| - narr = len(arr) |
1261 |
| - n = min(n, narr) |
1262 |
| - |
1263 |
| - # arr passed into kth_smallest must be contiguous. We copy |
1264 |
| - # here because kth_smallest will modify its input |
1265 |
| - kth_val = algos.kth_smallest(arr.copy(order="C"), n - 1) |
1266 |
| - (ns,) = np.nonzero(arr <= kth_val) |
1267 |
| - inds = ns[arr[ns].argsort(kind="mergesort")] |
1268 |
| - |
1269 |
| - if self.keep != "all": |
1270 |
| - inds = inds[:n] |
1271 |
| - findex = nbase |
1272 |
| - else: |
1273 |
| - if len(inds) < nbase <= len(nan_index) + len(inds): |
1274 |
| - findex = len(nan_index) + len(inds) |
1275 |
| - else: |
1276 |
| - findex = len(inds) |
1277 |
| - |
1278 |
| - if self.keep == "last": |
1279 |
| - # reverse indices |
1280 |
| - inds = narr - 1 - inds |
1281 |
| - |
1282 |
| - return concat([dropped.iloc[inds], nan_index]).iloc[:findex] |
1283 |
| - |
1284 |
| - |
1285 |
| -class SelectNFrame(SelectN): |
1286 |
| - """ |
1287 |
| - Implement n largest/smallest for DataFrame |
1288 |
| -
|
1289 |
| - Parameters |
1290 |
| - ---------- |
1291 |
| - obj : DataFrame |
1292 |
| - n : int |
1293 |
| - keep : {'first', 'last'}, default 'first' |
1294 |
| - columns : list or str |
1295 |
| -
|
1296 |
| - Returns |
1297 |
| - ------- |
1298 |
| - nordered : DataFrame |
1299 |
| - """ |
1300 |
| - |
1301 |
| - def __init__(self, obj: DataFrame, n: int, keep: str, columns: IndexLabel) -> None: |
1302 |
| - super().__init__(obj, n, keep) |
1303 |
| - if not is_list_like(columns) or isinstance(columns, tuple): |
1304 |
| - columns = [columns] |
1305 |
| - |
1306 |
| - columns = cast(Sequence[Hashable], columns) |
1307 |
| - columns = list(columns) |
1308 |
| - self.columns = columns |
1309 |
| - |
1310 |
| - def compute(self, method: str) -> DataFrame: |
1311 |
| - from pandas.core.api import Index |
1312 |
| - |
1313 |
| - n = self.n |
1314 |
| - frame = self.obj |
1315 |
| - columns = self.columns |
1316 |
| - |
1317 |
| - for column in columns: |
1318 |
| - dtype = frame[column].dtype |
1319 |
| - if not self.is_valid_dtype_n_method(dtype): |
1320 |
| - raise TypeError( |
1321 |
| - f"Column {repr(column)} has dtype {dtype}, " |
1322 |
| - f"cannot use method {repr(method)} with this dtype" |
1323 |
| - ) |
1324 |
| - |
1325 |
| - def get_indexer(current_indexer, other_indexer): |
1326 |
| - """ |
1327 |
| - Helper function to concat `current_indexer` and `other_indexer` |
1328 |
| - depending on `method` |
1329 |
| - """ |
1330 |
| - if method == "nsmallest": |
1331 |
| - return current_indexer.append(other_indexer) |
1332 |
| - else: |
1333 |
| - return other_indexer.append(current_indexer) |
1334 |
| - |
1335 |
| - # Below we save and reset the index in case index contains duplicates |
1336 |
| - original_index = frame.index |
1337 |
| - cur_frame = frame = frame.reset_index(drop=True) |
1338 |
| - cur_n = n |
1339 |
| - indexer = Index([], dtype=np.int64) |
1340 |
| - |
1341 |
| - for i, column in enumerate(columns): |
1342 |
| - # For each column we apply method to cur_frame[column]. |
1343 |
| - # If it's the last column or if we have the number of |
1344 |
| - # results desired we are done. |
1345 |
| - # Otherwise there are duplicates of the largest/smallest |
1346 |
| - # value and we need to look at the rest of the columns |
1347 |
| - # to determine which of the rows with the largest/smallest |
1348 |
| - # value in the column to keep. |
1349 |
| - series = cur_frame[column] |
1350 |
| - is_last_column = len(columns) - 1 == i |
1351 |
| - values = getattr(series, method)( |
1352 |
| - cur_n, keep=self.keep if is_last_column else "all" |
1353 |
| - ) |
1354 |
| - |
1355 |
| - if is_last_column or len(values) <= cur_n: |
1356 |
| - indexer = get_indexer(indexer, values.index) |
1357 |
| - break |
1358 |
| - |
1359 |
| - # Now find all values which are equal to |
1360 |
| - # the (nsmallest: largest)/(nlargest: smallest) |
1361 |
| - # from our series. |
1362 |
| - border_value = values == values[values.index[-1]] |
1363 |
| - |
1364 |
| - # Some of these values are among the top-n |
1365 |
| - # some aren't. |
1366 |
| - unsafe_values = values[border_value] |
1367 |
| - |
1368 |
| - # These values are definitely among the top-n |
1369 |
| - safe_values = values[~border_value] |
1370 |
| - indexer = get_indexer(indexer, safe_values.index) |
1371 |
| - |
1372 |
| - # Go on and separate the unsafe_values on the remaining |
1373 |
| - # columns. |
1374 |
| - cur_frame = cur_frame.loc[unsafe_values.index] |
1375 |
| - cur_n = n - len(indexer) |
1376 |
| - |
1377 |
| - frame = frame.take(indexer) |
1378 |
| - |
1379 |
| - # Restore the index on frame |
1380 |
| - frame.index = original_index.take(indexer) |
1381 |
| - |
1382 |
| - # If there is only one column, the frame is already sorted. |
1383 |
| - if len(columns) == 1: |
1384 |
| - return frame |
1385 |
| - |
1386 |
| - ascending = method == "nsmallest" |
1387 |
| - |
1388 |
| - return frame.sort_values(columns, ascending=ascending, kind="mergesort") |
1389 |
| - |
1390 |
| - |
1391 | 1165 | # ---- #
|
1392 | 1166 | # take #
|
1393 | 1167 | # ---- #
|
|
0 commit comments