diff --git a/doc/source/whatsnew/v2.0.0.rst b/doc/source/whatsnew/v2.0.0.rst index 0b29843735189..cd61e893ea84b 100644 --- a/doc/source/whatsnew/v2.0.0.rst +++ b/doc/source/whatsnew/v2.0.0.rst @@ -734,6 +734,7 @@ Performance improvements - Performance improvement in :meth:`MultiIndex.isin` when ``level=None`` (:issue:`48622`, :issue:`49577`) - Performance improvement in :meth:`MultiIndex.putmask` (:issue:`49830`) - Performance improvement in :meth:`Index.union` and :meth:`MultiIndex.union` when index contains duplicates (:issue:`48900`) +- Performance improvement in :meth:`Series.rank` for pyarrow-backed dtypes (:issue:`50264`) - Performance improvement in :meth:`Series.fillna` for extension array dtypes (:issue:`49722`, :issue:`50078`) - Performance improvement for :meth:`Series.value_counts` with nullable dtype (:issue:`48338`) - Performance improvement for :class:`Series` constructor passing integer numpy array with nullable dtype (:issue:`48338`) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 254ff8894b36c..1d3c31a129647 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -11,6 +11,7 @@ from pandas._typing import ( ArrayLike, + AxisInt, Dtype, FillnaOptions, Iterator, @@ -22,6 +23,7 @@ from pandas.compat import ( pa_version_under6p0, pa_version_under7p0, + pa_version_under9p0, ) from pandas.util._decorators import doc from pandas.util._validators import validate_fillna_kwargs @@ -949,7 +951,72 @@ def _indexing_key_to_indices( indices = np.arange(n)[key] return indices - # TODO: redefine _rank using pc.rank with pyarrow 9.0 + def _rank( + self, + *, + axis: AxisInt = 0, + method: str = "average", + na_option: str = "keep", + ascending: bool = True, + pct: bool = False, + ): + """ + See Series.rank.__doc__. + """ + if pa_version_under9p0 or axis != 0: + ranked = super()._rank( + axis=axis, + method=method, + na_option=na_option, + ascending=ascending, + pct=pct, + ) + # keep dtypes consistent with the implementation below + if method == "average" or pct: + pa_type = pa.float64() + else: + pa_type = pa.uint64() + result = pa.array(ranked, type=pa_type, from_pandas=True) + return type(self)(result) + + data = self._data.combine_chunks() + sort_keys = "ascending" if ascending else "descending" + null_placement = "at_start" if na_option == "top" else "at_end" + tiebreaker = "min" if method == "average" else method + + result = pc.rank( + data, + sort_keys=sort_keys, + null_placement=null_placement, + tiebreaker=tiebreaker, + ) + + if na_option == "keep": + mask = pc.is_null(self._data) + null = pa.scalar(None, type=result.type) + result = pc.if_else(mask, null, result) + + if method == "average": + result_max = pc.rank( + data, + sort_keys=sort_keys, + null_placement=null_placement, + tiebreaker="max", + ) + result_max = result_max.cast(pa.float64()) + result_min = result.cast(pa.float64()) + result = pc.divide(pc.add(result_min, result_max), 2) + + if pct: + if not pa.types.is_floating(result.type): + result = result.cast(pa.float64()) + if method == "dense": + divisor = pc.max(result) + else: + divisor = pc.count(result) + result = pc.divide(result, divisor) + + return type(self)(result) def _quantile( self: ArrowExtensionArrayT, qs: npt.NDArray[np.float64], interpolation: str diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index c36728391ba21..6954c97007d23 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -1576,8 +1576,6 @@ def _rank( if axis != 0: raise NotImplementedError - # TODO: we only have tests that get here with dt64 and td64 - # TODO: all tests that get here use the defaults for all the kwds return rank( self, axis=axis, diff --git a/pandas/tests/series/methods/test_rank.py b/pandas/tests/series/methods/test_rank.py index 3a66bf1adf25b..3183ba24bb88d 100644 --- a/pandas/tests/series/methods/test_rank.py +++ b/pandas/tests/series/methods/test_rank.py @@ -11,6 +11,7 @@ import pandas.util._test_decorators as td from pandas import ( + NA, NaT, Series, Timestamp, @@ -38,6 +39,21 @@ def results(request): return request.param +@pytest.fixture( + params=[ + "object", + "float64", + "int64", + "Float64", + "Int64", + pytest.param("float64[pyarrow]", marks=td.skip_if_no("pyarrow")), + pytest.param("int64[pyarrow]", marks=td.skip_if_no("pyarrow")), + ] +) +def dtype(request): + return request.param + + class TestSeriesRank: @td.skip_if_no_scipy def test_rank(self, datetime_series): @@ -238,13 +254,28 @@ def test_rank_tie_methods(self, ser, results, dtype): [ ("object", None, Infinity(), NegInfinity()), ("float64", np.nan, np.inf, -np.inf), + ("Float64", NA, np.inf, -np.inf), + pytest.param( + "float64[pyarrow]", + NA, + np.inf, + -np.inf, + marks=td.skip_if_no("pyarrow"), + ), ], ) def test_rank_tie_methods_on_infs_nans( self, method, na_option, ascending, dtype, na_value, pos_inf, neg_inf ): - chunk = 3 + if dtype == "float64[pyarrow]": + if method == "average": + exp_dtype = "float64[pyarrow]" + else: + exp_dtype = "uint64[pyarrow]" + else: + exp_dtype = "float64" + chunk = 3 in_arr = [neg_inf] * chunk + [na_value] * chunk + [pos_inf] * chunk iseries = Series(in_arr, dtype=dtype) exp_ranks = { @@ -264,7 +295,7 @@ def test_rank_tie_methods_on_infs_nans( expected = order if ascending else order[::-1] expected = list(chain.from_iterable(expected)) result = iseries.rank(method=method, na_option=na_option, ascending=ascending) - tm.assert_series_equal(result, Series(expected, dtype="float64")) + tm.assert_series_equal(result, Series(expected, dtype=exp_dtype)) def test_rank_desc_mix_nans_infs(self): # GH 19538 @@ -299,7 +330,6 @@ def test_rank_methods_series(self, method, op, value): expected = Series(sprank, index=index).astype("float64") tm.assert_series_equal(result, expected) - @pytest.mark.parametrize("dtype", ["O", "f8", "i8"]) @pytest.mark.parametrize( "ser, exp", [ @@ -319,7 +349,6 @@ def test_rank_dense_method(self, dtype, ser, exp): expected = Series(exp).astype(result.dtype) tm.assert_series_equal(result, expected) - @pytest.mark.parametrize("dtype", ["O", "f8", "i8"]) def test_rank_descending(self, ser, results, dtype): method, _ = results if "i" in dtype: @@ -365,7 +394,6 @@ def test_rank_modify_inplace(self): # GH15630, pct should be on 100% basis when method='dense' -@pytest.mark.parametrize("dtype", ["O", "f8", "i8"]) @pytest.mark.parametrize( "ser, exp", [ @@ -387,7 +415,6 @@ def test_rank_dense_pct(dtype, ser, exp): tm.assert_series_equal(result, expected) -@pytest.mark.parametrize("dtype", ["O", "f8", "i8"]) @pytest.mark.parametrize( "ser, exp", [ @@ -409,7 +436,6 @@ def test_rank_min_pct(dtype, ser, exp): tm.assert_series_equal(result, expected) -@pytest.mark.parametrize("dtype", ["O", "f8", "i8"]) @pytest.mark.parametrize( "ser, exp", [ @@ -431,7 +457,6 @@ def test_rank_max_pct(dtype, ser, exp): tm.assert_series_equal(result, expected) -@pytest.mark.parametrize("dtype", ["O", "f8", "i8"]) @pytest.mark.parametrize( "ser, exp", [ @@ -453,7 +478,6 @@ def test_rank_average_pct(dtype, ser, exp): tm.assert_series_equal(result, expected) -@pytest.mark.parametrize("dtype", ["f8", "i8"]) @pytest.mark.parametrize( "ser, exp", [