Skip to content

Commit 2c49f55

Browse files
BUG/API (string dtype): return float dtype for series[str].rank() (#59768)
* BUG/API (string dtype): return float dtype for series[str].rank() * update frame tests * add whatsnew * correct whatsnew note
1 parent 5927bd8 commit 2c49f55

File tree

5 files changed

+76
-36
lines changed

5 files changed

+76
-36
lines changed

doc/source/whatsnew/v2.3.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ Conversion
102102

103103
Strings
104104
^^^^^^^
105+
- Bug in :meth:`Series.rank` for :class:`StringDtype` with ``storage="pyarrow"`` incorrectly returning integer results in case of ``method="average"`` and raising an error if it would truncate results (:issue:`59768`)
105106
- Bug in :meth:`Series.str.replace` when ``n < 0`` for :class:`StringDtype` with ``storage="pyarrow"`` (:issue:`59628`)
106107
- Bug in ``ser.str.slice`` with negative ``step`` with :class:`ArrowDtype` and :class:`StringDtype` with ``storage="pyarrow"`` giving incorrect results (:issue:`59710`)
107108
- Bug in the ``center`` method on :class:`Series` and :class:`Index` object ``str`` accessors with pyarrow-backed dtype not matching the python behavior in corner cases with an odd number of fill characters (:issue:`54792`)

pandas/core/arrays/arrow/array.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -1999,7 +1999,7 @@ def _rank(
19991999
"""
20002000
See Series.rank.__doc__.
20012001
"""
2002-
return self._convert_int_result(
2002+
return self._convert_rank_result(
20032003
self._rank_calc(
20042004
axis=axis,
20052005
method=method,
@@ -2318,6 +2318,9 @@ def _convert_bool_result(self, result):
23182318
def _convert_int_result(self, result):
23192319
return type(self)(result)
23202320

2321+
def _convert_rank_result(self, result):
2322+
return type(self)(result)
2323+
23212324
def _str_count(self, pat: str, flags: int = 0) -> Self:
23222325
if flags:
23232326
raise NotImplementedError(f"count not implemented with {flags=}")

pandas/core/arrays/string_arrow.py

+11
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from pandas.core.arrays._arrow_string_mixins import ArrowStringArrayMixin
3030
from pandas.core.arrays.arrow import ArrowExtensionArray
3131
from pandas.core.arrays.boolean import BooleanDtype
32+
from pandas.core.arrays.floating import Float64Dtype
3233
from pandas.core.arrays.integer import Int64Dtype
3334
from pandas.core.arrays.numeric import NumericDtype
3435
from pandas.core.arrays.string_ import (
@@ -395,6 +396,16 @@ def _convert_int_result(self, result):
395396

396397
return Int64Dtype().__from_arrow__(result)
397398

399+
def _convert_rank_result(self, result):
400+
if self.dtype.na_value is np.nan:
401+
if isinstance(result, pa.Array):
402+
result = result.to_numpy(zero_copy_only=False)
403+
else:
404+
result = result.to_numpy()
405+
return result.astype("float64", copy=False)
406+
407+
return Float64Dtype().__from_arrow__(result)
408+
398409
def _reduce(
399410
self, name: str, *, skipna: bool = True, keepdims: bool = False, **kwargs
400411
):

pandas/tests/frame/methods/test_rank.py

+4-19
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,11 @@
66
import numpy as np
77
import pytest
88

9-
from pandas._config import using_string_dtype
10-
119
from pandas._libs.algos import (
1210
Infinity,
1311
NegInfinity,
1412
)
15-
from pandas.compat import HAS_PYARROW
1613

17-
import pandas as pd
1814
from pandas import (
1915
DataFrame,
2016
Index,
@@ -467,23 +463,10 @@ def test_rank_inf_nans_na_option(
467463
("top", False, [2.0, 3.0, 1.0, 4.0]),
468464
],
469465
)
470-
def test_rank_object_first(
471-
self,
472-
request,
473-
frame_or_series,
474-
na_option,
475-
ascending,
476-
expected,
477-
using_infer_string,
478-
):
466+
def test_rank_object_first(self, frame_or_series, na_option, ascending, expected):
479467
obj = frame_or_series(["foo", "foo", None, "foo"])
480-
if using_string_dtype() and not HAS_PYARROW and isinstance(obj, Series):
481-
request.applymarker(pytest.mark.xfail(reason="TODO(infer_string)"))
482-
483468
result = obj.rank(method="first", na_option=na_option, ascending=ascending)
484469
expected = frame_or_series(expected)
485-
if using_infer_string and isinstance(obj, Series):
486-
expected = expected.astype("uint64")
487470
tm.assert_equal(result, expected)
488471

489472
@pytest.mark.parametrize(
@@ -507,7 +490,9 @@ def test_rank_string_dtype(self, string_dtype_no_object):
507490
# GH#55362
508491
obj = Series(["foo", "foo", None, "foo"], dtype=string_dtype_no_object)
509492
result = obj.rank(method="first")
510-
exp_dtype = "Int64" if string_dtype_no_object.na_value is pd.NA else "float64"
493+
exp_dtype = (
494+
"Float64" if string_dtype_no_object == "string[pyarrow]" else "float64"
495+
)
511496
if string_dtype_no_object.storage == "python":
512497
# TODO nullable string[python] should also return nullable Int64
513498
exp_dtype = "float64"

pandas/tests/series/methods/test_rank.py

+56-16
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ def ser():
3333
["max", np.array([2, 6, 7, 4, np.nan, 4, 2, 8, np.nan, 6])],
3434
["first", np.array([1, 5, 7, 3, np.nan, 4, 2, 8, np.nan, 6])],
3535
["dense", np.array([1, 3, 4, 2, np.nan, 2, 1, 5, np.nan, 3])],
36-
]
36+
],
37+
ids=lambda x: x[0],
3738
)
3839
def results(request):
3940
return request.param
@@ -48,12 +49,29 @@ def results(request):
4849
"Int64",
4950
pytest.param("float64[pyarrow]", marks=td.skip_if_no("pyarrow")),
5051
pytest.param("int64[pyarrow]", marks=td.skip_if_no("pyarrow")),
52+
pytest.param("string[pyarrow]", marks=td.skip_if_no("pyarrow")),
53+
"string[python]",
54+
"str",
5155
]
5256
)
5357
def dtype(request):
5458
return request.param
5559

5660

61+
def expected_dtype(dtype, method, pct=False):
62+
exp_dtype = "float64"
63+
# elif dtype in ["Int64", "Float64", "string[pyarrow]", "string[python]"]:
64+
if dtype in ["string[pyarrow]"]:
65+
exp_dtype = "Float64"
66+
elif dtype in ["float64[pyarrow]", "int64[pyarrow]"]:
67+
if method == "average" or pct:
68+
exp_dtype = "double[pyarrow]"
69+
else:
70+
exp_dtype = "uint64[pyarrow]"
71+
72+
return exp_dtype
73+
74+
5775
class TestSeriesRank:
5876
def test_rank(self, datetime_series):
5977
sp_stats = pytest.importorskip("scipy.stats")
@@ -251,12 +269,14 @@ def test_rank_signature(self):
251269
with pytest.raises(ValueError, match=msg):
252270
s.rank("average")
253271

254-
@pytest.mark.parametrize("dtype", [None, object])
255-
def test_rank_tie_methods(self, ser, results, dtype):
272+
def test_rank_tie_methods(self, ser, results, dtype, using_infer_string):
256273
method, exp = results
274+
if dtype == "int64" or (not using_infer_string and dtype == "str"):
275+
pytest.skip("int64/str does not support NaN")
276+
257277
ser = ser if dtype is None else ser.astype(dtype)
258278
result = ser.rank(method=method)
259-
tm.assert_series_equal(result, Series(exp))
279+
tm.assert_series_equal(result, Series(exp, dtype=expected_dtype(dtype, method)))
260280

261281
@pytest.mark.parametrize("na_option", ["top", "bottom", "keep"])
262282
@pytest.mark.parametrize(
@@ -357,25 +377,35 @@ def test_rank_methods_series(self, rank_method, op, value):
357377
],
358378
)
359379
def test_rank_dense_method(self, dtype, ser, exp):
380+
if ser[0] < 0 and dtype.startswith("str"):
381+
exp = exp[::-1]
360382
s = Series(ser).astype(dtype)
361383
result = s.rank(method="dense")
362-
expected = Series(exp).astype(result.dtype)
384+
expected = Series(exp).astype(expected_dtype(dtype, "dense"))
363385
tm.assert_series_equal(result, expected)
364386

365-
def test_rank_descending(self, ser, results, dtype):
387+
def test_rank_descending(self, ser, results, dtype, using_infer_string):
366388
method, _ = results
367-
if "i" in dtype:
389+
if dtype == "int64" or (not using_infer_string and dtype == "str"):
368390
s = ser.dropna()
369391
else:
370392
s = ser.astype(dtype)
371393

372394
res = s.rank(ascending=False)
373-
expected = (s.max() - s).rank()
374-
tm.assert_series_equal(res, expected)
395+
if dtype.startswith("str"):
396+
expected = (s.astype("float64").max() - s.astype("float64")).rank()
397+
else:
398+
expected = (s.max() - s).rank()
399+
tm.assert_series_equal(res, expected.astype(expected_dtype(dtype, "average")))
375400

376-
expected = (s.max() - s).rank(method=method)
401+
if dtype.startswith("str"):
402+
expected = (s.astype("float64").max() - s.astype("float64")).rank(
403+
method=method
404+
)
405+
else:
406+
expected = (s.max() - s).rank(method=method)
377407
res2 = s.rank(method=method, ascending=False)
378-
tm.assert_series_equal(res2, expected)
408+
tm.assert_series_equal(res2, expected.astype(expected_dtype(dtype, method)))
379409

380410
def test_rank_int(self, ser, results):
381411
method, exp = results
@@ -432,9 +462,11 @@ def test_rank_ea_small_values(self):
432462
],
433463
)
434464
def test_rank_dense_pct(dtype, ser, exp):
465+
if ser[0] < 0 and dtype.startswith("str"):
466+
exp = exp[::-1]
435467
s = Series(ser).astype(dtype)
436468
result = s.rank(method="dense", pct=True)
437-
expected = Series(exp).astype(result.dtype)
469+
expected = Series(exp).astype(expected_dtype(dtype, "dense", pct=True))
438470
tm.assert_series_equal(result, expected)
439471

440472

@@ -453,9 +485,11 @@ def test_rank_dense_pct(dtype, ser, exp):
453485
],
454486
)
455487
def test_rank_min_pct(dtype, ser, exp):
488+
if ser[0] < 0 and dtype.startswith("str"):
489+
exp = exp[::-1]
456490
s = Series(ser).astype(dtype)
457491
result = s.rank(method="min", pct=True)
458-
expected = Series(exp).astype(result.dtype)
492+
expected = Series(exp).astype(expected_dtype(dtype, "min", pct=True))
459493
tm.assert_series_equal(result, expected)
460494

461495

@@ -474,9 +508,11 @@ def test_rank_min_pct(dtype, ser, exp):
474508
],
475509
)
476510
def test_rank_max_pct(dtype, ser, exp):
511+
if ser[0] < 0 and dtype.startswith("str"):
512+
exp = exp[::-1]
477513
s = Series(ser).astype(dtype)
478514
result = s.rank(method="max", pct=True)
479-
expected = Series(exp).astype(result.dtype)
515+
expected = Series(exp).astype(expected_dtype(dtype, "max", pct=True))
480516
tm.assert_series_equal(result, expected)
481517

482518

@@ -495,9 +531,11 @@ def test_rank_max_pct(dtype, ser, exp):
495531
],
496532
)
497533
def test_rank_average_pct(dtype, ser, exp):
534+
if ser[0] < 0 and dtype.startswith("str"):
535+
exp = exp[::-1]
498536
s = Series(ser).astype(dtype)
499537
result = s.rank(method="average", pct=True)
500-
expected = Series(exp).astype(result.dtype)
538+
expected = Series(exp).astype(expected_dtype(dtype, "average", pct=True))
501539
tm.assert_series_equal(result, expected)
502540

503541

@@ -516,9 +554,11 @@ def test_rank_average_pct(dtype, ser, exp):
516554
],
517555
)
518556
def test_rank_first_pct(dtype, ser, exp):
557+
if ser[0] < 0 and dtype.startswith("str"):
558+
exp = exp[::-1]
519559
s = Series(ser).astype(dtype)
520560
result = s.rank(method="first", pct=True)
521-
expected = Series(exp).astype(result.dtype)
561+
expected = Series(exp).astype(expected_dtype(dtype, "first", pct=True))
522562
tm.assert_series_equal(result, expected)
523563

524564

0 commit comments

Comments
 (0)