Skip to content

Commit 5a95983

Browse files
authored
Backport PR #52633 on branch 2.0.x (BUG: Logical and comparison ops with ArrowDtype & masked) (#52767)
* Backport PR #52633: BUG: Logical and comparison ops with ArrowDtype & masked * Make runtime import
1 parent c10de3a commit 5a95983

File tree

3 files changed

+44
-1
lines changed

3 files changed

+44
-1
lines changed

doc/source/whatsnew/v2.0.1.rst

+1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ Bug fixes
3434
- Bug in :meth:`DataFrame.max` and related casting different :class:`Timestamp` resolutions always to nanoseconds (:issue:`52524`)
3535
- Bug in :meth:`Series.describe` not returning :class:`ArrowDtype` with ``pyarrow.float64`` type with numeric data (:issue:`52427`)
3636
- Bug in :meth:`Series.dt.tz_localize` incorrectly localizing timestamps with :class:`ArrowDtype` (:issue:`52677`)
37+
- Bug in logical and comparison operations between :class:`ArrowDtype` and numpy masked types (e.g. ``"boolean"``) (:issue:`52625`)
3738
- Fixed bug in :func:`merge` when merging with ``ArrowDtype`` one one and a NumPy dtype on the other side (:issue:`52406`)
3839
- Fixed segfault in :meth:`Series.to_numpy` with ``null[pyarrow]`` dtype (:issue:`52443`)
3940

pandas/core/arrays/arrow/array.py

+10
Original file line numberDiff line numberDiff line change
@@ -434,11 +434,16 @@ def __setstate__(self, state) -> None:
434434
self.__dict__.update(state)
435435

436436
def _cmp_method(self, other, op):
437+
from pandas.core.arrays.masked import BaseMaskedArray
438+
437439
pc_func = ARROW_CMP_FUNCS[op.__name__]
438440
if isinstance(other, ArrowExtensionArray):
439441
result = pc_func(self._data, other._data)
440442
elif isinstance(other, (np.ndarray, list)):
441443
result = pc_func(self._data, other)
444+
elif isinstance(other, BaseMaskedArray):
445+
# GH 52625
446+
result = pc_func(self._data, other.__arrow_array__())
442447
elif is_scalar(other):
443448
try:
444449
result = pc_func(self._data, pa.scalar(other))
@@ -456,6 +461,8 @@ def _cmp_method(self, other, op):
456461
return ArrowExtensionArray(result)
457462

458463
def _evaluate_op_method(self, other, op, arrow_funcs):
464+
from pandas.core.arrays.masked import BaseMaskedArray
465+
459466
pa_type = self._data.type
460467
if (pa.types.is_string(pa_type) or pa.types.is_binary(pa_type)) and op in [
461468
operator.add,
@@ -486,6 +493,9 @@ def _evaluate_op_method(self, other, op, arrow_funcs):
486493
result = pc_func(self._data, other._data)
487494
elif isinstance(other, (np.ndarray, list)):
488495
result = pc_func(self._data, pa.array(other, from_pandas=True))
496+
elif isinstance(other, BaseMaskedArray):
497+
# GH 52625
498+
result = pc_func(self._data, other.__arrow_array__())
489499
elif is_scalar(other):
490500
if isna(other) and op.__name__ in ARROW_LOGICAL_FUNCS:
491501
# pyarrow kleene ops require null to be typed

pandas/tests/extension/test_arrow.py

+33-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
BytesIO,
2222
StringIO,
2323
)
24+
import operator
2425
import pickle
2526
import re
2627

@@ -1218,7 +1219,7 @@ def test_add_series_with_extension_array(self, data, request):
12181219

12191220

12201221
class TestBaseComparisonOps(base.BaseComparisonOpsTests):
1221-
def test_compare_array(self, data, comparison_op, na_value, request):
1222+
def test_compare_array(self, data, comparison_op, na_value):
12221223
ser = pd.Series(data)
12231224
# pd.Series([ser.iloc[0]] * len(ser)) may not return ArrowExtensionArray
12241225
# since ser.iloc[0] is a python scalar
@@ -1257,6 +1258,20 @@ def test_invalid_other_comp(self, data, comparison_op):
12571258
):
12581259
comparison_op(data, object())
12591260

1261+
@pytest.mark.parametrize("masked_dtype", ["boolean", "Int64", "Float64"])
1262+
def test_comp_masked_numpy(self, masked_dtype, comparison_op):
1263+
# GH 52625
1264+
data = [1, 0, None]
1265+
ser_masked = pd.Series(data, dtype=masked_dtype)
1266+
ser_pa = pd.Series(data, dtype=f"{masked_dtype.lower()}[pyarrow]")
1267+
result = comparison_op(ser_pa, ser_masked)
1268+
if comparison_op in [operator.lt, operator.gt, operator.ne]:
1269+
exp = [False, False, None]
1270+
else:
1271+
exp = [True, True, None]
1272+
expected = pd.Series(exp, dtype=ArrowDtype(pa.bool_()))
1273+
tm.assert_series_equal(result, expected)
1274+
12601275

12611276
class TestLogicalOps:
12621277
"""Various Series and DataFrame logical ops methods."""
@@ -1401,6 +1416,23 @@ def test_kleene_xor_scalar(self, other, expected):
14011416
a, pd.Series([True, False, None], dtype="boolean[pyarrow]")
14021417
)
14031418

1419+
@pytest.mark.parametrize(
1420+
"op, exp",
1421+
[
1422+
["__and__", True],
1423+
["__or__", True],
1424+
["__xor__", False],
1425+
],
1426+
)
1427+
def test_logical_masked_numpy(self, op, exp):
1428+
# GH 52625
1429+
data = [True, False, None]
1430+
ser_masked = pd.Series(data, dtype="boolean")
1431+
ser_pa = pd.Series(data, dtype="boolean[pyarrow]")
1432+
result = getattr(ser_pa, op)(ser_masked)
1433+
expected = pd.Series([exp, False, None], dtype=ArrowDtype(pa.bool_()))
1434+
tm.assert_series_equal(result, expected)
1435+
14041436

14051437
def test_arrowdtype_construct_from_string_type_with_unsupported_parameters():
14061438
with pytest.raises(NotImplementedError, match="Passing pyarrow type"):

0 commit comments

Comments
 (0)