Skip to content

Commit 1497bf2

Browse files
authored
ENH/TST: Add argsort/min/max for ArrowExtensionArray (#47811)
1 parent f171c96 commit 1497bf2

File tree

5 files changed

+157
-5
lines changed

5 files changed

+157
-5
lines changed

pandas/core/arrays/arrow/array.py

+57-1
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,12 @@
2222
pa_version_under4p0,
2323
pa_version_under5p0,
2424
pa_version_under6p0,
25+
pa_version_under7p0,
26+
)
27+
from pandas.util._decorators import (
28+
deprecate_nonkeyword_arguments,
29+
doc,
2530
)
26-
from pandas.util._decorators import doc
2731

2832
from pandas.core.dtypes.common import (
2933
is_array_like,
@@ -418,6 +422,58 @@ def isna(self) -> npt.NDArray[np.bool_]:
418422
else:
419423
return self._data.is_null().to_numpy()
420424

425+
@deprecate_nonkeyword_arguments(version=None, allowed_args=["self"])
426+
def argsort(
427+
self,
428+
ascending: bool = True,
429+
kind: str = "quicksort",
430+
na_position: str = "last",
431+
*args,
432+
**kwargs,
433+
) -> np.ndarray:
434+
order = "ascending" if ascending else "descending"
435+
null_placement = {"last": "at_end", "first": "at_start"}.get(na_position, None)
436+
if null_placement is None or pa_version_under7p0:
437+
# Although pc.array_sort_indices exists in version 6
438+
# there's a bug that affects the pa.ChunkedArray backing
439+
# https://issues.apache.org/jira/browse/ARROW-12042
440+
fallback_performancewarning("7")
441+
return super().argsort(
442+
ascending=ascending, kind=kind, na_position=na_position
443+
)
444+
445+
result = pc.array_sort_indices(
446+
self._data, order=order, null_placement=null_placement
447+
)
448+
if pa_version_under2p0:
449+
np_result = result.to_pandas().values
450+
else:
451+
np_result = result.to_numpy()
452+
return np_result.astype(np.intp, copy=False)
453+
454+
def _argmin_max(self, skipna: bool, method: str) -> int:
455+
if self._data.length() in (0, self._data.null_count) or (
456+
self._hasna and not skipna
457+
):
458+
# For empty or all null, pyarrow returns -1 but pandas expects TypeError
459+
# For skipna=False and data w/ null, pandas expects NotImplementedError
460+
# let ExtensionArray.arg{max|min} raise
461+
return getattr(super(), f"arg{method}")(skipna=skipna)
462+
463+
if pa_version_under6p0:
464+
raise NotImplementedError(
465+
f"arg{method} only implemented for pyarrow version >= 6.0"
466+
)
467+
468+
value = getattr(pc, method)(self._data, skip_nulls=skipna)
469+
return pc.index(self._data, value).as_py()
470+
471+
def argmin(self, skipna: bool = True) -> int:
472+
return self._argmin_max(skipna, "min")
473+
474+
def argmax(self, skipna: bool = True) -> int:
475+
return self._argmin_max(skipna, "max")
476+
421477
def copy(self: ArrowExtensionArrayT) -> ArrowExtensionArrayT:
422478
"""
423479
Return a shallow copy of the array.

pandas/tests/extension/test_arrow.py

+47
Original file line numberDiff line numberDiff line change
@@ -1385,6 +1385,11 @@ def test_value_counts_with_normalize(self, data, request):
13851385
)
13861386
super().test_value_counts_with_normalize(data)
13871387

1388+
@pytest.mark.xfail(
1389+
pa_version_under6p0,
1390+
raises=NotImplementedError,
1391+
reason="argmin/max only implemented for pyarrow version >= 6.0",
1392+
)
13881393
def test_argmin_argmax(
13891394
self, data_for_sorting, data_missing_for_sorting, na_value, request
13901395
):
@@ -1395,8 +1400,50 @@ def test_argmin_argmax(
13951400
reason=f"{pa_dtype} only has 2 unique possible values",
13961401
)
13971402
)
1403+
elif pa.types.is_duration(pa_dtype):
1404+
request.node.add_marker(
1405+
pytest.mark.xfail(
1406+
raises=pa.ArrowNotImplementedError,
1407+
reason=f"min_max not supported in pyarrow for {pa_dtype}",
1408+
)
1409+
)
13981410
super().test_argmin_argmax(data_for_sorting, data_missing_for_sorting, na_value)
13991411

1412+
@pytest.mark.parametrize(
1413+
"op_name, skipna, expected",
1414+
[
1415+
("idxmax", True, 0),
1416+
("idxmin", True, 2),
1417+
("argmax", True, 0),
1418+
("argmin", True, 2),
1419+
("idxmax", False, np.nan),
1420+
("idxmin", False, np.nan),
1421+
("argmax", False, -1),
1422+
("argmin", False, -1),
1423+
],
1424+
)
1425+
def test_argreduce_series(
1426+
self, data_missing_for_sorting, op_name, skipna, expected, request
1427+
):
1428+
pa_dtype = data_missing_for_sorting.dtype.pyarrow_dtype
1429+
if pa_version_under6p0 and skipna:
1430+
request.node.add_marker(
1431+
pytest.mark.xfail(
1432+
raises=NotImplementedError,
1433+
reason="min_max not supported in pyarrow",
1434+
)
1435+
)
1436+
elif not pa_version_under6p0 and pa.types.is_duration(pa_dtype) and skipna:
1437+
request.node.add_marker(
1438+
pytest.mark.xfail(
1439+
raises=pa.ArrowNotImplementedError,
1440+
reason=f"min_max not supported in pyarrow for {pa_dtype}",
1441+
)
1442+
)
1443+
super().test_argreduce_series(
1444+
data_missing_for_sorting, op_name, skipna, expected
1445+
)
1446+
14001447
@pytest.mark.parametrize("ascending", [True, False])
14011448
def test_sort_values(self, data_for_sorting, ascending, sort_by_key, request):
14021449
pa_dtype = data_for_sorting.dtype.pyarrow_dtype

pandas/tests/extension/test_string.py

+42-1
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,48 @@ def test_reduce_series_numeric(self, data, all_numeric_reductions, skipna):
167167

168168

169169
class TestMethods(base.BaseMethodsTests):
170-
pass
170+
def test_argmin_argmax(
171+
self, data_for_sorting, data_missing_for_sorting, na_value, request
172+
):
173+
if pa_version_under6p0 and data_missing_for_sorting.dtype.storage == "pyarrow":
174+
request.node.add_marker(
175+
pytest.mark.xfail(
176+
raises=NotImplementedError,
177+
reason="min_max not supported in pyarrow",
178+
)
179+
)
180+
super().test_argmin_argmax(data_for_sorting, data_missing_for_sorting, na_value)
181+
182+
@pytest.mark.parametrize(
183+
"op_name, skipna, expected",
184+
[
185+
("idxmax", True, 0),
186+
("idxmin", True, 2),
187+
("argmax", True, 0),
188+
("argmin", True, 2),
189+
("idxmax", False, np.nan),
190+
("idxmin", False, np.nan),
191+
("argmax", False, -1),
192+
("argmin", False, -1),
193+
],
194+
)
195+
def test_argreduce_series(
196+
self, data_missing_for_sorting, op_name, skipna, expected, request
197+
):
198+
if (
199+
pa_version_under6p0
200+
and data_missing_for_sorting.dtype.storage == "pyarrow"
201+
and skipna
202+
):
203+
request.node.add_marker(
204+
pytest.mark.xfail(
205+
raises=NotImplementedError,
206+
reason="min_max not supported in pyarrow",
207+
)
208+
)
209+
super().test_argreduce_series(
210+
data_missing_for_sorting, op_name, skipna, expected
211+
)
171212

172213

173214
class TestCasting(base.BaseCastingTests):

pandas/tests/indexes/test_common.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from pandas.compat import (
1212
IS64,
13-
pa_version_under2p0,
13+
pa_version_under7p0,
1414
)
1515

1616
from pandas.core.dtypes.common import is_integer_dtype
@@ -396,11 +396,16 @@ def test_astype_preserves_name(self, index, dtype):
396396
# imaginary components discarded
397397
warn = np.ComplexWarning
398398

399+
is_pyarrow_str = (
400+
str(index.dtype) == "string[pyarrow]"
401+
and pa_version_under7p0
402+
and dtype == "category"
403+
)
399404
try:
400405
# Some of these conversions cannot succeed so we use a try / except
401406
with tm.assert_produces_warning(
402407
warn,
403-
raise_on_extra_warnings=not pa_version_under2p0,
408+
raise_on_extra_warnings=is_pyarrow_str,
404409
):
405410
result = index.astype(dtype)
406411
except (ValueError, TypeError, NotImplementedError, SystemError):

pandas/tests/indexes/test_setops.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import numpy as np
99
import pytest
1010

11+
from pandas.compat import pa_version_under7p0
12+
1113
from pandas.core.dtypes.cast import find_common_type
1214

1315
from pandas import (
@@ -177,7 +179,8 @@ def test_dunder_inplace_setops_deprecated(index):
177179
with tm.assert_produces_warning(FutureWarning):
178180
index &= index
179181

180-
with tm.assert_produces_warning(FutureWarning):
182+
is_pyarrow = str(index.dtype) == "string[pyarrow]" and pa_version_under7p0
183+
with tm.assert_produces_warning(FutureWarning, raise_on_extra_warnings=is_pyarrow):
181184
index ^= index
182185

183186

0 commit comments

Comments
 (0)