@@ -45,6 +45,8 @@ def correct_function(values, index, a):
45
45
{"key" : ["a" , "a" , "b" , "b" , "a" ], "data" : [1.0 , 2.0 , 3.0 , 4.0 , 5.0 ]},
46
46
columns = ["key" , "data" ],
47
47
)
48
+ expected = data .groupby ("key" ).sum () * 2.7
49
+
48
50
# py signature binding
49
51
with pytest .raises (TypeError , match = "missing a required argument: 'a'" ):
50
52
data .groupby ("key" ).agg (incorrect_function , engine = "numba" , b = 1 )
@@ -59,11 +61,13 @@ def correct_function(values, index, a):
59
61
# numba signature check after binding
60
62
with pytest .raises (NumbaUtilError , match = "numba does not support" ):
61
63
data .groupby ("key" ).agg (incorrect_function , engine = "numba" , a = 1 )
62
- data .groupby ("key" ).agg (correct_function , engine = "numba" , a = 1 )
64
+ actual = data .groupby ("key" ).agg (correct_function , engine = "numba" , a = 1 )
65
+ tm .assert_frame_equal (expected + 1 , actual )
63
66
64
67
with pytest .raises (NumbaUtilError , match = "numba does not support" ):
65
68
data .groupby ("key" )["data" ].agg (incorrect_function , engine = "numba" , a = 1 )
66
- data .groupby ("key" )["data" ].agg (correct_function , engine = "numba" , a = 1 )
69
+ actual = data .groupby ("key" )["data" ].agg (correct_function , engine = "numba" , a = 1 )
70
+ tm .assert_series_equal (expected ["data" ] + 1 , actual )
67
71
68
72
69
73
@pytest .mark .filterwarnings ("ignore" )
0 commit comments