|
3 | 3 | import numpy as np
|
4 | 4 | import pytest
|
5 | 5 |
|
| 6 | +from pandas.compat import is_platform_windows |
6 | 7 | from pandas.errors import (
|
7 | 8 | NumExprClobberingError,
|
8 | 9 | UndefinedVariableError,
|
@@ -1291,3 +1292,79 @@ def func(*_):
|
1291 | 1292 |
|
1292 | 1293 | with pytest.raises(TypeError, match="Only named functions are supported"):
|
1293 | 1294 | df.eval("@funcs[0].__call__()")
|
| 1295 | + |
| 1296 | + def test_ea_dtypes(self, any_numeric_ea_and_arrow_dtype): |
| 1297 | + # GH#29618 |
| 1298 | + df = DataFrame( |
| 1299 | + [[1, 2], [3, 4]], columns=["a", "b"], dtype=any_numeric_ea_and_arrow_dtype |
| 1300 | + ) |
| 1301 | + warning = RuntimeWarning if NUMEXPR_INSTALLED else None |
| 1302 | + with tm.assert_produces_warning(warning): |
| 1303 | + result = df.eval("c = b - a") |
| 1304 | + expected = DataFrame( |
| 1305 | + [[1, 2, 1], [3, 4, 1]], |
| 1306 | + columns=["a", "b", "c"], |
| 1307 | + dtype=any_numeric_ea_and_arrow_dtype, |
| 1308 | + ) |
| 1309 | + tm.assert_frame_equal(result, expected) |
| 1310 | + |
| 1311 | + def test_ea_dtypes_and_scalar(self): |
| 1312 | + # GH#29618 |
| 1313 | + df = DataFrame([[1, 2], [3, 4]], columns=["a", "b"], dtype="Float64") |
| 1314 | + warning = RuntimeWarning if NUMEXPR_INSTALLED else None |
| 1315 | + with tm.assert_produces_warning(warning): |
| 1316 | + result = df.eval("c = b - 1") |
| 1317 | + expected = DataFrame( |
| 1318 | + [[1, 2, 1], [3, 4, 3]], columns=["a", "b", "c"], dtype="Float64" |
| 1319 | + ) |
| 1320 | + tm.assert_frame_equal(result, expected) |
| 1321 | + |
| 1322 | + def test_ea_dtypes_and_scalar_operation(self, any_numeric_ea_and_arrow_dtype): |
| 1323 | + # GH#29618 |
| 1324 | + df = DataFrame( |
| 1325 | + [[1, 2], [3, 4]], columns=["a", "b"], dtype=any_numeric_ea_and_arrow_dtype |
| 1326 | + ) |
| 1327 | + result = df.eval("c = 2 - 1") |
| 1328 | + expected = DataFrame( |
| 1329 | + { |
| 1330 | + "a": Series([1, 3], dtype=any_numeric_ea_and_arrow_dtype), |
| 1331 | + "b": Series([2, 4], dtype=any_numeric_ea_and_arrow_dtype), |
| 1332 | + "c": Series( |
| 1333 | + [1, 1], dtype="int64" if not is_platform_windows() else "int32" |
| 1334 | + ), |
| 1335 | + } |
| 1336 | + ) |
| 1337 | + tm.assert_frame_equal(result, expected) |
| 1338 | + |
| 1339 | + @pytest.mark.parametrize("dtype", ["int64", "Int64", "int64[pyarrow]"]) |
| 1340 | + def test_query_ea_dtypes(self, dtype): |
| 1341 | + if dtype == "int64[pyarrow]": |
| 1342 | + pytest.importorskip("pyarrow") |
| 1343 | + # GH#50261 |
| 1344 | + df = DataFrame({"a": Series([1, 2], dtype=dtype)}) |
| 1345 | + ref = {2} # noqa:F841 |
| 1346 | + result = df.query("a in @ref") |
| 1347 | + expected = DataFrame({"a": Series([2], dtype=dtype, index=[1])}) |
| 1348 | + tm.assert_frame_equal(result, expected) |
| 1349 | + |
| 1350 | + @pytest.mark.parametrize("engine", ["python", "numexpr"]) |
| 1351 | + @pytest.mark.parametrize("dtype", ["int64", "Int64", "int64[pyarrow]"]) |
| 1352 | + def test_query_ea_equality_comparison(self, dtype, engine): |
| 1353 | + # GH#50261 |
| 1354 | + warning = RuntimeWarning if engine == "numexpr" else None |
| 1355 | + if engine == "numexpr" and not NUMEXPR_INSTALLED: |
| 1356 | + pytest.skip("numexpr not installed") |
| 1357 | + if dtype == "int64[pyarrow]": |
| 1358 | + pytest.importorskip("pyarrow") |
| 1359 | + df = DataFrame( |
| 1360 | + {"A": Series([1, 1, 2], dtype="Int64"), "B": Series([1, 2, 2], dtype=dtype)} |
| 1361 | + ) |
| 1362 | + with tm.assert_produces_warning(warning): |
| 1363 | + result = df.query("A == B", engine=engine) |
| 1364 | + expected = DataFrame( |
| 1365 | + { |
| 1366 | + "A": Series([1, 2], dtype="Int64", index=[0, 2]), |
| 1367 | + "B": Series([1, 2], dtype=dtype, index=[0, 2]), |
| 1368 | + } |
| 1369 | + ) |
| 1370 | + tm.assert_frame_equal(result, expected) |
0 commit comments