Skip to content

Commit 68e3c4b

Browse files
authored
BUG: idxmax raising for arrow strings (#55384)
1 parent 10cf330 commit 68e3c4b

File tree

3 files changed

+30
-1
lines changed

3 files changed

+30
-1
lines changed

pandas/core/arrays/arrow/array.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -1627,6 +1627,15 @@ def _reduce(
16271627
------
16281628
TypeError : subclass does not define reductions
16291629
"""
1630+
result = self._reduce_calc(name, skipna=skipna, keepdims=keepdims, **kwargs)
1631+
if isinstance(result, pa.Array):
1632+
return type(self)(result)
1633+
else:
1634+
return result
1635+
1636+
def _reduce_calc(
1637+
self, name: str, *, skipna: bool = True, keepdims: bool = False, **kwargs
1638+
):
16301639
pa_result = self._reduce_pyarrow(name, skipna=skipna, **kwargs)
16311640

16321641
if keepdims:
@@ -1637,7 +1646,7 @@ def _reduce(
16371646
[pa_result],
16381647
type=to_pyarrow_type(infer_dtype_from_scalar(pa_result)[0]),
16391648
)
1640-
return type(self)(result)
1649+
return result
16411650

16421651
if pc.is_null(pa_result).as_py():
16431652
return self.dtype.na_value

pandas/core/arrays/string_arrow.py

+11
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,17 @@ def _str_find(self, sub: str, start: int = 0, end: int | None = None):
502502
def _convert_int_dtype(self, result):
503503
return Int64Dtype().__from_arrow__(result)
504504

505+
def _reduce(
506+
self, name: str, *, skipna: bool = True, keepdims: bool = False, **kwargs
507+
):
508+
result = self._reduce_calc(name, skipna=skipna, keepdims=keepdims, **kwargs)
509+
if name in ("argmin", "argmax") and isinstance(result, pa.Array):
510+
return self._convert_int_dtype(result)
511+
elif isinstance(result, pa.Array):
512+
return type(self)(result)
513+
else:
514+
return result
515+
505516
def _rank(
506517
self,
507518
*,

pandas/tests/frame/test_reductions.py

+9
Original file line numberDiff line numberDiff line change
@@ -1073,6 +1073,15 @@ def test_idxmax_arrow_types(self):
10731073
expected = Series([2, 1], index=["a", "b"])
10741074
tm.assert_series_equal(result, expected)
10751075

1076+
df = DataFrame({"a": ["b", "c", "a"]}, dtype="string[pyarrow]")
1077+
result = df.idxmax(numeric_only=False)
1078+
expected = Series([1], index=["a"])
1079+
tm.assert_series_equal(result, expected)
1080+
1081+
result = df.idxmin(numeric_only=False)
1082+
expected = Series([2], index=["a"])
1083+
tm.assert_series_equal(result, expected)
1084+
10761085
def test_idxmax_axis_2(self, float_frame):
10771086
frame = float_frame
10781087
msg = "No axis named 2 for object type DataFrame"

0 commit comments

Comments
 (0)