Skip to content

Commit 4ff2c68

Browse files
String dtype: allow string dtype in query/eval with default numexpr engine (pandas-dev#59810)
String dtype: allow string dtype in query/eval with default mumexpr engine
1 parent 532b9a1 commit 4ff2c68

File tree

3 files changed

+20
-19
lines changed

3 files changed

+20
-19
lines changed

pandas/core/computation/eval.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010
from pandas.util._exceptions import find_stack_level
1111
from pandas.util._validators import validate_bool_kwarg
1212

13-
from pandas.core.dtypes.common import is_extension_array_dtype
13+
from pandas.core.dtypes.common import (
14+
is_extension_array_dtype,
15+
is_string_dtype,
16+
)
1417

1518
from pandas.core.computation.engines import ENGINES
1619
from pandas.core.computation.expr import (
@@ -336,10 +339,13 @@ def eval(
336339
parsed_expr = Expr(expr, engine=engine, parser=parser, env=env)
337340

338341
if engine == "numexpr" and (
339-
is_extension_array_dtype(parsed_expr.terms.return_type)
342+
(
343+
is_extension_array_dtype(parsed_expr.terms.return_type)
344+
and not is_string_dtype(parsed_expr.terms.return_type)
345+
)
340346
or getattr(parsed_expr.terms, "operand_types", None) is not None
341347
and any(
342-
is_extension_array_dtype(elem)
348+
(is_extension_array_dtype(elem) and not is_string_dtype(elem))
343349
for elem in parsed_expr.terms.operand_types
344350
)
345351
):

pandas/core/computation/expr.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020

2121
from pandas.errors import UndefinedVariableError
2222

23+
from pandas.core.dtypes.common import is_string_dtype
24+
2325
import pandas.core.common as com
2426
from pandas.core.computation.ops import (
2527
ARITH_OPS_SYMS,
@@ -520,10 +522,12 @@ def _maybe_evaluate_binop(
520522
elif self.engine != "pytables":
521523
if (
522524
getattr(lhs, "return_type", None) == object
525+
or is_string_dtype(getattr(lhs, "return_type", None))
523526
or getattr(rhs, "return_type", None) == object
527+
or is_string_dtype(getattr(rhs, "return_type", None))
524528
):
525529
# evaluate "==" and "!=" in python if either of our operands
526-
# has an object return type
530+
# has an object or string return type
527531
return self._maybe_eval(res, eval_in_python + maybe_eval_in_python)
528532
return res
529533

pandas/tests/frame/test_query_eval.py

+6-15
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
import numpy as np
44
import pytest
55

6-
from pandas._config import using_string_dtype
7-
86
from pandas.errors import (
97
NumExprClobberingError,
108
UndefinedVariableError,
@@ -747,7 +745,6 @@ def test_inf(self, op, f, engine, parser):
747745
result = df.query(q, engine=engine, parser=parser)
748746
tm.assert_frame_equal(result, expected)
749747

750-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
751748
def test_check_tz_aware_index_query(self, tz_aware_fixture):
752749
# https://github.com/pandas-dev/pandas/issues/29463
753750
tz = tz_aware_fixture
@@ -760,6 +757,7 @@ def test_check_tz_aware_index_query(self, tz_aware_fixture):
760757
tm.assert_frame_equal(result, expected)
761758

762759
expected = DataFrame(df_index)
760+
expected.columns = expected.columns.astype(object)
763761
result = df.reset_index().query('"2018-01-03 00:00:00+00" < time')
764762
tm.assert_frame_equal(result, expected)
765763

@@ -1057,7 +1055,7 @@ def test_query_with_string_columns(self, parser, engine):
10571055
with pytest.raises(NotImplementedError, match=msg):
10581056
df.query("a in b and c < d", parser=parser, engine=engine)
10591057

1060-
def test_object_array_eq_ne(self, parser, engine, using_infer_string):
1058+
def test_object_array_eq_ne(self, parser, engine):
10611059
df = DataFrame(
10621060
{
10631061
"a": list("aaaabbbbcccc"),
@@ -1066,14 +1064,11 @@ def test_object_array_eq_ne(self, parser, engine, using_infer_string):
10661064
"d": np.random.default_rng(2).integers(9, size=12),
10671065
}
10681066
)
1069-
warning = RuntimeWarning if using_infer_string and engine == "numexpr" else None
1070-
with tm.assert_produces_warning(warning):
1071-
res = df.query("a == b", parser=parser, engine=engine)
1067+
res = df.query("a == b", parser=parser, engine=engine)
10721068
exp = df[df.a == df.b]
10731069
tm.assert_frame_equal(res, exp)
10741070

1075-
with tm.assert_produces_warning(warning):
1076-
res = df.query("a != b", parser=parser, engine=engine)
1071+
res = df.query("a != b", parser=parser, engine=engine)
10771072
exp = df[df.a != df.b]
10781073
tm.assert_frame_equal(res, exp)
10791074

@@ -1112,16 +1107,12 @@ def test_query_with_nested_special_character(self, parser, engine):
11121107
[">=", operator.ge],
11131108
],
11141109
)
1115-
def test_query_lex_compare_strings(
1116-
self, parser, engine, op, func, using_infer_string
1117-
):
1110+
def test_query_lex_compare_strings(self, parser, engine, op, func):
11181111
a = Series(np.random.default_rng(2).choice(list("abcde"), 20))
11191112
b = Series(np.arange(a.size))
11201113
df = DataFrame({"X": a, "Y": b})
11211114

1122-
warning = RuntimeWarning if using_infer_string and engine == "numexpr" else None
1123-
with tm.assert_produces_warning(warning):
1124-
res = df.query(f'X {op} "d"', engine=engine, parser=parser)
1115+
res = df.query(f'X {op} "d"', engine=engine, parser=parser)
11251116
expected = df[func(df.X, "d")]
11261117
tm.assert_frame_equal(res, expected)
11271118

0 commit comments

Comments
 (0)