Skip to content

ENH/TST: Add argsort/min/max for ArrowExtensionArray #47811

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Jul 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 57 additions & 1 deletion pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,12 @@
pa_version_under4p0,
pa_version_under5p0,
pa_version_under6p0,
pa_version_under7p0,
)
from pandas.util._decorators import (
deprecate_nonkeyword_arguments,
doc,
)
from pandas.util._decorators import doc

from pandas.core.dtypes.common import (
is_array_like,
Expand Down Expand Up @@ -418,6 +422,58 @@ def isna(self) -> npt.NDArray[np.bool_]:
else:
return self._data.is_null().to_numpy()

@deprecate_nonkeyword_arguments(version=None, allowed_args=["self"])
def argsort(
self,
ascending: bool = True,
kind: str = "quicksort",
na_position: str = "last",
*args,
**kwargs,
) -> np.ndarray:
order = "ascending" if ascending else "descending"
null_placement = {"last": "at_end", "first": "at_start"}.get(na_position, None)
if null_placement is None or pa_version_under7p0:
# Although pc.array_sort_indices exists in version 6
# there's a bug that affects the pa.ChunkedArray backing
# https://issues.apache.org/jira/browse/ARROW-12042
fallback_performancewarning("7")
return super().argsort(
ascending=ascending, kind=kind, na_position=na_position
)

result = pc.array_sort_indices(
self._data, order=order, null_placement=null_placement
)
if pa_version_under2p0:
np_result = result.to_pandas().values
else:
np_result = result.to_numpy()
return np_result.astype(np.intp, copy=False)

def _argmin_max(self, skipna: bool, method: str) -> int:
if self._data.length() in (0, self._data.null_count) or (
self._hasna and not skipna
):
# For empty or all null, pyarrow returns -1 but pandas expects TypeError
# For skipna=False and data w/ null, pandas expects NotImplementedError
# let ExtensionArray.arg{max|min} raise
return getattr(super(), f"arg{method}")(skipna=skipna)

if pa_version_under6p0:
raise NotImplementedError(
f"arg{method} only implemented for pyarrow version >= 6.0"
)

value = getattr(pc, method)(self._data, skip_nulls=skipna)
return pc.index(self._data, value).as_py()

def argmin(self, skipna: bool = True) -> int:
return self._argmin_max(skipna, "min")

def argmax(self, skipna: bool = True) -> int:
return self._argmin_max(skipna, "max")

def copy(self: ArrowExtensionArrayT) -> ArrowExtensionArrayT:
"""
Return a shallow copy of the array.
Expand Down
47 changes: 47 additions & 0 deletions pandas/tests/extension/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1385,6 +1385,11 @@ def test_value_counts_with_normalize(self, data, request):
)
super().test_value_counts_with_normalize(data)

@pytest.mark.xfail(
pa_version_under6p0,
raises=NotImplementedError,
reason="argmin/max only implemented for pyarrow version >= 6.0",
)
def test_argmin_argmax(
self, data_for_sorting, data_missing_for_sorting, na_value, request
):
Expand All @@ -1395,8 +1400,50 @@ def test_argmin_argmax(
reason=f"{pa_dtype} only has 2 unique possible values",
)
)
elif pa.types.is_duration(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
raises=pa.ArrowNotImplementedError,
reason=f"min_max not supported in pyarrow for {pa_dtype}",
)
)
super().test_argmin_argmax(data_for_sorting, data_missing_for_sorting, na_value)

@pytest.mark.parametrize(
"op_name, skipna, expected",
[
("idxmax", True, 0),
("idxmin", True, 2),
("argmax", True, 0),
("argmin", True, 2),
("idxmax", False, np.nan),
("idxmin", False, np.nan),
("argmax", False, -1),
("argmin", False, -1),
],
)
def test_argreduce_series(
self, data_missing_for_sorting, op_name, skipna, expected, request
):
pa_dtype = data_missing_for_sorting.dtype.pyarrow_dtype
if pa_version_under6p0 and skipna:
request.node.add_marker(
pytest.mark.xfail(
raises=NotImplementedError,
reason="min_max not supported in pyarrow",
)
)
elif not pa_version_under6p0 and pa.types.is_duration(pa_dtype) and skipna:
request.node.add_marker(
pytest.mark.xfail(
raises=pa.ArrowNotImplementedError,
reason=f"min_max not supported in pyarrow for {pa_dtype}",
)
)
super().test_argreduce_series(
data_missing_for_sorting, op_name, skipna, expected
)

@pytest.mark.parametrize("ascending", [True, False])
def test_sort_values(self, data_for_sorting, ascending, sort_by_key, request):
pa_dtype = data_for_sorting.dtype.pyarrow_dtype
Expand Down
43 changes: 42 additions & 1 deletion pandas/tests/extension/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,48 @@ def test_reduce_series_numeric(self, data, all_numeric_reductions, skipna):


class TestMethods(base.BaseMethodsTests):
pass
def test_argmin_argmax(
self, data_for_sorting, data_missing_for_sorting, na_value, request
):
if pa_version_under6p0 and data_missing_for_sorting.dtype.storage == "pyarrow":
request.node.add_marker(
pytest.mark.xfail(
raises=NotImplementedError,
reason="min_max not supported in pyarrow",
)
)
super().test_argmin_argmax(data_for_sorting, data_missing_for_sorting, na_value)

@pytest.mark.parametrize(
"op_name, skipna, expected",
[
("idxmax", True, 0),
("idxmin", True, 2),
("argmax", True, 0),
("argmin", True, 2),
("idxmax", False, np.nan),
("idxmin", False, np.nan),
("argmax", False, -1),
("argmin", False, -1),
],
)
def test_argreduce_series(
self, data_missing_for_sorting, op_name, skipna, expected, request
):
if (
pa_version_under6p0
and data_missing_for_sorting.dtype.storage == "pyarrow"
and skipna
):
request.node.add_marker(
pytest.mark.xfail(
raises=NotImplementedError,
reason="min_max not supported in pyarrow",
)
)
super().test_argreduce_series(
data_missing_for_sorting, op_name, skipna, expected
)


class TestCasting(base.BaseCastingTests):
Expand Down
9 changes: 7 additions & 2 deletions pandas/tests/indexes/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from pandas.compat import (
IS64,
pa_version_under2p0,
pa_version_under7p0,
)

from pandas.core.dtypes.common import is_integer_dtype
Expand Down Expand Up @@ -396,11 +396,16 @@ def test_astype_preserves_name(self, index, dtype):
# imaginary components discarded
warn = np.ComplexWarning

is_pyarrow_str = (
str(index.dtype) == "string[pyarrow]"
and pa_version_under7p0
and dtype == "category"
)
try:
# Some of these conversions cannot succeed so we use a try / except
with tm.assert_produces_warning(
warn,
raise_on_extra_warnings=not pa_version_under2p0,
raise_on_extra_warnings=is_pyarrow_str,
):
result = index.astype(dtype)
except (ValueError, TypeError, NotImplementedError, SystemError):
Expand Down
5 changes: 4 additions & 1 deletion pandas/tests/indexes/test_setops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import numpy as np
import pytest

from pandas.compat import pa_version_under7p0

from pandas.core.dtypes.cast import find_common_type

from pandas import (
Expand Down Expand Up @@ -177,7 +179,8 @@ def test_dunder_inplace_setops_deprecated(index):
with tm.assert_produces_warning(FutureWarning):
index &= index

with tm.assert_produces_warning(FutureWarning):
is_pyarrow = str(index.dtype) == "string[pyarrow]" and pa_version_under7p0
with tm.assert_produces_warning(FutureWarning, raise_on_extra_warnings=is_pyarrow):
index ^= index


Expand Down