Skip to content

Commit 0de3224

Browse files
committed
add tests
1 parent aa91722 commit 0de3224

File tree

4 files changed

+60
-5
lines changed

4 files changed

+60
-5
lines changed

pandas/tests/apply/test_frame_apply.py

+10
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,16 @@ def test_apply_args(float_frame, axis, raw, engine, nopython):
9090
tm.assert_frame_equal(result, expected)
9191

9292
if engine == "numba":
93+
# py signature binding
94+
with pytest.raises(TypeError, match="missing a required argument: 'a'"):
95+
float_frame.apply(
96+
lambda x, a: x + a,
97+
b=2,
98+
raw=raw,
99+
engine=engine,
100+
engine_kwargs=engine_kwargs,
101+
)
102+
93103
# keyword-only arguments are not supported in numba
94104
with pytest.raises(
95105
pd.errors.NumbaUtilError,

pandas/tests/groupby/aggregate/test_numba.py

+19-2
Original file line numberDiff line numberDiff line change
@@ -35,18 +35,35 @@ def incorrect_function(x):
3535
def test_check_nopython_kwargs():
3636
pytest.importorskip("numba")
3737

38-
def incorrect_function(values, index):
39-
return sum(values) * 2.7
38+
def incorrect_function(values, index, *, a):
39+
return sum(values) * 2.7 + a
40+
41+
def correct_function(values, index, a):
42+
return sum(values) * 2.7 + a
4043

4144
data = DataFrame(
4245
{"key": ["a", "a", "b", "b", "a"], "data": [1.0, 2.0, 3.0, 4.0, 5.0]},
4346
columns=["key", "data"],
4447
)
48+
# py signature binding
49+
with pytest.raises(TypeError, match="missing a required argument: 'a'"):
50+
data.groupby("key").agg(incorrect_function, engine="numba", b=1)
51+
with pytest.raises(TypeError, match="missing a required argument: 'a'"):
52+
data.groupby("key").agg(correct_function, engine="numba", b=1)
53+
54+
with pytest.raises(TypeError, match="missing a required argument: 'a'"):
55+
data.groupby("key")["data"].agg(incorrect_function, engine="numba", b=1)
56+
with pytest.raises(TypeError, match="missing a required argument: 'a'"):
57+
data.groupby("key")["data"].agg(correct_function, engine="numba", b=1)
58+
59+
# numba signature check after binding
4560
with pytest.raises(NumbaUtilError, match="numba does not support"):
4661
data.groupby("key").agg(incorrect_function, engine="numba", a=1)
62+
data.groupby("key").agg(correct_function, engine="numba", a=1)
4763

4864
with pytest.raises(NumbaUtilError, match="numba does not support"):
4965
data.groupby("key")["data"].agg(incorrect_function, engine="numba", a=1)
66+
data.groupby("key")["data"].agg(correct_function, engine="numba", a=1)
5067

5168

5269
@pytest.mark.filterwarnings("ignore")

pandas/tests/groupby/transform/test_numba.py

+19-2
Original file line numberDiff line numberDiff line change
@@ -33,18 +33,35 @@ def incorrect_function(x):
3333
def test_check_nopython_kwargs():
3434
pytest.importorskip("numba")
3535

36-
def incorrect_function(values, index):
37-
return values + 1
36+
def incorrect_function(values, index, *, a):
37+
return values + a
38+
39+
def correct_function(values, index, a):
40+
return values + a
3841

3942
data = DataFrame(
4043
{"key": ["a", "a", "b", "b", "a"], "data": [1.0, 2.0, 3.0, 4.0, 5.0]},
4144
columns=["key", "data"],
4245
)
46+
# py signature binding
47+
with pytest.raises(TypeError, match="missing a required argument: 'a'"):
48+
data.groupby("key").transform(incorrect_function, engine="numba", b=1)
49+
with pytest.raises(TypeError, match="missing a required argument: 'a'"):
50+
data.groupby("key").transform(correct_function, engine="numba", b=1)
51+
52+
with pytest.raises(TypeError, match="missing a required argument: 'a'"):
53+
data.groupby("key")["data"].transform(incorrect_function, engine="numba", b=1)
54+
with pytest.raises(TypeError, match="missing a required argument: 'a'"):
55+
data.groupby("key")["data"].transform(correct_function, engine="numba", b=1)
56+
57+
# numba signature check after binding
4358
with pytest.raises(NumbaUtilError, match="numba does not support"):
4459
data.groupby("key").transform(incorrect_function, engine="numba", a=1)
60+
data.groupby("key").transform(correct_function, engine="numba", a=1)
4561

4662
with pytest.raises(NumbaUtilError, match="numba does not support"):
4763
data.groupby("key")["data"].transform(incorrect_function, engine="numba", a=1)
64+
data.groupby("key")["data"].transform(correct_function, engine="numba", a=1)
4865

4966

5067
@pytest.mark.filterwarnings("ignore")

pandas/tests/window/test_numba.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -319,13 +319,24 @@ def f(x):
319319

320320
@td.skip_if_no("numba")
321321
def test_invalid_kwargs_nopython():
322+
with pytest.raises(TypeError, match="got an unexpected keyword argument 'a'"):
323+
Series(range(1)).rolling(1).apply(
324+
lambda x: x, kwargs={"a": 1}, engine="numba", raw=True
325+
)
322326
with pytest.raises(
323327
NumbaUtilError, match="numba does not support keyword-only arguments"
324328
):
325329
Series(range(1)).rolling(1).apply(
326-
lambda x: x, kwargs={"a": 1}, engine="numba", raw=True
330+
lambda x, *, a: x, kwargs={"a": 1}, engine="numba", raw=True
327331
)
328332

333+
tm.assert_series_equal(
334+
Series(range(1), dtype=float) + 1,
335+
Series(range(1))
336+
.rolling(1)
337+
.apply(lambda x, a: (x + a).sum(), kwargs={"a": 1}, engine="numba", raw=True),
338+
)
339+
329340

330341
@td.skip_if_no("numba")
331342
@pytest.mark.slow

0 commit comments

Comments
 (0)