Skip to content

Commit 69b08d3

Browse files
authored
REF: share NumericArray._arith_method with BooleanArray (#45849)
1 parent 1699197 commit 69b08d3

File tree

5 files changed

+93
-108
lines changed

5 files changed

+93
-108
lines changed

pandas/core/arrays/boolean.py

-37
Original file line numberDiff line numberDiff line change
@@ -382,42 +382,5 @@ def _logical_method(self, other, op):
382382
# expected "ndarray"
383383
return BooleanArray(result, mask) # type: ignore[arg-type]
384384

385-
def _arith_method(self, other, op):
386-
mask = None
387-
op_name = op.__name__
388-
389-
if isinstance(other, BaseMaskedArray):
390-
other, mask = other._data, other._mask
391-
392-
elif is_list_like(other):
393-
other = np.asarray(other)
394-
if other.ndim > 1:
395-
raise NotImplementedError("can only perform ops with 1-d structures")
396-
if len(self) != len(other):
397-
raise ValueError("Lengths must match")
398-
399-
mask = self._propagate_mask(mask, other)
400-
401-
if other is libmissing.NA:
402-
# if other is NA, the result will be all NA and we can't run the
403-
# actual op, so we need to choose the resulting dtype manually
404-
if op_name in {"floordiv", "rfloordiv", "mod", "rmod", "pow", "rpow"}:
405-
dtype = "int8"
406-
elif op_name in {"truediv", "rtruediv"}:
407-
dtype = "float64"
408-
else:
409-
dtype = "bool"
410-
result = np.zeros(len(self._data), dtype=dtype)
411-
else:
412-
if op_name in {"pow", "rpow"} and isinstance(other, np.bool_):
413-
# Avoid DeprecationWarning: In future, it will be an error
414-
# for 'np.bool_' scalars to be interpreted as an index
415-
other = bool(other)
416-
417-
with np.errstate(all="ignore"):
418-
result = op(self._data, other)
419-
420-
return self._maybe_mask_result(result, mask, other, op_name)
421-
422385
def __abs__(self):
423386
return self.copy()

pandas/core/arrays/masked.py

+80
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
from pandas.core.array_algos.quantile import quantile_with_mask
7777
from pandas.core.arraylike import OpsMixin
7878
from pandas.core.arrays import ExtensionArray
79+
from pandas.core.construction import ensure_wrapped_if_datetimelike
7980
from pandas.core.indexers import check_array_indexer
8081
from pandas.core.ops import invalid_comparison
8182

@@ -593,6 +594,85 @@ def _propagate_mask(
593594
mask = self._mask | mask
594595
return mask
595596

597+
def _arith_method(self, other, op):
598+
op_name = op.__name__
599+
omask = None
600+
601+
if isinstance(other, BaseMaskedArray):
602+
other, omask = other._data, other._mask
603+
604+
elif is_list_like(other):
605+
if not isinstance(other, ExtensionArray):
606+
other = np.asarray(other)
607+
if other.ndim > 1:
608+
raise NotImplementedError("can only perform ops with 1-d structures")
609+
610+
# We wrap the non-masked arithmetic logic used for numpy dtypes
611+
# in Series/Index arithmetic ops.
612+
other = ops.maybe_prepare_scalar_for_op(other, (len(self),))
613+
pd_op = ops.get_array_op(op)
614+
other = ensure_wrapped_if_datetimelike(other)
615+
616+
if op_name in {"pow", "rpow"} and isinstance(other, np.bool_):
617+
# Avoid DeprecationWarning: In future, it will be an error
618+
# for 'np.bool_' scalars to be interpreted as an index
619+
# e.g. test_array_scalar_like_equivalence
620+
other = bool(other)
621+
622+
mask = self._propagate_mask(omask, other)
623+
624+
if other is libmissing.NA:
625+
result = np.ones_like(self._data)
626+
if self.dtype.kind == "b":
627+
if op_name in {"floordiv", "rfloordiv", "mod", "rmod", "pow", "rpow"}:
628+
dtype = "int8"
629+
elif op_name in {"truediv", "rtruediv"}:
630+
dtype = "float64"
631+
else:
632+
dtype = "bool"
633+
result = result.astype(dtype)
634+
elif "truediv" in op_name and self.dtype.kind != "f":
635+
# The actual data here doesn't matter since the mask
636+
# will be all-True, but since this is division, we want
637+
# to end up with floating dtype.
638+
result = result.astype(np.float64)
639+
else:
640+
# Make sure we do this before the "pow" mask checks
641+
# to get an expected exception message on shape mismatch.
642+
if self.dtype.kind in ["i", "u"] and op_name in ["floordiv", "mod"]:
643+
# TODO(GH#30188) ATM we don't match the behavior of non-masked
644+
# types with respect to floordiv-by-zero
645+
pd_op = op
646+
647+
elif self.dtype.kind == "b" and (
648+
"div" in op_name or "pow" in op_name or "mod" in op_name
649+
):
650+
# TODO(GH#41165): should these be disallowed?
651+
pd_op = op
652+
653+
with np.errstate(all="ignore"):
654+
result = pd_op(self._data, other)
655+
656+
if op_name == "pow":
657+
# 1 ** x is 1.
658+
mask = np.where((self._data == 1) & ~self._mask, False, mask)
659+
# x ** 0 is 1.
660+
if omask is not None:
661+
mask = np.where((other == 0) & ~omask, False, mask)
662+
elif other is not libmissing.NA:
663+
mask = np.where(other == 0, False, mask)
664+
665+
elif op_name == "rpow":
666+
# 1 ** x is 1.
667+
if omask is not None:
668+
mask = np.where((other == 1) & ~omask, False, mask)
669+
elif other is not libmissing.NA:
670+
mask = np.where(other == 1, False, mask)
671+
# x ** 0 is 1.
672+
mask = np.where((self._data == 0) & ~self._mask, False, mask)
673+
674+
return self._maybe_mask_result(result, mask, other, op_name)
675+
596676
def _cmp_method(self, other, op) -> BooleanArray:
597677
from pandas.core.arrays import BooleanArray
598678

pandas/core/arrays/numeric.py

-63
Original file line numberDiff line numberDiff line change
@@ -23,19 +23,15 @@
2323
is_bool_dtype,
2424
is_float_dtype,
2525
is_integer_dtype,
26-
is_list_like,
2726
is_object_dtype,
2827
is_string_dtype,
2928
pandas_dtype,
3029
)
3130

32-
from pandas.core import ops
33-
from pandas.core.arrays.base import ExtensionArray
3431
from pandas.core.arrays.masked import (
3532
BaseMaskedArray,
3633
BaseMaskedDtype,
3734
)
38-
from pandas.core.construction import ensure_wrapped_if_datetimelike
3935

4036
if TYPE_CHECKING:
4137
import pyarrow
@@ -214,65 +210,6 @@ def _from_sequence_of_strings(
214210
scalars = to_numeric(strings, errors="raise")
215211
return cls._from_sequence(scalars, dtype=dtype, copy=copy)
216212

217-
def _arith_method(self, other, op):
218-
op_name = op.__name__
219-
omask = None
220-
221-
if isinstance(other, BaseMaskedArray):
222-
other, omask = other._data, other._mask
223-
224-
elif is_list_like(other):
225-
if not isinstance(other, ExtensionArray):
226-
other = np.asarray(other)
227-
if other.ndim > 1:
228-
raise NotImplementedError("can only perform ops with 1-d structures")
229-
230-
# We wrap the non-masked arithmetic logic used for numpy dtypes
231-
# in Series/Index arithmetic ops.
232-
other = ops.maybe_prepare_scalar_for_op(other, (len(self),))
233-
pd_op = ops.get_array_op(op)
234-
other = ensure_wrapped_if_datetimelike(other)
235-
236-
mask = self._propagate_mask(omask, other)
237-
238-
if other is libmissing.NA:
239-
result = np.ones_like(self._data)
240-
if "truediv" in op_name and self.dtype.kind != "f":
241-
# The actual data here doesn't matter since the mask
242-
# will be all-True, but since this is division, we want
243-
# to end up with floating dtype.
244-
result = result.astype(np.float64)
245-
else:
246-
# Make sure we do this before the "pow" mask checks
247-
# to get an expected exception message on shape mismatch.
248-
if self.dtype.kind in ["i", "u"] and op_name in ["floordiv", "mod"]:
249-
# ATM we don't match the behavior of non-masked types with
250-
# respect to floordiv-by-zero
251-
pd_op = op
252-
253-
with np.errstate(all="ignore"):
254-
result = pd_op(self._data, other)
255-
256-
if op_name == "pow":
257-
# 1 ** x is 1.
258-
mask = np.where((self._data == 1) & ~self._mask, False, mask)
259-
# x ** 0 is 1.
260-
if omask is not None:
261-
mask = np.where((other == 0) & ~omask, False, mask)
262-
elif other is not libmissing.NA:
263-
mask = np.where(other == 0, False, mask)
264-
265-
elif op_name == "rpow":
266-
# 1 ** x is 1.
267-
if omask is not None:
268-
mask = np.where((other == 1) & ~omask, False, mask)
269-
elif other is not libmissing.NA:
270-
mask = np.where(other == 1, False, mask)
271-
# x ** 0 is 1.
272-
mask = np.where((self._data == 0) & ~self._mask, False, mask)
273-
274-
return self._maybe_mask_result(result, mask, other, op_name)
275-
276213
_HANDLED_TYPES = (np.ndarray, numbers.Number)
277214

278215
def __neg__(self):

pandas/tests/arrays/boolean/test_arithmetic.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,7 @@ def test_div(left_array, right_array):
7171
[
7272
"floordiv",
7373
"mod",
74-
pytest.param(
75-
"pow", marks=pytest.mark.xfail(reason="TODO follow int8 behaviour? GH34686")
76-
),
74+
"pow",
7775
],
7876
)
7977
def test_op_int8(left_array, right_array, opname):

pandas/tests/arrays/masked/test_arithmetic.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -153,22 +153,29 @@ def test_error_len_mismatch(data, all_arithmetic_operators):
153153

154154
other = [scalar] * (len(data) - 1)
155155

156+
err = ValueError
156157
msg = "|".join(
157158
[
158159
r"operands could not be broadcast together with shapes \(3,\) \(4,\)",
159160
r"operands could not be broadcast together with shapes \(4,\) \(3,\)",
160161
]
161162
)
162-
163-
if data.dtype.kind == "b":
164-
msg = "Lengths must match"
163+
if data.dtype.kind == "b" and all_arithmetic_operators.strip("_") in [
164+
"sub",
165+
"rsub",
166+
]:
167+
err = TypeError
168+
msg = (
169+
r"numpy boolean subtract, the `\-` operator, is not supported, use "
170+
r"the bitwise_xor, the `\^` operator, or the logical_xor function instead"
171+
)
165172

166173
for other in [other, np.array(other)]:
167-
with pytest.raises(ValueError, match=msg):
174+
with pytest.raises(err, match=msg):
168175
op(data, other)
169176

170177
s = pd.Series(data)
171-
with pytest.raises(ValueError, match=msg):
178+
with pytest.raises(err, match=msg):
172179
op(s, other)
173180

174181

0 commit comments

Comments
 (0)