Skip to content

Commit 0158382

Browse files
BUG/TST: run and fix all arithmetic tests with+without numexpr (#40463)
1 parent a82f8e2 commit 0158382

File tree

4 files changed

+65
-12
lines changed

4 files changed

+65
-12
lines changed

pandas/core/ops/array_ops.py

+26-7
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,21 @@
22
Functions for arithmetic and comparison operations on NumPy arrays and
33
ExtensionArrays.
44
"""
5-
from datetime import timedelta
5+
import datetime
66
from functools import partial
77
import operator
88
from typing import Any
99

1010
import numpy as np
1111

1212
from pandas._libs import (
13+
NaT,
1314
Timedelta,
1415
Timestamp,
1516
lib,
1617
ops as libops,
1718
)
19+
from pandas._libs.tslibs import BaseOffset
1820
from pandas._typing import (
1921
ArrayLike,
2022
Shape,
@@ -154,8 +156,14 @@ def _na_arithmetic_op(left, right, op, is_cmp: bool = False):
154156
------
155157
TypeError : invalid operation
156158
"""
159+
if isinstance(right, str):
160+
# can never use numexpr
161+
func = op
162+
else:
163+
func = partial(expressions.evaluate, op)
164+
157165
try:
158-
result = expressions.evaluate(op, left, right)
166+
result = func(left, right)
159167
except TypeError:
160168
if is_object_dtype(left) or is_object_dtype(right) and not is_cmp:
161169
# For object dtype, fallback to a masked operation (only operating
@@ -201,8 +209,13 @@ def arithmetic_op(left: ArrayLike, right: Any, op):
201209
# casts integer dtypes to timedelta64 when operating with timedelta64 - GH#22390)
202210
right = _maybe_upcast_for_op(right, left.shape)
203211

204-
if should_extension_dispatch(left, right) or isinstance(right, Timedelta):
205-
# Timedelta is included because numexpr will fail on it, see GH#31457
212+
if (
213+
should_extension_dispatch(left, right)
214+
or isinstance(right, (Timedelta, BaseOffset, Timestamp))
215+
or right is NaT
216+
):
217+
# Timedelta/Timestamp and other custom scalars are included in the check
218+
# because numexpr will fail on it, see GH#31457
206219
res_values = op(left, right)
207220
else:
208221
res_values = _na_arithmetic_op(left, right, op)
@@ -246,7 +259,10 @@ def comparison_op(left: ArrayLike, right: Any, op) -> ArrayLike:
246259
"Lengths must match to compare", lvalues.shape, rvalues.shape
247260
)
248261

249-
if should_extension_dispatch(lvalues, rvalues):
262+
if should_extension_dispatch(lvalues, rvalues) or (
263+
(isinstance(rvalues, (Timedelta, BaseOffset, Timestamp)) or right is NaT)
264+
and not is_object_dtype(lvalues.dtype)
265+
):
250266
# Call the method on lvalues
251267
res_values = op(lvalues, rvalues)
252268

@@ -261,7 +277,7 @@ def comparison_op(left: ArrayLike, right: Any, op) -> ArrayLike:
261277
# GH#36377 going through the numexpr path would incorrectly raise
262278
return invalid_comparison(lvalues, rvalues, op)
263279

264-
elif is_object_dtype(lvalues.dtype):
280+
elif is_object_dtype(lvalues.dtype) or isinstance(rvalues, str):
265281
res_values = comp_method_OBJECT_ARRAY(op, lvalues, rvalues)
266282

267283
else:
@@ -438,11 +454,14 @@ def _maybe_upcast_for_op(obj, shape: Shape):
438454
Be careful to call this *after* determining the `name` attribute to be
439455
attached to the result of the arithmetic operation.
440456
"""
441-
if type(obj) is timedelta:
457+
if type(obj) is datetime.timedelta:
442458
# GH#22390 cast up to Timedelta to rely on Timedelta
443459
# implementation; otherwise operation against numeric-dtype
444460
# raises TypeError
445461
return Timedelta(obj)
462+
elif type(obj) is datetime.datetime:
463+
# cast up to Timestamp to rely on Timestamp implementation, see Timedelta above
464+
return Timestamp(obj)
446465
elif isinstance(obj, np.datetime64):
447466
# GH#28080 numpy casts integer-dtype to datetime64 when doing
448467
# array[int] + datetime64, which we do not allow

pandas/tests/arithmetic/conftest.py

+12
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,18 @@
99
UInt64Index,
1010
)
1111
import pandas._testing as tm
12+
from pandas.core.computation import expressions as expr
13+
14+
15+
@pytest.fixture(
16+
autouse=True, scope="module", params=[0, 1000000], ids=["numexpr", "python"]
17+
)
18+
def switch_numexpr_min_elements(request):
19+
_MIN_ELEMENTS = expr._MIN_ELEMENTS
20+
expr._MIN_ELEMENTS = request.param
21+
yield request.param
22+
expr._MIN_ELEMENTS = _MIN_ELEMENTS
23+
1224

1325
# ------------------------------------------------------------------
1426
# Helper Functions

pandas/tests/arithmetic/test_numeric.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
)
2828
import pandas._testing as tm
2929
from pandas.core import ops
30+
from pandas.core.computation import expressions as expr
3031

3132

3233
@pytest.fixture(params=[Index, Series, tm.to_array])
@@ -391,7 +392,7 @@ def test_div_negative_zero(self, zero, numeric_idx, op):
391392
# ------------------------------------------------------------------
392393

393394
@pytest.mark.parametrize("dtype1", [np.int64, np.float64, np.uint64])
394-
def test_ser_div_ser(self, dtype1, any_real_dtype):
395+
def test_ser_div_ser(self, switch_numexpr_min_elements, dtype1, any_real_dtype):
395396
# no longer do integer div for any ops, but deal with the 0's
396397
dtype2 = any_real_dtype
397398

@@ -405,6 +406,11 @@ def test_ser_div_ser(self, dtype1, any_real_dtype):
405406
name=None,
406407
)
407408
expected.iloc[0:3] = np.inf
409+
if first.dtype == "int64" and second.dtype == "float32":
410+
# when using numexpr, the casting rules are slightly different
411+
# and int64/float32 combo results in float32 instead of float64
412+
if expr.USE_NUMEXPR and switch_numexpr_min_elements == 0:
413+
expected = expected.astype("float32")
408414

409415
result = first / second
410416
tm.assert_series_equal(result, expected)
@@ -890,7 +896,13 @@ def test_series_frame_radd_bug(self):
890896

891897
# really raise this time
892898
now = pd.Timestamp.now().to_pydatetime()
893-
msg = "unsupported operand type"
899+
msg = "|".join(
900+
[
901+
"unsupported operand type",
902+
# wrong error message, see https://github.com/numpy/numpy/issues/18832
903+
"Concatenation operation",
904+
]
905+
)
894906
with pytest.raises(TypeError, match=msg):
895907
now + ts
896908

pandas/tests/frame/test_arithmetic.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -174,9 +174,19 @@ def test_timestamp_compare(self):
174174
with pytest.raises(TypeError, match=msg):
175175
right_f(pd.Timestamp("20010109"), df)
176176
# nats
177-
expected = left_f(df, pd.Timestamp("nat"))
178-
result = right_f(pd.Timestamp("nat"), df)
179-
tm.assert_frame_equal(result, expected)
177+
if left in ["eq", "ne"]:
178+
expected = left_f(df, pd.Timestamp("nat"))
179+
result = right_f(pd.Timestamp("nat"), df)
180+
tm.assert_frame_equal(result, expected)
181+
else:
182+
msg = (
183+
"'(<|>)=?' not supported between "
184+
"instances of 'numpy.ndarray' and 'NaTType'"
185+
)
186+
with pytest.raises(TypeError, match=msg):
187+
left_f(df, pd.Timestamp("nat"))
188+
with pytest.raises(TypeError, match=msg):
189+
right_f(pd.Timestamp("nat"), df)
180190

181191
def test_mixed_comparison(self):
182192
# GH#13128, GH#22163 != datetime64 vs non-dt64 should be False,

0 commit comments

Comments
 (0)