@@ -59,9 +59,7 @@ def generate_numba_agg_func(
59
59
kwargs : dict [str , Any ],
60
60
func : Callable [..., Scalar ],
61
61
engine_kwargs : dict [str , bool ] | None ,
62
- ) -> Callable [
63
- [np .ndarray , np .ndarray , np .ndarray , np .ndarray , int , int , Any ], np .ndarray
64
- ]:
62
+ ) -> Callable [[np .ndarray , np .ndarray , np .ndarray , np .ndarray , int , Any ], np .ndarray ]:
65
63
"""
66
64
Generate a numba jitted agg function specified by values from engine_kwargs.
67
65
@@ -100,10 +98,13 @@ def group_agg(
100
98
index : np .ndarray ,
101
99
begin : np .ndarray ,
102
100
end : np .ndarray ,
103
- num_groups : int ,
104
101
num_columns : int ,
105
102
* args : Any ,
106
103
) -> np .ndarray :
104
+
105
+ assert len (begin ) == len (end )
106
+ num_groups = len (begin )
107
+
107
108
result = np .empty ((num_groups , num_columns ))
108
109
for i in numba .prange (num_groups ):
109
110
group_index = index [begin [i ] : end [i ]]
@@ -119,9 +120,7 @@ def generate_numba_transform_func(
119
120
kwargs : dict [str , Any ],
120
121
func : Callable [..., np .ndarray ],
121
122
engine_kwargs : dict [str , bool ] | None ,
122
- ) -> Callable [
123
- [np .ndarray , np .ndarray , np .ndarray , np .ndarray , int , int , Any ], np .ndarray
124
- ]:
123
+ ) -> Callable [[np .ndarray , np .ndarray , np .ndarray , np .ndarray , int , Any ], np .ndarray ]:
125
124
"""
126
125
Generate a numba jitted transform function specified by values from engine_kwargs.
127
126
@@ -160,10 +159,13 @@ def group_transform(
160
159
index : np .ndarray ,
161
160
begin : np .ndarray ,
162
161
end : np .ndarray ,
163
- num_groups : int ,
164
162
num_columns : int ,
165
163
* args : Any ,
166
164
) -> np .ndarray :
165
+
166
+ assert len (begin ) == len (end )
167
+ num_groups = len (begin )
168
+
167
169
result = np .empty ((len (values ), num_columns ))
168
170
for i in numba .prange (num_groups ):
169
171
group_index = index [begin [i ] : end [i ]]
0 commit comments