Skip to content

Commit 8117a55

Browse files
authored
PERF: Series(pyarrow-backed).rank (#50264)
* ArrowExtensionArray._rank * gh ref * skip pyarrow tests if not installed * defer to pc.rank output types * fix test * more consistency * use pyarrow for method="average" * fix call to super * call super with axis != 0
1 parent 075947f commit 8117a55

File tree

4 files changed

+102
-12
lines changed

4 files changed

+102
-12
lines changed

doc/source/whatsnew/v2.0.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -736,6 +736,7 @@ Performance improvements
736736
- Performance improvement in :meth:`MultiIndex.isin` when ``level=None`` (:issue:`48622`, :issue:`49577`)
737737
- Performance improvement in :meth:`MultiIndex.putmask` (:issue:`49830`)
738738
- Performance improvement in :meth:`Index.union` and :meth:`MultiIndex.union` when index contains duplicates (:issue:`48900`)
739+
- Performance improvement in :meth:`Series.rank` for pyarrow-backed dtypes (:issue:`50264`)
739740
- Performance improvement in :meth:`Series.fillna` for extension array dtypes (:issue:`49722`, :issue:`50078`)
740741
- Performance improvement for :meth:`Series.value_counts` with nullable dtype (:issue:`48338`)
741742
- Performance improvement for :class:`Series` constructor passing integer numpy array with nullable dtype (:issue:`48338`)

pandas/core/arrays/arrow/array.py

+68-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from pandas._libs import lib
1313
from pandas._typing import (
1414
ArrayLike,
15+
AxisInt,
1516
Dtype,
1617
FillnaOptions,
1718
Iterator,
@@ -24,6 +25,7 @@
2425
from pandas.compat import (
2526
pa_version_under6p0,
2627
pa_version_under7p0,
28+
pa_version_under9p0,
2729
)
2830
from pandas.util._decorators import doc
2931
from pandas.util._validators import validate_fillna_kwargs
@@ -1006,7 +1008,72 @@ def _indexing_key_to_indices(
10061008
indices = np.arange(n)[key]
10071009
return indices
10081010

1009-
# TODO: redefine _rank using pc.rank with pyarrow 9.0
1011+
def _rank(
1012+
self,
1013+
*,
1014+
axis: AxisInt = 0,
1015+
method: str = "average",
1016+
na_option: str = "keep",
1017+
ascending: bool = True,
1018+
pct: bool = False,
1019+
):
1020+
"""
1021+
See Series.rank.__doc__.
1022+
"""
1023+
if pa_version_under9p0 or axis != 0:
1024+
ranked = super()._rank(
1025+
axis=axis,
1026+
method=method,
1027+
na_option=na_option,
1028+
ascending=ascending,
1029+
pct=pct,
1030+
)
1031+
# keep dtypes consistent with the implementation below
1032+
if method == "average" or pct:
1033+
pa_type = pa.float64()
1034+
else:
1035+
pa_type = pa.uint64()
1036+
result = pa.array(ranked, type=pa_type, from_pandas=True)
1037+
return type(self)(result)
1038+
1039+
data = self._data.combine_chunks()
1040+
sort_keys = "ascending" if ascending else "descending"
1041+
null_placement = "at_start" if na_option == "top" else "at_end"
1042+
tiebreaker = "min" if method == "average" else method
1043+
1044+
result = pc.rank(
1045+
data,
1046+
sort_keys=sort_keys,
1047+
null_placement=null_placement,
1048+
tiebreaker=tiebreaker,
1049+
)
1050+
1051+
if na_option == "keep":
1052+
mask = pc.is_null(self._data)
1053+
null = pa.scalar(None, type=result.type)
1054+
result = pc.if_else(mask, null, result)
1055+
1056+
if method == "average":
1057+
result_max = pc.rank(
1058+
data,
1059+
sort_keys=sort_keys,
1060+
null_placement=null_placement,
1061+
tiebreaker="max",
1062+
)
1063+
result_max = result_max.cast(pa.float64())
1064+
result_min = result.cast(pa.float64())
1065+
result = pc.divide(pc.add(result_min, result_max), 2)
1066+
1067+
if pct:
1068+
if not pa.types.is_floating(result.type):
1069+
result = result.cast(pa.float64())
1070+
if method == "dense":
1071+
divisor = pc.max(result)
1072+
else:
1073+
divisor = pc.count(result)
1074+
result = pc.divide(result, divisor)
1075+
1076+
return type(self)(result)
10101077

10111078
def _quantile(
10121079
self: ArrowExtensionArrayT, qs: npt.NDArray[np.float64], interpolation: str

pandas/core/arrays/base.py

-2
Original file line numberDiff line numberDiff line change
@@ -1576,8 +1576,6 @@ def _rank(
15761576
if axis != 0:
15771577
raise NotImplementedError
15781578

1579-
# TODO: we only have tests that get here with dt64 and td64
1580-
# TODO: all tests that get here use the defaults for all the kwds
15811579
return rank(
15821580
self,
15831581
axis=axis,

pandas/tests/series/methods/test_rank.py

+33-9
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import pandas.util._test_decorators as td
1212

1313
from pandas import (
14+
NA,
1415
NaT,
1516
Series,
1617
Timestamp,
@@ -38,6 +39,21 @@ def results(request):
3839
return request.param
3940

4041

42+
@pytest.fixture(
43+
params=[
44+
"object",
45+
"float64",
46+
"int64",
47+
"Float64",
48+
"Int64",
49+
pytest.param("float64[pyarrow]", marks=td.skip_if_no("pyarrow")),
50+
pytest.param("int64[pyarrow]", marks=td.skip_if_no("pyarrow")),
51+
]
52+
)
53+
def dtype(request):
54+
return request.param
55+
56+
4157
class TestSeriesRank:
4258
@td.skip_if_no_scipy
4359
def test_rank(self, datetime_series):
@@ -238,13 +254,28 @@ def test_rank_tie_methods(self, ser, results, dtype):
238254
[
239255
("object", None, Infinity(), NegInfinity()),
240256
("float64", np.nan, np.inf, -np.inf),
257+
("Float64", NA, np.inf, -np.inf),
258+
pytest.param(
259+
"float64[pyarrow]",
260+
NA,
261+
np.inf,
262+
-np.inf,
263+
marks=td.skip_if_no("pyarrow"),
264+
),
241265
],
242266
)
243267
def test_rank_tie_methods_on_infs_nans(
244268
self, method, na_option, ascending, dtype, na_value, pos_inf, neg_inf
245269
):
246-
chunk = 3
270+
if dtype == "float64[pyarrow]":
271+
if method == "average":
272+
exp_dtype = "float64[pyarrow]"
273+
else:
274+
exp_dtype = "uint64[pyarrow]"
275+
else:
276+
exp_dtype = "float64"
247277

278+
chunk = 3
248279
in_arr = [neg_inf] * chunk + [na_value] * chunk + [pos_inf] * chunk
249280
iseries = Series(in_arr, dtype=dtype)
250281
exp_ranks = {
@@ -264,7 +295,7 @@ def test_rank_tie_methods_on_infs_nans(
264295
expected = order if ascending else order[::-1]
265296
expected = list(chain.from_iterable(expected))
266297
result = iseries.rank(method=method, na_option=na_option, ascending=ascending)
267-
tm.assert_series_equal(result, Series(expected, dtype="float64"))
298+
tm.assert_series_equal(result, Series(expected, dtype=exp_dtype))
268299

269300
def test_rank_desc_mix_nans_infs(self):
270301
# GH 19538
@@ -299,7 +330,6 @@ def test_rank_methods_series(self, method, op, value):
299330
expected = Series(sprank, index=index).astype("float64")
300331
tm.assert_series_equal(result, expected)
301332

302-
@pytest.mark.parametrize("dtype", ["O", "f8", "i8"])
303333
@pytest.mark.parametrize(
304334
"ser, exp",
305335
[
@@ -319,7 +349,6 @@ def test_rank_dense_method(self, dtype, ser, exp):
319349
expected = Series(exp).astype(result.dtype)
320350
tm.assert_series_equal(result, expected)
321351

322-
@pytest.mark.parametrize("dtype", ["O", "f8", "i8"])
323352
def test_rank_descending(self, ser, results, dtype):
324353
method, _ = results
325354
if "i" in dtype:
@@ -365,7 +394,6 @@ def test_rank_modify_inplace(self):
365394
# GH15630, pct should be on 100% basis when method='dense'
366395

367396

368-
@pytest.mark.parametrize("dtype", ["O", "f8", "i8"])
369397
@pytest.mark.parametrize(
370398
"ser, exp",
371399
[
@@ -387,7 +415,6 @@ def test_rank_dense_pct(dtype, ser, exp):
387415
tm.assert_series_equal(result, expected)
388416

389417

390-
@pytest.mark.parametrize("dtype", ["O", "f8", "i8"])
391418
@pytest.mark.parametrize(
392419
"ser, exp",
393420
[
@@ -409,7 +436,6 @@ def test_rank_min_pct(dtype, ser, exp):
409436
tm.assert_series_equal(result, expected)
410437

411438

412-
@pytest.mark.parametrize("dtype", ["O", "f8", "i8"])
413439
@pytest.mark.parametrize(
414440
"ser, exp",
415441
[
@@ -431,7 +457,6 @@ def test_rank_max_pct(dtype, ser, exp):
431457
tm.assert_series_equal(result, expected)
432458

433459

434-
@pytest.mark.parametrize("dtype", ["O", "f8", "i8"])
435460
@pytest.mark.parametrize(
436461
"ser, exp",
437462
[
@@ -453,7 +478,6 @@ def test_rank_average_pct(dtype, ser, exp):
453478
tm.assert_series_equal(result, expected)
454479

455480

456-
@pytest.mark.parametrize("dtype", ["f8", "i8"])
457481
@pytest.mark.parametrize(
458482
"ser, exp",
459483
[

0 commit comments

Comments
 (0)