@@ -12,39 +12,27 @@ _int64_max = np.iinfo(np.int64).max
12
12
# group_nth, group_last, group_rank
13
13
# ----------------------------------------------------------------------
14
14
15
- {{py:
16
-
17
- # name, c_type, nan_val
18
- dtypes = [('float64', 'float64_t', 'NAN'),
19
- ('float32', 'float32_t', 'NAN'),
20
- ('int64', 'int64_t', 'NPY_NAT'),
21
- ('object', 'object', 'NAN')]
22
-
23
- def get_dispatch(dtypes):
24
-
25
- for name, c_type, nan_val in dtypes:
26
-
27
- yield name, c_type, nan_val
28
- }}
29
-
30
-
31
- {{for name, c_type, nan_val in get_dispatch(dtypes)}}
15
+ ctypedef fused rank_t:
16
+ float64_t
17
+ float32_t
18
+ int64_t
19
+ object
32
20
33
21
34
22
@cython.wraparound(False)
35
23
@cython.boundscheck(False)
36
- def group_last_{{name}}({{c_type}} [:, :] out,
37
- int64_t[:] counts,
38
- {{c_type}} [:, :] values,
39
- const int64_t[:] labels,
40
- Py_ssize_t min_count=-1):
24
+ def group_last(rank_t [:, :] out,
25
+ int64_t[:] counts,
26
+ rank_t [:, :] values,
27
+ const int64_t[:] labels,
28
+ Py_ssize_t min_count=-1):
41
29
"""
42
30
Only aggregates on axis=0
43
31
"""
44
32
cdef:
45
33
Py_ssize_t i, j, N, K, lab, ncounts = len(counts)
46
- {{c_type}} val
47
- ndarray[{{c_type}} , ndim=2] resx
34
+ rank_t val
35
+ ndarray[rank_t , ndim=2] resx
48
36
ndarray[int64_t, ndim=2] nobs
49
37
50
38
assert min_count == -1, "'min_count' only used in add and prod"
@@ -53,19 +41,15 @@ def group_last_{{name}}({{c_type}}[:, :] out,
53
41
raise AssertionError("len(index) != len(labels)")
54
42
55
43
nobs = np.zeros((<object>out).shape, dtype=np.int64)
56
- {{if name == 'object'}}
57
- resx = np.empty((<object>out).shape, dtype=object)
58
- {{else}}
59
- resx = np.empty_like(out)
60
- {{endif}}
44
+ if rank_t is object:
45
+ resx = np.empty((<object>out).shape, dtype=object)
46
+ else:
47
+ resx = np.empty_like(out)
61
48
62
49
N, K = (<object>values).shape
63
50
64
- {{if name == "object"}}
65
- if True: # make templating happy
66
- {{else}}
67
- with nogil:
68
- {{endif}}
51
+ if rank_t is object:
52
+ # TODO: De-duplicate once conditional-nogil is available
69
53
for i in range(N):
70
54
lab = labels[i]
71
55
if lab < 0:
@@ -76,36 +60,77 @@ def group_last_{{name}}({{c_type}}[:, :] out,
76
60
val = values[i, j]
77
61
78
62
# not nan
79
- if (
80
- {{if not name.startswith("int")}}
81
- val == val and
82
- {{endif}}
83
- val != {{nan_val}}):
84
- nobs[lab, j] += 1
85
- resx[lab, j] = val
63
+ if rank_t is int64_t:
64
+ # need a special notna check
65
+ if val != NPY_NAT:
66
+ nobs[lab, j] += 1
67
+ resx[lab, j] = val
68
+ else:
69
+ if val == val:
70
+ nobs[lab, j] += 1
71
+ resx[lab, j] = val
86
72
87
73
for i in range(ncounts):
88
74
for j in range(K):
89
75
if nobs[i, j] == 0:
90
- out[i, j] = {{nan_val}}
76
+ if rank_t is int64_t:
77
+ out[i, j] = NPY_NAT
78
+ else:
79
+ out[i, j] = NAN
91
80
else:
92
81
out[i, j] = resx[i, j]
82
+ else:
83
+ with nogil:
84
+ for i in range(N):
85
+ lab = labels[i]
86
+ if lab < 0:
87
+ continue
88
+
89
+ counts[lab] += 1
90
+ for j in range(K):
91
+ val = values[i, j]
92
+
93
+ # not nan
94
+ if rank_t is int64_t:
95
+ # need a special notna check
96
+ if val != NPY_NAT:
97
+ nobs[lab, j] += 1
98
+ resx[lab, j] = val
99
+ else:
100
+ if val == val:
101
+ nobs[lab, j] += 1
102
+ resx[lab, j] = val
103
+
104
+ for i in range(ncounts):
105
+ for j in range(K):
106
+ if nobs[i, j] == 0:
107
+ if rank_t is int64_t:
108
+ out[i, j] = NPY_NAT
109
+ else:
110
+ out[i, j] = NAN
111
+ else:
112
+ out[i, j] = resx[i, j]
113
+
114
+ group_last_float64 = group_last["float64_t"]
115
+ group_last_float32 = group_last["float32_t"]
116
+ group_last_int64 = group_last["int64_t"]
117
+ group_last_object = group_last["object"]
93
118
94
119
95
120
@cython.wraparound(False)
96
121
@cython.boundscheck(False)
97
- def group_nth_{{name}}({{c_type}} [:, :] out,
98
- int64_t[:] counts,
99
- {{c_type}} [:, :] values,
100
- const int64_t[:] labels, int64_t rank,
101
- Py_ssize_t min_count=-1):
122
+ def group_nth(rank_t [:, :] out,
123
+ int64_t[:] counts,
124
+ rank_t [:, :] values,
125
+ const int64_t[:] labels, int64_t rank,
126
+ Py_ssize_t min_count=-1):
102
127
"""
103
128
Only aggregates on axis=0
104
129
"""
105
130
cdef:
106
131
Py_ssize_t i, j, N, K, lab, ncounts = len(counts)
107
- {{c_type}} val
108
- ndarray[{{c_type}} , ndim=2] resx
132
+ rank_t val
133
+ ndarray[rank_t , ndim=2] resx
109
134
ndarray[int64_t, ndim=2] nobs
110
135
111
136
assert min_count == -1, "'min_count' only used in add and prod"
@@ -114,19 +139,15 @@ def group_nth_{{name}}({{c_type}}[:, :] out,
114
139
raise AssertionError("len(index) != len(labels)")
115
140
116
141
nobs = np.zeros((<object>out).shape, dtype=np.int64)
117
- {{if name=='object'}}
118
- resx = np.empty((<object>out).shape, dtype=object)
119
- {{else}}
120
- resx = np.empty_like(out)
121
- {{endif}}
142
+ if rank_t is object:
143
+ resx = np.empty((<object>out).shape, dtype=object)
144
+ else:
145
+ resx = np.empty_like(out)
122
146
123
147
N, K = (<object>values).shape
124
148
125
- {{if name == "object"}}
126
- if True: # make templating happy
127
- {{else}}
128
- with nogil:
129
- {{endif}}
149
+ if rank_t is object:
150
+ # TODO: De-duplicate once conditional-nogil is available
130
151
for i in range(N):
131
152
lab = labels[i]
132
153
if lab < 0:
@@ -137,40 +158,73 @@ def group_nth_{{name}}({{c_type}}[:, :] out,
137
158
val = values[i, j]
138
159
139
160
# not nan
140
- if (
141
- {{if not name.startswith("int")}}
142
- val == val and
143
- {{endif}}
144
- val != {{nan_val}}):
161
+ if val == val:
145
162
nobs[lab, j] += 1
146
163
if nobs[lab, j] == rank:
147
164
resx[lab, j] = val
148
165
149
166
for i in range(ncounts):
150
167
for j in range(K):
151
168
if nobs[i, j] == 0:
152
- out[i, j] = {{nan_val}}
169
+ out[i, j] = NAN
153
170
else:
154
171
out[i, j] = resx[i, j]
155
172
173
+ else:
174
+ with nogil:
175
+ for i in range(N):
176
+ lab = labels[i]
177
+ if lab < 0:
178
+ continue
179
+
180
+ counts[lab] += 1
181
+ for j in range(K):
182
+ val = values[i, j]
183
+
184
+ # not nan
185
+ if rank_t is int64_t:
186
+ # need a special notna check
187
+ if val != NPY_NAT:
188
+ nobs[lab, j] += 1
189
+ if nobs[lab, j] == rank:
190
+ resx[lab, j] = val
191
+ else:
192
+ if val == val:
193
+ nobs[lab, j] += 1
194
+ if nobs[lab, j] == rank:
195
+ resx[lab, j] = val
196
+
197
+ for i in range(ncounts):
198
+ for j in range(K):
199
+ if nobs[i, j] == 0:
200
+ if rank_t is int64_t:
201
+ out[i, j] = NPY_NAT
202
+ else:
203
+ out[i, j] = NAN
204
+ else:
205
+ out[i, j] = resx[i, j]
206
+
156
207
157
- {{if name != 'object'}}
208
+ group_nth_float64 = group_nth["float64_t"]
209
+ group_nth_float32 = group_nth["float32_t"]
210
+ group_nth_int64 = group_nth["int64_t"]
211
+ group_nth_object = group_nth["object"]
158
212
159
213
160
214
@cython.boundscheck(False)
161
215
@cython.wraparound(False)
162
- def group_rank_{{name}} (float64_t[:, :] out,
163
- {{c_type}} [:, :] values,
164
- const int64_t[:] labels,
165
- bint is_datetimelike, object ties_method,
166
- bint ascending, bint pct, object na_option):
216
+ def group_rank (float64_t[:, :] out,
217
+ rank_t [:, :] values,
218
+ const int64_t[:] labels,
219
+ bint is_datetimelike, object ties_method,
220
+ bint ascending, bint pct, object na_option):
167
221
"""
168
222
Provides the rank of values within each group.
169
223
170
224
Parameters
171
225
----------
172
226
out : array of float64_t values which this method will write its results to
173
- values : array of {{c_type}} values to be ranked
227
+ values : array of rank_t values to be ranked
174
228
labels : array containing unique label for each group, with its ordering
175
229
matching up to the corresponding record in `values`
176
230
is_datetimelike : bool, default False
@@ -203,10 +257,13 @@ def group_rank_{{name}}(float64_t[:, :] out,
203
257
Py_ssize_t grp_vals_seen=1, grp_na_count=0, grp_tie_count=0
204
258
ndarray[int64_t] _as
205
259
ndarray[float64_t, ndim=2] grp_sizes
206
- ndarray[{{c_type}} ] masked_vals
260
+ ndarray[rank_t ] masked_vals
207
261
ndarray[uint8_t] mask
208
262
bint keep_na
209
- {{c_type}} nan_fill_val
263
+ rank_t nan_fill_val
264
+
265
+ if rank_t is object:
266
+ raise NotImplementedError("Cant do nogil")
210
267
211
268
tiebreak = tiebreakers[ties_method]
212
269
keep_na = na_option == 'keep'
@@ -217,25 +274,23 @@ def group_rank_{{name}}(float64_t[:, :] out,
217
274
# with mask, without obfuscating location of missing data
218
275
# in values array
219
276
masked_vals = np.array(values[:, 0], copy=True)
220
- {{if name == 'int64'}}
221
- mask = (masked_vals == {{nan_val}}).astype(np.uint8)
222
- {{else}}
223
- mask = np.isnan(masked_vals).astype(np.uint8)
224
- {{endif}}
277
+ if rank_t is int64_t:
278
+ mask = (masked_vals == NPY_NAT).astype(np.uint8)
279
+ else:
280
+ mask = np.isnan(masked_vals).astype(np.uint8)
225
281
226
282
if ascending ^ (na_option == 'top'):
227
- {{if name == 'int64'}}
228
- nan_fill_val = np.iinfo(np.int64).max
229
- {{else}}
230
- nan_fill_val = np.inf
231
- {{endif}}
283
+ if rank_t is int64_t:
284
+ nan_fill_val = np.iinfo(np.int64).max
285
+ else:
286
+ nan_fill_val = np.inf
232
287
order = (masked_vals, mask, labels)
233
288
else:
234
- {{ if name == 'int64'}}
235
- nan_fill_val = np.iinfo(np.int64).min
236
- {{ else}}
237
- nan_fill_val = -np.inf
238
- {{endif}}
289
+ if rank_t is int64_t:
290
+ nan_fill_val = np.iinfo(np.int64).min
291
+ else:
292
+ nan_fill_val = -np.inf
293
+
239
294
order = (masked_vals, ~mask, labels)
240
295
np.putmask(masked_vals, mask, nan_fill_val)
241
296
@@ -337,8 +392,13 @@ def group_rank_{{name}}(float64_t[:, :] out,
337
392
out[i, 0] = NAN
338
393
elif grp_sizes[i, 0] != 0:
339
394
out[i, 0] = out[i, 0] / grp_sizes[i, 0]
340
- {{endif}}
341
- {{endfor}}
395
+
396
+
397
+ group_rank_float64 = group_rank["float64_t"]
398
+ group_rank_float32 = group_rank["float32_t"]
399
+ group_rank_int64 = group_rank["int64_t"]
400
+ # Note: we do not have a group_rank_object because that would require a
401
+ # not-nogil implementation, see GH#19560
342
402
343
403
344
404
# ----------------------------------------------------------------------
@@ -484,7 +544,8 @@ def group_cummin(groupby_t[:, :] out,
484
544
const int64_t[:] labels,
485
545
int ngroups,
486
546
bint is_datetimelike):
487
- """Cumulative minimum of columns of `values`, in row groups `labels`.
547
+ """
548
+ Cumulative minimum of columns of `values`, in row groups `labels`.
488
549
489
550
Parameters
490
551
----------
@@ -548,9 +609,10 @@ def group_cummin(groupby_t[:, :] out,
548
609
def group_cummax(groupby_t[:, :] out,
549
610
groupby_t[:, :] values,
550
611
const int64_t[:] labels,
551
- int ngroups,
612
+ int ngroups,
552
613
bint is_datetimelike):
553
- """Cumulative maximum of columns of `values`, in row groups `labels`.
614
+ """
615
+ Cumulative maximum of columns of `values`, in row groups `labels`.
554
616
555
617
Parameters
556
618
----------
0 commit comments