diff --git a/pandas/compat/__init__.py b/pandas/compat/__init__.py index 5db859897b663..147134afd70c3 100644 --- a/pandas/compat/__init__.py +++ b/pandas/compat/__init__.py @@ -28,6 +28,7 @@ pa_version_under6p0, pa_version_under7p0, pa_version_under8p0, + pa_version_under9p0, ) if TYPE_CHECKING: @@ -160,4 +161,5 @@ def get_lzma_file() -> type[lzma.LZMAFile]: "pa_version_under6p0", "pa_version_under7p0", "pa_version_under8p0", + "pa_version_under9p0", ] diff --git a/pandas/compat/pyarrow.py b/pandas/compat/pyarrow.py index 833cda20368a2..6965865acb5da 100644 --- a/pandas/compat/pyarrow.py +++ b/pandas/compat/pyarrow.py @@ -17,6 +17,7 @@ pa_version_under6p0 = _palv < Version("6.0.0") pa_version_under7p0 = _palv < Version("7.0.0") pa_version_under8p0 = _palv < Version("8.0.0") + pa_version_under9p0 = _palv < Version("9.0.0") except ImportError: pa_version_under1p01 = True pa_version_under2p0 = True @@ -26,3 +27,4 @@ pa_version_under6p0 = True pa_version_under7p0 = True pa_version_under8p0 = True + pa_version_under9p0 = True diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index b0e4d46564ba4..2c4859061998b 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -825,6 +825,57 @@ def _indexing_key_to_indices( indices = np.arange(n)[key] return indices + # TODO: redefine _rank using pc.rank with pyarrow 9.0 + + def _quantile( + self: ArrowExtensionArrayT, qs: npt.NDArray[np.float64], interpolation: str + ) -> ArrowExtensionArrayT: + """ + Compute the quantiles of self for each quantile in `qs`. + + Parameters + ---------- + qs : np.ndarray[float64] + interpolation: str + + Returns + ------- + same type as self + """ + if pa_version_under4p0: + raise NotImplementedError( + "quantile only supported for pyarrow version >= 4.0" + ) + result = pc.quantile(self._data, q=qs, interpolation=interpolation) + return type(self)(result) + + def _mode(self: ArrowExtensionArrayT, dropna: bool = True) -> ArrowExtensionArrayT: + """ + Returns the mode(s) of the ExtensionArray. + + Always returns `ExtensionArray` even if only one value. + + Parameters + ---------- + dropna : bool, default True + Don't consider counts of NA values. + Not implemented by pyarrow. + + Returns + ------- + same type as self + Sorted, if possible. + """ + if pa_version_under6p0: + raise NotImplementedError("mode only supported for pyarrow version >= 6.0") + modes = pc.mode(self._data, pc.count_distinct(self._data).as_py()) + values = modes.field(0) + counts = modes.field(1) + # counts sorted descending i.e counts[0] = max + mask = pc.equal(counts, counts[0]) + most_common = values.filter(mask) + return type(self)(most_common) + def _maybe_convert_setitem_value(self, value): """Maybe convert value to be pyarrow compatible.""" # TODO: Make more robust like ArrowStringArray._maybe_convert_setitem_value diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index a2a96da02b2a6..136c147c07f2e 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -10,7 +10,6 @@ classes (if they are relevant for the extension interface for all dtypes), or be added to the array-specific tests in `pandas/tests/arrays/`. """ - from datetime import ( date, datetime, @@ -24,8 +23,10 @@ from pandas.compat import ( pa_version_under2p0, pa_version_under3p0, + pa_version_under4p0, pa_version_under6p0, pa_version_under8p0, + pa_version_under9p0, ) import pandas as pd @@ -1946,3 +1947,72 @@ def test_compare_array(self, data, comparison_op, na_value, request): def test_arrowdtype_construct_from_string_type_with_unsupported_parameters(): with pytest.raises(NotImplementedError, match="Passing pyarrow type"): ArrowDtype.construct_from_string("timestamp[s, tz=UTC][pyarrow]") + + +@pytest.mark.xfail( + pa_version_under4p0, + raises=NotImplementedError, + reason="quantile only supported for pyarrow version >= 4.0", +) +@pytest.mark.parametrize( + "interpolation", ["linear", "lower", "higher", "nearest", "midpoint"] +) +@pytest.mark.parametrize("quantile", [0.5, [0.5, 0.5]]) +def test_quantile(data, interpolation, quantile, request): + pa_dtype = data.dtype.pyarrow_dtype + if not (pa.types.is_integer(pa_dtype) or pa.types.is_floating(pa_dtype)): + request.node.add_marker( + pytest.mark.xfail( + raises=pa.ArrowNotImplementedError, + reason=f"quantile not supported by pyarrow for {pa_dtype}", + ) + ) + data = data.take([0, 0, 0]) + ser = pd.Series(data) + result = ser.quantile(q=quantile, interpolation=interpolation) + if quantile == 0.5: + assert result == data[0] + else: + # Just check the values + result = result.astype("float64[pyarrow]") + expected = pd.Series( + data.take([0, 0]).astype("float64[pyarrow]"), index=[0.5, 0.5] + ) + tm.assert_series_equal(result, expected) + + +@pytest.mark.xfail( + pa_version_under6p0, + raises=NotImplementedError, + reason="mode only supported for pyarrow version >= 6.0", +) +@pytest.mark.parametrize("dropna", [True, False]) +@pytest.mark.parametrize( + "take_idx, exp_idx", + [[[0, 0, 2, 2, 4, 4], [4, 0]], [[0, 0, 0, 2, 4, 4], [0]]], + ids=["multi_mode", "single_mode"], +) +def test_mode(data_for_grouping, dropna, take_idx, exp_idx, request): + pa_dtype = data_for_grouping.dtype.pyarrow_dtype + if pa.types.is_temporal(pa_dtype): + request.node.add_marker( + pytest.mark.xfail( + raises=pa.ArrowNotImplementedError, + reason=f"mode not supported by pyarrow for {pa_dtype}", + ) + ) + elif ( + pa.types.is_boolean(pa_dtype) + and "multi_mode" in request.node.nodeid + and pa_version_under9p0 + ): + request.node.add_marker( + pytest.mark.xfail( + reason="https://issues.apache.org/jira/browse/ARROW-17096", + ) + ) + data = data_for_grouping.take(take_idx) + ser = pd.Series(data) + result = ser.mode(dropna=dropna) + expected = pd.Series(data_for_grouping.take(exp_idx)) + tm.assert_series_equal(result, expected)