diff --git a/pandas/core/computation/eval.py b/pandas/core/computation/eval.py index aad768d31483a..485c7f87d6f33 100644 --- a/pandas/core/computation/eval.py +++ b/pandas/core/computation/eval.py @@ -14,7 +14,10 @@ from pandas.util._exceptions import find_stack_level from pandas.util._validators import validate_bool_kwarg -from pandas.core.dtypes.common import is_extension_array_dtype +from pandas.core.dtypes.common import ( + is_extension_array_dtype, + is_string_dtype, +) from pandas.core.computation.engines import ENGINES from pandas.core.computation.expr import ( @@ -345,10 +348,13 @@ def eval( parsed_expr = Expr(expr, engine=engine, parser=parser, env=env) if engine == "numexpr" and ( - is_extension_array_dtype(parsed_expr.terms.return_type) + ( + is_extension_array_dtype(parsed_expr.terms.return_type) + and not is_string_dtype(parsed_expr.terms.return_type) + ) or getattr(parsed_expr.terms, "operand_types", None) is not None and any( - is_extension_array_dtype(elem) + (is_extension_array_dtype(elem) and not is_string_dtype(elem)) for elem in parsed_expr.terms.operand_types ) ): diff --git a/pandas/core/computation/expr.py b/pandas/core/computation/expr.py index b074e768e0842..f45bc453d2541 100644 --- a/pandas/core/computation/expr.py +++ b/pandas/core/computation/expr.py @@ -21,6 +21,8 @@ from pandas.errors import UndefinedVariableError +from pandas.core.dtypes.common import is_string_dtype + import pandas.core.common as com from pandas.core.computation.ops import ( ARITH_OPS_SYMS, @@ -524,10 +526,12 @@ def _maybe_evaluate_binop( elif self.engine != "pytables": if ( getattr(lhs, "return_type", None) == object + or is_string_dtype(getattr(lhs, "return_type", None)) or getattr(rhs, "return_type", None) == object + or is_string_dtype(getattr(rhs, "return_type", None)) ): # evaluate "==" and "!=" in python if either of our operands - # has an object return type + # has an object or string return type return self._maybe_eval(res, eval_in_python + maybe_eval_in_python) return res diff --git a/pandas/tests/frame/test_query_eval.py b/pandas/tests/frame/test_query_eval.py index fa71153d01157..a574989860957 100644 --- a/pandas/tests/frame/test_query_eval.py +++ b/pandas/tests/frame/test_query_eval.py @@ -4,8 +4,6 @@ import numpy as np import pytest -from pandas._config import using_string_dtype - from pandas.errors import ( NumExprClobberingError, UndefinedVariableError, @@ -762,7 +760,6 @@ def test_inf(self, op, f, engine, parser): result = df.query(q, engine=engine, parser=parser) tm.assert_frame_equal(result, expected) - @pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)") def test_check_tz_aware_index_query(self, tz_aware_fixture): # https://github.com/pandas-dev/pandas/issues/29463 tz = tz_aware_fixture @@ -775,6 +772,7 @@ def test_check_tz_aware_index_query(self, tz_aware_fixture): tm.assert_frame_equal(result, expected) expected = DataFrame(df_index) + expected.columns = expected.columns.astype(object) result = df.reset_index().query('"2018-01-03 00:00:00+00" < time') tm.assert_frame_equal(result, expected) @@ -1072,7 +1070,7 @@ def test_query_with_string_columns(self, parser, engine): with pytest.raises(NotImplementedError, match=msg): df.query("a in b and c < d", parser=parser, engine=engine) - def test_object_array_eq_ne(self, parser, engine, using_infer_string): + def test_object_array_eq_ne(self, parser, engine): df = DataFrame( { "a": list("aaaabbbbcccc"), @@ -1081,14 +1079,11 @@ def test_object_array_eq_ne(self, parser, engine, using_infer_string): "d": np.random.default_rng(2).integers(9, size=12), } ) - warning = RuntimeWarning if using_infer_string and engine == "numexpr" else None - with tm.assert_produces_warning(warning): - res = df.query("a == b", parser=parser, engine=engine) + res = df.query("a == b", parser=parser, engine=engine) exp = df[df.a == df.b] tm.assert_frame_equal(res, exp) - with tm.assert_produces_warning(warning): - res = df.query("a != b", parser=parser, engine=engine) + res = df.query("a != b", parser=parser, engine=engine) exp = df[df.a != df.b] tm.assert_frame_equal(res, exp) @@ -1128,15 +1123,13 @@ def test_query_with_nested_special_character(self, parser, engine): ], ) def test_query_lex_compare_strings( - self, parser, engine, op, func, using_infer_string + self, parser, engine, op, func ): a = Series(np.random.default_rng(2).choice(list("abcde"), 20)) b = Series(np.arange(a.size)) df = DataFrame({"X": a, "Y": b}) - warning = RuntimeWarning if using_infer_string and engine == "numexpr" else None - with tm.assert_produces_warning(warning): - res = df.query(f'X {op} "d"', engine=engine, parser=parser) + res = df.query(f'X {op} "d"', engine=engine, parser=parser) expected = df[func(df.X, "d")] tm.assert_frame_equal(res, expected) @@ -1400,7 +1393,6 @@ def test_expr_with_column_name_with_backtick(self): expected = df[df["a`b"] < 2] tm.assert_frame_equal(result, expected) - @pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)") def test_expr_with_string_with_backticks(self): # GH 59285 df = DataFrame(("`", "`````", "``````````"), columns=["#backticks"]) @@ -1408,7 +1400,6 @@ def test_expr_with_string_with_backticks(self): expected = df["```" < df["#backticks"]] tm.assert_frame_equal(result, expected) - @pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)") def test_expr_with_string_with_backticked_substring_same_as_column_name(self): # GH 59285 df = DataFrame(("`", "`````", "``````````"), columns=["#backticks"]) @@ -1439,7 +1430,6 @@ def test_expr_with_column_names_with_special_characters(self, col1, col2, expr): expected = df[df[col1] < df[col2]] tm.assert_frame_equal(result, expected) - @pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)") def test_expr_with_no_backticks(self): # GH 59285 df = DataFrame(("aaa", "vvv", "zzz"), columns=["column_name"]) @@ -1483,7 +1473,6 @@ def test_expr_with_quote_opened_before_backtick_and_quote_is_unmatched(self): ): df.query("`column-name` < 'It`s that\\'s \"quote\" #hash") - @pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)") def test_expr_with_quote_opened_before_backtick_and_quote_is_matched_at_end(self): # GH 59285 df = DataFrame(("aaa", "vvv", "zzz"), columns=["column-name"]) @@ -1491,7 +1480,6 @@ def test_expr_with_quote_opened_before_backtick_and_quote_is_matched_at_end(self expected = df[df["column-name"] < 'It`s that\'s "quote" #hash'] tm.assert_frame_equal(result, expected) - @pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)") def test_expr_with_quote_opened_before_backtick_and_quote_is_matched_in_mid(self): # GH 59285 df = DataFrame(("aaa", "vvv", "zzz"), columns=["column-name"])