forked from pandas-dev/pandas
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest_numba.py
129 lines (104 loc) · 4.09 KB
/
test_numba.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import numpy as np
import pytest
from pandas.compat import is_platform_arm
import pandas.util._test_decorators as td
import pandas as pd
from pandas import (
DataFrame,
Index,
)
import pandas._testing as tm
from pandas.util.version import Version
pytestmark = [td.skip_if_no("numba"), pytest.mark.single_cpu, pytest.mark.skipif()]
numba = pytest.importorskip("numba")
pytestmark.append(
pytest.mark.skipif(
Version(numba.__version__) == Version("0.61") and is_platform_arm(),
reason=f"Segfaults on ARM platforms with numba {numba.__version__}",
)
)
@pytest.fixture(params=[0, 1])
def apply_axis(request):
return request.param
def test_numba_vs_python_noop(float_frame, apply_axis):
func = lambda x: x
result = float_frame.apply(func, engine="numba", axis=apply_axis)
expected = float_frame.apply(func, engine="python", axis=apply_axis)
tm.assert_frame_equal(result, expected)
def test_numba_vs_python_string_index():
# GH#56189
df = DataFrame(
1,
index=Index(["a", "b"], dtype=pd.StringDtype(na_value=np.nan)),
columns=Index(["x", "y"], dtype=pd.StringDtype(na_value=np.nan)),
)
func = lambda x: x
result = df.apply(func, engine="numba", axis=0)
expected = df.apply(func, engine="python", axis=0)
tm.assert_frame_equal(
result, expected, check_column_type=False, check_index_type=False
)
def test_numba_vs_python_indexing():
frame = DataFrame(
{"a": [1, 2, 3], "b": [4, 5, 6], "c": [7.0, 8.0, 9.0]},
index=Index(["A", "B", "C"]),
)
row_func = lambda x: x["c"]
result = frame.apply(row_func, engine="numba", axis=1)
expected = frame.apply(row_func, engine="python", axis=1)
tm.assert_series_equal(result, expected)
col_func = lambda x: x["A"]
result = frame.apply(col_func, engine="numba", axis=0)
expected = frame.apply(col_func, engine="python", axis=0)
tm.assert_series_equal(result, expected)
@pytest.mark.parametrize(
"reduction",
[lambda x: x.mean(), lambda x: x.min(), lambda x: x.max(), lambda x: x.sum()],
)
def test_numba_vs_python_reductions(reduction, apply_axis):
df = DataFrame(np.ones((4, 4), dtype=np.float64))
result = df.apply(reduction, engine="numba", axis=apply_axis)
expected = df.apply(reduction, engine="python", axis=apply_axis)
tm.assert_series_equal(result, expected)
@pytest.mark.parametrize("colnames", [[1, 2, 3], [1.0, 2.0, 3.0]])
def test_numba_numeric_colnames(colnames):
# Check that numeric column names lower properly and can be indxed on
df = DataFrame(
np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.int64), columns=colnames
)
first_col = colnames[0]
f = lambda x: x[first_col] # Get the first column
result = df.apply(f, engine="numba", axis=1)
expected = df.apply(f, engine="python", axis=1)
tm.assert_series_equal(result, expected)
def test_numba_parallel_unsupported(float_frame):
f = lambda x: x
with pytest.raises(
NotImplementedError,
match="Parallel apply is not supported when raw=False and engine='numba'",
):
float_frame.apply(f, engine="numba", engine_kwargs={"parallel": True})
def test_numba_nonunique_unsupported(apply_axis):
f = lambda x: x
df = DataFrame({"a": [1, 2]}, index=Index(["a", "a"]))
with pytest.raises(
NotImplementedError,
match="The index/columns must be unique when raw=False and engine='numba'",
):
df.apply(f, engine="numba", axis=apply_axis)
def test_numba_unsupported_dtypes(apply_axis):
pytest.importorskip("pyarrow")
f = lambda x: x
df = DataFrame({"a": [1, 2], "b": ["a", "b"], "c": [4, 5]})
df["c"] = df["c"].astype("double[pyarrow]")
with pytest.raises(
ValueError,
match="Column b must have a numeric dtype. Found 'object|str' instead",
):
df.apply(f, engine="numba", axis=apply_axis)
with pytest.raises(
ValueError,
match="Column c is backed by an extension array, "
"which is not supported by the numba engine.",
):
df["c"].to_frame().apply(f, engine="numba", axis=apply_axis)