Skip to content

use numexpr for Series comparisons #32047

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Feb 26, 2020
3 changes: 2 additions & 1 deletion pandas/core/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,7 @@ def _comp_method_SERIES(cls, op, special):
Wrapper function for Series arithmetic operations, to avoid
code duplication.
"""
str_rep = _get_opstr(op)
op_name = _get_op_name(op, special)

@unpack_zerodim_and_defer(op_name)
Expand All @@ -523,7 +524,7 @@ def wrapper(self, other):
lvalues = extract_array(self, extract_numpy=True)
rvalues = extract_array(other, extract_numpy=True)

res_values = comparison_op(lvalues, rvalues, op)
res_values = comparison_op(lvalues, rvalues, op, str_rep)

return _construct_result(self, res_values, index=self.index, name=res_name)

Expand Down
30 changes: 18 additions & 12 deletions pandas/core/ops/array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def na_op(x, y):
return na_op


def na_arithmetic_op(left, right, op, str_rep: str):
def na_arithmetic_op(left, right, op, str_rep: Optional[str], is_cmp: bool = False):
"""
Return the result of evaluating op on the passed in values.
Expand All @@ -137,6 +137,8 @@ def na_arithmetic_op(left, right, op, str_rep: str):
left : np.ndarray
right : np.ndarray or scalar
str_rep : str or None
is_cmp : bool, default False
If this a comparison operation.
Returns
-------
Expand All @@ -151,8 +153,18 @@ def na_arithmetic_op(left, right, op, str_rep: str):
try:
result = expressions.evaluate(op, str_rep, left, right)
except TypeError:
if is_cmp:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what hits this AND is a is_cmp? can you add a comment

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so we have 7 cases that get here. 4 of these are something like ndarray[float] > datetime, and would pass if we removed this is_cmp check. Of the remaining three:

  • 2 involves passing a DatetimeArray to masked_arith_op, which expects ndarray. this should probably be prevented.
  • one involves comparions of complex numbers, which numpy does differently than python (which is its own PITA). falling through to the masked op would break test_bool_flex_frame_complex_dtype.

# numexpr failed on comparison op, e.g. ndarray[float] > datetime
# In this case we do not fall back to the masked op, as that
# will handle complex numbers incorrectly, see GH#32047
raise
result = masked_arith_op(left, right, op)

if is_cmp and (is_scalar(result) or result is NotImplemented):
# numpy returned a scalar instead of operating element-wise
# e.g. numeric array vs str
return invalid_comparison(left, right, op)

return missing.dispatch_fill_zeros(op, left, right, result)


Expand Down Expand Up @@ -199,7 +211,9 @@ def arithmetic_op(left: ArrayLike, right: Any, op, str_rep: str):
return res_values


def comparison_op(left: ArrayLike, right: Any, op) -> ArrayLike:
def comparison_op(
left: ArrayLike, right: Any, op, str_rep: Optional[str] = None,
) -> ArrayLike:
"""
Evaluate a comparison operation `=`, `!=`, `>=`, `>`, `<=`, or `<`.
Expand Down Expand Up @@ -244,16 +258,8 @@ def comparison_op(left: ArrayLike, right: Any, op) -> ArrayLike:
res_values = comp_method_OBJECT_ARRAY(op, lvalues, rvalues)

else:
op_name = f"__{op.__name__}__"
method = getattr(lvalues, op_name)
with np.errstate(all="ignore"):
res_values = method(rvalues)

if res_values is NotImplemented:
res_values = invalid_comparison(lvalues, rvalues, op)
if is_scalar(res_values):
typ = type(rvalues)
raise TypeError(f"Could not compare {typ} type with Series")
res_values = na_arithmetic_op(lvalues, rvalues, op, str_rep, is_cmp=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is thi not handled in na_arithmetic_op? seems odd to handle here

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure


return res_values

Expand Down Expand Up @@ -380,7 +386,7 @@ def get_array_op(op, str_rep: Optional[str] = None):
"""
op_name = op.__name__.strip("_")
if op_name in {"eq", "ne", "lt", "le", "gt", "ge"}:
return partial(comparison_op, op=op)
return partial(comparison_op, op=op, str_rep=str_rep)
elif op_name in {"and", "or", "xor", "rand", "ror", "rxor"}:
return partial(logical_op, op=op)
else:
Expand Down
3 changes: 1 addition & 2 deletions pandas/tests/arithmetic/test_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,7 @@ def test_df_numeric_cmp_dt64_raises(self):
ts = pd.Timestamp.now()
df = pd.DataFrame({"x": range(5)})

msg = "Invalid comparison between dtype=int64 and Timestamp"

msg = "'[<>]' not supported between instances of 'Timestamp' and 'int'"
with pytest.raises(TypeError, match=msg):
df > ts
with pytest.raises(TypeError, match=msg):
Expand Down