Skip to content

Commit ab4cec6

Browse files
committed
BUG: idxmax raising for arrow strings (pandas-dev#55384)
(cherry picked from commit 68e3c4b)
1 parent 0ee6fc9 commit ab4cec6

File tree

3 files changed

+30
-1
lines changed

3 files changed

+30
-1
lines changed

pandas/core/arrays/arrow/array.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -1596,6 +1596,15 @@ def _reduce(
15961596
------
15971597
TypeError : subclass does not define reductions
15981598
"""
1599+
result = self._reduce_calc(name, skipna=skipna, keepdims=keepdims, **kwargs)
1600+
if isinstance(result, pa.Array):
1601+
return type(self)(result)
1602+
else:
1603+
return result
1604+
1605+
def _reduce_calc(
1606+
self, name: str, *, skipna: bool = True, keepdims: bool = False, **kwargs
1607+
):
15991608
pa_result = self._reduce_pyarrow(name, skipna=skipna, **kwargs)
16001609

16011610
if keepdims:
@@ -1606,7 +1615,7 @@ def _reduce(
16061615
[pa_result],
16071616
type=to_pyarrow_type(infer_dtype_from_scalar(pa_result)[0]),
16081617
)
1609-
return type(self)(result)
1618+
return result
16101619

16111620
if pc.is_null(pa_result).as_py():
16121621
return self.dtype.na_value

pandas/core/arrays/string_arrow.py

+11
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,17 @@ def _str_rstrip(self, to_strip=None):
444444
result = pc.utf8_rtrim(self._pa_array, characters=to_strip)
445445
return type(self)(result)
446446

447+
def _reduce(
448+
self, name: str, *, skipna: bool = True, keepdims: bool = False, **kwargs
449+
):
450+
result = self._reduce_calc(name, skipna=skipna, keepdims=keepdims, **kwargs)
451+
if name in ("argmin", "argmax") and isinstance(result, pa.Array):
452+
return self._convert_int_dtype(result)
453+
elif isinstance(result, pa.Array):
454+
return type(self)(result)
455+
else:
456+
return result
457+
447458

448459
class ArrowStringArrayNumpySemantics(ArrowStringArray):
449460
_storage = "pyarrow_numpy"

pandas/tests/frame/test_reductions.py

+9
Original file line numberDiff line numberDiff line change
@@ -1073,6 +1073,15 @@ def test_idxmax_arrow_types(self):
10731073
expected = Series([2, 1], index=["a", "b"])
10741074
tm.assert_series_equal(result, expected)
10751075

1076+
df = DataFrame({"a": ["b", "c", "a"]}, dtype="string[pyarrow]")
1077+
result = df.idxmax(numeric_only=False)
1078+
expected = Series([1], index=["a"])
1079+
tm.assert_series_equal(result, expected)
1080+
1081+
result = df.idxmin(numeric_only=False)
1082+
expected = Series([2], index=["a"])
1083+
tm.assert_series_equal(result, expected)
1084+
10761085
def test_idxmax_axis_2(self, float_frame):
10771086
frame = float_frame
10781087
msg = "No axis named 2 for object type DataFrame"

0 commit comments

Comments
 (0)