Skip to content

Commit 93631a9

Browse files
authored
REF: dispatch Series.rank to EA (#45037)
1 parent 6bc6366 commit 93631a9

File tree

4 files changed

+79
-17
lines changed

4 files changed

+79
-17
lines changed

pandas/core/algorithms.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -286,8 +286,6 @@ def _get_hashtable_algo(values: np.ndarray):
286286

287287

288288
def _get_values_for_rank(values: ArrayLike) -> np.ndarray:
289-
if is_categorical_dtype(values):
290-
values = cast("Categorical", values)._values_for_rank()
291289

292290
values = _ensure_data(values)
293291
if values.dtype.kind in ["i", "u", "f"]:
@@ -992,13 +990,13 @@ def rank(
992990
na_option: str = "keep",
993991
ascending: bool = True,
994992
pct: bool = False,
995-
) -> np.ndarray:
993+
) -> npt.NDArray[np.float64]:
996994
"""
997995
Rank the values along a given axis.
998996
999997
Parameters
1000998
----------
1001-
values : array-like
999+
values : np.ndarray or ExtensionArray
10021000
Array whose values will be ranked. The number of dimensions in this
10031001
array must not exceed 2.
10041002
axis : int, default 0

pandas/core/arrays/base.py

+27
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
from pandas.core.algorithms import (
7474
factorize_array,
7575
isin,
76+
rank,
7677
unique,
7778
)
7879
from pandas.core.array_algos.quantile import quantile_with_mask
@@ -1496,6 +1497,32 @@ def _fill_mask_inplace(
14961497
self[mask] = new_values[mask]
14971498
return
14981499

1500+
def _rank(
1501+
self,
1502+
*,
1503+
axis: int = 0,
1504+
method: str = "average",
1505+
na_option: str = "keep",
1506+
ascending: bool = True,
1507+
pct: bool = False,
1508+
):
1509+
"""
1510+
See Series.rank.__doc__.
1511+
"""
1512+
if axis != 0:
1513+
raise NotImplementedError
1514+
1515+
# TODO: we only have tests that get here with dt64 and td64
1516+
# TODO: all tests that get here use the defaults for all the kwds
1517+
return rank(
1518+
self,
1519+
axis=axis,
1520+
method=method,
1521+
na_option=na_option,
1522+
ascending=ascending,
1523+
pct=pct,
1524+
)
1525+
14991526
@classmethod
15001527
def _empty(cls, shape: Shape, dtype: ExtensionDtype):
15011528
"""

pandas/core/arrays/categorical.py

+24
Original file line numberDiff line numberDiff line change
@@ -1842,6 +1842,30 @@ def sort_values(
18421842
codes = self._codes[sorted_idx]
18431843
return self._from_backing_data(codes)
18441844

1845+
def _rank(
1846+
self,
1847+
*,
1848+
axis: int = 0,
1849+
method: str = "average",
1850+
na_option: str = "keep",
1851+
ascending: bool = True,
1852+
pct: bool = False,
1853+
):
1854+
"""
1855+
See Series.rank.__doc__.
1856+
"""
1857+
if axis != 0:
1858+
raise NotImplementedError
1859+
vff = self._values_for_rank()
1860+
return algorithms.rank(
1861+
vff,
1862+
axis=axis,
1863+
method=method,
1864+
na_option=na_option,
1865+
ascending=ascending,
1866+
pct=pct,
1867+
)
1868+
18451869
def _values_for_rank(self):
18461870
"""
18471871
For correctly ranking ordered categorical data. See GH#15420

pandas/core/generic.py

+26-13
Original file line numberDiff line numberDiff line change
@@ -8513,19 +8513,32 @@ def rank(
85138513
raise ValueError(msg)
85148514

85158515
def ranker(data):
8516-
ranks = algos.rank(
8517-
data.values,
8518-
axis=axis,
8519-
method=method,
8520-
ascending=ascending,
8521-
na_option=na_option,
8522-
pct=pct,
8523-
)
8524-
# error: Argument 1 to "NDFrame" has incompatible type "ndarray"; expected
8525-
# "Union[ArrayManager, BlockManager]"
8526-
ranks_obj = self._constructor(
8527-
ranks, **data._construct_axes_dict() # type: ignore[arg-type]
8528-
)
8516+
if data.ndim == 2:
8517+
# i.e. DataFrame, we cast to ndarray
8518+
values = data.values
8519+
else:
8520+
# i.e. Series, can dispatch to EA
8521+
values = data._values
8522+
8523+
if isinstance(values, ExtensionArray):
8524+
ranks = values._rank(
8525+
axis=axis,
8526+
method=method,
8527+
ascending=ascending,
8528+
na_option=na_option,
8529+
pct=pct,
8530+
)
8531+
else:
8532+
ranks = algos.rank(
8533+
values,
8534+
axis=axis,
8535+
method=method,
8536+
ascending=ascending,
8537+
na_option=na_option,
8538+
pct=pct,
8539+
)
8540+
8541+
ranks_obj = self._constructor(ranks, **data._construct_axes_dict())
85298542
return ranks_obj.__finalize__(self, method="rank")
85308543

85318544
# if numeric_only is None, and we can't get anything, we try with

0 commit comments

Comments
 (0)