From 37db5c087d92e61285d16234130ccaafb6ca2eb2 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke Date: Wed, 20 Jul 2022 13:36:24 -0700 Subject: [PATCH 01/13] add arg functions --- pandas/core/arrays/arrow/array.py | 55 ++++++++++++++++++++++++++++++- 1 file changed, 54 insertions(+), 1 deletion(-) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 8957ea493e9ad..7fc9b613ac848 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -22,7 +22,10 @@ pa_version_under5p0, pa_version_under6p0, ) -from pandas.util._decorators import doc +from pandas.util._decorators import ( + deprecate_nonkeyword_arguments, + doc, +) from pandas.core.dtypes.common import ( is_array_like, @@ -399,6 +402,56 @@ def isna(self) -> npt.NDArray[np.bool_]: else: return self._data.is_null().to_numpy() + def _values_for_argsort(self) -> np.ndarray: + if pa_version_under2p0: + return self._data.to_pandas().values + else: + return self._data.to_numpy() + + @deprecate_nonkeyword_arguments(version=None, allowed_args=["self"]) + def argsort( + self, + ascending: bool = True, + kind: str = "quicksort", + na_position: str = "last", + *args, + **kwargs, + ) -> np.ndarray: + if pa_version_under6p0: + raise NotImplementedError( + "argsort only implemented for pyarrow version >= 6.0" + ) + + order = "ascending" if ascending else "descending" + try: + null_placement = {"last": "at_end", "first": "at_start"}[na_position] + except KeyError as err: + raise ValueError( + f"na_position must be 'last' or 'first'. Got {na_position}" + ) from err + result = pc.array_sort_indices( + self._data, order=order, null_placement=null_placement + ) + if pa_version_under2p0: + return result.to_pandas().values + else: + return result.to_numpy() + + def _argmin_max(self, skipna: bool, method: str) -> int: + if pa_version_under6p0: + raise NotImplementedError( + f"arg{method} only implemented for pyarrow version >= 6.0" + ) + + value = getattr(pc, method)(self._data, skip_nulls=skipna) + return pc.index(self._data, value).as_py() + + def argmin(self, skipna: bool = True) -> int: + return self._argmin_max(skipna, "min") + + def argmax(self, skipna: bool = True) -> int: + return self._argmin_max(skipna, "max") + def copy(self: ArrowExtensionArrayT) -> ArrowExtensionArrayT: """ Return a shallow copy of the array. From a40a6b15855770c947be7f6e0b0af7967ae66eab Mon Sep 17 00:00:00 2001 From: Matthew Roeschke Date: Wed, 20 Jul 2022 18:06:51 -0700 Subject: [PATCH 02/13] finish testing --- pandas/core/arrays/arrow/array.py | 27 +++++++---- pandas/tests/extension/test_arrow.py | 70 ++++++++++++++++++++++++++++ 2 files changed, 87 insertions(+), 10 deletions(-) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 7fc9b613ac848..fcb2942e5dd01 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -417,27 +417,34 @@ def argsort( *args, **kwargs, ) -> np.ndarray: + order = "ascending" if ascending else "descending" + null_placement = {"last": "at_end", "first": "at_start"}.get(na_position, None) + if null_placement is None: + raise ValueError( + f"na_position must be 'last' or 'first'. Got {na_position}" + ) if pa_version_under6p0: raise NotImplementedError( "argsort only implemented for pyarrow version >= 6.0" ) - - order = "ascending" if ascending else "descending" - try: - null_placement = {"last": "at_end", "first": "at_start"}[na_position] - except KeyError as err: - raise ValueError( - f"na_position must be 'last' or 'first'. Got {na_position}" - ) from err result = pc.array_sort_indices( self._data, order=order, null_placement=null_placement ) if pa_version_under2p0: - return result.to_pandas().values + np_result = result.to_pandas().values else: - return result.to_numpy() + np_result = result.to_numpy() + return np_result.astype(np.intp, copy=False) def _argmin_max(self, skipna: bool, method: str) -> int: + if self._data.length() in (0, self._data.null_count) or ( + self._hasna and not skipna + ): + # For empty or all null, pyarrow returns -1 but pandas expects TypeError + # For skipna=False and data w/ null, pandas expects NotImplementedError + # let ExtensionArray.arg{max|min} raise + return getattr(super(), f"arg{method}")(skipna) + if pa_version_under6p0: raise NotImplementedError( f"arg{method} only implemented for pyarrow version >= 6.0" diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index ef576692c83b6..ab29b1ddf71a2 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -24,6 +24,7 @@ from pandas.compat import ( pa_version_under2p0, pa_version_under3p0, + pa_version_under6p0, pa_version_under8p0, ) @@ -1277,6 +1278,35 @@ def test_value_counts_with_normalize(self, data, request): ) super().test_value_counts_with_normalize(data) + @pytest.mark.xfail( + pa_version_under6p0, + raises=NotImplementedError, + reason="argsort only implemented for pyarrow version >= 6.0", + ) + def test_argsort(self, data_for_sorting): + super().test_argsort(data_for_sorting) + + @pytest.mark.xfail( + pa_version_under6p0, + raises=NotImplementedError, + reason="argsort only implemented for pyarrow version >= 6.0", + ) + def test_argsort_missing_array(self, data_missing_for_sorting): + super().test_argsort_missing_array(data_missing_for_sorting) + + @pytest.mark.xfail( + pa_version_under6p0, + raises=NotImplementedError, + reason="argsort only implemented for pyarrow version >= 6.0", + ) + def test_argsort_missing(self, data_missing_for_sorting): + super().test_argsort_missing(data_missing_for_sorting) + + @pytest.mark.xfail( + pa_version_under6p0, + raises=NotImplementedError, + reason="argsort only implemented for pyarrow version >= 6.0", + ) def test_argmin_argmax( self, data_for_sorting, data_missing_for_sorting, na_value, request ): @@ -1287,8 +1317,48 @@ def test_argmin_argmax( reason=f"{pa_dtype} only has 2 unique possible values", ) ) + elif pa.types.is_duration(pa_dtype): + request.node.add_marker( + pytest.mark.xfail( + raises=pa.ArrowNotImplementedError, + reason=f"min_max not supported in pyarrow for {pa_dtype}", + ) + ) super().test_argmin_argmax(data_for_sorting, data_missing_for_sorting, na_value) + @pytest.mark.xfail( + pa_version_under6p0, + raises=NotImplementedError, + reason="argsort only implemented for pyarrow version >= 6.0", + ) + @pytest.mark.parametrize( + "op_name, skipna, expected", + [ + ("idxmax", True, 0), + ("idxmin", True, 2), + ("argmax", True, 0), + ("argmin", True, 2), + ("idxmax", False, np.nan), + ("idxmin", False, np.nan), + ("argmax", False, -1), + ("argmin", False, -1), + ], + ) + def test_argreduce_series( + self, data_missing_for_sorting, op_name, skipna, expected, request + ): + pa_dtype = data_missing_for_sorting.dtype.pyarrow_dtype + if pa.types.is_duration(pa_dtype) and skipna: + request.node.add_marker( + pytest.mark.xfail( + raises=pa.ArrowNotImplementedError, + reason=f"min_max not supported in pyarrow for {pa_dtype}", + ) + ) + super().test_argreduce_series( + data_missing_for_sorting, op_name, skipna, expected + ) + @pytest.mark.parametrize("ascending", [True, False]) def test_sort_values(self, data_for_sorting, ascending, sort_by_key, request): pa_dtype = data_for_sorting.dtype.pyarrow_dtype From 789f8161e885ee98dcbad4dbeceb5ec9fe93771a Mon Sep 17 00:00:00 2001 From: Matthew Roeschke Date: Thu, 21 Jul 2022 11:07:48 -0700 Subject: [PATCH 03/13] Let argsort fallback --- pandas/core/arrays/arrow/array.py | 12 ++++-------- pandas/tests/extension/test_arrow.py | 28 ++-------------------------- 2 files changed, 6 insertions(+), 34 deletions(-) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index fcb2942e5dd01..c6b01ffafbe01 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -419,14 +419,10 @@ def argsort( ) -> np.ndarray: order = "ascending" if ascending else "descending" null_placement = {"last": "at_end", "first": "at_start"}.get(na_position, None) - if null_placement is None: - raise ValueError( - f"na_position must be 'last' or 'first'. Got {na_position}" - ) - if pa_version_under6p0: - raise NotImplementedError( - "argsort only implemented for pyarrow version >= 6.0" - ) + if null_placement is None or pa_version_under6p0: + fallback_performancewarning("6") + return super().argsort(ascending, kind, na_position, *args, **kwargs) + result = pc.array_sort_indices( self._data, order=order, null_placement=null_placement ) diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index ab29b1ddf71a2..be649b7cd401c 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -1281,31 +1281,7 @@ def test_value_counts_with_normalize(self, data, request): @pytest.mark.xfail( pa_version_under6p0, raises=NotImplementedError, - reason="argsort only implemented for pyarrow version >= 6.0", - ) - def test_argsort(self, data_for_sorting): - super().test_argsort(data_for_sorting) - - @pytest.mark.xfail( - pa_version_under6p0, - raises=NotImplementedError, - reason="argsort only implemented for pyarrow version >= 6.0", - ) - def test_argsort_missing_array(self, data_missing_for_sorting): - super().test_argsort_missing_array(data_missing_for_sorting) - - @pytest.mark.xfail( - pa_version_under6p0, - raises=NotImplementedError, - reason="argsort only implemented for pyarrow version >= 6.0", - ) - def test_argsort_missing(self, data_missing_for_sorting): - super().test_argsort_missing(data_missing_for_sorting) - - @pytest.mark.xfail( - pa_version_under6p0, - raises=NotImplementedError, - reason="argsort only implemented for pyarrow version >= 6.0", + reason="argmin/max only implemented for pyarrow version >= 6.0", ) def test_argmin_argmax( self, data_for_sorting, data_missing_for_sorting, na_value, request @@ -1329,7 +1305,7 @@ def test_argmin_argmax( @pytest.mark.xfail( pa_version_under6p0, raises=NotImplementedError, - reason="argsort only implemented for pyarrow version >= 6.0", + reason="argmin/max only implemented for pyarrow version >= 6.0", ) @pytest.mark.parametrize( "op_name, skipna, expected", From 5e4abbade280eed0289850e7f0a11142f913c2f6 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke Date: Thu, 21 Jul 2022 13:21:09 -0700 Subject: [PATCH 04/13] Address test failure due to pyarrow bug --- pandas/tests/base/test_value_counts.py | 10 ++- pandas/tests/extension/test_string.py | 96 +++++++++++++++++++++++++- pandas/tests/indexes/test_setops.py | 18 ++++- 3 files changed, 119 insertions(+), 5 deletions(-) diff --git a/pandas/tests/base/test_value_counts.py b/pandas/tests/base/test_value_counts.py index 55a6cc48ebfc8..7eb7ea0f1e3be 100644 --- a/pandas/tests/base/test_value_counts.py +++ b/pandas/tests/base/test_value_counts.py @@ -4,6 +4,8 @@ import numpy as np import pytest +from pandas.compat import pa_version_under7p0 + import pandas as pd from pandas import ( DatetimeIndex, @@ -42,7 +44,7 @@ def test_value_counts(index_or_series_obj): @pytest.mark.parametrize("null_obj", [np.nan, None]) -def test_value_counts_null(null_obj, index_or_series_obj): +def test_value_counts_null(null_obj, index_or_series_obj, request): orig = index_or_series_obj obj = orig.copy() @@ -52,6 +54,12 @@ def test_value_counts_null(null_obj, index_or_series_obj): pytest.skip("Test doesn't make sense on empty data") elif isinstance(orig, pd.MultiIndex): pytest.skip(f"MultiIndex can't hold '{null_obj}'") + elif pa_version_under7p0 and orig.dtype == "string[pyarrow]": + request.node.add_marker( + pytest.mark.xfail( + reason="https://issues.apache.org/jira/browse/ARROW-12042" + ) + ) values = obj._values values[0:2] = null_obj diff --git a/pandas/tests/extension/test_string.py b/pandas/tests/extension/test_string.py index 6cea21b6672d8..794943cc2cd9e 100644 --- a/pandas/tests/extension/test_string.py +++ b/pandas/tests/extension/test_string.py @@ -18,7 +18,10 @@ import numpy as np import pytest -from pandas.compat import pa_version_under6p0 +from pandas.compat import ( + pa_version_under6p0, + pa_version_under7p0, +) from pandas.errors import PerformanceWarning import pandas as pd @@ -167,7 +170,96 @@ def test_reduce_series_numeric(self, data, all_numeric_reductions, skipna): class TestMethods(base.BaseMethodsTests): - pass + def test_argsort(self, data_for_sorting, request): + if ( + pa_version_under7p0 + and data_for_sorting.dtype == "string[pyarrow]" + and "True" in request.node.name + ): + request.node.add_marker( + pytest.mark.xfail( + reason="https://issues.apache.org/jira/browse/ARROW-12042" + ) + ) + super().test_argsort(data_for_sorting) + + def test_argsort_missing_array(self, data_missing_for_sorting, request): + if ( + pa_version_under7p0 + and data_missing_for_sorting.dtype == "string[pyarrow]" + and "True" in request.node.name + ): + request.node.add_marker( + pytest.mark.xfail( + reason="https://issues.apache.org/jira/browse/ARROW-12042" + ) + ) + super().test_argsort_missing_array(data_missing_for_sorting) + + def test_argsort_missing(self, data_missing_for_sorting, request): + if ( + pa_version_under7p0 + and data_missing_for_sorting.dtype == "string[pyarrow]" + and "True" in request.node.name + ): + request.node.add_marker( + pytest.mark.xfail( + reason="https://issues.apache.org/jira/browse/ARROW-12042" + ) + ) + super().test_argsort_missing(data_missing_for_sorting) + + @pytest.mark.parametrize( + "na_position, expected", + [ + ("last", np.array([2, 0, 1], dtype=np.dtype("intp"))), + ("first", np.array([1, 2, 0], dtype=np.dtype("intp"))), + ], + ) + def test_nargsort(self, data_missing_for_sorting, na_position, expected, request): + if ( + pa_version_under7p0 + and data_missing_for_sorting.dtype == "string[pyarrow]" + and "True" in request.node.name + ): + request.node.add_marker( + pytest.mark.xfail( + reason="https://issues.apache.org/jira/browse/ARROW-12042" + ) + ) + super().test_nargsort(data_missing_for_sorting, na_position, expected) + + @pytest.mark.parametrize("ascending", [True, False]) + def test_sort_values(self, data_for_sorting, ascending, sort_by_key, request): + if ( + pa_version_under7p0 + and data_for_sorting.dtype == "string[pyarrow]" + and "-True-" in request.node.name + ): + request.node.add_marker( + pytest.mark.xfail( + reason="https://issues.apache.org/jira/browse/ARROW-12042" + ) + ) + super().test_sort_values(data_for_sorting, ascending, sort_by_key) + + @pytest.mark.parametrize("ascending", [True, False]) + def test_sort_values_missing( + self, data_missing_for_sorting, ascending, sort_by_key, request + ): + if ( + pa_version_under7p0 + and data_missing_for_sorting.dtype == "string[pyarrow]" + and "-True-" in request.node.name + ): + request.node.add_marker( + pytest.mark.xfail( + reason="https://issues.apache.org/jira/browse/ARROW-12042" + ) + ) + super().test_sort_values_missing( + data_missing_for_sorting, ascending, sort_by_key + ) class TestCasting(base.BaseCastingTests): diff --git a/pandas/tests/indexes/test_setops.py b/pandas/tests/indexes/test_setops.py index f38a6c89e1bcb..9dfa8112d7a73 100644 --- a/pandas/tests/indexes/test_setops.py +++ b/pandas/tests/indexes/test_setops.py @@ -8,6 +8,8 @@ import numpy as np import pytest +from pandas.compat import pa_version_under7p0 + from pandas.core.dtypes.cast import find_common_type from pandas import ( @@ -228,7 +230,13 @@ def test_intersection_base(self, index): with pytest.raises(TypeError, match=msg): first.intersection([1, 2, 3]) - def test_union_base(self, index): + def test_union_base(self, index, request): + if pa_version_under7p0 and index.dtype == "string[pyarrow]": + request.node.add_marker( + pytest.mark.xfail( + reason="https://issues.apache.org/jira/browse/ARROW-12042" + ) + ) first = index[3:] second = index[:5] everything = index @@ -277,7 +285,13 @@ def test_difference_base(self, sort, index): with pytest.raises(TypeError, match=msg): first.difference([1, 2, 3], sort) - def test_symmetric_difference(self, index): + def test_symmetric_difference(self, index, request): + if pa_version_under7p0 and index.dtype == "string[pyarrow]": + request.node.add_marker( + pytest.mark.xfail( + reason="https://issues.apache.org/jira/browse/ARROW-12042" + ) + ) if isinstance(index, CategoricalIndex): return if len(index) < 2: From 8512afd7b23e0872fb81352af4face859325af8d Mon Sep 17 00:00:00 2001 From: Matthew Roeschke Date: Thu, 21 Jul 2022 18:08:54 -0700 Subject: [PATCH 05/13] Add more conditions --- pandas/tests/base/test_value_counts.py | 11 +++++-- pandas/tests/extension/test_string.py | 45 ++++++++++++++++++++++++++ pandas/tests/indexes/test_setops.py | 26 ++++++++++++--- 3 files changed, 76 insertions(+), 6 deletions(-) diff --git a/pandas/tests/base/test_value_counts.py b/pandas/tests/base/test_value_counts.py index 7eb7ea0f1e3be..bf02b1dbdd2b5 100644 --- a/pandas/tests/base/test_value_counts.py +++ b/pandas/tests/base/test_value_counts.py @@ -4,7 +4,10 @@ import numpy as np import pytest -from pandas.compat import pa_version_under7p0 +from pandas.compat import ( + pa_version_under6p0, + pa_version_under7p0, +) import pandas as pd from pandas import ( @@ -54,7 +57,11 @@ def test_value_counts_null(null_obj, index_or_series_obj, request): pytest.skip("Test doesn't make sense on empty data") elif isinstance(orig, pd.MultiIndex): pytest.skip(f"MultiIndex can't hold '{null_obj}'") - elif pa_version_under7p0 and orig.dtype == "string[pyarrow]": + elif ( + pa_version_under7p0 + and orig.dtype == "string[pyarrow]" + and not pa_version_under6p0 + ): request.node.add_marker( pytest.mark.xfail( reason="https://issues.apache.org/jira/browse/ARROW-12042" diff --git a/pandas/tests/extension/test_string.py b/pandas/tests/extension/test_string.py index 794943cc2cd9e..49368680908be 100644 --- a/pandas/tests/extension/test_string.py +++ b/pandas/tests/extension/test_string.py @@ -170,9 +170,49 @@ def test_reduce_series_numeric(self, data, all_numeric_reductions, skipna): class TestMethods(base.BaseMethodsTests): + def test_argmin_argmax( + self, data_for_sorting, data_missing_for_sorting, na_value, request + ): + if pa_version_under6p0 and data_for_sorting.dtype == "string[pyarrow]": + request.node.add_marker( + pytest.mark.xfail( + raises=NotImplementedError, + reason="min_max not supported in pyarrow", + ) + ) + super().test_argmin_argmax(data_for_sorting, data_missing_for_sorting, na_value) + + @pytest.mark.parametrize( + "op_name, skipna, expected", + [ + ("idxmax", True, 0), + ("idxmin", True, 2), + ("argmax", True, 0), + ("argmin", True, 2), + ("idxmax", False, np.nan), + ("idxmin", False, np.nan), + ("argmax", False, -1), + ("argmin", False, -1), + ], + ) + def test_argreduce_series( + self, data_missing_for_sorting, op_name, skipna, expected, request + ): + if pa_version_under6p0 and data_missing_for_sorting.dtype == "string[pyarrow]": + request.node.add_marker( + pytest.mark.xfail( + raises=NotImplementedError, + reason="min_max not supported in pyarrow", + ) + ) + super().test_argreduce_series( + data_missing_for_sorting, op_name, skipna, expected + ) + def test_argsort(self, data_for_sorting, request): if ( pa_version_under7p0 + and not pa_version_under6p0 and data_for_sorting.dtype == "string[pyarrow]" and "True" in request.node.name ): @@ -186,6 +226,7 @@ def test_argsort(self, data_for_sorting, request): def test_argsort_missing_array(self, data_missing_for_sorting, request): if ( pa_version_under7p0 + and not pa_version_under6p0 and data_missing_for_sorting.dtype == "string[pyarrow]" and "True" in request.node.name ): @@ -199,6 +240,7 @@ def test_argsort_missing_array(self, data_missing_for_sorting, request): def test_argsort_missing(self, data_missing_for_sorting, request): if ( pa_version_under7p0 + and not pa_version_under6p0 and data_missing_for_sorting.dtype == "string[pyarrow]" and "True" in request.node.name ): @@ -219,6 +261,7 @@ def test_argsort_missing(self, data_missing_for_sorting, request): def test_nargsort(self, data_missing_for_sorting, na_position, expected, request): if ( pa_version_under7p0 + and not pa_version_under6p0 and data_missing_for_sorting.dtype == "string[pyarrow]" and "True" in request.node.name ): @@ -233,6 +276,7 @@ def test_nargsort(self, data_missing_for_sorting, na_position, expected, request def test_sort_values(self, data_for_sorting, ascending, sort_by_key, request): if ( pa_version_under7p0 + and not pa_version_under6p0 and data_for_sorting.dtype == "string[pyarrow]" and "-True-" in request.node.name ): @@ -249,6 +293,7 @@ def test_sort_values_missing( ): if ( pa_version_under7p0 + and not pa_version_under6p0 and data_missing_for_sorting.dtype == "string[pyarrow]" and "-True-" in request.node.name ): diff --git a/pandas/tests/indexes/test_setops.py b/pandas/tests/indexes/test_setops.py index 9dfa8112d7a73..4f5bc46e999bd 100644 --- a/pandas/tests/indexes/test_setops.py +++ b/pandas/tests/indexes/test_setops.py @@ -8,7 +8,10 @@ import numpy as np import pytest -from pandas.compat import pa_version_under7p0 +from pandas.compat import ( + pa_version_under6p0, + pa_version_under7p0, +) from pandas.core.dtypes.cast import find_common_type @@ -179,7 +182,14 @@ def test_dunder_inplace_setops_deprecated(index): with tm.assert_produces_warning(FutureWarning): index &= index - with tm.assert_produces_warning(FutureWarning): + is_pyarrow = ( + index.dtype == "string[pyarrow]" + and pa_version_under7p0 + and not pa_version_under6p0 + ) + with tm.assert_produces_warning( + FutureWarning, raise_on_extra_warnings=not is_pyarrow + ): index ^= index @@ -231,7 +241,11 @@ def test_intersection_base(self, index): first.intersection([1, 2, 3]) def test_union_base(self, index, request): - if pa_version_under7p0 and index.dtype == "string[pyarrow]": + if ( + pa_version_under7p0 + and index.dtype == "string[pyarrow]" + and pa_version_under6p0 + ): request.node.add_marker( pytest.mark.xfail( reason="https://issues.apache.org/jira/browse/ARROW-12042" @@ -286,7 +300,11 @@ def test_difference_base(self, sort, index): first.difference([1, 2, 3], sort) def test_symmetric_difference(self, index, request): - if pa_version_under7p0 and index.dtype == "string[pyarrow]": + if ( + pa_version_under7p0 + and index.dtype == "string[pyarrow]" + and not pa_version_under6p0 + ): request.node.add_marker( pytest.mark.xfail( reason="https://issues.apache.org/jira/browse/ARROW-12042" From c01a30ef6f190f0b9861022cc344907ffb9816fb Mon Sep 17 00:00:00 2001 From: Matthew Roeschke Date: Fri, 22 Jul 2022 11:37:37 -0700 Subject: [PATCH 06/13] Fix pa 6 condition --- pandas/tests/indexes/test_setops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/tests/indexes/test_setops.py b/pandas/tests/indexes/test_setops.py index 4f5bc46e999bd..12544428bbdb8 100644 --- a/pandas/tests/indexes/test_setops.py +++ b/pandas/tests/indexes/test_setops.py @@ -244,7 +244,7 @@ def test_union_base(self, index, request): if ( pa_version_under7p0 and index.dtype == "string[pyarrow]" - and pa_version_under6p0 + and not pa_version_under6p0 ): request.node.add_marker( pytest.mark.xfail( From ae43913185fe22f7b1311a5fedc58f293a4b2b75 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke Date: Fri, 22 Jul 2022 12:14:46 -0700 Subject: [PATCH 07/13] Address more xfail conditions --- pandas/tests/extension/test_arrow.py | 14 ++++++++------ pandas/tests/extension/test_string.py | 6 +++++- pandas/tests/indexes/test_common.py | 9 +++++++-- pandas/tests/indexes/test_setops.py | 6 +----- 4 files changed, 21 insertions(+), 14 deletions(-) diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index 1415db3bd4a2e..2f3482ddc4811 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -1409,11 +1409,6 @@ def test_argmin_argmax( ) super().test_argmin_argmax(data_for_sorting, data_missing_for_sorting, na_value) - @pytest.mark.xfail( - pa_version_under6p0, - raises=NotImplementedError, - reason="argmin/max only implemented for pyarrow version >= 6.0", - ) @pytest.mark.parametrize( "op_name, skipna, expected", [ @@ -1431,7 +1426,14 @@ def test_argreduce_series( self, data_missing_for_sorting, op_name, skipna, expected, request ): pa_dtype = data_missing_for_sorting.dtype.pyarrow_dtype - if pa.types.is_duration(pa_dtype) and skipna: + if pa_version_under6p0 and skipna: + request.node.add_marker( + pytest.mark.xfail( + raises=NotImplementedError, + reason="min_max not supported in pyarrow", + ) + ) + elif not pa_version_under6p0 and pa.types.is_duration(pa_dtype) and skipna: request.node.add_marker( pytest.mark.xfail( raises=pa.ArrowNotImplementedError, diff --git a/pandas/tests/extension/test_string.py b/pandas/tests/extension/test_string.py index 49368680908be..0e27533a1c28f 100644 --- a/pandas/tests/extension/test_string.py +++ b/pandas/tests/extension/test_string.py @@ -198,7 +198,11 @@ def test_argmin_argmax( def test_argreduce_series( self, data_missing_for_sorting, op_name, skipna, expected, request ): - if pa_version_under6p0 and data_missing_for_sorting.dtype == "string[pyarrow]": + if ( + pa_version_under6p0 + and data_missing_for_sorting.dtype == "string[pyarrow]" + and skipna + ): request.node.add_marker( pytest.mark.xfail( raises=NotImplementedError, diff --git a/pandas/tests/indexes/test_common.py b/pandas/tests/indexes/test_common.py index d582a469eaf0e..24d46d4e411ba 100644 --- a/pandas/tests/indexes/test_common.py +++ b/pandas/tests/indexes/test_common.py @@ -10,7 +10,7 @@ from pandas.compat import ( IS64, - pa_version_under2p0, + pa_version_under7p0, ) from pandas.core.dtypes.common import is_integer_dtype @@ -396,11 +396,16 @@ def test_astype_preserves_name(self, index, dtype): # imaginary components discarded warn = np.ComplexWarning + is_pyarrow_str = ( + index.dtype == "string[pyarrow]" + and pa_version_under7p0 + and dtype == "category" + ) try: # Some of these conversions cannot succeed so we use a try / except with tm.assert_produces_warning( warn, - raise_on_extra_warnings=not pa_version_under2p0, + raise_on_extra_warnings=not is_pyarrow_str, ): result = index.astype(dtype) except (ValueError, TypeError, NotImplementedError, SystemError): diff --git a/pandas/tests/indexes/test_setops.py b/pandas/tests/indexes/test_setops.py index 12544428bbdb8..bd0cd1247f775 100644 --- a/pandas/tests/indexes/test_setops.py +++ b/pandas/tests/indexes/test_setops.py @@ -182,11 +182,7 @@ def test_dunder_inplace_setops_deprecated(index): with tm.assert_produces_warning(FutureWarning): index &= index - is_pyarrow = ( - index.dtype == "string[pyarrow]" - and pa_version_under7p0 - and not pa_version_under6p0 - ) + is_pyarrow = index.dtype == "string[pyarrow]" and pa_version_under7p0 with tm.assert_produces_warning( FutureWarning, raise_on_extra_warnings=not is_pyarrow ): From e4c13df1b95cbecc308d3b910c9a82ddd3594b57 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke Date: Fri, 22 Jul 2022 13:42:16 -0700 Subject: [PATCH 08/13] use str comparison if pyarrow isnt installed --- pandas/tests/base/test_value_counts.py | 2 +- pandas/tests/extension/test_string.py | 4 ++-- pandas/tests/indexes/test_common.py | 2 +- pandas/tests/indexes/test_setops.py | 6 +++--- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/pandas/tests/base/test_value_counts.py b/pandas/tests/base/test_value_counts.py index bf02b1dbdd2b5..5f06f9e4b3fd1 100644 --- a/pandas/tests/base/test_value_counts.py +++ b/pandas/tests/base/test_value_counts.py @@ -59,7 +59,7 @@ def test_value_counts_null(null_obj, index_or_series_obj, request): pytest.skip(f"MultiIndex can't hold '{null_obj}'") elif ( pa_version_under7p0 - and orig.dtype == "string[pyarrow]" + and str(orig.dtype) == "string[pyarrow]" and not pa_version_under6p0 ): request.node.add_marker( diff --git a/pandas/tests/extension/test_string.py b/pandas/tests/extension/test_string.py index 0e27533a1c28f..1b2b54493d4de 100644 --- a/pandas/tests/extension/test_string.py +++ b/pandas/tests/extension/test_string.py @@ -173,7 +173,7 @@ class TestMethods(base.BaseMethodsTests): def test_argmin_argmax( self, data_for_sorting, data_missing_for_sorting, na_value, request ): - if pa_version_under6p0 and data_for_sorting.dtype == "string[pyarrow]": + if pa_version_under6p0 and str(data_for_sorting.dtype) == "string[pyarrow]": request.node.add_marker( pytest.mark.xfail( raises=NotImplementedError, @@ -200,7 +200,7 @@ def test_argreduce_series( ): if ( pa_version_under6p0 - and data_missing_for_sorting.dtype == "string[pyarrow]" + and str(data_missing_for_sorting.dtype) == "string[pyarrow]" and skipna ): request.node.add_marker( diff --git a/pandas/tests/indexes/test_common.py b/pandas/tests/indexes/test_common.py index 24d46d4e411ba..e62d234db9a53 100644 --- a/pandas/tests/indexes/test_common.py +++ b/pandas/tests/indexes/test_common.py @@ -397,7 +397,7 @@ def test_astype_preserves_name(self, index, dtype): warn = np.ComplexWarning is_pyarrow_str = ( - index.dtype == "string[pyarrow]" + str(index.dtype) == "string[pyarrow]" and pa_version_under7p0 and dtype == "category" ) diff --git a/pandas/tests/indexes/test_setops.py b/pandas/tests/indexes/test_setops.py index bd0cd1247f775..7471f00511d75 100644 --- a/pandas/tests/indexes/test_setops.py +++ b/pandas/tests/indexes/test_setops.py @@ -182,7 +182,7 @@ def test_dunder_inplace_setops_deprecated(index): with tm.assert_produces_warning(FutureWarning): index &= index - is_pyarrow = index.dtype == "string[pyarrow]" and pa_version_under7p0 + is_pyarrow = str(index.dtype) == "string[pyarrow]" and pa_version_under7p0 with tm.assert_produces_warning( FutureWarning, raise_on_extra_warnings=not is_pyarrow ): @@ -239,7 +239,7 @@ def test_intersection_base(self, index): def test_union_base(self, index, request): if ( pa_version_under7p0 - and index.dtype == "string[pyarrow]" + and str(index.dtype) == "string[pyarrow]" and not pa_version_under6p0 ): request.node.add_marker( @@ -298,7 +298,7 @@ def test_difference_base(self, sort, index): def test_symmetric_difference(self, index, request): if ( pa_version_under7p0 - and index.dtype == "string[pyarrow]" + and str(index.dtype) == "string[pyarrow]" and not pa_version_under6p0 ): request.node.add_marker( From c9f857aec1107c65b1ec7d1e89267acd16145bc5 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke Date: Fri, 22 Jul 2022 17:48:30 -0700 Subject: [PATCH 09/13] Just avoid the bug in pyarrow 6 --- pandas/core/arrays/arrow/array.py | 8 +- pandas/tests/base/test_value_counts.py | 17 +--- pandas/tests/extension/test_string.py | 106 +------------------------ pandas/tests/indexes/test_common.py | 2 +- pandas/tests/indexes/test_setops.py | 33 +------- 5 files changed, 15 insertions(+), 151 deletions(-) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 8bcba477c75d9..c3cabc68ff467 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -22,6 +22,7 @@ pa_version_under4p0, pa_version_under5p0, pa_version_under6p0, + pa_version_under7p0, ) from pandas.util._decorators import ( deprecate_nonkeyword_arguments, @@ -438,8 +439,11 @@ def argsort( ) -> np.ndarray: order = "ascending" if ascending else "descending" null_placement = {"last": "at_end", "first": "at_start"}.get(na_position, None) - if null_placement is None or pa_version_under6p0: - fallback_performancewarning("6") + if null_placement is None or pa_version_under7p0: + # Although pc.array_sort_indices exists in version 6 + # there's a bug that affects the pa.ChunkedArray backing + # https://issues.apache.org/jira/browse/ARROW-12042 + fallback_performancewarning("7") return super().argsort(ascending, kind, na_position, *args, **kwargs) result = pc.array_sort_indices( diff --git a/pandas/tests/base/test_value_counts.py b/pandas/tests/base/test_value_counts.py index 5f06f9e4b3fd1..55a6cc48ebfc8 100644 --- a/pandas/tests/base/test_value_counts.py +++ b/pandas/tests/base/test_value_counts.py @@ -4,11 +4,6 @@ import numpy as np import pytest -from pandas.compat import ( - pa_version_under6p0, - pa_version_under7p0, -) - import pandas as pd from pandas import ( DatetimeIndex, @@ -47,7 +42,7 @@ def test_value_counts(index_or_series_obj): @pytest.mark.parametrize("null_obj", [np.nan, None]) -def test_value_counts_null(null_obj, index_or_series_obj, request): +def test_value_counts_null(null_obj, index_or_series_obj): orig = index_or_series_obj obj = orig.copy() @@ -57,16 +52,6 @@ def test_value_counts_null(null_obj, index_or_series_obj, request): pytest.skip("Test doesn't make sense on empty data") elif isinstance(orig, pd.MultiIndex): pytest.skip(f"MultiIndex can't hold '{null_obj}'") - elif ( - pa_version_under7p0 - and str(orig.dtype) == "string[pyarrow]" - and not pa_version_under6p0 - ): - request.node.add_marker( - pytest.mark.xfail( - reason="https://issues.apache.org/jira/browse/ARROW-12042" - ) - ) values = obj._values values[0:2] = null_obj diff --git a/pandas/tests/extension/test_string.py b/pandas/tests/extension/test_string.py index 1b2b54493d4de..e4293d6d70e38 100644 --- a/pandas/tests/extension/test_string.py +++ b/pandas/tests/extension/test_string.py @@ -18,10 +18,7 @@ import numpy as np import pytest -from pandas.compat import ( - pa_version_under6p0, - pa_version_under7p0, -) +from pandas.compat import pa_version_under6p0 from pandas.errors import PerformanceWarning import pandas as pd @@ -173,7 +170,7 @@ class TestMethods(base.BaseMethodsTests): def test_argmin_argmax( self, data_for_sorting, data_missing_for_sorting, na_value, request ): - if pa_version_under6p0 and str(data_for_sorting.dtype) == "string[pyarrow]": + if pa_version_under6p0 and data_missing_for_sorting.dtype.storage == "pyarrow": request.node.add_marker( pytest.mark.xfail( raises=NotImplementedError, @@ -200,7 +197,7 @@ def test_argreduce_series( ): if ( pa_version_under6p0 - and str(data_missing_for_sorting.dtype) == "string[pyarrow]" + and data_missing_for_sorting.dtype.storage == "pyarrow" and skipna ): request.node.add_marker( @@ -213,103 +210,6 @@ def test_argreduce_series( data_missing_for_sorting, op_name, skipna, expected ) - def test_argsort(self, data_for_sorting, request): - if ( - pa_version_under7p0 - and not pa_version_under6p0 - and data_for_sorting.dtype == "string[pyarrow]" - and "True" in request.node.name - ): - request.node.add_marker( - pytest.mark.xfail( - reason="https://issues.apache.org/jira/browse/ARROW-12042" - ) - ) - super().test_argsort(data_for_sorting) - - def test_argsort_missing_array(self, data_missing_for_sorting, request): - if ( - pa_version_under7p0 - and not pa_version_under6p0 - and data_missing_for_sorting.dtype == "string[pyarrow]" - and "True" in request.node.name - ): - request.node.add_marker( - pytest.mark.xfail( - reason="https://issues.apache.org/jira/browse/ARROW-12042" - ) - ) - super().test_argsort_missing_array(data_missing_for_sorting) - - def test_argsort_missing(self, data_missing_for_sorting, request): - if ( - pa_version_under7p0 - and not pa_version_under6p0 - and data_missing_for_sorting.dtype == "string[pyarrow]" - and "True" in request.node.name - ): - request.node.add_marker( - pytest.mark.xfail( - reason="https://issues.apache.org/jira/browse/ARROW-12042" - ) - ) - super().test_argsort_missing(data_missing_for_sorting) - - @pytest.mark.parametrize( - "na_position, expected", - [ - ("last", np.array([2, 0, 1], dtype=np.dtype("intp"))), - ("first", np.array([1, 2, 0], dtype=np.dtype("intp"))), - ], - ) - def test_nargsort(self, data_missing_for_sorting, na_position, expected, request): - if ( - pa_version_under7p0 - and not pa_version_under6p0 - and data_missing_for_sorting.dtype == "string[pyarrow]" - and "True" in request.node.name - ): - request.node.add_marker( - pytest.mark.xfail( - reason="https://issues.apache.org/jira/browse/ARROW-12042" - ) - ) - super().test_nargsort(data_missing_for_sorting, na_position, expected) - - @pytest.mark.parametrize("ascending", [True, False]) - def test_sort_values(self, data_for_sorting, ascending, sort_by_key, request): - if ( - pa_version_under7p0 - and not pa_version_under6p0 - and data_for_sorting.dtype == "string[pyarrow]" - and "-True-" in request.node.name - ): - request.node.add_marker( - pytest.mark.xfail( - reason="https://issues.apache.org/jira/browse/ARROW-12042" - ) - ) - super().test_sort_values(data_for_sorting, ascending, sort_by_key) - - @pytest.mark.parametrize("ascending", [True, False]) - def test_sort_values_missing( - self, data_missing_for_sorting, ascending, sort_by_key, request - ): - if ( - pa_version_under7p0 - and not pa_version_under6p0 - and data_missing_for_sorting.dtype == "string[pyarrow]" - and "-True-" in request.node.name - ): - request.node.add_marker( - pytest.mark.xfail( - reason="https://issues.apache.org/jira/browse/ARROW-12042" - ) - ) - super().test_sort_values_missing( - data_missing_for_sorting, ascending, sort_by_key - ) - class TestCasting(base.BaseCastingTests): pass diff --git a/pandas/tests/indexes/test_common.py b/pandas/tests/indexes/test_common.py index e62d234db9a53..e7e971f957e48 100644 --- a/pandas/tests/indexes/test_common.py +++ b/pandas/tests/indexes/test_common.py @@ -405,7 +405,7 @@ def test_astype_preserves_name(self, index, dtype): # Some of these conversions cannot succeed so we use a try / except with tm.assert_produces_warning( warn, - raise_on_extra_warnings=not is_pyarrow_str, + raise_on_extra_warnings=is_pyarrow_str, ): result = index.astype(dtype) except (ValueError, TypeError, NotImplementedError, SystemError): diff --git a/pandas/tests/indexes/test_setops.py b/pandas/tests/indexes/test_setops.py index 7471f00511d75..45ecd09e550d0 100644 --- a/pandas/tests/indexes/test_setops.py +++ b/pandas/tests/indexes/test_setops.py @@ -8,10 +8,7 @@ import numpy as np import pytest -from pandas.compat import ( - pa_version_under6p0, - pa_version_under7p0, -) +from pandas.compat import pa_version_under7p0 from pandas.core.dtypes.cast import find_common_type @@ -183,9 +180,7 @@ def test_dunder_inplace_setops_deprecated(index): index &= index is_pyarrow = str(index.dtype) == "string[pyarrow]" and pa_version_under7p0 - with tm.assert_produces_warning( - FutureWarning, raise_on_extra_warnings=not is_pyarrow - ): + with tm.assert_produces_warning(FutureWarning, raise_on_extra_warnings=is_pyarrow): index ^= index @@ -236,17 +231,7 @@ def test_intersection_base(self, index): with pytest.raises(TypeError, match=msg): first.intersection([1, 2, 3]) - def test_union_base(self, index, request): - if ( - pa_version_under7p0 - and str(index.dtype) == "string[pyarrow]" - and not pa_version_under6p0 - ): - request.node.add_marker( - pytest.mark.xfail( - reason="https://issues.apache.org/jira/browse/ARROW-12042" - ) - ) + def test_union_base(self, index): first = index[3:] second = index[:5] everything = index @@ -295,17 +280,7 @@ def test_difference_base(self, sort, index): with pytest.raises(TypeError, match=msg): first.difference([1, 2, 3], sort) - def test_symmetric_difference(self, index, request): - if ( - pa_version_under7p0 - and str(index.dtype) == "string[pyarrow]" - and not pa_version_under6p0 - ): - request.node.add_marker( - pytest.mark.xfail( - reason="https://issues.apache.org/jira/browse/ARROW-12042" - ) - ) + def test_symmetric_difference(self, index): if isinstance(index, CategoricalIndex): return if len(index) < 2: From 711d44a46a8beabfb63a286ef711447d77bd105d Mon Sep 17 00:00:00 2001 From: Matthew Roeschke Date: Mon, 25 Jul 2022 11:48:09 -0700 Subject: [PATCH 10/13] Use keywords --- pandas/core/arrays/arrow/array.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 7a872f8b01a9f..baa5b4ab8c1a7 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -444,7 +444,9 @@ def argsort( # there's a bug that affects the pa.ChunkedArray backing # https://issues.apache.org/jira/browse/ARROW-12042 fallback_performancewarning("7") - return super().argsort(ascending, kind, na_position, *args, **kwargs) + return super().argsort( + ascending=ascending, kind=kind, na_position=na_position + ) result = pc.array_sort_indices( self._data, order=order, null_placement=null_placement From c4d52518ed13816d6d259f59396c0ff78f77c00a Mon Sep 17 00:00:00 2001 From: Matthew Roeschke Date: Mon, 25 Jul 2022 14:10:14 -0700 Subject: [PATCH 11/13] Add another keyword, try commenting something out --- pandas/core/arrays/arrow/array.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index baa5b4ab8c1a7..d78de0c6f689d 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -422,11 +422,11 @@ def isna(self) -> npt.NDArray[np.bool_]: else: return self._data.is_null().to_numpy() - def _values_for_argsort(self) -> np.ndarray: - if pa_version_under2p0: - return self._data.to_pandas().values - else: - return self._data.to_numpy() + # def _values_for_argsort(self) -> np.ndarray: + # if pa_version_under2p0: + # return self._data.to_pandas().values + # else: + # return self._data.to_numpy() @deprecate_nonkeyword_arguments(version=None, allowed_args=["self"]) def argsort( @@ -464,7 +464,7 @@ def _argmin_max(self, skipna: bool, method: str) -> int: # For empty or all null, pyarrow returns -1 but pandas expects TypeError # For skipna=False and data w/ null, pandas expects NotImplementedError # let ExtensionArray.arg{max|min} raise - return getattr(super(), f"arg{method}")(skipna) + return getattr(super(), f"arg{method}")(skipna=skipna) if pa_version_under6p0: raise NotImplementedError( From f4d0c4b6b9b269c7d62b1bb9af89f5aab0af71f0 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke Date: Mon, 25 Jul 2022 18:18:59 -0700 Subject: [PATCH 12/13] Was it values for argsort? --- pandas/core/arrays/arrow/array.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index d78de0c6f689d..0ddfa3016721e 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -422,11 +422,11 @@ def isna(self) -> npt.NDArray[np.bool_]: else: return self._data.is_null().to_numpy() - # def _values_for_argsort(self) -> np.ndarray: - # if pa_version_under2p0: - # return self._data.to_pandas().values - # else: - # return self._data.to_numpy() + def _values_for_argsort(self) -> np.ndarray: + if pa_version_under2p0: + return self._data.to_pandas().values + else: + return self._data.to_numpy() @deprecate_nonkeyword_arguments(version=None, allowed_args=["self"]) def argsort( From 86233e738f9446904d8126e4b917f4ac7ec97778 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke Date: Tue, 26 Jul 2022 09:35:09 -0700 Subject: [PATCH 13/13] Just remove values for argsort --- pandas/core/arrays/arrow/array.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 0ddfa3016721e..841275e54e3d6 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -422,12 +422,6 @@ def isna(self) -> npt.NDArray[np.bool_]: else: return self._data.is_null().to_numpy() - def _values_for_argsort(self) -> np.ndarray: - if pa_version_under2p0: - return self._data.to_pandas().values - else: - return self._data.to_numpy() - @deprecate_nonkeyword_arguments(version=None, allowed_args=["self"]) def argsort( self,