Skip to content

ENH/TST: Add quantile & mode tests for ArrowExtensionArray #47744

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Jul 28, 2022
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
4c65875
Add mode
mroeschke Jul 15, 2022
285aee2
implement quantile
mroeschke Jul 15, 2022
218fc80
Make note about rank
mroeschke Jul 15, 2022
ecaf2db
Remove typing
mroeschke Jul 15, 2022
79b85a7
Add NotImplementedError for unsupported versions
mroeschke Jul 16, 2022
53e4d8a
improve xfails
mroeschke Jul 16, 2022
18be328
Change notimmplemented to performancewarning
mroeschke Jul 16, 2022
0ce0b49
Merge remote-tracking branch 'upstream/main' into arrow/misc_methods
mroeschke Jul 16, 2022
f787903
Fix scalar case
mroeschke Jul 16, 2022
1fa1cfa
Ignore other warnings
mroeschke Jul 17, 2022
6b018ca
Merge remote-tracking branch 'upstream/main' into arrow/misc_methods
mroeschke Jul 17, 2022
f5aa946
Merge remote-tracking branch 'upstream/main' into arrow/misc_methods
mroeschke Jul 18, 2022
8a6586f
Print tests to debug min build timeout
mroeschke Jul 18, 2022
2b97a99
Merge remote-tracking branch 'upstream/main' into arrow/misc_methods
mroeschke Jul 18, 2022
1b7e575
Raise notimplementederror
mroeschke Jul 18, 2022
c8f7468
Undo warning message
mroeschke Jul 18, 2022
6d6c735
Undo typing
mroeschke Jul 18, 2022
a27c854
reason
mroeschke Jul 19, 2022
f5d1b97
multimode case will be fixed in pa=9
mroeschke Jul 19, 2022
cd412ab
Fix typo
mroeschke Jul 19, 2022
e6246df
Merge remote-tracking branch 'upstream/main' into arrow/misc_methods
mroeschke Jul 19, 2022
c99b73d
Merge remote-tracking branch 'upstream/main' into arrow/misc_methods
mroeschke Jul 22, 2022
1996c63
Merge remote-tracking branch 'upstream/main' into arrow/misc_methods
mroeschke Jul 22, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions pandas/core/arrays/arrow/_arrow_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@ def fallback_performancewarning(version: str | None = None) -> None:
Raise a PerformanceWarning for falling back to ExtensionArray's
non-pyarrow method
"""
msg = "Falling back on a non-pyarrow code path which may decrease performance."
msg = (
"Falling back on a non-pyarrow code path which may decrease performance or "
"not be fully compatible with pyarrow."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a specific reason for adding this additional sentence?
(since it is so generic, it's also not that useful / potentially confusing I think)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also wanted to convey in this message that a user may get incorrect results if falling back.

e.g. A user's code could fall back to using ExtensionArray._mode/_quantile which may have some numpy-based assumptions that may fail when passing arrow (or arrow transformed) data.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we know (think?) that a certain fallback can give incorrect results, can we rather have it error in that case? (for that specific function)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that's a good point. Okay I'll leave this message alone, and raise a NotImplementedError for the associated case (mode/quantile)

)
if version is not None:
msg += f" Upgrade to pyarrow >={version} to possibly suppress this warning."
msg += f" Upgrade to pyarrow >={version} to suppress this warning."
warnings.warn(msg, PerformanceWarning, stacklevel=find_stack_level())


Expand Down
58 changes: 57 additions & 1 deletion pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@

from pandas.core.algorithms import resolve_na_sentinel
from pandas.core.arraylike import OpsMixin
from pandas.core.arrays.base import ExtensionArray
from pandas.core.arrays.base import (
ExtensionArray,
ExtensionArrayT,
)
from pandas.core.indexers import (
check_array_indexer,
unpack_tuple_and_ellipses,
Expand Down Expand Up @@ -702,6 +705,59 @@ def _indexing_key_to_indices(
indices = np.arange(n)[key]
return indices

# TODO: redefine _rank using pc.rank with pyarrow 9.0

def _quantile(
self: ArrowExtensionArrayT, qs: npt.NDArray[np.float64], interpolation: str
) -> ArrowExtensionArrayT | ExtensionArrayT:
"""
Compute the quantiles of self for each quantile in `qs`.

Parameters
----------
qs : np.ndarray[float64]
interpolation: str

Returns
-------
same type as self
"""
if pa_version_under4p0:
fallback_performancewarning("4")
return super()._quantile(qs, interpolation)
result = pc.quantile(self._data, q=qs, interpolation=interpolation)
return type(self)(result)

def _mode(
self: ArrowExtensionArrayT, dropna: bool = True
) -> ArrowExtensionArrayT | ExtensionArrayT:
"""
Returns the mode(s) of the ExtensionArray.

Always returns `ExtensionArray` even if only one value.

Parameters
----------
dropna : bool, default True
Don't consider counts of NA values.
Not implemented by pyarrow.

Returns
-------
same type as self
Sorted, if possible.
"""
if pa_version_under6p0:
fallback_performancewarning("6")
return super()._mode(dropna)
modes = pc.mode(self._data, pc.count_distinct(self._data).as_py())
values = modes.field(0)
counts = modes.field(1)
# counts sorted descending i.e counts[0] = max
mask = pc.equal(counts, counts[0])
most_common = values.filter(mask)
return type(self)(most_common)

def _maybe_convert_setitem_value(self, value):
"""Maybe convert value to be pyarrow compatible."""
# TODO: Make more robust like ArrowStringArray._maybe_convert_setitem_value
Expand Down
77 changes: 76 additions & 1 deletion pandas/tests/extension/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
classes (if they are relevant for the extension interface for all dtypes), or
be added to the array-specific tests in `pandas/tests/arrays/`.
"""

import contextlib
from datetime import (
date,
datetime,
Expand All @@ -24,8 +24,11 @@
from pandas.compat import (
pa_version_under2p0,
pa_version_under3p0,
pa_version_under4p0,
pa_version_under6p0,
pa_version_under8p0,
)
from pandas.errors import PerformanceWarning

import pandas as pd
import pandas._testing as tm
Expand Down Expand Up @@ -1838,3 +1841,75 @@ def test_compare_array(self, data, comparison_op, na_value, request):
def test_arrowdtype_construct_from_string_type_with_unsupported_parameters():
with pytest.raises(NotImplementedError, match="Passing pyarrow type"):
ArrowDtype.construct_from_string("timestamp[s, tz=UTC][pyarrow]")


@pytest.mark.parametrize(
"interpolation", ["linear", "lower", "higher", "nearest", "midpoint"]
)
@pytest.mark.parametrize("quantile", [0.5, [0.5, 0.5]])
def test_quantile(data, interpolation, quantile, request):
data = data.take([0, 0, 0])
ser = pd.Series(data)
if pa_version_under4p0:
with tm.assert_produces_warning(PerformanceWarning):
# Just validate the PerformanceWarning
# ExtensionArray._quantile may not support all pyarrow types
with contextlib.suppress(Exception):
ser.quantile(q=quantile, interpolation=interpolation)
else:
pa_dtype = data.dtype.pyarrow_dtype
if not (pa.types.is_integer(pa_dtype) or pa.types.is_floating(pa_dtype)):
request.node.add_marker(
pytest.mark.xfail(
raises=pa.ArrowNotImplementedError,
reason=f"quantile not supported by pyarrow for {pa_dtype}",
)
)
result = ser.quantile(q=quantile, interpolation=interpolation)
if quantile == 0.5:
assert result == data[0]
else:
# Just check the values
result = result.astype("float64[pyarrow]")
expected = pd.Series(
data.take([0, 0]).astype("float64[pyarrow]"), index=[0.5, 0.5]
)
tm.assert_series_equal(result, expected)


@pytest.mark.parametrize("dropna", [True, False])
@pytest.mark.parametrize(
"take_idx, exp_idx",
[[[0, 0, 2, 2, 4, 4], [4, 0]], [[0, 0, 0, 2, 4, 4], [0]]],
ids=["multi_mode", "single_mode"],
)
def test_mode(data_for_grouping, dropna, take_idx, exp_idx, request):
data = data_for_grouping.take(take_idx)
ser = pd.Series(data)
if pa_version_under6p0:
with tm.assert_produces_warning(
PerformanceWarning, raise_on_extra_warnings=False
):
# Just validate the PerformanceWarning
# ExtensionArray._mode may not support all pyarrow types
with contextlib.suppress(Exception):
ser.mode(dropna=dropna)
else:
pa_dtype = data_for_grouping.dtype.pyarrow_dtype
if pa.types.is_temporal(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
raises=pa.ArrowNotImplementedError,
reason=f"mode not supported by pyarrow for {pa_dtype}",
)
)
elif pa.types.is_boolean(pa_dtype) and "multi_mode" in request.node.nodeid:
# https://issues.apache.org/jira/browse/ARROW-17096
request.node.add_marker(
pytest.mark.xfail(
reason="https://issues.apache.org/jira/browse/ARROW-17096",
)
)
result = ser.mode(dropna=dropna)
expected = pd.Series(data_for_grouping.take(exp_idx))
tm.assert_series_equal(result, expected)