Skip to content

Commit ea1e118

Browse files
BUG/API (string dtype): return float dtype for series[str].rank() (#59768)
* BUG/API (string dtype): return float dtype for series[str].rank() * update frame tests * add whatsnew * correct whatsnew note
1 parent 2f1caf5 commit ea1e118

File tree

5 files changed

+76
-36
lines changed

5 files changed

+76
-36
lines changed

doc/source/whatsnew/v2.3.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ Conversion
102102

103103
Strings
104104
^^^^^^^
105+
- Bug in :meth:`Series.rank` for :class:`StringDtype` with ``storage="pyarrow"`` incorrectly returning integer results in case of ``method="average"`` and raising an error if it would truncate results (:issue:`59768`)
105106
- Bug in :meth:`Series.str.replace` when ``n < 0`` for :class:`StringDtype` with ``storage="pyarrow"`` (:issue:`59628`)
106107
- Bug in ``ser.str.slice`` with negative ``step`` with :class:`ArrowDtype` and :class:`StringDtype` with ``storage="pyarrow"`` giving incorrect results (:issue:`59710`)
107108
- Bug in the ``center`` method on :class:`Series` and :class:`Index` object ``str`` accessors with pyarrow-backed dtype not matching the python behavior in corner cases with an odd number of fill characters (:issue:`54792`)

pandas/core/arrays/arrow/array.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -1989,7 +1989,7 @@ def _rank(
19891989
"""
19901990
See Series.rank.__doc__.
19911991
"""
1992-
return self._convert_int_result(
1992+
return self._convert_rank_result(
19931993
self._rank_calc(
19941994
axis=axis,
19951995
method=method,
@@ -2291,6 +2291,9 @@ def _convert_bool_result(self, result):
22912291
def _convert_int_result(self, result):
22922292
return type(self)(result)
22932293

2294+
def _convert_rank_result(self, result):
2295+
return type(self)(result)
2296+
22942297
def _str_count(self, pat: str, flags: int = 0):
22952298
if flags:
22962299
raise NotImplementedError(f"count not implemented with {flags=}")

pandas/core/arrays/string_arrow.py

+11
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from pandas.core.arrays._arrow_string_mixins import ArrowStringArrayMixin
3131
from pandas.core.arrays.arrow import ArrowExtensionArray
3232
from pandas.core.arrays.boolean import BooleanDtype
33+
from pandas.core.arrays.floating import Float64Dtype
3334
from pandas.core.arrays.integer import Int64Dtype
3435
from pandas.core.arrays.numeric import NumericDtype
3536
from pandas.core.arrays.string_ import (
@@ -388,6 +389,16 @@ def _convert_int_result(self, result):
388389

389390
return Int64Dtype().__from_arrow__(result)
390391

392+
def _convert_rank_result(self, result):
393+
if self.dtype.na_value is np.nan:
394+
if isinstance(result, pa.Array):
395+
result = result.to_numpy(zero_copy_only=False)
396+
else:
397+
result = result.to_numpy()
398+
return result.astype("float64", copy=False)
399+
400+
return Float64Dtype().__from_arrow__(result)
401+
391402
def _reduce(
392403
self, name: str, *, skipna: bool = True, keepdims: bool = False, **kwargs
393404
):

pandas/tests/frame/methods/test_rank.py

+4-19
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,11 @@
66
import numpy as np
77
import pytest
88

9-
from pandas._config import using_string_dtype
10-
119
from pandas._libs.algos import (
1210
Infinity,
1311
NegInfinity,
1412
)
15-
from pandas.compat import HAS_PYARROW
1613

17-
import pandas as pd
1814
from pandas import (
1915
DataFrame,
2016
Index,
@@ -474,23 +470,10 @@ def test_rank_inf_nans_na_option(
474470
("top", False, [2.0, 3.0, 1.0, 4.0]),
475471
],
476472
)
477-
def test_rank_object_first(
478-
self,
479-
request,
480-
frame_or_series,
481-
na_option,
482-
ascending,
483-
expected,
484-
using_infer_string,
485-
):
473+
def test_rank_object_first(self, frame_or_series, na_option, ascending, expected):
486474
obj = frame_or_series(["foo", "foo", None, "foo"])
487-
if using_string_dtype() and not HAS_PYARROW and isinstance(obj, Series):
488-
request.applymarker(pytest.mark.xfail(reason="TODO(infer_string)"))
489-
490475
result = obj.rank(method="first", na_option=na_option, ascending=ascending)
491476
expected = frame_or_series(expected)
492-
if using_infer_string and isinstance(obj, Series):
493-
expected = expected.astype("uint64")
494477
tm.assert_equal(result, expected)
495478

496479
@pytest.mark.parametrize(
@@ -514,7 +497,9 @@ def test_rank_string_dtype(self, string_dtype_no_object):
514497
# GH#55362
515498
obj = Series(["foo", "foo", None, "foo"], dtype=string_dtype_no_object)
516499
result = obj.rank(method="first")
517-
exp_dtype = "Int64" if string_dtype_no_object.na_value is pd.NA else "float64"
500+
exp_dtype = (
501+
"Float64" if string_dtype_no_object == "string[pyarrow]" else "float64"
502+
)
518503
if string_dtype_no_object.storage == "python":
519504
# TODO nullable string[python] should also return nullable Int64
520505
exp_dtype = "float64"

pandas/tests/series/methods/test_rank.py

+56-16
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ def ser():
3333
["max", np.array([2, 6, 7, 4, np.nan, 4, 2, 8, np.nan, 6])],
3434
["first", np.array([1, 5, 7, 3, np.nan, 4, 2, 8, np.nan, 6])],
3535
["dense", np.array([1, 3, 4, 2, np.nan, 2, 1, 5, np.nan, 3])],
36-
]
36+
],
37+
ids=lambda x: x[0],
3738
)
3839
def results(request):
3940
return request.param
@@ -48,12 +49,29 @@ def results(request):
4849
"Int64",
4950
pytest.param("float64[pyarrow]", marks=td.skip_if_no("pyarrow")),
5051
pytest.param("int64[pyarrow]", marks=td.skip_if_no("pyarrow")),
52+
pytest.param("string[pyarrow]", marks=td.skip_if_no("pyarrow")),
53+
"string[python]",
54+
"str",
5155
]
5256
)
5357
def dtype(request):
5458
return request.param
5559

5660

61+
def expected_dtype(dtype, method, pct=False):
62+
exp_dtype = "float64"
63+
# elif dtype in ["Int64", "Float64", "string[pyarrow]", "string[python]"]:
64+
if dtype in ["string[pyarrow]"]:
65+
exp_dtype = "Float64"
66+
elif dtype in ["float64[pyarrow]", "int64[pyarrow]"]:
67+
if method == "average" or pct:
68+
exp_dtype = "double[pyarrow]"
69+
else:
70+
exp_dtype = "uint64[pyarrow]"
71+
72+
return exp_dtype
73+
74+
5775
class TestSeriesRank:
5876
def test_rank(self, datetime_series):
5977
sp_stats = pytest.importorskip("scipy.stats")
@@ -241,12 +259,14 @@ def test_rank_signature(self):
241259
with pytest.raises(ValueError, match=msg):
242260
s.rank("average")
243261

244-
@pytest.mark.parametrize("dtype", [None, object])
245-
def test_rank_tie_methods(self, ser, results, dtype):
262+
def test_rank_tie_methods(self, ser, results, dtype, using_infer_string):
246263
method, exp = results
264+
if dtype == "int64" or (not using_infer_string and dtype == "str"):
265+
pytest.skip("int64/str does not support NaN")
266+
247267
ser = ser if dtype is None else ser.astype(dtype)
248268
result = ser.rank(method=method)
249-
tm.assert_series_equal(result, Series(exp))
269+
tm.assert_series_equal(result, Series(exp, dtype=expected_dtype(dtype, method)))
250270

251271
@pytest.mark.parametrize("ascending", [True, False])
252272
@pytest.mark.parametrize("method", ["average", "min", "max", "first", "dense"])
@@ -346,25 +366,35 @@ def test_rank_methods_series(self, method, op, value):
346366
],
347367
)
348368
def test_rank_dense_method(self, dtype, ser, exp):
369+
if ser[0] < 0 and dtype.startswith("str"):
370+
exp = exp[::-1]
349371
s = Series(ser).astype(dtype)
350372
result = s.rank(method="dense")
351-
expected = Series(exp).astype(result.dtype)
373+
expected = Series(exp).astype(expected_dtype(dtype, "dense"))
352374
tm.assert_series_equal(result, expected)
353375

354-
def test_rank_descending(self, ser, results, dtype):
376+
def test_rank_descending(self, ser, results, dtype, using_infer_string):
355377
method, _ = results
356-
if "i" in dtype:
378+
if dtype == "int64" or (not using_infer_string and dtype == "str"):
357379
s = ser.dropna()
358380
else:
359381
s = ser.astype(dtype)
360382

361383
res = s.rank(ascending=False)
362-
expected = (s.max() - s).rank()
363-
tm.assert_series_equal(res, expected)
384+
if dtype.startswith("str"):
385+
expected = (s.astype("float64").max() - s.astype("float64")).rank()
386+
else:
387+
expected = (s.max() - s).rank()
388+
tm.assert_series_equal(res, expected.astype(expected_dtype(dtype, "average")))
364389

365-
expected = (s.max() - s).rank(method=method)
390+
if dtype.startswith("str"):
391+
expected = (s.astype("float64").max() - s.astype("float64")).rank(
392+
method=method
393+
)
394+
else:
395+
expected = (s.max() - s).rank(method=method)
366396
res2 = s.rank(method=method, ascending=False)
367-
tm.assert_series_equal(res2, expected)
397+
tm.assert_series_equal(res2, expected.astype(expected_dtype(dtype, method)))
368398

369399
def test_rank_int(self, ser, results):
370400
method, exp = results
@@ -421,9 +451,11 @@ def test_rank_ea_small_values(self):
421451
],
422452
)
423453
def test_rank_dense_pct(dtype, ser, exp):
454+
if ser[0] < 0 and dtype.startswith("str"):
455+
exp = exp[::-1]
424456
s = Series(ser).astype(dtype)
425457
result = s.rank(method="dense", pct=True)
426-
expected = Series(exp).astype(result.dtype)
458+
expected = Series(exp).astype(expected_dtype(dtype, "dense", pct=True))
427459
tm.assert_series_equal(result, expected)
428460

429461

@@ -442,9 +474,11 @@ def test_rank_dense_pct(dtype, ser, exp):
442474
],
443475
)
444476
def test_rank_min_pct(dtype, ser, exp):
477+
if ser[0] < 0 and dtype.startswith("str"):
478+
exp = exp[::-1]
445479
s = Series(ser).astype(dtype)
446480
result = s.rank(method="min", pct=True)
447-
expected = Series(exp).astype(result.dtype)
481+
expected = Series(exp).astype(expected_dtype(dtype, "min", pct=True))
448482
tm.assert_series_equal(result, expected)
449483

450484

@@ -463,9 +497,11 @@ def test_rank_min_pct(dtype, ser, exp):
463497
],
464498
)
465499
def test_rank_max_pct(dtype, ser, exp):
500+
if ser[0] < 0 and dtype.startswith("str"):
501+
exp = exp[::-1]
466502
s = Series(ser).astype(dtype)
467503
result = s.rank(method="max", pct=True)
468-
expected = Series(exp).astype(result.dtype)
504+
expected = Series(exp).astype(expected_dtype(dtype, "max", pct=True))
469505
tm.assert_series_equal(result, expected)
470506

471507

@@ -484,9 +520,11 @@ def test_rank_max_pct(dtype, ser, exp):
484520
],
485521
)
486522
def test_rank_average_pct(dtype, ser, exp):
523+
if ser[0] < 0 and dtype.startswith("str"):
524+
exp = exp[::-1]
487525
s = Series(ser).astype(dtype)
488526
result = s.rank(method="average", pct=True)
489-
expected = Series(exp).astype(result.dtype)
527+
expected = Series(exp).astype(expected_dtype(dtype, "average", pct=True))
490528
tm.assert_series_equal(result, expected)
491529

492530

@@ -505,9 +543,11 @@ def test_rank_average_pct(dtype, ser, exp):
505543
],
506544
)
507545
def test_rank_first_pct(dtype, ser, exp):
546+
if ser[0] < 0 and dtype.startswith("str"):
547+
exp = exp[::-1]
508548
s = Series(ser).astype(dtype)
509549
result = s.rank(method="first", pct=True)
510-
expected = Series(exp).astype(result.dtype)
550+
expected = Series(exp).astype(expected_dtype(dtype, "first", pct=True))
511551
tm.assert_series_equal(result, expected)
512552

513553

0 commit comments

Comments
 (0)