Skip to content

REF: move SelectN from core.algorithms #51460

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
226 changes: 0 additions & 226 deletions pandas/core/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,8 @@
from textwrap import dedent
from typing import (
TYPE_CHECKING,
Hashable,
Literal,
Sequence,
cast,
final,
)
import warnings

Expand All @@ -29,7 +26,6 @@
ArrayLike,
AxisInt,
DtypeObj,
IndexLabel,
TakeIndexer,
npt,
)
Expand Down Expand Up @@ -97,7 +93,6 @@

from pandas import (
Categorical,
DataFrame,
Index,
Series,
)
Expand Down Expand Up @@ -1167,227 +1162,6 @@ def checked_add_with_arr(
return result


# --------------- #
# select n #
# --------------- #


class SelectN:
def __init__(self, obj, n: int, keep: str) -> None:
self.obj = obj
self.n = n
self.keep = keep

if self.keep not in ("first", "last", "all"):
raise ValueError('keep must be either "first", "last" or "all"')

def compute(self, method: str) -> DataFrame | Series:
raise NotImplementedError

@final
def nlargest(self):
return self.compute("nlargest")

@final
def nsmallest(self):
return self.compute("nsmallest")

@final
@staticmethod
def is_valid_dtype_n_method(dtype: DtypeObj) -> bool:
"""
Helper function to determine if dtype is valid for
nsmallest/nlargest methods
"""
return (
not is_complex_dtype(dtype)
if is_numeric_dtype(dtype)
else needs_i8_conversion(dtype)
)


class SelectNSeries(SelectN):
"""
Implement n largest/smallest for Series

Parameters
----------
obj : Series
n : int
keep : {'first', 'last'}, default 'first'

Returns
-------
nordered : Series
"""

def compute(self, method: str) -> Series:
from pandas.core.reshape.concat import concat

n = self.n
dtype = self.obj.dtype
if not self.is_valid_dtype_n_method(dtype):
raise TypeError(f"Cannot use method '{method}' with dtype {dtype}")

if n <= 0:
return self.obj[[]]

dropped = self.obj.dropna()
nan_index = self.obj.drop(dropped.index)

# slow method
if n >= len(self.obj):
ascending = method == "nsmallest"
return self.obj.sort_values(ascending=ascending).head(n)

# fast method
new_dtype = dropped.dtype
arr = _ensure_data(dropped.values)
if method == "nlargest":
arr = -arr
if is_integer_dtype(new_dtype):
# GH 21426: ensure reverse ordering at boundaries
arr -= 1

elif is_bool_dtype(new_dtype):
# GH 26154: ensure False is smaller than True
arr = 1 - (-arr)

if self.keep == "last":
arr = arr[::-1]

nbase = n
narr = len(arr)
n = min(n, narr)

# arr passed into kth_smallest must be contiguous. We copy
# here because kth_smallest will modify its input
kth_val = algos.kth_smallest(arr.copy(order="C"), n - 1)
(ns,) = np.nonzero(arr <= kth_val)
inds = ns[arr[ns].argsort(kind="mergesort")]

if self.keep != "all":
inds = inds[:n]
findex = nbase
else:
if len(inds) < nbase <= len(nan_index) + len(inds):
findex = len(nan_index) + len(inds)
else:
findex = len(inds)

if self.keep == "last":
# reverse indices
inds = narr - 1 - inds

return concat([dropped.iloc[inds], nan_index]).iloc[:findex]


class SelectNFrame(SelectN):
"""
Implement n largest/smallest for DataFrame

Parameters
----------
obj : DataFrame
n : int
keep : {'first', 'last'}, default 'first'
columns : list or str

Returns
-------
nordered : DataFrame
"""

def __init__(self, obj: DataFrame, n: int, keep: str, columns: IndexLabel) -> None:
super().__init__(obj, n, keep)
if not is_list_like(columns) or isinstance(columns, tuple):
columns = [columns]

columns = cast(Sequence[Hashable], columns)
columns = list(columns)
self.columns = columns

def compute(self, method: str) -> DataFrame:
from pandas.core.api import Index

n = self.n
frame = self.obj
columns = self.columns

for column in columns:
dtype = frame[column].dtype
if not self.is_valid_dtype_n_method(dtype):
raise TypeError(
f"Column {repr(column)} has dtype {dtype}, "
f"cannot use method {repr(method)} with this dtype"
)

def get_indexer(current_indexer, other_indexer):
"""
Helper function to concat `current_indexer` and `other_indexer`
depending on `method`
"""
if method == "nsmallest":
return current_indexer.append(other_indexer)
else:
return other_indexer.append(current_indexer)

# Below we save and reset the index in case index contains duplicates
original_index = frame.index
cur_frame = frame = frame.reset_index(drop=True)
cur_n = n
indexer = Index([], dtype=np.int64)

for i, column in enumerate(columns):
# For each column we apply method to cur_frame[column].
# If it's the last column or if we have the number of
# results desired we are done.
# Otherwise there are duplicates of the largest/smallest
# value and we need to look at the rest of the columns
# to determine which of the rows with the largest/smallest
# value in the column to keep.
series = cur_frame[column]
is_last_column = len(columns) - 1 == i
values = getattr(series, method)(
cur_n, keep=self.keep if is_last_column else "all"
)

if is_last_column or len(values) <= cur_n:
indexer = get_indexer(indexer, values.index)
break

# Now find all values which are equal to
# the (nsmallest: largest)/(nlargest: smallest)
# from our series.
border_value = values == values[values.index[-1]]

# Some of these values are among the top-n
# some aren't.
unsafe_values = values[border_value]

# These values are definitely among the top-n
safe_values = values[~border_value]
indexer = get_indexer(indexer, safe_values.index)

# Go on and separate the unsafe_values on the remaining
# columns.
cur_frame = cur_frame.loc[unsafe_values.index]
cur_n = n - len(indexer)

frame = frame.take(indexer)

# Restore the index on frame
frame.index = original_index.take(indexer)

# If there is only one column, the frame is already sorted.
if len(columns) == 1:
return frame

ascending = method == "nsmallest"

return frame.sort_values(columns, ascending=ascending, kind="mergesort")


# ---- #
# take #
# ---- #
Expand Down
7 changes: 3 additions & 4 deletions pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@
to_arrays,
treat_as_nested,
)
from pandas.core.methods import selectn
from pandas.core.reshape.melt import melt
from pandas.core.series import Series
from pandas.core.shared_docs import _shared_docs
Expand Down Expand Up @@ -7178,7 +7179,7 @@ def nlargest(self, n: int, columns: IndexLabel, keep: str = "first") -> DataFram
Italy 59000000 1937894 IT
Brunei 434000 12128 BN
"""
return algorithms.SelectNFrame(self, n=n, keep=keep, columns=columns).nlargest()
return selectn.SelectNFrame(self, n=n, keep=keep, columns=columns).nlargest()

def nsmallest(self, n: int, columns: IndexLabel, keep: str = "first") -> DataFrame:
"""
Expand Down Expand Up @@ -7276,9 +7277,7 @@ def nsmallest(self, n: int, columns: IndexLabel, keep: str = "first") -> DataFra
Anguilla 11300 311 AI
Nauru 337000 182 NR
"""
return algorithms.SelectNFrame(
self, n=n, keep=keep, columns=columns
).nsmallest()
return selectn.SelectNFrame(self, n=n, keep=keep, columns=columns).nsmallest()

@doc(
Series.swaplevel,
Expand Down
Loading