Skip to content

Commit e6bd49f

Browse files
authored
use numexpr for Series comparisons (#32047)
1 parent e39cd30 commit e6bd49f

File tree

3 files changed

+21
-15
lines changed

3 files changed

+21
-15
lines changed

pandas/core/ops/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,7 @@ def _comp_method_SERIES(cls, op, special):
510510
Wrapper function for Series arithmetic operations, to avoid
511511
code duplication.
512512
"""
513+
str_rep = _get_opstr(op)
513514
op_name = _get_op_name(op, special)
514515

515516
@unpack_zerodim_and_defer(op_name)
@@ -523,7 +524,7 @@ def wrapper(self, other):
523524
lvalues = extract_array(self, extract_numpy=True)
524525
rvalues = extract_array(other, extract_numpy=True)
525526

526-
res_values = comparison_op(lvalues, rvalues, op)
527+
res_values = comparison_op(lvalues, rvalues, op, str_rep)
527528

528529
return _construct_result(self, res_values, index=self.index, name=res_name)
529530

pandas/core/ops/array_ops.py

+18-12
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def na_op(x, y):
126126
return na_op
127127

128128

129-
def na_arithmetic_op(left, right, op, str_rep: str):
129+
def na_arithmetic_op(left, right, op, str_rep: Optional[str], is_cmp: bool = False):
130130
"""
131131
Return the result of evaluating op on the passed in values.
132132
@@ -137,6 +137,8 @@ def na_arithmetic_op(left, right, op, str_rep: str):
137137
left : np.ndarray
138138
right : np.ndarray or scalar
139139
str_rep : str or None
140+
is_cmp : bool, default False
141+
If this a comparison operation.
140142
141143
Returns
142144
-------
@@ -151,8 +153,18 @@ def na_arithmetic_op(left, right, op, str_rep: str):
151153
try:
152154
result = expressions.evaluate(op, str_rep, left, right)
153155
except TypeError:
156+
if is_cmp:
157+
# numexpr failed on comparison op, e.g. ndarray[float] > datetime
158+
# In this case we do not fall back to the masked op, as that
159+
# will handle complex numbers incorrectly, see GH#32047
160+
raise
154161
result = masked_arith_op(left, right, op)
155162

163+
if is_cmp and (is_scalar(result) or result is NotImplemented):
164+
# numpy returned a scalar instead of operating element-wise
165+
# e.g. numeric array vs str
166+
return invalid_comparison(left, right, op)
167+
156168
return missing.dispatch_fill_zeros(op, left, right, result)
157169

158170

@@ -199,7 +211,9 @@ def arithmetic_op(left: ArrayLike, right: Any, op, str_rep: str):
199211
return res_values
200212

201213

202-
def comparison_op(left: ArrayLike, right: Any, op) -> ArrayLike:
214+
def comparison_op(
215+
left: ArrayLike, right: Any, op, str_rep: Optional[str] = None,
216+
) -> ArrayLike:
203217
"""
204218
Evaluate a comparison operation `=`, `!=`, `>=`, `>`, `<=`, or `<`.
205219
@@ -244,16 +258,8 @@ def comparison_op(left: ArrayLike, right: Any, op) -> ArrayLike:
244258
res_values = comp_method_OBJECT_ARRAY(op, lvalues, rvalues)
245259

246260
else:
247-
op_name = f"__{op.__name__}__"
248-
method = getattr(lvalues, op_name)
249261
with np.errstate(all="ignore"):
250-
res_values = method(rvalues)
251-
252-
if res_values is NotImplemented:
253-
res_values = invalid_comparison(lvalues, rvalues, op)
254-
if is_scalar(res_values):
255-
typ = type(rvalues)
256-
raise TypeError(f"Could not compare {typ} type with Series")
262+
res_values = na_arithmetic_op(lvalues, rvalues, op, str_rep, is_cmp=True)
257263

258264
return res_values
259265

@@ -380,7 +386,7 @@ def get_array_op(op, str_rep: Optional[str] = None):
380386
"""
381387
op_name = op.__name__.strip("_")
382388
if op_name in {"eq", "ne", "lt", "le", "gt", "ge"}:
383-
return partial(comparison_op, op=op)
389+
return partial(comparison_op, op=op, str_rep=str_rep)
384390
elif op_name in {"and", "or", "xor", "rand", "ror", "rxor"}:
385391
return partial(logical_op, op=op)
386392
else:

pandas/tests/arithmetic/test_numeric.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,7 @@ def test_df_numeric_cmp_dt64_raises(self):
6666
ts = pd.Timestamp.now()
6767
df = pd.DataFrame({"x": range(5)})
6868

69-
msg = "Invalid comparison between dtype=int64 and Timestamp"
70-
69+
msg = "'[<>]' not supported between instances of 'Timestamp' and 'int'"
7170
with pytest.raises(TypeError, match=msg):
7271
df > ts
7372
with pytest.raises(TypeError, match=msg):

0 commit comments

Comments
 (0)