Skip to content

Commit 83016f3

Browse files
TST/REF: arithmetic tests for BooleanArray + consolidate with integer masked tests (#34623)
1 parent 624a1be commit 83016f3

File tree

6 files changed

+285
-177
lines changed

6 files changed

+285
-177
lines changed

pandas/_testing.py

+26
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from datetime import datetime
55
from functools import wraps
66
import gzip
7+
import operator
78
import os
89
from shutil import rmtree
910
import string
@@ -2758,3 +2759,28 @@ def get_cython_table_params(ndframe, func_names_and_expected):
27582759
if name == func_name
27592760
]
27602761
return results
2762+
2763+
2764+
def get_op_from_name(op_name: str) -> Callable:
2765+
"""
2766+
The operator function for a given op name.
2767+
2768+
Parameters
2769+
----------
2770+
op_name : string
2771+
The op name, in form of "add" or "__add__".
2772+
2773+
Returns
2774+
-------
2775+
function
2776+
A function performing the operation.
2777+
"""
2778+
short_opname = op_name.strip("_")
2779+
try:
2780+
op = getattr(operator, short_opname)
2781+
except AttributeError:
2782+
# Assume it is the reverse operator
2783+
rop = getattr(operator, short_opname[1:])
2784+
op = lambda x, y: rop(y, x)
2785+
2786+
return op

pandas/core/arrays/boolean.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -717,11 +717,22 @@ def boolean_arithmetic_method(self, other):
717717
# nans propagate
718718
if mask is None:
719719
mask = self._mask
720+
if other is libmissing.NA:
721+
mask |= True
720722
else:
721723
mask = self._mask | mask
722724

723-
with np.errstate(all="ignore"):
724-
result = op(self._data, other)
725+
if other is libmissing.NA:
726+
# if other is NA, the result will be all NA and we can't run the
727+
# actual op, so we need to choose the resulting dtype manually
728+
if op_name in {"floordiv", "rfloordiv", "mod", "rmod", "pow", "rpow"}:
729+
dtype = "int8"
730+
else:
731+
dtype = "bool"
732+
result = np.zeros(len(self._data), dtype=dtype)
733+
else:
734+
with np.errstate(all="ignore"):
735+
result = op(self._data, other)
725736

726737
# divmod returns a tuple
727738
if op_name == "divmod":
+82-23
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
import operator
2+
13
import numpy as np
24
import pytest
35

46
import pandas as pd
5-
from pandas.tests.extension.base import BaseOpsUtil
7+
import pandas._testing as tm
68

79

810
@pytest.fixture
@@ -13,30 +15,87 @@ def data():
1315
)
1416

1517

16-
class TestArithmeticOps(BaseOpsUtil):
17-
def test_error(self, data, all_arithmetic_operators):
18-
# invalid ops
18+
@pytest.fixture
19+
def left_array():
20+
return pd.array([True] * 3 + [False] * 3 + [None] * 3, dtype="boolean")
1921

20-
op = all_arithmetic_operators
21-
s = pd.Series(data)
22-
ops = getattr(s, op)
23-
opa = getattr(data, op)
2422

25-
# invalid scalars
26-
with pytest.raises(TypeError):
27-
ops("foo")
28-
with pytest.raises(TypeError):
29-
ops(pd.Timestamp("20180101"))
23+
@pytest.fixture
24+
def right_array():
25+
return pd.array([True, False, None] * 3, dtype="boolean")
26+
3027

31-
# invalid array-likes
32-
if op not in ("__mul__", "__rmul__"):
33-
# TODO(extension) numpy's mul with object array sees booleans as numbers
34-
with pytest.raises(TypeError):
35-
ops(pd.Series("foo", index=s.index))
28+
# Basic test for the arithmetic array ops
29+
# -----------------------------------------------------------------------------
3630

37-
# 2d
38-
result = opa(pd.DataFrame({"A": s}))
39-
assert result is NotImplemented
4031

41-
with pytest.raises(NotImplementedError):
42-
opa(np.arange(len(s)).reshape(-1, len(s)))
32+
@pytest.mark.parametrize(
33+
"opname, exp",
34+
[
35+
("add", [True, True, None, True, False, None, None, None, None]),
36+
("mul", [True, False, None, False, False, None, None, None, None]),
37+
],
38+
ids=["add", "mul"],
39+
)
40+
def test_add_mul(left_array, right_array, opname, exp):
41+
op = getattr(operator, opname)
42+
result = op(left_array, right_array)
43+
expected = pd.array(exp, dtype="boolean")
44+
tm.assert_extension_array_equal(result, expected)
45+
46+
47+
def test_sub(left_array, right_array):
48+
with pytest.raises(TypeError):
49+
# numpy points to ^ operator or logical_xor function instead
50+
left_array - right_array
51+
52+
53+
def test_div(left_array, right_array):
54+
# for now division gives a float numpy array
55+
result = left_array / right_array
56+
expected = np.array(
57+
[1.0, np.inf, np.nan, 0.0, np.nan, np.nan, np.nan, np.nan, np.nan],
58+
dtype="float64",
59+
)
60+
tm.assert_numpy_array_equal(result, expected)
61+
62+
63+
@pytest.mark.parametrize(
64+
"opname",
65+
[
66+
"floordiv",
67+
"mod",
68+
pytest.param(
69+
"pow", marks=pytest.mark.xfail(reason="TODO follow int8 behaviour? GH34686")
70+
),
71+
],
72+
)
73+
def test_op_int8(left_array, right_array, opname):
74+
op = getattr(operator, opname)
75+
result = op(left_array, right_array)
76+
expected = op(left_array.astype("Int8"), right_array.astype("Int8"))
77+
tm.assert_extension_array_equal(result, expected)
78+
79+
80+
# Test generic characteristics / errors
81+
# -----------------------------------------------------------------------------
82+
83+
84+
def test_error_invalid_values(data, all_arithmetic_operators):
85+
# invalid ops
86+
87+
op = all_arithmetic_operators
88+
s = pd.Series(data)
89+
ops = getattr(s, op)
90+
91+
# invalid scalars
92+
with pytest.raises(TypeError):
93+
ops("foo")
94+
with pytest.raises(TypeError):
95+
ops(pd.Timestamp("20180101"))
96+
97+
# invalid array-likes
98+
if op not in ("__mul__", "__rmul__"):
99+
# TODO(extension) numpy's mul with object array sees booleans as numbers
100+
with pytest.raises(TypeError):
101+
ops(pd.Series("foo", index=s.index))

pandas/tests/arrays/integer/test_arithmetic.py

+4-142
Original file line numberDiff line numberDiff line change
@@ -5,23 +5,9 @@
55

66
import pandas as pd
77
import pandas._testing as tm
8-
from pandas.core.arrays import ExtensionArray, integer_array
8+
from pandas.core.arrays import integer_array
99
import pandas.core.ops as ops
1010

11-
12-
# TODO need to use existing utility function or move this somewhere central
13-
def get_op_from_name(op_name):
14-
short_opname = op_name.strip("_")
15-
try:
16-
op = getattr(operator, short_opname)
17-
except AttributeError:
18-
# Assume it is the reverse operator
19-
rop = getattr(operator, short_opname[1:])
20-
op = lambda x, y: rop(y, x)
21-
22-
return op
23-
24-
2511
# Basic test for the arithmetic array ops
2612
# -----------------------------------------------------------------------------
2713

@@ -151,55 +137,6 @@ def test_rpow_one_to_na():
151137
tm.assert_numpy_array_equal(result, expected)
152138

153139

154-
# Test equivalence of scalars, numpy arrays with array ops
155-
# -----------------------------------------------------------------------------
156-
157-
158-
def test_array_scalar_like_equivalence(data, all_arithmetic_operators):
159-
op = get_op_from_name(all_arithmetic_operators)
160-
161-
scalar = 2
162-
scalar_array = pd.array([2] * len(data), dtype=data.dtype)
163-
164-
# TODO also add len-1 array (np.array([2], dtype=data.dtype.numpy_dtype))
165-
for scalar in [2, data.dtype.type(2)]:
166-
result = op(data, scalar)
167-
expected = op(data, scalar_array)
168-
if isinstance(expected, ExtensionArray):
169-
tm.assert_extension_array_equal(result, expected)
170-
else:
171-
# TODO div still gives float ndarray -> remove this once we have Float EA
172-
tm.assert_numpy_array_equal(result, expected)
173-
174-
175-
def test_array_NA(data, all_arithmetic_operators):
176-
if "truediv" in all_arithmetic_operators:
177-
pytest.skip("division with pd.NA raises")
178-
op = get_op_from_name(all_arithmetic_operators)
179-
180-
scalar = pd.NA
181-
scalar_array = pd.array([pd.NA] * len(data), dtype=data.dtype)
182-
183-
result = op(data, scalar)
184-
expected = op(data, scalar_array)
185-
tm.assert_extension_array_equal(result, expected)
186-
187-
188-
def test_numpy_array_equivalence(data, all_arithmetic_operators):
189-
op = get_op_from_name(all_arithmetic_operators)
190-
191-
numpy_array = np.array([2] * len(data), dtype=data.dtype.numpy_dtype)
192-
pd_array = pd.array(numpy_array, dtype=data.dtype)
193-
194-
result = op(data, numpy_array)
195-
expected = op(data, pd_array)
196-
if isinstance(expected, ExtensionArray):
197-
tm.assert_extension_array_equal(result, expected)
198-
else:
199-
# TODO div still gives float ndarray -> remove this once we have Float EA
200-
tm.assert_numpy_array_equal(result, expected)
201-
202-
203140
@pytest.mark.parametrize("other", [0, 0.5])
204141
def test_numpy_zero_dim_ndarray(other):
205142
arr = integer_array([1, None, 2])
@@ -208,53 +145,7 @@ def test_numpy_zero_dim_ndarray(other):
208145
tm.assert_equal(result, expected)
209146

210147

211-
# Test equivalence with Series and DataFrame ops
212-
# -----------------------------------------------------------------------------
213-
214-
215-
def test_frame(data, all_arithmetic_operators):
216-
op = get_op_from_name(all_arithmetic_operators)
217-
218-
# DataFrame with scalar
219-
df = pd.DataFrame({"A": data})
220-
scalar = 2
221-
222-
result = op(df, scalar)
223-
expected = pd.DataFrame({"A": op(data, scalar)})
224-
tm.assert_frame_equal(result, expected)
225-
226-
227-
def test_series(data, all_arithmetic_operators):
228-
op = get_op_from_name(all_arithmetic_operators)
229-
230-
s = pd.Series(data)
231-
232-
# Series with scalar
233-
scalar = 2
234-
result = op(s, scalar)
235-
expected = pd.Series(op(data, scalar))
236-
tm.assert_series_equal(result, expected)
237-
238-
# Series with np.ndarray
239-
other = np.ones(len(data), dtype=data.dtype.type)
240-
result = op(s, other)
241-
expected = pd.Series(op(data, other))
242-
tm.assert_series_equal(result, expected)
243-
244-
# Series with pd.array
245-
other = pd.array(np.ones(len(data)), dtype=data.dtype)
246-
result = op(s, other)
247-
expected = pd.Series(op(data, other))
248-
tm.assert_series_equal(result, expected)
249-
250-
# Series with Series
251-
other = pd.Series(np.ones(len(data)), dtype=data.dtype)
252-
result = op(s, other)
253-
expected = pd.Series(op(data, other.array))
254-
tm.assert_series_equal(result, expected)
255-
256-
257-
# Test generic charachteristics / errors
148+
# Test generic characteristics / errors
258149
# -----------------------------------------------------------------------------
259150

260151

@@ -291,35 +182,6 @@ def test_error_invalid_values(data, all_arithmetic_operators):
291182
ops(pd.Series(pd.date_range("20180101", periods=len(s))))
292183

293184

294-
def test_error_invalid_object(data, all_arithmetic_operators):
295-
296-
op = all_arithmetic_operators
297-
opa = getattr(data, op)
298-
299-
# 2d -> return NotImplemented
300-
result = opa(pd.DataFrame({"A": data}))
301-
assert result is NotImplemented
302-
303-
msg = r"can only perform ops with 1-d structures"
304-
with pytest.raises(NotImplementedError, match=msg):
305-
opa(np.arange(len(data)).reshape(-1, len(data)))
306-
307-
308-
def test_error_len_mismatch(all_arithmetic_operators):
309-
# operating with a list-like with non-matching length raises
310-
op = get_op_from_name(all_arithmetic_operators)
311-
312-
data = pd.array([1, 2, 3], dtype="Int64")
313-
314-
for other in [[1, 2], np.array([1.0, 2.0])]:
315-
with pytest.raises(ValueError, match="Lengths must match"):
316-
op(data, other)
317-
318-
s = pd.Series(data)
319-
with pytest.raises(ValueError, match="Lengths must match"):
320-
op(s, other)
321-
322-
323185
# Various
324186
# -----------------------------------------------------------------------------
325187

@@ -328,7 +190,7 @@ def test_error_len_mismatch(all_arithmetic_operators):
328190

329191

330192
def test_arith_coerce_scalar(data, all_arithmetic_operators):
331-
op = get_op_from_name(all_arithmetic_operators)
193+
op = tm.get_op_from_name(all_arithmetic_operators)
332194
s = pd.Series(data)
333195
other = 0.01
334196

@@ -345,7 +207,7 @@ def test_arith_coerce_scalar(data, all_arithmetic_operators):
345207
def test_arithmetic_conversion(all_arithmetic_operators, other):
346208
# if we have a float operand we should have a float result
347209
# if that is equal to an integer
348-
op = get_op_from_name(all_arithmetic_operators)
210+
op = tm.get_op_from_name(all_arithmetic_operators)
349211

350212
s = pd.Series([1, 2, 3], dtype="Int64")
351213
result = op(s, other)

0 commit comments

Comments
 (0)