Skip to content

Commit 82252be

Browse files
committed
add tests
1 parent 0de3224 commit 82252be

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

pandas/tests/groupby/aggregate/test_numba.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ def correct_function(values, index, a):
4545
{"key": ["a", "a", "b", "b", "a"], "data": [1.0, 2.0, 3.0, 4.0, 5.0]},
4646
columns=["key", "data"],
4747
)
48+
expected = data.groupby("key").sum() * 2.7
49+
4850
# py signature binding
4951
with pytest.raises(TypeError, match="missing a required argument: 'a'"):
5052
data.groupby("key").agg(incorrect_function, engine="numba", b=1)
@@ -59,11 +61,13 @@ def correct_function(values, index, a):
5961
# numba signature check after binding
6062
with pytest.raises(NumbaUtilError, match="numba does not support"):
6163
data.groupby("key").agg(incorrect_function, engine="numba", a=1)
62-
data.groupby("key").agg(correct_function, engine="numba", a=1)
64+
actual = data.groupby("key").agg(correct_function, engine="numba", a=1)
65+
tm.assert_frame_equal(expected + 1, actual)
6366

6467
with pytest.raises(NumbaUtilError, match="numba does not support"):
6568
data.groupby("key")["data"].agg(incorrect_function, engine="numba", a=1)
66-
data.groupby("key")["data"].agg(correct_function, engine="numba", a=1)
69+
actual = data.groupby("key")["data"].agg(correct_function, engine="numba", a=1)
70+
tm.assert_series_equal(expected["data"] + 1, actual)
6771

6872

6973
@pytest.mark.filterwarnings("ignore")

pandas/tests/groupby/transform/test_numba.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,15 @@ def correct_function(values, index, a):
5757
# numba signature check after binding
5858
with pytest.raises(NumbaUtilError, match="numba does not support"):
5959
data.groupby("key").transform(incorrect_function, engine="numba", a=1)
60-
data.groupby("key").transform(correct_function, engine="numba", a=1)
60+
actual = data.groupby("key").transform(correct_function, engine="numba", a=1)
61+
tm.assert_frame_equal(data[["data"]] + 1, actual)
6162

6263
with pytest.raises(NumbaUtilError, match="numba does not support"):
6364
data.groupby("key")["data"].transform(incorrect_function, engine="numba", a=1)
64-
data.groupby("key")["data"].transform(correct_function, engine="numba", a=1)
65+
actual = data.groupby("key")["data"].transform(
66+
correct_function, engine="numba", a=1
67+
)
68+
tm.assert_series_equal(data["data"] + 1, actual)
6569

6670

6771
@pytest.mark.filterwarnings("ignore")

0 commit comments

Comments
 (0)