Skip to content

API: ArrowExtensionArray._cmp_method to return pyarrow.bool_ type #51643

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Mar 9, 2023
2 changes: 1 addition & 1 deletion doc/source/whatsnew/v2.1.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ See :ref:`install.dependencies` and :ref:`install.optional_dependencies` for mor

Other API changes
^^^^^^^^^^^^^^^^^
-
- :class:`~arrays.ArrowExtensionArray` comparison methods now return data with :class:`ArrowDtype` with ``pyarrow.bool_`` type instead of ``"boolean"`` dtype (:issue:`51643`)
-

.. ---------------------------------------------------------------------------
Expand Down
140 changes: 126 additions & 14 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,10 @@

from pandas.core import roperator
from pandas.core.arraylike import OpsMixin
from pandas.core.arrays.base import ExtensionArray
from pandas.core.arrays.base import (
ExtensionArray,
ExtensionArraySupportsAnyAll,
)
import pandas.core.common as com
from pandas.core.indexers import (
check_array_indexer,
Expand Down Expand Up @@ -170,7 +173,9 @@ def to_pyarrow_type(
return None


class ArrowExtensionArray(OpsMixin, ExtensionArray, BaseStringArrayMethods):
class ArrowExtensionArray(
OpsMixin, ExtensionArraySupportsAnyAll, BaseStringArrayMethods
):
"""
Pandas ExtensionArray backed by a PyArrow ChunkedArray.

Expand Down Expand Up @@ -438,8 +443,6 @@ def __setstate__(self, state) -> None:
self.__dict__.update(state)

def _cmp_method(self, other, op):
from pandas.arrays import BooleanArray

pc_func = ARROW_CMP_FUNCS[op.__name__]
if isinstance(other, ArrowExtensionArray):
result = pc_func(self._data, other._data)
Expand All @@ -453,20 +456,13 @@ def _cmp_method(self, other, op):
valid = ~mask
result = np.zeros(len(self), dtype="bool")
result[valid] = op(np.array(self)[valid], other)
return BooleanArray(result, mask)
result = pa.array(result, type=pa.bool_())
result = pc.if_else(valid, result, None)
else:
raise NotImplementedError(
f"{op.__name__} not implemented for {type(other)}"
)

if result.null_count > 0:
# GH50524: avoid conversion to object for better perf
values = pc.fill_null(result, False).to_numpy()
mask = result.is_null().to_numpy()
else:
values = result.to_numpy()
mask = np.zeros(len(values), dtype=np.bool_)
return BooleanArray(values, mask)
return ArrowExtensionArray(result)

def _evaluate_op_method(self, other, op, arrow_funcs):
pa_type = self._data.type
Expand Down Expand Up @@ -580,6 +576,122 @@ def isna(self) -> npt.NDArray[np.bool_]:

return self._data.is_null().to_numpy()

def any(self, *, skipna: bool = True, **kwargs):
"""
Return whether any element is truthy.

Returns False unless there is at least one element that is truthy.
By default, NAs are skipped. If ``skipna=False`` is specified and
missing values are present, similar :ref:`Kleene logic <boolean.kleene>`
is used as for logical operations.

Parameters
----------
skipna : bool, default True
Exclude NA values. If the entire array is NA and `skipna` is
True, then the result will be False, as for an empty array.
If `skipna` is False, the result will still be True if there is
at least one element that is truthy, otherwise NA will be returned
if there are NA's present.

Returns
-------
bool or :attr:`pandas.NA`

See Also
--------
ArrowExtensionArray.all : Return whether all elements are truthy.

Examples
--------
The result indicates whether any element is truthy (and by default
skips NAs):

>>> pd.array([True, False, True], dtype="boolean[pyarrow]").any()
True
>>> pd.array([True, False, pd.NA], dtype="boolean[pyarrow]").any()
True
>>> pd.array([False, False, pd.NA], dtype="boolean[pyarrow]").any()
False
>>> pd.array([], dtype="boolean[pyarrow]").any()
False
>>> pd.array([pd.NA], dtype="boolean[pyarrow]").any()
False
>>> pd.array([pd.NA], dtype="float64[pyarrow]").any()
False

With ``skipna=False``, the result can be NA if this is logically
required (whether ``pd.NA`` is True or False influences the result):

>>> pd.array([True, False, pd.NA], dtype="boolean[pyarrow]").any(skipna=False)
True
>>> pd.array([1, 0, pd.NA], dtype="boolean[pyarrow]").any(skipna=False)
True
>>> pd.array([False, False, pd.NA], dtype="boolean[pyarrow]").any(skipna=False)
<NA>
>>> pd.array([0, 0, pd.NA], dtype="boolean[pyarrow]").any(skipna=False)
<NA>
"""
return self._reduce("any", skipna=skipna, **kwargs)

def all(self, *, skipna: bool = True, **kwargs):
"""
Return whether all elements are truthy.

Returns True unless there is at least one element that is falsey.
By default, NAs are skipped. If ``skipna=False`` is specified and
missing values are present, similar :ref:`Kleene logic <boolean.kleene>`
is used as for logical operations.

Parameters
----------
skipna : bool, default True
Exclude NA values. If the entire array is NA and `skipna` is
True, then the result will be True, as for an empty array.
If `skipna` is False, the result will still be False if there is
at least one element that is falsey, otherwise NA will be returned
if there are NA's present.

Returns
-------
bool or :attr:`pandas.NA`

See Also
--------
ArrowExtensionArray.any : Return whether any element is truthy.

Examples
--------
The result indicates whether all elements are truthy (and by default
skips NAs):

>>> pd.array([True, True, pd.NA], dtype="boolean[pyarrow]").all()
True
>>> pd.array([1, 1, pd.NA], dtype="boolean[pyarrow]").all()
True
>>> pd.array([True, False, pd.NA], dtype="boolean[pyarrow]").all()
False
>>> pd.array([], dtype="boolean[pyarrow]").all()
True
>>> pd.array([pd.NA], dtype="boolean[pyarrow]").all()
True
>>> pd.array([pd.NA], dtype="float64[pyarrow]").all()
True

With ``skipna=False``, the result can be NA if this is logically
required (whether ``pd.NA`` is True or False influences the result):

>>> pd.array([True, True, pd.NA], dtype="boolean[pyarrow]").all(skipna=False)
<NA>
>>> pd.array([1, 1, pd.NA], dtype="boolean[pyarrow]").all(skipna=False)
<NA>
>>> pd.array([True, False, pd.NA], dtype="boolean[pyarrow]").all(skipna=False)
False
>>> pd.array([1, 0, pd.NA], dtype="boolean[pyarrow]").all(skipna=False)
False
"""
return self._reduce("all", skipna=skipna, **kwargs)

def argsort(
self,
*,
Expand Down
16 changes: 10 additions & 6 deletions pandas/tests/arrays/string_/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,16 +196,18 @@ def test_comparison_methods_scalar(comparison_op, dtype):
a = pd.array(["a", None, "c"], dtype=dtype)
other = "a"
result = getattr(a, op_name)(other)
expected_dtype = "boolean[pyarrow]" if dtype.storage == "pyarrow" else "boolean"
expected = np.array([getattr(item, op_name)(other) for item in a], dtype=object)
expected = pd.array(expected, dtype="boolean")
expected = pd.array(expected, dtype=expected_dtype)
tm.assert_extension_array_equal(result, expected)


def test_comparison_methods_scalar_pd_na(comparison_op, dtype):
op_name = f"__{comparison_op.__name__}__"
a = pd.array(["a", None, "c"], dtype=dtype)
result = getattr(a, op_name)(pd.NA)
expected = pd.array([None, None, None], dtype="boolean")
expected_dtype = "boolean[pyarrow]" if dtype.storage == "pyarrow" else "boolean"
expected = pd.array([None, None, None], dtype=expected_dtype)
tm.assert_extension_array_equal(result, expected)


Expand All @@ -225,7 +227,8 @@ def test_comparison_methods_scalar_not_string(comparison_op, dtype):
expected_data = {"__eq__": [False, None, False], "__ne__": [True, None, True]}[
op_name
]
expected = pd.array(expected_data, dtype="boolean")
expected_dtype = "boolean[pyarrow]" if dtype.storage == "pyarrow" else "boolean"
expected = pd.array(expected_data, dtype=expected_dtype)
tm.assert_extension_array_equal(result, expected)


Expand All @@ -235,13 +238,14 @@ def test_comparison_methods_array(comparison_op, dtype):
a = pd.array(["a", None, "c"], dtype=dtype)
other = [None, None, "c"]
result = getattr(a, op_name)(other)
expected = np.empty_like(a, dtype="object")
expected_dtype = "boolean[pyarrow]" if dtype.storage == "pyarrow" else "boolean"
expected = np.full(len(a), fill_value=None, dtype="object")
expected[-1] = getattr(other[-1], op_name)(a[-1])
expected = pd.array(expected, dtype="boolean")
expected = pd.array(expected, dtype=expected_dtype)
tm.assert_extension_array_equal(result, expected)

result = getattr(a, op_name)(pd.NA)
expected = pd.array([None, None, None], dtype="boolean")
expected = pd.array([None, None, None], dtype=expected_dtype)
tm.assert_extension_array_equal(result, expected)


Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/arrays/string_/test_string_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
def test_eq_all_na():
a = pd.array([pd.NA, pd.NA], dtype=StringDtype("pyarrow"))
result = a == a
expected = pd.array([pd.NA, pd.NA], dtype="boolean")
expected = pd.array([pd.NA, pd.NA], dtype="boolean[pyarrow]")
tm.assert_extension_array_equal(result, expected)


Expand Down
14 changes: 0 additions & 14 deletions pandas/tests/extension/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1215,14 +1215,7 @@ def test_add_series_with_extension_array(self, data, request):


class TestBaseComparisonOps(base.BaseComparisonOpsTests):
def assert_series_equal(self, left, right, *args, **kwargs):
# Series.combine for "expected" retains bool[pyarrow] dtype
# While "result" return "boolean" dtype
right = pd.Series(right._values.to_numpy(), dtype="boolean")
super().assert_series_equal(left, right, *args, **kwargs)

def test_compare_array(self, data, comparison_op, na_value, request):
pa_dtype = data.dtype.pyarrow_dtype
ser = pd.Series(data)
# pd.Series([ser.iloc[0]] * len(ser)) may not return ArrowExtensionArray
# since ser.iloc[0] is a python scalar
Expand All @@ -1248,13 +1241,6 @@ def test_compare_array(self, data, comparison_op, na_value, request):

if exc is None:
# Didn't error, then should match point-wise behavior
if pa.types.is_temporal(pa_dtype):
# point-wise comparison with pd.NA raises TypeError
assert result[8] is na_value
assert result[97] is na_value
result = result.drop([8, 97]).reset_index(drop=True)
ser = ser.drop([8, 97])
other = other.drop([8, 97])
expected = ser.combine(other, comparison_op)
self.assert_series_equal(result, expected)
else:
Expand Down
3 changes: 2 additions & 1 deletion pandas/tests/extension/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,8 @@ class TestComparisonOps(base.BaseComparisonOpsTests):
def _compare_other(self, ser, data, op, other):
op_name = f"__{op.__name__}__"
result = getattr(ser, op_name)(other)
expected = getattr(ser.astype(object), op_name)(other).astype("boolean")
dtype = "boolean[pyarrow]" if ser.dtype.storage == "pyarrow" else "boolean"
expected = getattr(ser.astype(object), op_name)(other).astype(dtype)
self.assert_series_equal(result, expected)

def test_compare_scalar(self, data, comparison_op):
Expand Down