Skip to content

ENH/TST: Add quantile & mode tests for ArrowExtensionArray #47744

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Jul 28, 2022
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
4c65875
Add mode
mroeschke Jul 15, 2022
285aee2
implement quantile
mroeschke Jul 15, 2022
218fc80
Make note about rank
mroeschke Jul 15, 2022
ecaf2db
Remove typing
mroeschke Jul 15, 2022
79b85a7
Add NotImplementedError for unsupported versions
mroeschke Jul 16, 2022
53e4d8a
improve xfails
mroeschke Jul 16, 2022
18be328
Change notimmplemented to performancewarning
mroeschke Jul 16, 2022
0ce0b49
Merge remote-tracking branch 'upstream/main' into arrow/misc_methods
mroeschke Jul 16, 2022
f787903
Fix scalar case
mroeschke Jul 16, 2022
1fa1cfa
Ignore other warnings
mroeschke Jul 17, 2022
6b018ca
Merge remote-tracking branch 'upstream/main' into arrow/misc_methods
mroeschke Jul 17, 2022
f5aa946
Merge remote-tracking branch 'upstream/main' into arrow/misc_methods
mroeschke Jul 18, 2022
8a6586f
Print tests to debug min build timeout
mroeschke Jul 18, 2022
2b97a99
Merge remote-tracking branch 'upstream/main' into arrow/misc_methods
mroeschke Jul 18, 2022
1b7e575
Raise notimplementederror
mroeschke Jul 18, 2022
c8f7468
Undo warning message
mroeschke Jul 18, 2022
6d6c735
Undo typing
mroeschke Jul 18, 2022
a27c854
reason
mroeschke Jul 19, 2022
f5d1b97
multimode case will be fixed in pa=9
mroeschke Jul 19, 2022
cd412ab
Fix typo
mroeschke Jul 19, 2022
e6246df
Merge remote-tracking branch 'upstream/main' into arrow/misc_methods
mroeschke Jul 19, 2022
c99b73d
Merge remote-tracking branch 'upstream/main' into arrow/misc_methods
mroeschke Jul 22, 2022
1996c63
Merge remote-tracking branch 'upstream/main' into arrow/misc_methods
mroeschke Jul 22, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pandas/compat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
pa_version_under6p0,
pa_version_under7p0,
pa_version_under8p0,
pa_version_under9p0,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -160,4 +161,5 @@ def get_lzma_file() -> type[lzma.LZMAFile]:
"pa_version_under6p0",
"pa_version_under7p0",
"pa_version_under8p0",
"pa_version_under9p0",
]
2 changes: 2 additions & 0 deletions pandas/compat/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
pa_version_under6p0 = _palv < Version("6.0.0")
pa_version_under7p0 = _palv < Version("7.0.0")
pa_version_under8p0 = _palv < Version("8.0.0")
pa_version_under9p0 = _palv < Version("9.0.0")
except ImportError:
pa_version_under1p01 = True
pa_version_under2p0 = True
Expand All @@ -26,3 +27,4 @@
pa_version_under6p0 = True
pa_version_under7p0 = True
pa_version_under8p0 = True
pa_version_under9p0 = True
51 changes: 51 additions & 0 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,6 +702,57 @@ def _indexing_key_to_indices(
indices = np.arange(n)[key]
return indices

# TODO: redefine _rank using pc.rank with pyarrow 9.0

def _quantile(
self: ArrowExtensionArrayT, qs: npt.NDArray[np.float64], interpolation: str
) -> ArrowExtensionArrayT:
"""
Compute the quantiles of self for each quantile in `qs`.

Parameters
----------
qs : np.ndarray[float64]
interpolation: str

Returns
-------
same type as self
"""
if pa_version_under4p0:
raise NotImplementedError(
"quantile only supported for pyarrow version >= 4.0"
)
result = pc.quantile(self._data, q=qs, interpolation=interpolation)
return type(self)(result)

def _mode(self: ArrowExtensionArrayT, dropna: bool = True) -> ArrowExtensionArrayT:
"""
Returns the mode(s) of the ExtensionArray.

Always returns `ExtensionArray` even if only one value.

Parameters
----------
dropna : bool, default True
Don't consider counts of NA values.
Not implemented by pyarrow.

Returns
-------
same type as self
Sorted, if possible.
"""
if pa_version_under6p0:
raise NotImplementedError("mode only supported for pyarrow version >= 6.0")
modes = pc.mode(self._data, pc.count_distinct(self._data).as_py())
values = modes.field(0)
counts = modes.field(1)
# counts sorted descending i.e counts[0] = max
mask = pc.equal(counts, counts[0])
most_common = values.filter(mask)
return type(self)(most_common)

def _maybe_convert_setitem_value(self, value):
"""Maybe convert value to be pyarrow compatible."""
# TODO: Make more robust like ArrowStringArray._maybe_convert_setitem_value
Expand Down
73 changes: 72 additions & 1 deletion pandas/tests/extension/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
classes (if they are relevant for the extension interface for all dtypes), or
be added to the array-specific tests in `pandas/tests/arrays/`.
"""

from datetime import (
date,
datetime,
Expand All @@ -24,7 +23,10 @@
from pandas.compat import (
pa_version_under2p0,
pa_version_under3p0,
pa_version_under4p0,
pa_version_under6p0,
pa_version_under8p0,
pa_version_under9p0,
)

import pandas as pd
Expand Down Expand Up @@ -1838,3 +1840,72 @@ def test_compare_array(self, data, comparison_op, na_value, request):
def test_arrowdtype_construct_from_string_type_with_unsupported_parameters():
with pytest.raises(NotImplementedError, match="Passing pyarrow type"):
ArrowDtype.construct_from_string("timestamp[s, tz=UTC][pyarrow]")


@pytest.mark.xfail(
pa_version_under4p0,
raises=NotImplementedError,
reason="quantile only supported for pyarrow version >= 4.0",
)
@pytest.mark.parametrize(
"interpolation", ["linear", "lower", "higher", "nearest", "midpoint"]
)
@pytest.mark.parametrize("quantile", [0.5, [0.5, 0.5]])
def test_quantile(data, interpolation, quantile, request):
pa_dtype = data.dtype.pyarrow_dtype
if not (pa.types.is_integer(pa_dtype) or pa.types.is_floating(pa_dtype)):
request.node.add_marker(
pytest.mark.xfail(
raises=pa.ArrowNotImplementedError,
reason=f"quantile not supported by pyarrow for {pa_dtype}",
)
)
data = data.take([0, 0, 0])
ser = pd.Series(data)
result = ser.quantile(q=quantile, interpolation=interpolation)
if quantile == 0.5:
assert result == data[0]
else:
# Just check the values
result = result.astype("float64[pyarrow]")
expected = pd.Series(
data.take([0, 0]).astype("float64[pyarrow]"), index=[0.5, 0.5]
)
tm.assert_series_equal(result, expected)


@pytest.mark.xfail(
pa_version_under6p0,
raises=NotImplementedError,
reason="mode only supported for pyarrow version >= 6.0",
)
@pytest.mark.parametrize("dropna", [True, False])
@pytest.mark.parametrize(
"take_idx, exp_idx",
[[[0, 0, 2, 2, 4, 4], [4, 0]], [[0, 0, 0, 2, 4, 4], [0]]],
ids=["multi_mode", "single_mode"],
)
def test_mode(data_for_grouping, dropna, take_idx, exp_idx, request):
pa_dtype = data_for_grouping.dtype.pyarrow_dtype
if pa.types.is_temporal(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
raises=pa.ArrowNotImplementedError,
reason=f"mode not supported by pyarrow for {pa_dtype}",
)
)
elif (
pa.types.is_boolean(pa_dtype)
and "multi_mode" in request.node.nodeid
and pa_version_under9p0
):
request.node.add_marker(
pytest.mark.xfail(
reason="https://issues.apache.org/jira/browse/ARROW-17096",
)
)
data = data_for_grouping.take(take_idx)
ser = pd.Series(data)
result = ser.mode(dropna=dropna)
expected = pd.Series(data_for_grouping.take(exp_idx))
tm.assert_series_equal(result, expected)