Skip to content

Commit 013ac67

Browse files
String dtype: allow string dtype in query/eval with default numexpr engine (#59810)
String dtype: allow string dtype in query/eval with default mumexpr engine
1 parent 160b3eb commit 013ac67

File tree

3 files changed

+20
-22
lines changed

3 files changed

+20
-22
lines changed

pandas/core/computation/eval.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414
from pandas.util._exceptions import find_stack_level
1515
from pandas.util._validators import validate_bool_kwarg
1616

17-
from pandas.core.dtypes.common import is_extension_array_dtype
17+
from pandas.core.dtypes.common import (
18+
is_extension_array_dtype,
19+
is_string_dtype,
20+
)
1821

1922
from pandas.core.computation.engines import ENGINES
2023
from pandas.core.computation.expr import (
@@ -345,10 +348,13 @@ def eval(
345348
parsed_expr = Expr(expr, engine=engine, parser=parser, env=env)
346349

347350
if engine == "numexpr" and (
348-
is_extension_array_dtype(parsed_expr.terms.return_type)
351+
(
352+
is_extension_array_dtype(parsed_expr.terms.return_type)
353+
and not is_string_dtype(parsed_expr.terms.return_type)
354+
)
349355
or getattr(parsed_expr.terms, "operand_types", None) is not None
350356
and any(
351-
is_extension_array_dtype(elem)
357+
(is_extension_array_dtype(elem) and not is_string_dtype(elem))
352358
for elem in parsed_expr.terms.operand_types
353359
)
354360
):

pandas/core/computation/expr.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121

2222
from pandas.errors import UndefinedVariableError
2323

24+
from pandas.core.dtypes.common import is_string_dtype
25+
2426
import pandas.core.common as com
2527
from pandas.core.computation.ops import (
2628
ARITH_OPS_SYMS,
@@ -524,10 +526,12 @@ def _maybe_evaluate_binop(
524526
elif self.engine != "pytables":
525527
if (
526528
getattr(lhs, "return_type", None) == object
529+
or is_string_dtype(getattr(lhs, "return_type", None))
527530
or getattr(rhs, "return_type", None) == object
531+
or is_string_dtype(getattr(rhs, "return_type", None))
528532
):
529533
# evaluate "==" and "!=" in python if either of our operands
530-
# has an object return type
534+
# has an object or string return type
531535
return self._maybe_eval(res, eval_in_python + maybe_eval_in_python)
532536
return res
533537

pandas/tests/frame/test_query_eval.py

+6-18
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
import numpy as np
55
import pytest
66

7-
from pandas._config import using_string_dtype
8-
97
from pandas.errors import (
108
NumExprClobberingError,
119
UndefinedVariableError,
@@ -762,7 +760,6 @@ def test_inf(self, op, f, engine, parser):
762760
result = df.query(q, engine=engine, parser=parser)
763761
tm.assert_frame_equal(result, expected)
764762

765-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
766763
def test_check_tz_aware_index_query(self, tz_aware_fixture):
767764
# https://github.com/pandas-dev/pandas/issues/29463
768765
tz = tz_aware_fixture
@@ -775,6 +772,7 @@ def test_check_tz_aware_index_query(self, tz_aware_fixture):
775772
tm.assert_frame_equal(result, expected)
776773

777774
expected = DataFrame(df_index)
775+
expected.columns = expected.columns.astype(object)
778776
result = df.reset_index().query('"2018-01-03 00:00:00+00" < time')
779777
tm.assert_frame_equal(result, expected)
780778

@@ -1072,7 +1070,7 @@ def test_query_with_string_columns(self, parser, engine):
10721070
with pytest.raises(NotImplementedError, match=msg):
10731071
df.query("a in b and c < d", parser=parser, engine=engine)
10741072

1075-
def test_object_array_eq_ne(self, parser, engine, using_infer_string):
1073+
def test_object_array_eq_ne(self, parser, engine):
10761074
df = DataFrame(
10771075
{
10781076
"a": list("aaaabbbbcccc"),
@@ -1081,14 +1079,11 @@ def test_object_array_eq_ne(self, parser, engine, using_infer_string):
10811079
"d": np.random.default_rng(2).integers(9, size=12),
10821080
}
10831081
)
1084-
warning = RuntimeWarning if using_infer_string and engine == "numexpr" else None
1085-
with tm.assert_produces_warning(warning):
1086-
res = df.query("a == b", parser=parser, engine=engine)
1082+
res = df.query("a == b", parser=parser, engine=engine)
10871083
exp = df[df.a == df.b]
10881084
tm.assert_frame_equal(res, exp)
10891085

1090-
with tm.assert_produces_warning(warning):
1091-
res = df.query("a != b", parser=parser, engine=engine)
1086+
res = df.query("a != b", parser=parser, engine=engine)
10921087
exp = df[df.a != df.b]
10931088
tm.assert_frame_equal(res, exp)
10941089

@@ -1128,15 +1123,13 @@ def test_query_with_nested_special_character(self, parser, engine):
11281123
],
11291124
)
11301125
def test_query_lex_compare_strings(
1131-
self, parser, engine, op, func, using_infer_string
1126+
self, parser, engine, op, func
11321127
):
11331128
a = Series(np.random.default_rng(2).choice(list("abcde"), 20))
11341129
b = Series(np.arange(a.size))
11351130
df = DataFrame({"X": a, "Y": b})
11361131

1137-
warning = RuntimeWarning if using_infer_string and engine == "numexpr" else None
1138-
with tm.assert_produces_warning(warning):
1139-
res = df.query(f'X {op} "d"', engine=engine, parser=parser)
1132+
res = df.query(f'X {op} "d"', engine=engine, parser=parser)
11401133
expected = df[func(df.X, "d")]
11411134
tm.assert_frame_equal(res, expected)
11421135

@@ -1400,15 +1393,13 @@ def test_expr_with_column_name_with_backtick(self):
14001393
expected = df[df["a`b"] < 2]
14011394
tm.assert_frame_equal(result, expected)
14021395

1403-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
14041396
def test_expr_with_string_with_backticks(self):
14051397
# GH 59285
14061398
df = DataFrame(("`", "`````", "``````````"), columns=["#backticks"])
14071399
result = df.query("'```' < `#backticks`")
14081400
expected = df["```" < df["#backticks"]]
14091401
tm.assert_frame_equal(result, expected)
14101402

1411-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
14121403
def test_expr_with_string_with_backticked_substring_same_as_column_name(self):
14131404
# GH 59285
14141405
df = DataFrame(("`", "`````", "``````````"), columns=["#backticks"])
@@ -1439,7 +1430,6 @@ def test_expr_with_column_names_with_special_characters(self, col1, col2, expr):
14391430
expected = df[df[col1] < df[col2]]
14401431
tm.assert_frame_equal(result, expected)
14411432

1442-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
14431433
def test_expr_with_no_backticks(self):
14441434
# GH 59285
14451435
df = DataFrame(("aaa", "vvv", "zzz"), columns=["column_name"])
@@ -1483,15 +1473,13 @@ def test_expr_with_quote_opened_before_backtick_and_quote_is_unmatched(self):
14831473
):
14841474
df.query("`column-name` < 'It`s that\\'s \"quote\" #hash")
14851475

1486-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
14871476
def test_expr_with_quote_opened_before_backtick_and_quote_is_matched_at_end(self):
14881477
# GH 59285
14891478
df = DataFrame(("aaa", "vvv", "zzz"), columns=["column-name"])
14901479
result = df.query("`column-name` < 'It`s that\\'s \"quote\" #hash'")
14911480
expected = df[df["column-name"] < 'It`s that\'s "quote" #hash']
14921481
tm.assert_frame_equal(result, expected)
14931482

1494-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
14951483
def test_expr_with_quote_opened_before_backtick_and_quote_is_matched_in_mid(self):
14961484
# GH 59285
14971485
df = DataFrame(("aaa", "vvv", "zzz"), columns=["column-name"])

0 commit comments

Comments
 (0)