Skip to content

Commit ffc111c

Browse files
authored
ENH/TST: Add quantile & mode tests for ArrowExtensionArray (#47744)
1 parent 1497bf2 commit ffc111c

File tree

4 files changed

+126
-1
lines changed

4 files changed

+126
-1
lines changed

pandas/compat/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
pa_version_under6p0,
2929
pa_version_under7p0,
3030
pa_version_under8p0,
31+
pa_version_under9p0,
3132
)
3233

3334
if TYPE_CHECKING:
@@ -160,4 +161,5 @@ def get_lzma_file() -> type[lzma.LZMAFile]:
160161
"pa_version_under6p0",
161162
"pa_version_under7p0",
162163
"pa_version_under8p0",
164+
"pa_version_under9p0",
163165
]

pandas/compat/pyarrow.py

+2
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
pa_version_under6p0 = _palv < Version("6.0.0")
1818
pa_version_under7p0 = _palv < Version("7.0.0")
1919
pa_version_under8p0 = _palv < Version("8.0.0")
20+
pa_version_under9p0 = _palv < Version("9.0.0")
2021
except ImportError:
2122
pa_version_under1p01 = True
2223
pa_version_under2p0 = True
@@ -26,3 +27,4 @@
2627
pa_version_under6p0 = True
2728
pa_version_under7p0 = True
2829
pa_version_under8p0 = True
30+
pa_version_under9p0 = True

pandas/core/arrays/arrow/array.py

+51
Original file line numberDiff line numberDiff line change
@@ -881,6 +881,57 @@ def _indexing_key_to_indices(
881881
indices = np.arange(n)[key]
882882
return indices
883883

884+
# TODO: redefine _rank using pc.rank with pyarrow 9.0
885+
886+
def _quantile(
887+
self: ArrowExtensionArrayT, qs: npt.NDArray[np.float64], interpolation: str
888+
) -> ArrowExtensionArrayT:
889+
"""
890+
Compute the quantiles of self for each quantile in `qs`.
891+
892+
Parameters
893+
----------
894+
qs : np.ndarray[float64]
895+
interpolation: str
896+
897+
Returns
898+
-------
899+
same type as self
900+
"""
901+
if pa_version_under4p0:
902+
raise NotImplementedError(
903+
"quantile only supported for pyarrow version >= 4.0"
904+
)
905+
result = pc.quantile(self._data, q=qs, interpolation=interpolation)
906+
return type(self)(result)
907+
908+
def _mode(self: ArrowExtensionArrayT, dropna: bool = True) -> ArrowExtensionArrayT:
909+
"""
910+
Returns the mode(s) of the ExtensionArray.
911+
912+
Always returns `ExtensionArray` even if only one value.
913+
914+
Parameters
915+
----------
916+
dropna : bool, default True
917+
Don't consider counts of NA values.
918+
Not implemented by pyarrow.
919+
920+
Returns
921+
-------
922+
same type as self
923+
Sorted, if possible.
924+
"""
925+
if pa_version_under6p0:
926+
raise NotImplementedError("mode only supported for pyarrow version >= 6.0")
927+
modes = pc.mode(self._data, pc.count_distinct(self._data).as_py())
928+
values = modes.field(0)
929+
counts = modes.field(1)
930+
# counts sorted descending i.e counts[0] = max
931+
mask = pc.equal(counts, counts[0])
932+
most_common = values.filter(mask)
933+
return type(self)(most_common)
934+
884935
def _maybe_convert_setitem_value(self, value):
885936
"""Maybe convert value to be pyarrow compatible."""
886937
# TODO: Make more robust like ArrowStringArray._maybe_convert_setitem_value

pandas/tests/extension/test_arrow.py

+71-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
classes (if they are relevant for the extension interface for all dtypes), or
1111
be added to the array-specific tests in `pandas/tests/arrays/`.
1212
"""
13-
1413
from datetime import (
1514
date,
1615
datetime,
@@ -24,8 +23,10 @@
2423
from pandas.compat import (
2524
pa_version_under2p0,
2625
pa_version_under3p0,
26+
pa_version_under4p0,
2727
pa_version_under6p0,
2828
pa_version_under8p0,
29+
pa_version_under9p0,
2930
)
3031

3132
import pandas as pd
@@ -1993,3 +1994,72 @@ def test_compare_array(self, data, comparison_op, na_value, request):
19931994
def test_arrowdtype_construct_from_string_type_with_unsupported_parameters():
19941995
with pytest.raises(NotImplementedError, match="Passing pyarrow type"):
19951996
ArrowDtype.construct_from_string("timestamp[s, tz=UTC][pyarrow]")
1997+
1998+
1999+
@pytest.mark.xfail(
2000+
pa_version_under4p0,
2001+
raises=NotImplementedError,
2002+
reason="quantile only supported for pyarrow version >= 4.0",
2003+
)
2004+
@pytest.mark.parametrize(
2005+
"interpolation", ["linear", "lower", "higher", "nearest", "midpoint"]
2006+
)
2007+
@pytest.mark.parametrize("quantile", [0.5, [0.5, 0.5]])
2008+
def test_quantile(data, interpolation, quantile, request):
2009+
pa_dtype = data.dtype.pyarrow_dtype
2010+
if not (pa.types.is_integer(pa_dtype) or pa.types.is_floating(pa_dtype)):
2011+
request.node.add_marker(
2012+
pytest.mark.xfail(
2013+
raises=pa.ArrowNotImplementedError,
2014+
reason=f"quantile not supported by pyarrow for {pa_dtype}",
2015+
)
2016+
)
2017+
data = data.take([0, 0, 0])
2018+
ser = pd.Series(data)
2019+
result = ser.quantile(q=quantile, interpolation=interpolation)
2020+
if quantile == 0.5:
2021+
assert result == data[0]
2022+
else:
2023+
# Just check the values
2024+
result = result.astype("float64[pyarrow]")
2025+
expected = pd.Series(
2026+
data.take([0, 0]).astype("float64[pyarrow]"), index=[0.5, 0.5]
2027+
)
2028+
tm.assert_series_equal(result, expected)
2029+
2030+
2031+
@pytest.mark.xfail(
2032+
pa_version_under6p0,
2033+
raises=NotImplementedError,
2034+
reason="mode only supported for pyarrow version >= 6.0",
2035+
)
2036+
@pytest.mark.parametrize("dropna", [True, False])
2037+
@pytest.mark.parametrize(
2038+
"take_idx, exp_idx",
2039+
[[[0, 0, 2, 2, 4, 4], [4, 0]], [[0, 0, 0, 2, 4, 4], [0]]],
2040+
ids=["multi_mode", "single_mode"],
2041+
)
2042+
def test_mode(data_for_grouping, dropna, take_idx, exp_idx, request):
2043+
pa_dtype = data_for_grouping.dtype.pyarrow_dtype
2044+
if pa.types.is_temporal(pa_dtype):
2045+
request.node.add_marker(
2046+
pytest.mark.xfail(
2047+
raises=pa.ArrowNotImplementedError,
2048+
reason=f"mode not supported by pyarrow for {pa_dtype}",
2049+
)
2050+
)
2051+
elif (
2052+
pa.types.is_boolean(pa_dtype)
2053+
and "multi_mode" in request.node.nodeid
2054+
and pa_version_under9p0
2055+
):
2056+
request.node.add_marker(
2057+
pytest.mark.xfail(
2058+
reason="https://issues.apache.org/jira/browse/ARROW-17096",
2059+
)
2060+
)
2061+
data = data_for_grouping.take(take_idx)
2062+
ser = pd.Series(data)
2063+
result = ser.mode(dropna=dropna)
2064+
expected = pd.Series(data_for_grouping.take(exp_idx))
2065+
tm.assert_series_equal(result, expected)

0 commit comments

Comments
 (0)