Skip to content

Commit e9b3fa8

Browse files
lukemanleyphofl
authored andcommitted
API: ArrowExtensionArray._cmp_method to return pyarrow.bool_ type (pandas-dev#51643)
* API: ArrowExtensionArray._cmp_method to return pyarrow.bool_ type * whatsnew * try removing asv parallel build * fix logical func keys * cleanup * fix test * subclass ExtensionArraySupportsAnyAll
1 parent f8571f2 commit e9b3fa8

File tree

5 files changed

+139
-36
lines changed

5 files changed

+139
-36
lines changed

pandas/core/arrays/arrow/array.py

+126-14
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,10 @@
5353

5454
from pandas.core import roperator
5555
from pandas.core.arraylike import OpsMixin
56-
from pandas.core.arrays.base import ExtensionArray
56+
from pandas.core.arrays.base import (
57+
ExtensionArray,
58+
ExtensionArraySupportsAnyAll,
59+
)
5760
import pandas.core.common as com
5861
from pandas.core.indexers import (
5962
check_array_indexer,
@@ -171,7 +174,9 @@ def to_pyarrow_type(
171174
return None
172175

173176

174-
class ArrowExtensionArray(OpsMixin, ExtensionArray, BaseStringArrayMethods):
177+
class ArrowExtensionArray(
178+
OpsMixin, ExtensionArraySupportsAnyAll, BaseStringArrayMethods
179+
):
175180
"""
176181
Pandas ExtensionArray backed by a PyArrow ChunkedArray.
177182
@@ -429,8 +434,6 @@ def __setstate__(self, state) -> None:
429434
self.__dict__.update(state)
430435

431436
def _cmp_method(self, other, op):
432-
from pandas.arrays import BooleanArray
433-
434437
pc_func = ARROW_CMP_FUNCS[op.__name__]
435438
if isinstance(other, ArrowExtensionArray):
436439
result = pc_func(self._data, other._data)
@@ -444,20 +447,13 @@ def _cmp_method(self, other, op):
444447
valid = ~mask
445448
result = np.zeros(len(self), dtype="bool")
446449
result[valid] = op(np.array(self)[valid], other)
447-
return BooleanArray(result, mask)
450+
result = pa.array(result, type=pa.bool_())
451+
result = pc.if_else(valid, result, None)
448452
else:
449453
raise NotImplementedError(
450454
f"{op.__name__} not implemented for {type(other)}"
451455
)
452-
453-
if result.null_count > 0:
454-
# GH50524: avoid conversion to object for better perf
455-
values = pc.fill_null(result, False).to_numpy()
456-
mask = result.is_null().to_numpy()
457-
else:
458-
values = result.to_numpy()
459-
mask = np.zeros(len(values), dtype=np.bool_)
460-
return BooleanArray(values, mask)
456+
return ArrowExtensionArray(result)
461457

462458
def _evaluate_op_method(self, other, op, arrow_funcs):
463459
pa_type = self._data.type
@@ -564,6 +560,122 @@ def isna(self) -> npt.NDArray[np.bool_]:
564560
"""
565561
return self._data.is_null().to_numpy()
566562

563+
def any(self, *, skipna: bool = True, **kwargs):
564+
"""
565+
Return whether any element is truthy.
566+
567+
Returns False unless there is at least one element that is truthy.
568+
By default, NAs are skipped. If ``skipna=False`` is specified and
569+
missing values are present, similar :ref:`Kleene logic <boolean.kleene>`
570+
is used as for logical operations.
571+
572+
Parameters
573+
----------
574+
skipna : bool, default True
575+
Exclude NA values. If the entire array is NA and `skipna` is
576+
True, then the result will be False, as for an empty array.
577+
If `skipna` is False, the result will still be True if there is
578+
at least one element that is truthy, otherwise NA will be returned
579+
if there are NA's present.
580+
581+
Returns
582+
-------
583+
bool or :attr:`pandas.NA`
584+
585+
See Also
586+
--------
587+
ArrowExtensionArray.all : Return whether all elements are truthy.
588+
589+
Examples
590+
--------
591+
The result indicates whether any element is truthy (and by default
592+
skips NAs):
593+
594+
>>> pd.array([True, False, True], dtype="boolean[pyarrow]").any()
595+
True
596+
>>> pd.array([True, False, pd.NA], dtype="boolean[pyarrow]").any()
597+
True
598+
>>> pd.array([False, False, pd.NA], dtype="boolean[pyarrow]").any()
599+
False
600+
>>> pd.array([], dtype="boolean[pyarrow]").any()
601+
False
602+
>>> pd.array([pd.NA], dtype="boolean[pyarrow]").any()
603+
False
604+
>>> pd.array([pd.NA], dtype="float64[pyarrow]").any()
605+
False
606+
607+
With ``skipna=False``, the result can be NA if this is logically
608+
required (whether ``pd.NA`` is True or False influences the result):
609+
610+
>>> pd.array([True, False, pd.NA], dtype="boolean[pyarrow]").any(skipna=False)
611+
True
612+
>>> pd.array([1, 0, pd.NA], dtype="boolean[pyarrow]").any(skipna=False)
613+
True
614+
>>> pd.array([False, False, pd.NA], dtype="boolean[pyarrow]").any(skipna=False)
615+
<NA>
616+
>>> pd.array([0, 0, pd.NA], dtype="boolean[pyarrow]").any(skipna=False)
617+
<NA>
618+
"""
619+
return self._reduce("any", skipna=skipna, **kwargs)
620+
621+
def all(self, *, skipna: bool = True, **kwargs):
622+
"""
623+
Return whether all elements are truthy.
624+
625+
Returns True unless there is at least one element that is falsey.
626+
By default, NAs are skipped. If ``skipna=False`` is specified and
627+
missing values are present, similar :ref:`Kleene logic <boolean.kleene>`
628+
is used as for logical operations.
629+
630+
Parameters
631+
----------
632+
skipna : bool, default True
633+
Exclude NA values. If the entire array is NA and `skipna` is
634+
True, then the result will be True, as for an empty array.
635+
If `skipna` is False, the result will still be False if there is
636+
at least one element that is falsey, otherwise NA will be returned
637+
if there are NA's present.
638+
639+
Returns
640+
-------
641+
bool or :attr:`pandas.NA`
642+
643+
See Also
644+
--------
645+
ArrowExtensionArray.any : Return whether any element is truthy.
646+
647+
Examples
648+
--------
649+
The result indicates whether all elements are truthy (and by default
650+
skips NAs):
651+
652+
>>> pd.array([True, True, pd.NA], dtype="boolean[pyarrow]").all()
653+
True
654+
>>> pd.array([1, 1, pd.NA], dtype="boolean[pyarrow]").all()
655+
True
656+
>>> pd.array([True, False, pd.NA], dtype="boolean[pyarrow]").all()
657+
False
658+
>>> pd.array([], dtype="boolean[pyarrow]").all()
659+
True
660+
>>> pd.array([pd.NA], dtype="boolean[pyarrow]").all()
661+
True
662+
>>> pd.array([pd.NA], dtype="float64[pyarrow]").all()
663+
True
664+
665+
With ``skipna=False``, the result can be NA if this is logically
666+
required (whether ``pd.NA`` is True or False influences the result):
667+
668+
>>> pd.array([True, True, pd.NA], dtype="boolean[pyarrow]").all(skipna=False)
669+
<NA>
670+
>>> pd.array([1, 1, pd.NA], dtype="boolean[pyarrow]").all(skipna=False)
671+
<NA>
672+
>>> pd.array([True, False, pd.NA], dtype="boolean[pyarrow]").all(skipna=False)
673+
False
674+
>>> pd.array([1, 0, pd.NA], dtype="boolean[pyarrow]").all(skipna=False)
675+
False
676+
"""
677+
return self._reduce("all", skipna=skipna, **kwargs)
678+
567679
def argsort(
568680
self,
569681
*,

pandas/tests/arrays/string_/test_string.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -196,16 +196,18 @@ def test_comparison_methods_scalar(comparison_op, dtype):
196196
a = pd.array(["a", None, "c"], dtype=dtype)
197197
other = "a"
198198
result = getattr(a, op_name)(other)
199+
expected_dtype = "boolean[pyarrow]" if dtype.storage == "pyarrow" else "boolean"
199200
expected = np.array([getattr(item, op_name)(other) for item in a], dtype=object)
200-
expected = pd.array(expected, dtype="boolean")
201+
expected = pd.array(expected, dtype=expected_dtype)
201202
tm.assert_extension_array_equal(result, expected)
202203

203204

204205
def test_comparison_methods_scalar_pd_na(comparison_op, dtype):
205206
op_name = f"__{comparison_op.__name__}__"
206207
a = pd.array(["a", None, "c"], dtype=dtype)
207208
result = getattr(a, op_name)(pd.NA)
208-
expected = pd.array([None, None, None], dtype="boolean")
209+
expected_dtype = "boolean[pyarrow]" if dtype.storage == "pyarrow" else "boolean"
210+
expected = pd.array([None, None, None], dtype=expected_dtype)
209211
tm.assert_extension_array_equal(result, expected)
210212

211213

@@ -225,7 +227,8 @@ def test_comparison_methods_scalar_not_string(comparison_op, dtype):
225227
expected_data = {"__eq__": [False, None, False], "__ne__": [True, None, True]}[
226228
op_name
227229
]
228-
expected = pd.array(expected_data, dtype="boolean")
230+
expected_dtype = "boolean[pyarrow]" if dtype.storage == "pyarrow" else "boolean"
231+
expected = pd.array(expected_data, dtype=expected_dtype)
229232
tm.assert_extension_array_equal(result, expected)
230233

231234

@@ -235,13 +238,14 @@ def test_comparison_methods_array(comparison_op, dtype):
235238
a = pd.array(["a", None, "c"], dtype=dtype)
236239
other = [None, None, "c"]
237240
result = getattr(a, op_name)(other)
238-
expected = np.empty_like(a, dtype="object")
241+
expected_dtype = "boolean[pyarrow]" if dtype.storage == "pyarrow" else "boolean"
242+
expected = np.full(len(a), fill_value=None, dtype="object")
239243
expected[-1] = getattr(other[-1], op_name)(a[-1])
240-
expected = pd.array(expected, dtype="boolean")
244+
expected = pd.array(expected, dtype=expected_dtype)
241245
tm.assert_extension_array_equal(result, expected)
242246

243247
result = getattr(a, op_name)(pd.NA)
244-
expected = pd.array([None, None, None], dtype="boolean")
248+
expected = pd.array([None, None, None], dtype=expected_dtype)
245249
tm.assert_extension_array_equal(result, expected)
246250

247251

pandas/tests/arrays/string_/test_string_arrow.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
def test_eq_all_na():
2525
a = pd.array([pd.NA, pd.NA], dtype=StringDtype("pyarrow"))
2626
result = a == a
27-
expected = pd.array([pd.NA, pd.NA], dtype="boolean")
27+
expected = pd.array([pd.NA, pd.NA], dtype="boolean[pyarrow]")
2828
tm.assert_extension_array_equal(result, expected)
2929

3030

pandas/tests/extension/test_arrow.py

-14
Original file line numberDiff line numberDiff line change
@@ -1217,14 +1217,7 @@ def test_add_series_with_extension_array(self, data, request):
12171217

12181218

12191219
class TestBaseComparisonOps(base.BaseComparisonOpsTests):
1220-
def assert_series_equal(self, left, right, *args, **kwargs):
1221-
# Series.combine for "expected" retains bool[pyarrow] dtype
1222-
# While "result" return "boolean" dtype
1223-
right = pd.Series(right._values.to_numpy(), dtype="boolean")
1224-
super().assert_series_equal(left, right, *args, **kwargs)
1225-
12261220
def test_compare_array(self, data, comparison_op, na_value, request):
1227-
pa_dtype = data.dtype.pyarrow_dtype
12281221
ser = pd.Series(data)
12291222
# pd.Series([ser.iloc[0]] * len(ser)) may not return ArrowExtensionArray
12301223
# since ser.iloc[0] is a python scalar
@@ -1250,13 +1243,6 @@ def test_compare_array(self, data, comparison_op, na_value, request):
12501243

12511244
if exc is None:
12521245
# Didn't error, then should match point-wise behavior
1253-
if pa.types.is_temporal(pa_dtype):
1254-
# point-wise comparison with pd.NA raises TypeError
1255-
assert result[8] is na_value
1256-
assert result[97] is na_value
1257-
result = result.drop([8, 97]).reset_index(drop=True)
1258-
ser = ser.drop([8, 97])
1259-
other = other.drop([8, 97])
12601246
expected = ser.combine(other, comparison_op)
12611247
self.assert_series_equal(result, expected)
12621248
else:

pandas/tests/extension/test_string.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,8 @@ class TestComparisonOps(base.BaseComparisonOpsTests):
221221
def _compare_other(self, ser, data, op, other):
222222
op_name = f"__{op.__name__}__"
223223
result = getattr(ser, op_name)(other)
224-
expected = getattr(ser.astype(object), op_name)(other).astype("boolean")
224+
dtype = "boolean[pyarrow]" if ser.dtype.storage == "pyarrow" else "boolean"
225+
expected = getattr(ser.astype(object), op_name)(other).astype(dtype)
225226
self.assert_series_equal(result, expected)
226227

227228
def test_compare_scalar(self, data, comparison_op):

0 commit comments

Comments
 (0)