Skip to content

Commit 4938568

Browse files
jbrockmendelproost
authored andcommitted
REF: use fused types for groupby_helper (pandas-dev#28886)
1 parent 12720b7 commit 4938568

File tree

1 file changed

+156
-94
lines changed

1 file changed

+156
-94
lines changed

pandas/_libs/groupby_helper.pxi.in

+156-94
Original file line numberDiff line numberDiff line change
@@ -12,39 +12,27 @@ _int64_max = np.iinfo(np.int64).max
1212
# group_nth, group_last, group_rank
1313
# ----------------------------------------------------------------------
1414

15-
{{py:
16-
17-
# name, c_type, nan_val
18-
dtypes = [('float64', 'float64_t', 'NAN'),
19-
('float32', 'float32_t', 'NAN'),
20-
('int64', 'int64_t', 'NPY_NAT'),
21-
('object', 'object', 'NAN')]
22-
23-
def get_dispatch(dtypes):
24-
25-
for name, c_type, nan_val in dtypes:
26-
27-
yield name, c_type, nan_val
28-
}}
29-
30-
31-
{{for name, c_type, nan_val in get_dispatch(dtypes)}}
15+
ctypedef fused rank_t:
16+
float64_t
17+
float32_t
18+
int64_t
19+
object
3220

3321

3422
@cython.wraparound(False)
3523
@cython.boundscheck(False)
36-
def group_last_{{name}}({{c_type}}[:, :] out,
37-
int64_t[:] counts,
38-
{{c_type}}[:, :] values,
39-
const int64_t[:] labels,
40-
Py_ssize_t min_count=-1):
24+
def group_last(rank_t[:, :] out,
25+
int64_t[:] counts,
26+
rank_t[:, :] values,
27+
const int64_t[:] labels,
28+
Py_ssize_t min_count=-1):
4129
"""
4230
Only aggregates on axis=0
4331
"""
4432
cdef:
4533
Py_ssize_t i, j, N, K, lab, ncounts = len(counts)
46-
{{c_type}} val
47-
ndarray[{{c_type}}, ndim=2] resx
34+
rank_t val
35+
ndarray[rank_t, ndim=2] resx
4836
ndarray[int64_t, ndim=2] nobs
4937

5038
assert min_count == -1, "'min_count' only used in add and prod"
@@ -53,19 +41,15 @@ def group_last_{{name}}({{c_type}}[:, :] out,
5341
raise AssertionError("len(index) != len(labels)")
5442

5543
nobs = np.zeros((<object>out).shape, dtype=np.int64)
56-
{{if name == 'object'}}
57-
resx = np.empty((<object>out).shape, dtype=object)
58-
{{else}}
59-
resx = np.empty_like(out)
60-
{{endif}}
44+
if rank_t is object:
45+
resx = np.empty((<object>out).shape, dtype=object)
46+
else:
47+
resx = np.empty_like(out)
6148

6249
N, K = (<object>values).shape
6350

64-
{{if name == "object"}}
65-
if True: # make templating happy
66-
{{else}}
67-
with nogil:
68-
{{endif}}
51+
if rank_t is object:
52+
# TODO: De-duplicate once conditional-nogil is available
6953
for i in range(N):
7054
lab = labels[i]
7155
if lab < 0:
@@ -76,36 +60,77 @@ def group_last_{{name}}({{c_type}}[:, :] out,
7660
val = values[i, j]
7761

7862
# not nan
79-
if (
80-
{{if not name.startswith("int")}}
81-
val == val and
82-
{{endif}}
83-
val != {{nan_val}}):
84-
nobs[lab, j] += 1
85-
resx[lab, j] = val
63+
if rank_t is int64_t:
64+
# need a special notna check
65+
if val != NPY_NAT:
66+
nobs[lab, j] += 1
67+
resx[lab, j] = val
68+
else:
69+
if val == val:
70+
nobs[lab, j] += 1
71+
resx[lab, j] = val
8672

8773
for i in range(ncounts):
8874
for j in range(K):
8975
if nobs[i, j] == 0:
90-
out[i, j] = {{nan_val}}
76+
if rank_t is int64_t:
77+
out[i, j] = NPY_NAT
78+
else:
79+
out[i, j] = NAN
9180
else:
9281
out[i, j] = resx[i, j]
82+
else:
83+
with nogil:
84+
for i in range(N):
85+
lab = labels[i]
86+
if lab < 0:
87+
continue
88+
89+
counts[lab] += 1
90+
for j in range(K):
91+
val = values[i, j]
92+
93+
# not nan
94+
if rank_t is int64_t:
95+
# need a special notna check
96+
if val != NPY_NAT:
97+
nobs[lab, j] += 1
98+
resx[lab, j] = val
99+
else:
100+
if val == val:
101+
nobs[lab, j] += 1
102+
resx[lab, j] = val
103+
104+
for i in range(ncounts):
105+
for j in range(K):
106+
if nobs[i, j] == 0:
107+
if rank_t is int64_t:
108+
out[i, j] = NPY_NAT
109+
else:
110+
out[i, j] = NAN
111+
else:
112+
out[i, j] = resx[i, j]
113+
114+
group_last_float64 = group_last["float64_t"]
115+
group_last_float32 = group_last["float32_t"]
116+
group_last_int64 = group_last["int64_t"]
117+
group_last_object = group_last["object"]
93118

94119

95120
@cython.wraparound(False)
96121
@cython.boundscheck(False)
97-
def group_nth_{{name}}({{c_type}}[:, :] out,
98-
int64_t[:] counts,
99-
{{c_type}}[:, :] values,
100-
const int64_t[:] labels, int64_t rank,
101-
Py_ssize_t min_count=-1):
122+
def group_nth(rank_t[:, :] out,
123+
int64_t[:] counts,
124+
rank_t[:, :] values,
125+
const int64_t[:] labels, int64_t rank,
126+
Py_ssize_t min_count=-1):
102127
"""
103128
Only aggregates on axis=0
104129
"""
105130
cdef:
106131
Py_ssize_t i, j, N, K, lab, ncounts = len(counts)
107-
{{c_type}} val
108-
ndarray[{{c_type}}, ndim=2] resx
132+
rank_t val
133+
ndarray[rank_t, ndim=2] resx
109134
ndarray[int64_t, ndim=2] nobs
110135

111136
assert min_count == -1, "'min_count' only used in add and prod"
@@ -114,19 +139,15 @@ def group_nth_{{name}}({{c_type}}[:, :] out,
114139
raise AssertionError("len(index) != len(labels)")
115140

116141
nobs = np.zeros((<object>out).shape, dtype=np.int64)
117-
{{if name=='object'}}
118-
resx = np.empty((<object>out).shape, dtype=object)
119-
{{else}}
120-
resx = np.empty_like(out)
121-
{{endif}}
142+
if rank_t is object:
143+
resx = np.empty((<object>out).shape, dtype=object)
144+
else:
145+
resx = np.empty_like(out)
122146

123147
N, K = (<object>values).shape
124148

125-
{{if name == "object"}}
126-
if True: # make templating happy
127-
{{else}}
128-
with nogil:
129-
{{endif}}
149+
if rank_t is object:
150+
# TODO: De-duplicate once conditional-nogil is available
130151
for i in range(N):
131152
lab = labels[i]
132153
if lab < 0:
@@ -137,40 +158,73 @@ def group_nth_{{name}}({{c_type}}[:, :] out,
137158
val = values[i, j]
138159

139160
# not nan
140-
if (
141-
{{if not name.startswith("int")}}
142-
val == val and
143-
{{endif}}
144-
val != {{nan_val}}):
161+
if val == val:
145162
nobs[lab, j] += 1
146163
if nobs[lab, j] == rank:
147164
resx[lab, j] = val
148165

149166
for i in range(ncounts):
150167
for j in range(K):
151168
if nobs[i, j] == 0:
152-
out[i, j] = {{nan_val}}
169+
out[i, j] = NAN
153170
else:
154171
out[i, j] = resx[i, j]
155172

173+
else:
174+
with nogil:
175+
for i in range(N):
176+
lab = labels[i]
177+
if lab < 0:
178+
continue
179+
180+
counts[lab] += 1
181+
for j in range(K):
182+
val = values[i, j]
183+
184+
# not nan
185+
if rank_t is int64_t:
186+
# need a special notna check
187+
if val != NPY_NAT:
188+
nobs[lab, j] += 1
189+
if nobs[lab, j] == rank:
190+
resx[lab, j] = val
191+
else:
192+
if val == val:
193+
nobs[lab, j] += 1
194+
if nobs[lab, j] == rank:
195+
resx[lab, j] = val
196+
197+
for i in range(ncounts):
198+
for j in range(K):
199+
if nobs[i, j] == 0:
200+
if rank_t is int64_t:
201+
out[i, j] = NPY_NAT
202+
else:
203+
out[i, j] = NAN
204+
else:
205+
out[i, j] = resx[i, j]
206+
156207

157-
{{if name != 'object'}}
208+
group_nth_float64 = group_nth["float64_t"]
209+
group_nth_float32 = group_nth["float32_t"]
210+
group_nth_int64 = group_nth["int64_t"]
211+
group_nth_object = group_nth["object"]
158212

159213

160214
@cython.boundscheck(False)
161215
@cython.wraparound(False)
162-
def group_rank_{{name}}(float64_t[:, :] out,
163-
{{c_type}}[:, :] values,
164-
const int64_t[:] labels,
165-
bint is_datetimelike, object ties_method,
166-
bint ascending, bint pct, object na_option):
216+
def group_rank(float64_t[:, :] out,
217+
rank_t[:, :] values,
218+
const int64_t[:] labels,
219+
bint is_datetimelike, object ties_method,
220+
bint ascending, bint pct, object na_option):
167221
"""
168222
Provides the rank of values within each group.
169223

170224
Parameters
171225
----------
172226
out : array of float64_t values which this method will write its results to
173-
values : array of {{c_type}} values to be ranked
227+
values : array of rank_t values to be ranked
174228
labels : array containing unique label for each group, with its ordering
175229
matching up to the corresponding record in `values`
176230
is_datetimelike : bool, default False
@@ -203,10 +257,13 @@ def group_rank_{{name}}(float64_t[:, :] out,
203257
Py_ssize_t grp_vals_seen=1, grp_na_count=0, grp_tie_count=0
204258
ndarray[int64_t] _as
205259
ndarray[float64_t, ndim=2] grp_sizes
206-
ndarray[{{c_type}}] masked_vals
260+
ndarray[rank_t] masked_vals
207261
ndarray[uint8_t] mask
208262
bint keep_na
209-
{{c_type}} nan_fill_val
263+
rank_t nan_fill_val
264+
265+
if rank_t is object:
266+
raise NotImplementedError("Cant do nogil")
210267

211268
tiebreak = tiebreakers[ties_method]
212269
keep_na = na_option == 'keep'
@@ -217,25 +274,23 @@ def group_rank_{{name}}(float64_t[:, :] out,
217274
# with mask, without obfuscating location of missing data
218275
# in values array
219276
masked_vals = np.array(values[:, 0], copy=True)
220-
{{if name == 'int64'}}
221-
mask = (masked_vals == {{nan_val}}).astype(np.uint8)
222-
{{else}}
223-
mask = np.isnan(masked_vals).astype(np.uint8)
224-
{{endif}}
277+
if rank_t is int64_t:
278+
mask = (masked_vals == NPY_NAT).astype(np.uint8)
279+
else:
280+
mask = np.isnan(masked_vals).astype(np.uint8)
225281

226282
if ascending ^ (na_option == 'top'):
227-
{{if name == 'int64'}}
228-
nan_fill_val = np.iinfo(np.int64).max
229-
{{else}}
230-
nan_fill_val = np.inf
231-
{{endif}}
283+
if rank_t is int64_t:
284+
nan_fill_val = np.iinfo(np.int64).max
285+
else:
286+
nan_fill_val = np.inf
232287
order = (masked_vals, mask, labels)
233288
else:
234-
{{if name == 'int64'}}
235-
nan_fill_val = np.iinfo(np.int64).min
236-
{{else}}
237-
nan_fill_val = -np.inf
238-
{{endif}}
289+
if rank_t is int64_t:
290+
nan_fill_val = np.iinfo(np.int64).min
291+
else:
292+
nan_fill_val = -np.inf
293+
239294
order = (masked_vals, ~mask, labels)
240295
np.putmask(masked_vals, mask, nan_fill_val)
241296

@@ -337,8 +392,13 @@ def group_rank_{{name}}(float64_t[:, :] out,
337392
out[i, 0] = NAN
338393
elif grp_sizes[i, 0] != 0:
339394
out[i, 0] = out[i, 0] / grp_sizes[i, 0]
340-
{{endif}}
341-
{{endfor}}
395+
396+
397+
group_rank_float64 = group_rank["float64_t"]
398+
group_rank_float32 = group_rank["float32_t"]
399+
group_rank_int64 = group_rank["int64_t"]
400+
# Note: we do not have a group_rank_object because that would require a
401+
# not-nogil implementation, see GH#19560
342402

343403

344404
# ----------------------------------------------------------------------
@@ -484,7 +544,8 @@ def group_cummin(groupby_t[:, :] out,
484544
const int64_t[:] labels,
485545
int ngroups,
486546
bint is_datetimelike):
487-
"""Cumulative minimum of columns of `values`, in row groups `labels`.
547+
"""
548+
Cumulative minimum of columns of `values`, in row groups `labels`.
488549

489550
Parameters
490551
----------
@@ -548,9 +609,10 @@ def group_cummin(groupby_t[:, :] out,
548609
def group_cummax(groupby_t[:, :] out,
549610
groupby_t[:, :] values,
550611
const int64_t[:] labels,
551-
int ngroups,
612+
int ngroups,
552613
bint is_datetimelike):
553-
"""Cumulative maximum of columns of `values`, in row groups `labels`.
614+
"""
615+
Cumulative maximum of columns of `values`, in row groups `labels`.
554616

555617
Parameters
556618
----------

0 commit comments

Comments
 (0)