diff --git a/doc/source/whatsnew/v2.0.0.rst b/doc/source/whatsnew/v2.0.0.rst index 532881fee0892..d8beff81e6c6b 100644 --- a/doc/source/whatsnew/v2.0.0.rst +++ b/doc/source/whatsnew/v2.0.0.rst @@ -1183,6 +1183,7 @@ Conversion - Bug in :meth:`DataFrame.astype` not copying data when converting to pyarrow dtype (:issue:`50984`) - Bug in :func:`to_datetime` was not respecting ``exact`` argument when ``format`` was an ISO8601 format (:issue:`12649`) - Bug in :meth:`TimedeltaArray.astype` raising ``TypeError`` when converting to a pyarrow duration type (:issue:`49795`) +- Bug in :meth:`DataFrame.eval` and :meth:`DataFrame.query` raising for extension array dtypes (:issue:`29618`, :issue:`50261`, :issue:`31913`) - Strings diff --git a/pandas/core/computation/common.py b/pandas/core/computation/common.py index a1ac3dfa06ee0..115191829f044 100644 --- a/pandas/core/computation/common.py +++ b/pandas/core/computation/common.py @@ -26,3 +26,23 @@ def result_type_many(*arrays_and_dtypes): except ValueError: # we have > NPY_MAXARGS terms in our expression return reduce(np.result_type, arrays_and_dtypes) + except TypeError: + from pandas.core.dtypes.cast import find_common_type + from pandas.core.dtypes.common import is_extension_array_dtype + + arr_and_dtypes = list(arrays_and_dtypes) + ea_dtypes, non_ea_dtypes = [], [] + for arr_or_dtype in arr_and_dtypes: + if is_extension_array_dtype(arr_or_dtype): + ea_dtypes.append(arr_or_dtype) + else: + non_ea_dtypes.append(arr_or_dtype) + + if non_ea_dtypes: + try: + np_dtype = np.result_type(*non_ea_dtypes) + except ValueError: + np_dtype = reduce(np.result_type, arrays_and_dtypes) + return find_common_type(ea_dtypes + [np_dtype]) + + return find_common_type(ea_dtypes) diff --git a/pandas/core/computation/eval.py b/pandas/core/computation/eval.py index f0127ae05182a..0326760a1ff24 100644 --- a/pandas/core/computation/eval.py +++ b/pandas/core/computation/eval.py @@ -7,8 +7,11 @@ from typing import TYPE_CHECKING import warnings +from pandas.util._exceptions import find_stack_level from pandas.util._validators import validate_bool_kwarg +from pandas.core.dtypes.common import is_extension_array_dtype + from pandas.core.computation.engines import ENGINES from pandas.core.computation.expr import ( PARSERS, @@ -333,6 +336,22 @@ def eval( parsed_expr = Expr(expr, engine=engine, parser=parser, env=env) + if engine == "numexpr" and ( + is_extension_array_dtype(parsed_expr.terms.return_type) + or getattr(parsed_expr.terms, "operand_types", None) is not None + and any( + is_extension_array_dtype(elem) + for elem in parsed_expr.terms.operand_types + ) + ): + warnings.warn( + "Engine has switched to 'python' because numexpr does not support " + "extension array dtypes. Please set your engine to python manually.", + RuntimeWarning, + stacklevel=find_stack_level(), + ) + engine = "python" + # construct the engine and evaluate the parsed expression eng = ENGINES[engine] eng_inst = eng(parsed_expr) diff --git a/pandas/tests/frame/test_query_eval.py b/pandas/tests/frame/test_query_eval.py index 8a40461f10039..c815b897e6a14 100644 --- a/pandas/tests/frame/test_query_eval.py +++ b/pandas/tests/frame/test_query_eval.py @@ -3,6 +3,7 @@ import numpy as np import pytest +from pandas.compat import is_platform_windows from pandas.errors import ( NumExprClobberingError, UndefinedVariableError, @@ -1291,3 +1292,79 @@ def func(*_): with pytest.raises(TypeError, match="Only named functions are supported"): df.eval("@funcs[0].__call__()") + + def test_ea_dtypes(self, any_numeric_ea_and_arrow_dtype): + # GH#29618 + df = DataFrame( + [[1, 2], [3, 4]], columns=["a", "b"], dtype=any_numeric_ea_and_arrow_dtype + ) + warning = RuntimeWarning if NUMEXPR_INSTALLED else None + with tm.assert_produces_warning(warning): + result = df.eval("c = b - a") + expected = DataFrame( + [[1, 2, 1], [3, 4, 1]], + columns=["a", "b", "c"], + dtype=any_numeric_ea_and_arrow_dtype, + ) + tm.assert_frame_equal(result, expected) + + def test_ea_dtypes_and_scalar(self): + # GH#29618 + df = DataFrame([[1, 2], [3, 4]], columns=["a", "b"], dtype="Float64") + warning = RuntimeWarning if NUMEXPR_INSTALLED else None + with tm.assert_produces_warning(warning): + result = df.eval("c = b - 1") + expected = DataFrame( + [[1, 2, 1], [3, 4, 3]], columns=["a", "b", "c"], dtype="Float64" + ) + tm.assert_frame_equal(result, expected) + + def test_ea_dtypes_and_scalar_operation(self, any_numeric_ea_and_arrow_dtype): + # GH#29618 + df = DataFrame( + [[1, 2], [3, 4]], columns=["a", "b"], dtype=any_numeric_ea_and_arrow_dtype + ) + result = df.eval("c = 2 - 1") + expected = DataFrame( + { + "a": Series([1, 3], dtype=any_numeric_ea_and_arrow_dtype), + "b": Series([2, 4], dtype=any_numeric_ea_and_arrow_dtype), + "c": Series( + [1, 1], dtype="int64" if not is_platform_windows() else "int32" + ), + } + ) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("dtype", ["int64", "Int64", "int64[pyarrow]"]) + def test_query_ea_dtypes(self, dtype): + if dtype == "int64[pyarrow]": + pytest.importorskip("pyarrow") + # GH#50261 + df = DataFrame({"a": Series([1, 2], dtype=dtype)}) + ref = {2} # noqa:F841 + result = df.query("a in @ref") + expected = DataFrame({"a": Series([2], dtype=dtype, index=[1])}) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("engine", ["python", "numexpr"]) + @pytest.mark.parametrize("dtype", ["int64", "Int64", "int64[pyarrow]"]) + def test_query_ea_equality_comparison(self, dtype, engine): + # GH#50261 + warning = RuntimeWarning if engine == "numexpr" else None + if engine == "numexpr" and not NUMEXPR_INSTALLED: + pytest.skip("numexpr not installed") + if dtype == "int64[pyarrow]": + pytest.importorskip("pyarrow") + df = DataFrame( + {"A": Series([1, 1, 2], dtype="Int64"), "B": Series([1, 2, 2], dtype=dtype)} + ) + with tm.assert_produces_warning(warning): + result = df.query("A == B", engine=engine) + expected = DataFrame( + { + "A": Series([1, 2], dtype="Int64", index=[0, 2]), + "B": Series([1, 2], dtype=dtype, index=[0, 2]), + } + ) + tm.assert_frame_equal(result, expected)