Skip to content

Commit 8417b4b

Browse files
authored
BUG: Allow plain bools in ExtensionArray.equals (#34661)
1 parent 7e31210 commit 8417b4b

File tree

3 files changed

+32
-3
lines changed

3 files changed

+32
-3
lines changed

pandas/core/arrays/base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -738,7 +738,7 @@ def equals(self, other: "ExtensionArray") -> bool:
738738
# boolean array with NA -> fill with False
739739
equal_values = equal_values.fillna(False)
740740
equal_na = self.isna() & other.isna()
741-
return (equal_values | equal_na).all().item()
741+
return bool((equal_values | equal_na).all())
742742

743743
def _values_for_factorize(self) -> Tuple[np.ndarray, Any]:
744744
"""

pandas/tests/extension/arrow/arrays.py

+26-2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
"""
99
import copy
1010
import itertools
11+
import operator
1112
from typing import Type
1213

1314
import numpy as np
@@ -106,6 +107,27 @@ def astype(self, dtype, copy=True):
106107
def dtype(self):
107108
return self._dtype
108109

110+
def _boolean_op(self, other, op):
111+
if not isinstance(other, type(self)):
112+
raise NotImplementedError()
113+
114+
result = op(np.array(self._data), np.array(other._data))
115+
return ArrowBoolArray(
116+
pa.chunked_array([pa.array(result, mask=pd.isna(self._data.to_pandas()))])
117+
)
118+
119+
def __eq__(self, other):
120+
if not isinstance(other, type(self)):
121+
return False
122+
123+
return self._boolean_op(other, operator.eq)
124+
125+
def __and__(self, other):
126+
return self._boolean_op(other, operator.and_)
127+
128+
def __or__(self, other):
129+
return self._boolean_op(other, operator.or_)
130+
109131
@property
110132
def nbytes(self):
111133
return sum(
@@ -153,10 +175,12 @@ def _reduce(self, method, skipna=True, **kwargs):
153175
return op(**kwargs)
154176

155177
def any(self, axis=0, out=None):
156-
return self._data.to_pandas().any()
178+
# Explicitly return a plain bool to reproduce GH-34660
179+
return bool(self._data.to_pandas().any())
157180

158181
def all(self, axis=0, out=None):
159-
return self._data.to_pandas().all()
182+
# Explicitly return a plain bool to reproduce GH-34660
183+
return bool(self._data.to_pandas().all())
160184

161185

162186
class ArrowBoolArray(ArrowExtensionArray):

pandas/tests/extension/arrow/test_bool.py

+5
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,11 @@ def data_missing():
2929
return ArrowBoolArray.from_scalars([None, True])
3030

3131

32+
def test_basic_equals(data):
33+
# https://github.com/pandas-dev/pandas/issues/34660
34+
assert pd.Series(data).equals(pd.Series(data))
35+
36+
3237
class BaseArrowTests:
3338
pass
3439

0 commit comments

Comments
 (0)