Skip to content

Commit b3d18af

Browse files
authored
API: ArrowExtensionArray._cmp_method to return pyarrow.bool_ type (#51643)
* API: ArrowExtensionArray._cmp_method to return pyarrow.bool_ type * whatsnew * try removing asv parallel build * fix logical func keys * cleanup * fix test * subclass ExtensionArraySupportsAnyAll
1 parent dd39140 commit b3d18af

File tree

6 files changed

+140
-37
lines changed

6 files changed

+140
-37
lines changed

doc/source/whatsnew/v2.1.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ See :ref:`install.dependencies` and :ref:`install.optional_dependencies` for mor
8686

8787
Other API changes
8888
^^^^^^^^^^^^^^^^^
89-
-
89+
- :class:`~arrays.ArrowExtensionArray` comparison methods now return data with :class:`ArrowDtype` with ``pyarrow.bool_`` type instead of ``"boolean"`` dtype (:issue:`51643`)
9090
-
9191

9292
.. ---------------------------------------------------------------------------

pandas/core/arrays/arrow/array.py

+126-14
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,10 @@
5252

5353
from pandas.core import roperator
5454
from pandas.core.arraylike import OpsMixin
55-
from pandas.core.arrays.base import ExtensionArray
55+
from pandas.core.arrays.base import (
56+
ExtensionArray,
57+
ExtensionArraySupportsAnyAll,
58+
)
5659
import pandas.core.common as com
5760
from pandas.core.indexers import (
5861
check_array_indexer,
@@ -170,7 +173,9 @@ def to_pyarrow_type(
170173
return None
171174

172175

173-
class ArrowExtensionArray(OpsMixin, ExtensionArray, BaseStringArrayMethods):
176+
class ArrowExtensionArray(
177+
OpsMixin, ExtensionArraySupportsAnyAll, BaseStringArrayMethods
178+
):
174179
"""
175180
Pandas ExtensionArray backed by a PyArrow ChunkedArray.
176181
@@ -438,8 +443,6 @@ def __setstate__(self, state) -> None:
438443
self.__dict__.update(state)
439444

440445
def _cmp_method(self, other, op):
441-
from pandas.arrays import BooleanArray
442-
443446
pc_func = ARROW_CMP_FUNCS[op.__name__]
444447
if isinstance(other, ArrowExtensionArray):
445448
result = pc_func(self._data, other._data)
@@ -453,20 +456,13 @@ def _cmp_method(self, other, op):
453456
valid = ~mask
454457
result = np.zeros(len(self), dtype="bool")
455458
result[valid] = op(np.array(self)[valid], other)
456-
return BooleanArray(result, mask)
459+
result = pa.array(result, type=pa.bool_())
460+
result = pc.if_else(valid, result, None)
457461
else:
458462
raise NotImplementedError(
459463
f"{op.__name__} not implemented for {type(other)}"
460464
)
461-
462-
if result.null_count > 0:
463-
# GH50524: avoid conversion to object for better perf
464-
values = pc.fill_null(result, False).to_numpy()
465-
mask = result.is_null().to_numpy()
466-
else:
467-
values = result.to_numpy()
468-
mask = np.zeros(len(values), dtype=np.bool_)
469-
return BooleanArray(values, mask)
465+
return ArrowExtensionArray(result)
470466

471467
def _evaluate_op_method(self, other, op, arrow_funcs):
472468
pa_type = self._data.type
@@ -580,6 +576,122 @@ def isna(self) -> npt.NDArray[np.bool_]:
580576

581577
return self._data.is_null().to_numpy()
582578

579+
def any(self, *, skipna: bool = True, **kwargs):
580+
"""
581+
Return whether any element is truthy.
582+
583+
Returns False unless there is at least one element that is truthy.
584+
By default, NAs are skipped. If ``skipna=False`` is specified and
585+
missing values are present, similar :ref:`Kleene logic <boolean.kleene>`
586+
is used as for logical operations.
587+
588+
Parameters
589+
----------
590+
skipna : bool, default True
591+
Exclude NA values. If the entire array is NA and `skipna` is
592+
True, then the result will be False, as for an empty array.
593+
If `skipna` is False, the result will still be True if there is
594+
at least one element that is truthy, otherwise NA will be returned
595+
if there are NA's present.
596+
597+
Returns
598+
-------
599+
bool or :attr:`pandas.NA`
600+
601+
See Also
602+
--------
603+
ArrowExtensionArray.all : Return whether all elements are truthy.
604+
605+
Examples
606+
--------
607+
The result indicates whether any element is truthy (and by default
608+
skips NAs):
609+
610+
>>> pd.array([True, False, True], dtype="boolean[pyarrow]").any()
611+
True
612+
>>> pd.array([True, False, pd.NA], dtype="boolean[pyarrow]").any()
613+
True
614+
>>> pd.array([False, False, pd.NA], dtype="boolean[pyarrow]").any()
615+
False
616+
>>> pd.array([], dtype="boolean[pyarrow]").any()
617+
False
618+
>>> pd.array([pd.NA], dtype="boolean[pyarrow]").any()
619+
False
620+
>>> pd.array([pd.NA], dtype="float64[pyarrow]").any()
621+
False
622+
623+
With ``skipna=False``, the result can be NA if this is logically
624+
required (whether ``pd.NA`` is True or False influences the result):
625+
626+
>>> pd.array([True, False, pd.NA], dtype="boolean[pyarrow]").any(skipna=False)
627+
True
628+
>>> pd.array([1, 0, pd.NA], dtype="boolean[pyarrow]").any(skipna=False)
629+
True
630+
>>> pd.array([False, False, pd.NA], dtype="boolean[pyarrow]").any(skipna=False)
631+
<NA>
632+
>>> pd.array([0, 0, pd.NA], dtype="boolean[pyarrow]").any(skipna=False)
633+
<NA>
634+
"""
635+
return self._reduce("any", skipna=skipna, **kwargs)
636+
637+
def all(self, *, skipna: bool = True, **kwargs):
638+
"""
639+
Return whether all elements are truthy.
640+
641+
Returns True unless there is at least one element that is falsey.
642+
By default, NAs are skipped. If ``skipna=False`` is specified and
643+
missing values are present, similar :ref:`Kleene logic <boolean.kleene>`
644+
is used as for logical operations.
645+
646+
Parameters
647+
----------
648+
skipna : bool, default True
649+
Exclude NA values. If the entire array is NA and `skipna` is
650+
True, then the result will be True, as for an empty array.
651+
If `skipna` is False, the result will still be False if there is
652+
at least one element that is falsey, otherwise NA will be returned
653+
if there are NA's present.
654+
655+
Returns
656+
-------
657+
bool or :attr:`pandas.NA`
658+
659+
See Also
660+
--------
661+
ArrowExtensionArray.any : Return whether any element is truthy.
662+
663+
Examples
664+
--------
665+
The result indicates whether all elements are truthy (and by default
666+
skips NAs):
667+
668+
>>> pd.array([True, True, pd.NA], dtype="boolean[pyarrow]").all()
669+
True
670+
>>> pd.array([1, 1, pd.NA], dtype="boolean[pyarrow]").all()
671+
True
672+
>>> pd.array([True, False, pd.NA], dtype="boolean[pyarrow]").all()
673+
False
674+
>>> pd.array([], dtype="boolean[pyarrow]").all()
675+
True
676+
>>> pd.array([pd.NA], dtype="boolean[pyarrow]").all()
677+
True
678+
>>> pd.array([pd.NA], dtype="float64[pyarrow]").all()
679+
True
680+
681+
With ``skipna=False``, the result can be NA if this is logically
682+
required (whether ``pd.NA`` is True or False influences the result):
683+
684+
>>> pd.array([True, True, pd.NA], dtype="boolean[pyarrow]").all(skipna=False)
685+
<NA>
686+
>>> pd.array([1, 1, pd.NA], dtype="boolean[pyarrow]").all(skipna=False)
687+
<NA>
688+
>>> pd.array([True, False, pd.NA], dtype="boolean[pyarrow]").all(skipna=False)
689+
False
690+
>>> pd.array([1, 0, pd.NA], dtype="boolean[pyarrow]").all(skipna=False)
691+
False
692+
"""
693+
return self._reduce("all", skipna=skipna, **kwargs)
694+
583695
def argsort(
584696
self,
585697
*,

pandas/tests/arrays/string_/test_string.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -196,16 +196,18 @@ def test_comparison_methods_scalar(comparison_op, dtype):
196196
a = pd.array(["a", None, "c"], dtype=dtype)
197197
other = "a"
198198
result = getattr(a, op_name)(other)
199+
expected_dtype = "boolean[pyarrow]" if dtype.storage == "pyarrow" else "boolean"
199200
expected = np.array([getattr(item, op_name)(other) for item in a], dtype=object)
200-
expected = pd.array(expected, dtype="boolean")
201+
expected = pd.array(expected, dtype=expected_dtype)
201202
tm.assert_extension_array_equal(result, expected)
202203

203204

204205
def test_comparison_methods_scalar_pd_na(comparison_op, dtype):
205206
op_name = f"__{comparison_op.__name__}__"
206207
a = pd.array(["a", None, "c"], dtype=dtype)
207208
result = getattr(a, op_name)(pd.NA)
208-
expected = pd.array([None, None, None], dtype="boolean")
209+
expected_dtype = "boolean[pyarrow]" if dtype.storage == "pyarrow" else "boolean"
210+
expected = pd.array([None, None, None], dtype=expected_dtype)
209211
tm.assert_extension_array_equal(result, expected)
210212

211213

@@ -225,7 +227,8 @@ def test_comparison_methods_scalar_not_string(comparison_op, dtype):
225227
expected_data = {"__eq__": [False, None, False], "__ne__": [True, None, True]}[
226228
op_name
227229
]
228-
expected = pd.array(expected_data, dtype="boolean")
230+
expected_dtype = "boolean[pyarrow]" if dtype.storage == "pyarrow" else "boolean"
231+
expected = pd.array(expected_data, dtype=expected_dtype)
229232
tm.assert_extension_array_equal(result, expected)
230233

231234

@@ -235,13 +238,14 @@ def test_comparison_methods_array(comparison_op, dtype):
235238
a = pd.array(["a", None, "c"], dtype=dtype)
236239
other = [None, None, "c"]
237240
result = getattr(a, op_name)(other)
238-
expected = np.empty_like(a, dtype="object")
241+
expected_dtype = "boolean[pyarrow]" if dtype.storage == "pyarrow" else "boolean"
242+
expected = np.full(len(a), fill_value=None, dtype="object")
239243
expected[-1] = getattr(other[-1], op_name)(a[-1])
240-
expected = pd.array(expected, dtype="boolean")
244+
expected = pd.array(expected, dtype=expected_dtype)
241245
tm.assert_extension_array_equal(result, expected)
242246

243247
result = getattr(a, op_name)(pd.NA)
244-
expected = pd.array([None, None, None], dtype="boolean")
248+
expected = pd.array([None, None, None], dtype=expected_dtype)
245249
tm.assert_extension_array_equal(result, expected)
246250

247251

pandas/tests/arrays/string_/test_string_arrow.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
def test_eq_all_na():
2525
a = pd.array([pd.NA, pd.NA], dtype=StringDtype("pyarrow"))
2626
result = a == a
27-
expected = pd.array([pd.NA, pd.NA], dtype="boolean")
27+
expected = pd.array([pd.NA, pd.NA], dtype="boolean[pyarrow]")
2828
tm.assert_extension_array_equal(result, expected)
2929

3030

pandas/tests/extension/test_arrow.py

-14
Original file line numberDiff line numberDiff line change
@@ -1215,14 +1215,7 @@ def test_add_series_with_extension_array(self, data, request):
12151215

12161216

12171217
class TestBaseComparisonOps(base.BaseComparisonOpsTests):
1218-
def assert_series_equal(self, left, right, *args, **kwargs):
1219-
# Series.combine for "expected" retains bool[pyarrow] dtype
1220-
# While "result" return "boolean" dtype
1221-
right = pd.Series(right._values.to_numpy(), dtype="boolean")
1222-
super().assert_series_equal(left, right, *args, **kwargs)
1223-
12241218
def test_compare_array(self, data, comparison_op, na_value, request):
1225-
pa_dtype = data.dtype.pyarrow_dtype
12261219
ser = pd.Series(data)
12271220
# pd.Series([ser.iloc[0]] * len(ser)) may not return ArrowExtensionArray
12281221
# since ser.iloc[0] is a python scalar
@@ -1248,13 +1241,6 @@ def test_compare_array(self, data, comparison_op, na_value, request):
12481241

12491242
if exc is None:
12501243
# Didn't error, then should match point-wise behavior
1251-
if pa.types.is_temporal(pa_dtype):
1252-
# point-wise comparison with pd.NA raises TypeError
1253-
assert result[8] is na_value
1254-
assert result[97] is na_value
1255-
result = result.drop([8, 97]).reset_index(drop=True)
1256-
ser = ser.drop([8, 97])
1257-
other = other.drop([8, 97])
12581244
expected = ser.combine(other, comparison_op)
12591245
self.assert_series_equal(result, expected)
12601246
else:

pandas/tests/extension/test_string.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,8 @@ class TestComparisonOps(base.BaseComparisonOpsTests):
221221
def _compare_other(self, ser, data, op, other):
222222
op_name = f"__{op.__name__}__"
223223
result = getattr(ser, op_name)(other)
224-
expected = getattr(ser.astype(object), op_name)(other).astype("boolean")
224+
dtype = "boolean[pyarrow]" if ser.dtype.storage == "pyarrow" else "boolean"
225+
expected = getattr(ser.astype(object), op_name)(other).astype(dtype)
225226
self.assert_series_equal(result, expected)
226227

227228
def test_compare_scalar(self, data, comparison_op):

0 commit comments

Comments
 (0)