16
16
17
17
18
18
@functools .cache
19
- def make_looper (func , result_dtype , nopython , nogil , parallel ):
19
+ def make_looper (func , result_dtype , is_grouped_kernel , nopython , nogil , parallel ):
20
20
if TYPE_CHECKING :
21
21
import numba
22
22
else :
23
23
numba = import_optional_dependency ("numba" )
24
24
25
- @numba .jit (nopython = nopython , nogil = nogil , parallel = parallel )
26
- def column_looper (
27
- values : np .ndarray ,
28
- start : np .ndarray ,
29
- end : np .ndarray ,
30
- min_periods : int ,
31
- * args ,
32
- ):
33
- result = np .empty ((values .shape [0 ], len (start )), dtype = result_dtype )
34
- na_positions = {}
35
- for i in numba .prange (values .shape [0 ]):
36
- output , na_pos = func (
37
- values [i ], result_dtype , start , end , min_periods , * args
38
- )
39
- result [i ] = output
40
- if len (na_pos ) > 0 :
41
- na_positions [i ] = np .array (na_pos )
42
- return result , na_positions
25
+ if is_grouped_kernel :
26
+ @numba .jit (nopython = nopython , nogil = nogil , parallel = parallel )
27
+ def column_looper (
28
+ values : np .ndarray ,
29
+ labels : np .ndarray ,
30
+ ngroups : int ,
31
+ min_periods : int ,
32
+ * args ,
33
+ ):
34
+ result = np .empty ((values .shape [0 ], ngroups ), dtype = result_dtype )
35
+ na_positions = {}
36
+ for i in numba .prange (values .shape [0 ]):
37
+ output , na_pos = func (
38
+ values [i ], result_dtype , labels , ngroups , min_periods , * args
39
+ )
40
+ result [i ] = output
41
+ if len (na_pos ) > 0 :
42
+ na_positions [i ] = np .array (na_pos )
43
+ return result , na_positions
44
+ else :
45
+ @numba .jit (nopython = nopython , nogil = nogil , parallel = parallel )
46
+ def column_looper (
47
+ values : np .ndarray ,
48
+ start : np .ndarray ,
49
+ end : np .ndarray ,
50
+ min_periods : int ,
51
+ * args ,
52
+ ):
53
+ result = np .empty ((values .shape [0 ], len (start )), dtype = result_dtype )
54
+ na_positions = {}
55
+ for i in numba .prange (values .shape [0 ]):
56
+ output , na_pos = func (
57
+ values [i ], result_dtype , start , end , min_periods , * args
58
+ )
59
+ result [i ] = output
60
+ if len (na_pos ) > 0 :
61
+ na_positions [i ] = np .array (na_pos )
62
+ return result , na_positions
43
63
44
64
return column_looper
45
65
@@ -96,6 +116,7 @@ def column_looper(
96
116
def generate_shared_aggregator (
97
117
func : Callable [..., Scalar ],
98
118
dtype_mapping : dict [np .dtype , np .dtype ],
119
+ is_grouped_kernel : bool ,
99
120
nopython : bool ,
100
121
nogil : bool ,
101
122
parallel : bool ,
@@ -111,6 +132,11 @@ def generate_shared_aggregator(
111
132
dtype_mapping: dict or None
112
133
If not None, maps a dtype to a result dtype.
113
134
Otherwise, will fall back to default mapping.
135
+ is_grouped_kernel: bool, default False
136
+ Whether func operates using the group labels (True)
137
+ or using starts/ends arrays
138
+
139
+ If true, you also need to pass the number of groups to this function
114
140
nopython : bool
115
141
nopython to be passed into numba.jit
116
142
nogil : bool
@@ -130,13 +156,18 @@ def generate_shared_aggregator(
130
156
# is less than min_periods
131
157
# Cannot do this in numba nopython mode
132
158
# (you'll run into type-unification error when you cast int -> float)
133
- def looper_wrapper (values , start , end , min_periods , ** kwargs ):
159
+ def looper_wrapper (values , start = None , end = None , labels = None , ngroups = None , min_periods = 0 , ** kwargs ):
134
160
result_dtype = dtype_mapping [values .dtype ]
135
- column_looper = make_looper (func , result_dtype , nopython , nogil , parallel )
161
+ column_looper = make_looper (func , result_dtype , is_grouped_kernel , nopython , nogil , parallel )
136
162
# Need to unpack kwargs since numba only supports *args
137
- result , na_positions = column_looper (
138
- values , start , end , min_periods , * kwargs .values ()
139
- )
163
+ if is_grouped_kernel :
164
+ result , na_positions = column_looper (
165
+ values , labels , ngroups , min_periods , * kwargs .values ()
166
+ )
167
+ else :
168
+ result , na_positions = column_looper (
169
+ values , start , end , min_periods , * kwargs .values ()
170
+ )
140
171
if result .dtype .kind == "i" :
141
172
# Look if na_positions is not empty
142
173
# If so, convert the whole block
0 commit comments