Skip to content

Commit b9a4335

Browse files
phoflmroeschke
andauthored
BUG: eval and query not working with ea dtypes (pandas-dev#50764)
* BUG: eval and query not working with ea dtypes * Fix windows build * Fix * Fix another bug * Fix eval * Fix * Fix * Add arrow tests * Fix pyarrow-less ci * Add try except * Add warning * Adjust warning * Fix warning * Fix * Update test_query_eval.py * Update pandas/core/computation/eval.py --------- Co-authored-by: Matthew Roeschke <[email protected]>
1 parent 4c8b2ea commit b9a4335

File tree

4 files changed

+117
-0
lines changed

4 files changed

+117
-0
lines changed

doc/source/whatsnew/v2.0.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -1185,6 +1185,7 @@ Conversion
11851185
- Bug in :meth:`DataFrame.astype` not copying data when converting to pyarrow dtype (:issue:`50984`)
11861186
- Bug in :func:`to_datetime` was not respecting ``exact`` argument when ``format`` was an ISO8601 format (:issue:`12649`)
11871187
- Bug in :meth:`TimedeltaArray.astype` raising ``TypeError`` when converting to a pyarrow duration type (:issue:`49795`)
1188+
- Bug in :meth:`DataFrame.eval` and :meth:`DataFrame.query` raising for extension array dtypes (:issue:`29618`, :issue:`50261`, :issue:`31913`)
11881189
-
11891190

11901191
Strings

pandas/core/computation/common.py

+20
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,23 @@ def result_type_many(*arrays_and_dtypes):
2626
except ValueError:
2727
# we have > NPY_MAXARGS terms in our expression
2828
return reduce(np.result_type, arrays_and_dtypes)
29+
except TypeError:
30+
from pandas.core.dtypes.cast import find_common_type
31+
from pandas.core.dtypes.common import is_extension_array_dtype
32+
33+
arr_and_dtypes = list(arrays_and_dtypes)
34+
ea_dtypes, non_ea_dtypes = [], []
35+
for arr_or_dtype in arr_and_dtypes:
36+
if is_extension_array_dtype(arr_or_dtype):
37+
ea_dtypes.append(arr_or_dtype)
38+
else:
39+
non_ea_dtypes.append(arr_or_dtype)
40+
41+
if non_ea_dtypes:
42+
try:
43+
np_dtype = np.result_type(*non_ea_dtypes)
44+
except ValueError:
45+
np_dtype = reduce(np.result_type, arrays_and_dtypes)
46+
return find_common_type(ea_dtypes + [np_dtype])
47+
48+
return find_common_type(ea_dtypes)

pandas/core/computation/eval.py

+19
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,11 @@
77
from typing import TYPE_CHECKING
88
import warnings
99

10+
from pandas.util._exceptions import find_stack_level
1011
from pandas.util._validators import validate_bool_kwarg
1112

13+
from pandas.core.dtypes.common import is_extension_array_dtype
14+
1215
from pandas.core.computation.engines import ENGINES
1316
from pandas.core.computation.expr import (
1417
PARSERS,
@@ -333,6 +336,22 @@ def eval(
333336

334337
parsed_expr = Expr(expr, engine=engine, parser=parser, env=env)
335338

339+
if engine == "numexpr" and (
340+
is_extension_array_dtype(parsed_expr.terms.return_type)
341+
or getattr(parsed_expr.terms, "operand_types", None) is not None
342+
and any(
343+
is_extension_array_dtype(elem)
344+
for elem in parsed_expr.terms.operand_types
345+
)
346+
):
347+
warnings.warn(
348+
"Engine has switched to 'python' because numexpr does not support "
349+
"extension array dtypes. Please set your engine to python manually.",
350+
RuntimeWarning,
351+
stacklevel=find_stack_level(),
352+
)
353+
engine = "python"
354+
336355
# construct the engine and evaluate the parsed expression
337356
eng = ENGINES[engine]
338357
eng_inst = eng(parsed_expr)

pandas/tests/frame/test_query_eval.py

+77
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import numpy as np
44
import pytest
55

6+
from pandas.compat import is_platform_windows
67
from pandas.errors import (
78
NumExprClobberingError,
89
UndefinedVariableError,
@@ -1291,3 +1292,79 @@ def func(*_):
12911292

12921293
with pytest.raises(TypeError, match="Only named functions are supported"):
12931294
df.eval("@funcs[0].__call__()")
1295+
1296+
def test_ea_dtypes(self, any_numeric_ea_and_arrow_dtype):
1297+
# GH#29618
1298+
df = DataFrame(
1299+
[[1, 2], [3, 4]], columns=["a", "b"], dtype=any_numeric_ea_and_arrow_dtype
1300+
)
1301+
warning = RuntimeWarning if NUMEXPR_INSTALLED else None
1302+
with tm.assert_produces_warning(warning):
1303+
result = df.eval("c = b - a")
1304+
expected = DataFrame(
1305+
[[1, 2, 1], [3, 4, 1]],
1306+
columns=["a", "b", "c"],
1307+
dtype=any_numeric_ea_and_arrow_dtype,
1308+
)
1309+
tm.assert_frame_equal(result, expected)
1310+
1311+
def test_ea_dtypes_and_scalar(self):
1312+
# GH#29618
1313+
df = DataFrame([[1, 2], [3, 4]], columns=["a", "b"], dtype="Float64")
1314+
warning = RuntimeWarning if NUMEXPR_INSTALLED else None
1315+
with tm.assert_produces_warning(warning):
1316+
result = df.eval("c = b - 1")
1317+
expected = DataFrame(
1318+
[[1, 2, 1], [3, 4, 3]], columns=["a", "b", "c"], dtype="Float64"
1319+
)
1320+
tm.assert_frame_equal(result, expected)
1321+
1322+
def test_ea_dtypes_and_scalar_operation(self, any_numeric_ea_and_arrow_dtype):
1323+
# GH#29618
1324+
df = DataFrame(
1325+
[[1, 2], [3, 4]], columns=["a", "b"], dtype=any_numeric_ea_and_arrow_dtype
1326+
)
1327+
result = df.eval("c = 2 - 1")
1328+
expected = DataFrame(
1329+
{
1330+
"a": Series([1, 3], dtype=any_numeric_ea_and_arrow_dtype),
1331+
"b": Series([2, 4], dtype=any_numeric_ea_and_arrow_dtype),
1332+
"c": Series(
1333+
[1, 1], dtype="int64" if not is_platform_windows() else "int32"
1334+
),
1335+
}
1336+
)
1337+
tm.assert_frame_equal(result, expected)
1338+
1339+
@pytest.mark.parametrize("dtype", ["int64", "Int64", "int64[pyarrow]"])
1340+
def test_query_ea_dtypes(self, dtype):
1341+
if dtype == "int64[pyarrow]":
1342+
pytest.importorskip("pyarrow")
1343+
# GH#50261
1344+
df = DataFrame({"a": Series([1, 2], dtype=dtype)})
1345+
ref = {2} # noqa:F841
1346+
result = df.query("a in @ref")
1347+
expected = DataFrame({"a": Series([2], dtype=dtype, index=[1])})
1348+
tm.assert_frame_equal(result, expected)
1349+
1350+
@pytest.mark.parametrize("engine", ["python", "numexpr"])
1351+
@pytest.mark.parametrize("dtype", ["int64", "Int64", "int64[pyarrow]"])
1352+
def test_query_ea_equality_comparison(self, dtype, engine):
1353+
# GH#50261
1354+
warning = RuntimeWarning if engine == "numexpr" else None
1355+
if engine == "numexpr" and not NUMEXPR_INSTALLED:
1356+
pytest.skip("numexpr not installed")
1357+
if dtype == "int64[pyarrow]":
1358+
pytest.importorskip("pyarrow")
1359+
df = DataFrame(
1360+
{"A": Series([1, 1, 2], dtype="Int64"), "B": Series([1, 2, 2], dtype=dtype)}
1361+
)
1362+
with tm.assert_produces_warning(warning):
1363+
result = df.query("A == B", engine=engine)
1364+
expected = DataFrame(
1365+
{
1366+
"A": Series([1, 2], dtype="Int64", index=[0, 2]),
1367+
"B": Series([1, 2], dtype=dtype, index=[0, 2]),
1368+
}
1369+
)
1370+
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)