Skip to content

Commit 0c17c96

Browse files
authored
Backport PR #55384 on branch 2.1.x (BUG: idxmax raising for arrow strings) (#55531)
BUG: idxmax raising for arrow strings (#55384) (cherry picked from commit 68e3c4b)
1 parent 5933c60 commit 0c17c96

File tree

3 files changed

+30
-1
lines changed

3 files changed

+30
-1
lines changed

pandas/core/arrays/arrow/array.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -1596,6 +1596,15 @@ def _reduce(
15961596
------
15971597
TypeError : subclass does not define reductions
15981598
"""
1599+
result = self._reduce_calc(name, skipna=skipna, keepdims=keepdims, **kwargs)
1600+
if isinstance(result, pa.Array):
1601+
return type(self)(result)
1602+
else:
1603+
return result
1604+
1605+
def _reduce_calc(
1606+
self, name: str, *, skipna: bool = True, keepdims: bool = False, **kwargs
1607+
):
15991608
pa_result = self._reduce_pyarrow(name, skipna=skipna, **kwargs)
16001609

16011610
if keepdims:
@@ -1606,7 +1615,7 @@ def _reduce(
16061615
[pa_result],
16071616
type=to_pyarrow_type(infer_dtype_from_scalar(pa_result)[0]),
16081617
)
1609-
return type(self)(result)
1618+
return result
16101619

16111620
if pc.is_null(pa_result).as_py():
16121621
return self.dtype.na_value

pandas/core/arrays/string_arrow.py

+11
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,17 @@ def _str_rstrip(self, to_strip=None):
445445
result = pc.utf8_rtrim(self._pa_array, characters=to_strip)
446446
return type(self)(result)
447447

448+
def _reduce(
449+
self, name: str, *, skipna: bool = True, keepdims: bool = False, **kwargs
450+
):
451+
result = self._reduce_calc(name, skipna=skipna, keepdims=keepdims, **kwargs)
452+
if name in ("argmin", "argmax") and isinstance(result, pa.Array):
453+
return self._convert_int_dtype(result)
454+
elif isinstance(result, pa.Array):
455+
return type(self)(result)
456+
else:
457+
return result
458+
448459
def _convert_int_dtype(self, result):
449460
return Int64Dtype().__from_arrow__(result)
450461

pandas/tests/frame/test_reductions.py

+9
Original file line numberDiff line numberDiff line change
@@ -1073,6 +1073,15 @@ def test_idxmax_arrow_types(self):
10731073
expected = Series([2, 1], index=["a", "b"])
10741074
tm.assert_series_equal(result, expected)
10751075

1076+
df = DataFrame({"a": ["b", "c", "a"]}, dtype="string[pyarrow]")
1077+
result = df.idxmax(numeric_only=False)
1078+
expected = Series([1], index=["a"])
1079+
tm.assert_series_equal(result, expected)
1080+
1081+
result = df.idxmin(numeric_only=False)
1082+
expected = Series([2], index=["a"])
1083+
tm.assert_series_equal(result, expected)
1084+
10761085
def test_idxmax_axis_2(self, float_frame):
10771086
frame = float_frame
10781087
msg = "No axis named 2 for object type DataFrame"

0 commit comments

Comments
 (0)