Skip to content

Commit 6a83910

Browse files
authored
BUG: rank raising for arrow string dtypes (#55362)
1 parent 6b84daa commit 6a83910

File tree

4 files changed

+65
-7
lines changed

4 files changed

+65
-7
lines changed

doc/source/whatsnew/v2.1.2.rst

+1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ Bug fixes
2929
- Fixed bug in :meth:`DataFrame.resample` not respecting ``closed`` and ``label`` arguments for :class:`~pandas.tseries.offsets.BusinessDay` (:issue:`55282`)
3030
- Fixed bug in :meth:`DataFrame.resample` where bin edges were not correct for :class:`~pandas.tseries.offsets.BusinessDay` (:issue:`55281`)
3131
- Fixed bug in :meth:`Index.insert` raising when inserting ``None`` into :class:`Index` with ``dtype="string[pyarrow_numpy]"`` (:issue:`55365`)
32+
- Fixed bug in :meth:`Series.rank` for ``string[pyarrow_numpy]`` dtype (:issue:`55362`)
3233
- Silence ``Period[B]`` warnings introduced by :issue:`53446` during normal plotting activity (:issue:`55138`)
3334
-
3435

pandas/core/arrays/arrow/array.py

+25-6
Original file line numberDiff line numberDiff line change
@@ -1747,7 +1747,7 @@ def __setitem__(self, key, value) -> None:
17471747
data = pa.chunked_array([data])
17481748
self._pa_array = data
17491749

1750-
def _rank(
1750+
def _rank_calc(
17511751
self,
17521752
*,
17531753
axis: AxisInt = 0,
@@ -1756,9 +1756,6 @@ def _rank(
17561756
ascending: bool = True,
17571757
pct: bool = False,
17581758
):
1759-
"""
1760-
See Series.rank.__doc__.
1761-
"""
17621759
if pa_version_under9p0 or axis != 0:
17631760
ranked = super()._rank(
17641761
axis=axis,
@@ -1773,7 +1770,7 @@ def _rank(
17731770
else:
17741771
pa_type = pa.uint64()
17751772
result = pa.array(ranked, type=pa_type, from_pandas=True)
1776-
return type(self)(result)
1773+
return result
17771774

17781775
data = self._pa_array.combine_chunks()
17791776
sort_keys = "ascending" if ascending else "descending"
@@ -1812,7 +1809,29 @@ def _rank(
18121809
divisor = pc.count(result)
18131810
result = pc.divide(result, divisor)
18141811

1815-
return type(self)(result)
1812+
return result
1813+
1814+
def _rank(
1815+
self,
1816+
*,
1817+
axis: AxisInt = 0,
1818+
method: str = "average",
1819+
na_option: str = "keep",
1820+
ascending: bool = True,
1821+
pct: bool = False,
1822+
):
1823+
"""
1824+
See Series.rank.__doc__.
1825+
"""
1826+
return type(self)(
1827+
self._rank_calc(
1828+
axis=axis,
1829+
method=method,
1830+
na_option=na_option,
1831+
ascending=ascending,
1832+
pct=pct,
1833+
)
1834+
)
18161835

18171836
def _quantile(self, qs: npt.NDArray[np.float64], interpolation: str) -> Self:
18181837
"""

pandas/core/arrays/string_arrow.py

+27-1
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
from collections.abc import Sequence
5454

5555
from pandas._typing import (
56+
AxisInt,
5657
Dtype,
5758
Scalar,
5859
npt,
@@ -501,6 +502,28 @@ def _str_find(self, sub: str, start: int = 0, end: int | None = None):
501502
def _convert_int_dtype(self, result):
502503
return Int64Dtype().__from_arrow__(result)
503504

505+
def _rank(
506+
self,
507+
*,
508+
axis: AxisInt = 0,
509+
method: str = "average",
510+
na_option: str = "keep",
511+
ascending: bool = True,
512+
pct: bool = False,
513+
):
514+
"""
515+
See Series.rank.__doc__.
516+
"""
517+
return self._convert_int_dtype(
518+
self._rank_calc(
519+
axis=axis,
520+
method=method,
521+
na_option=na_option,
522+
ascending=ascending,
523+
pct=pct,
524+
)
525+
)
526+
504527

505528
class ArrowStringArrayNumpySemantics(ArrowStringArray):
506529
_storage = "pyarrow_numpy"
@@ -584,7 +607,10 @@ def _str_map(
584607
return lib.map_infer_mask(arr, f, mask.view("uint8"))
585608

586609
def _convert_int_dtype(self, result):
587-
result = result.to_numpy()
610+
if isinstance(result, pa.Array):
611+
result = result.to_numpy(zero_copy_only=False)
612+
else:
613+
result = result.to_numpy()
588614
if result.dtype == np.int32:
589615
result = result.astype(np.int64)
590616
return result

pandas/tests/frame/methods/test_rank.py

+12
Original file line numberDiff line numberDiff line change
@@ -488,3 +488,15 @@ def test_rank_mixed_axis_zero(self, data, expected):
488488
df.rank()
489489
result = df.rank(numeric_only=True)
490490
tm.assert_frame_equal(result, expected)
491+
492+
@pytest.mark.parametrize(
493+
"dtype, exp_dtype",
494+
[("string[pyarrow]", "Int64"), ("string[pyarrow_numpy]", "float64")],
495+
)
496+
def test_rank_string_dtype(self, dtype, exp_dtype):
497+
# GH#55362
498+
pytest.importorskip("pyarrow")
499+
obj = Series(["foo", "foo", None, "foo"], dtype=dtype)
500+
result = obj.rank(method="first")
501+
expected = Series([1, 2, None, 3], dtype=exp_dtype)
502+
tm.assert_series_equal(result, expected)

0 commit comments

Comments
 (0)