Skip to content

Commit 0b8d8bb

Browse files
authored
ENH/TST: Add TestBaseArithmeticOps tests for ArrowExtensionArray #47601 (#47645)
1 parent 4f54bf6 commit 0b8d8bb

File tree

7 files changed

+448
-9
lines changed

7 files changed

+448
-9
lines changed

pandas/compat/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
pa_version_under5p0,
2828
pa_version_under6p0,
2929
pa_version_under7p0,
30+
pa_version_under8p0,
3031
)
3132

3233
if TYPE_CHECKING:
@@ -158,4 +159,5 @@ def get_lzma_file() -> type[lzma.LZMAFile]:
158159
"pa_version_under5p0",
159160
"pa_version_under6p0",
160161
"pa_version_under7p0",
162+
"pa_version_under8p0",
161163
]

pandas/core/arrays/arrow/array.py

+93
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,76 @@
5757
"ge": pc.greater_equal,
5858
}
5959

60+
ARROW_LOGICAL_FUNCS = {
61+
"and": NotImplemented if pa_version_under2p0 else pc.and_kleene,
62+
"rand": NotImplemented
63+
if pa_version_under2p0
64+
else lambda x, y: pc.and_kleene(y, x),
65+
"or": NotImplemented if pa_version_under2p0 else pc.or_kleene,
66+
"ror": NotImplemented
67+
if pa_version_under2p0
68+
else lambda x, y: pc.or_kleene(y, x),
69+
"xor": NotImplemented if pa_version_under2p0 else pc.xor,
70+
"rxor": NotImplemented if pa_version_under2p0 else lambda x, y: pc.xor(y, x),
71+
}
72+
73+
def cast_for_truediv(
74+
arrow_array: pa.ChunkedArray, pa_object: pa.Array | pa.Scalar
75+
) -> pa.ChunkedArray:
76+
# Ensure int / int -> float mirroring Python/Numpy behavior
77+
# as pc.divide_checked(int, int) -> int
78+
if pa.types.is_integer(arrow_array.type) and pa.types.is_integer(
79+
pa_object.type
80+
):
81+
return arrow_array.cast(pa.float64())
82+
return arrow_array
83+
84+
def floordiv_compat(
85+
left: pa.ChunkedArray | pa.Array | pa.Scalar,
86+
right: pa.ChunkedArray | pa.Array | pa.Scalar,
87+
) -> pa.ChunkedArray:
88+
# Ensure int // int -> int mirroring Python/Numpy behavior
89+
# as pc.floor(pc.divide_checked(int, int)) -> float
90+
result = pc.floor(pc.divide_checked(left, right))
91+
if pa.types.is_integer(left.type) and pa.types.is_integer(right.type):
92+
result = result.cast(left.type)
93+
return result
94+
95+
ARROW_ARITHMETIC_FUNCS = {
96+
"add": NotImplemented if pa_version_under2p0 else pc.add_checked,
97+
"radd": NotImplemented
98+
if pa_version_under2p0
99+
else lambda x, y: pc.add_checked(y, x),
100+
"sub": NotImplemented if pa_version_under2p0 else pc.subtract_checked,
101+
"rsub": NotImplemented
102+
if pa_version_under2p0
103+
else lambda x, y: pc.subtract_checked(y, x),
104+
"mul": NotImplemented if pa_version_under2p0 else pc.multiply_checked,
105+
"rmul": NotImplemented
106+
if pa_version_under2p0
107+
else lambda x, y: pc.multiply_checked(y, x),
108+
"truediv": NotImplemented
109+
if pa_version_under2p0
110+
else lambda x, y: pc.divide_checked(cast_for_truediv(x, y), y),
111+
"rtruediv": NotImplemented
112+
if pa_version_under2p0
113+
else lambda x, y: pc.divide_checked(y, cast_for_truediv(x, y)),
114+
"floordiv": NotImplemented
115+
if pa_version_under2p0
116+
else lambda x, y: floordiv_compat(x, y),
117+
"rfloordiv": NotImplemented
118+
if pa_version_under2p0
119+
else lambda x, y: floordiv_compat(y, x),
120+
"mod": NotImplemented,
121+
"rmod": NotImplemented,
122+
"divmod": NotImplemented,
123+
"rdivmod": NotImplemented,
124+
"pow": NotImplemented if pa_version_under2p0 else pc.power_checked,
125+
"rpow": NotImplemented
126+
if pa_version_under2p0
127+
else lambda x, y: pc.power_checked(y, x),
128+
}
129+
60130
if TYPE_CHECKING:
61131
from pandas import Series
62132

@@ -74,6 +144,7 @@ def to_pyarrow_type(
74144
elif isinstance(dtype, pa.DataType):
75145
pa_dtype = dtype
76146
elif dtype:
147+
# Accepts python types too
77148
pa_dtype = pa.from_numpy_dtype(dtype)
78149
else:
79150
pa_dtype = None
@@ -263,6 +334,28 @@ def _cmp_method(self, other, op):
263334
result = result.to_numpy()
264335
return BooleanArray._from_sequence(result)
265336

337+
def _evaluate_op_method(self, other, op, arrow_funcs):
338+
pc_func = arrow_funcs[op.__name__]
339+
if pc_func is NotImplemented:
340+
raise NotImplementedError(f"{op.__name__} not implemented.")
341+
if isinstance(other, ArrowExtensionArray):
342+
result = pc_func(self._data, other._data)
343+
elif isinstance(other, (np.ndarray, list)):
344+
result = pc_func(self._data, pa.array(other, from_pandas=True))
345+
elif is_scalar(other):
346+
result = pc_func(self._data, pa.scalar(other))
347+
else:
348+
raise NotImplementedError(
349+
f"{op.__name__} not implemented for {type(other)}"
350+
)
351+
return type(self)(result)
352+
353+
def _logical_method(self, other, op):
354+
return self._evaluate_op_method(other, op, ARROW_LOGICAL_FUNCS)
355+
356+
def _arith_method(self, other, op):
357+
return self._evaluate_op_method(other, op, ARROW_ARITHMETIC_FUNCS)
358+
266359
def equals(self, other) -> bool:
267360
if not isinstance(other, ArrowExtensionArray):
268361
return False

pandas/core/strings/object_array.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ def _str_get_dummies(self, sep="|"):
360360
arr = Series(self).fillna("")
361361
try:
362362
arr = sep + arr + sep
363-
except TypeError:
363+
except (TypeError, NotImplementedError):
364364
arr = sep + arr.astype(str) + sep
365365

366366
tags: set[str] = set()

pandas/tests/arrays/string_/test_string.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def test_add(dtype, request):
101101
"unsupported operand type(s) for +: 'ArrowStringArray' and "
102102
"'ArrowStringArray'"
103103
)
104-
mark = pytest.mark.xfail(raises=TypeError, reason=reason)
104+
mark = pytest.mark.xfail(raises=NotImplementedError, reason=reason)
105105
request.node.add_marker(mark)
106106

107107
a = pd.Series(["a", "b", "c", None, None], dtype=dtype)
@@ -142,7 +142,7 @@ def test_add_2d(dtype, request):
142142
def test_add_sequence(dtype, request):
143143
if dtype.storage == "pyarrow":
144144
reason = "unsupported operand type(s) for +: 'ArrowStringArray' and 'list'"
145-
mark = pytest.mark.xfail(raises=TypeError, reason=reason)
145+
mark = pytest.mark.xfail(raises=NotImplementedError, reason=reason)
146146
request.node.add_marker(mark)
147147

148148
a = pd.array(["a", "b", None, None], dtype=dtype)
@@ -160,7 +160,7 @@ def test_add_sequence(dtype, request):
160160
def test_mul(dtype, request):
161161
if dtype.storage == "pyarrow":
162162
reason = "unsupported operand type(s) for *: 'ArrowStringArray' and 'int'"
163-
mark = pytest.mark.xfail(raises=TypeError, reason=reason)
163+
mark = pytest.mark.xfail(raises=NotImplementedError, reason=reason)
164164
request.node.add_marker(mark)
165165

166166
a = pd.array(["a", "b", None], dtype=dtype)

pandas/tests/extension/base/ops.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,10 @@ class BaseArithmeticOpsTests(BaseOpsUtil):
6767
* divmod_exc = TypeError
6868
"""
6969

70-
series_scalar_exc: type[TypeError] | None = TypeError
71-
frame_scalar_exc: type[TypeError] | None = TypeError
72-
series_array_exc: type[TypeError] | None = TypeError
73-
divmod_exc: type[TypeError] | None = TypeError
70+
series_scalar_exc: type[Exception] | None = TypeError
71+
frame_scalar_exc: type[Exception] | None = TypeError
72+
series_array_exc: type[Exception] | None = TypeError
73+
divmod_exc: type[Exception] | None = TypeError
7474

7575
def test_arith_series_with_scalar(self, data, all_arithmetic_operators):
7676
# series & scalar

0 commit comments

Comments
 (0)