@@ -1127,18 +1127,40 @@ ctypedef fused groupby_t:
1127
1127
1128
1128
@ cython.wraparound (False )
1129
1129
@ cython.boundscheck (False )
1130
- def group_max (groupby_t[:, ::1] out ,
1131
- int64_t[::1] counts ,
1132
- ndarray[groupby_t , ndim = 2 ] values,
1133
- const int64_t[:] labels ,
1134
- Py_ssize_t min_count = - 1 ):
1130
+ cdef group_min_max(groupby_t[:, ::1 ] out,
1131
+ int64_t[::1 ] counts,
1132
+ ndarray[groupby_t, ndim= 2 ] values,
1133
+ const int64_t[:] labels,
1134
+ Py_ssize_t min_count = - 1 ,
1135
+ bint compute_max = True ):
1135
1136
"""
1136
- Only aggregates on axis=0
1137
+ Compute minimum/maximum of columns of `values`, in row groups `labels`.
1138
+
1139
+ Parameters
1140
+ ----------
1141
+ out : array
1142
+ Array to store result in.
1143
+ counts : int64 array
1144
+ Input as a zeroed array, populated by group sizes during algorithm
1145
+ values : array
1146
+ Values to find column-wise min/max of.
1147
+ labels : int64 array
1148
+ Labels to group by.
1149
+ min_count : Py_ssize_t, default -1
1150
+ The minimum number of non-NA group elements, NA result if threshold
1151
+ is not met
1152
+ compute_max : bint, default True
1153
+ True to compute group-wise max, False to compute min
1154
+
1155
+ Notes
1156
+ -----
1157
+ This method modifies the `out` parameter, rather than returning an object.
1158
+ `counts` is modified to hold group sizes
1137
1159
"""
1138
1160
cdef:
1139
- Py_ssize_t i, j, N, K, lab, ncounts = len (counts)
1140
- groupby_t val, count, nan_val
1141
- ndarray[groupby_t, ndim= 2 ] maxx
1161
+ Py_ssize_t i, j, N, K, lab, ngroups = len (counts)
1162
+ groupby_t val, nan_val
1163
+ ndarray[groupby_t, ndim= 2 ] group_min_or_max
1142
1164
bint runtime_error = False
1143
1165
int64_t[:, ::1 ] nobs
1144
1166
@@ -1150,18 +1172,17 @@ def group_max(groupby_t[:, ::1] out,
1150
1172
min_count = max (min_count, 1 )
1151
1173
nobs = np.zeros((< object > out).shape, dtype = np.int64)
1152
1174
1153
- maxx = np.empty_like(out)
1175
+ group_min_or_max = np.empty_like(out)
1154
1176
if groupby_t is int64_t:
1155
- # Note: evaluated at compile-time
1156
- maxx[:] = - _int64_max
1177
+ group_min_or_max[:] = - _int64_max if compute_max else _int64_max
1157
1178
nan_val = NPY_NAT
1158
1179
elif groupby_t is uint64_t:
1159
1180
# NB: We do not define nan_val because there is no such thing
1160
- # for uint64_t. We carefully avoid having to reference it in this
1161
- # case.
1162
- maxx [:] = 0
1181
+ # for uint64_t. We carefully avoid having to reference it in this
1182
+ # case.
1183
+ group_min_or_max [:] = 0 if compute_max else np.iinfo(np.uint64).max
1163
1184
else :
1164
- maxx [:] = - np.inf
1185
+ group_min_or_max [:] = - np.inf if compute_max else np.inf
1165
1186
nan_val = NAN
1166
1187
1167
1188
N, K = (< object > values).shape
@@ -1179,20 +1200,23 @@ def group_max(groupby_t[:, ::1] out,
1179
1200
if not _treat_as_na(val, True ):
1180
1201
# TODO: Sure we always want is_datetimelike=True?
1181
1202
nobs[lab, j] += 1
1182
- if val > maxx[lab, j]:
1183
- maxx[lab, j] = val
1203
+ if compute_max:
1204
+ if val > group_min_or_max[lab, j]:
1205
+ group_min_or_max[lab, j] = val
1206
+ else :
1207
+ if val < group_min_or_max[lab, j]:
1208
+ group_min_or_max[lab, j] = val
1184
1209
1185
- for i in range (ncounts ):
1210
+ for i in range (ngroups ):
1186
1211
for j in range (K):
1187
1212
if nobs[i, j] < min_count:
1188
1213
if groupby_t is uint64_t:
1189
1214
runtime_error = True
1190
1215
break
1191
1216
else :
1192
-
1193
1217
out[i, j] = nan_val
1194
1218
else :
1195
- out[i, j] = maxx [i, j]
1219
+ out[i, j] = group_min_or_max [i, j]
1196
1220
1197
1221
if runtime_error:
1198
1222
# We cannot raise directly above because that is within a nogil
@@ -1202,75 +1226,24 @@ def group_max(groupby_t[:, ::1] out,
1202
1226
1203
1227
@ cython.wraparound (False )
1204
1228
@ cython.boundscheck (False )
1205
- def group_min (groupby_t[:, ::1] out ,
1229
+ def group_max (groupby_t[:, ::1] out ,
1206
1230
int64_t[::1] counts ,
1207
1231
ndarray[groupby_t , ndim = 2 ] values,
1208
1232
const int64_t[:] labels ,
1209
1233
Py_ssize_t min_count = - 1 ):
1210
- """
1211
- Only aggregates on axis=0
1212
- """
1213
- cdef:
1214
- Py_ssize_t i, j, N, K, lab, ncounts = len (counts)
1215
- groupby_t val, count, nan_val
1216
- ndarray[groupby_t, ndim= 2 ] minx
1217
- bint runtime_error = False
1218
- int64_t[:, ::1 ] nobs
1219
-
1220
- # TODO(cython 3.0):
1221
- # Instead of `labels.shape[0]` use `len(labels)`
1222
- if not len (values) == labels.shape[0 ]:
1223
- raise AssertionError (" len(index) != len(labels)" )
1224
-
1225
- min_count = max (min_count, 1 )
1226
- nobs = np.zeros((< object > out).shape, dtype = np.int64)
1227
-
1228
- minx = np.empty_like(out)
1229
- if groupby_t is int64_t:
1230
- minx[:] = _int64_max
1231
- nan_val = NPY_NAT
1232
- elif groupby_t is uint64_t:
1233
- # NB: We do not define nan_val because there is no such thing
1234
- # for uint64_t. We carefully avoid having to reference it in this
1235
- # case.
1236
- minx[:] = np.iinfo(np.uint64).max
1237
- else :
1238
- minx[:] = np.inf
1239
- nan_val = NAN
1234
+ """ See group_min_max.__doc__"""
1235
+ group_min_max(out, counts, values, labels, min_count = min_count, compute_max = True )
1240
1236
1241
- N, K = (< object > values).shape
1242
1237
1243
- with nogil:
1244
- for i in range (N):
1245
- lab = labels[i]
1246
- if lab < 0 :
1247
- continue
1248
-
1249
- counts[lab] += 1
1250
- for j in range (K):
1251
- val = values[i, j]
1252
-
1253
- if not _treat_as_na(val, True ):
1254
- # TODO: Sure we always want is_datetimelike=True?
1255
- nobs[lab, j] += 1
1256
- if val < minx[lab, j]:
1257
- minx[lab, j] = val
1258
-
1259
- for i in range (ncounts):
1260
- for j in range (K):
1261
- if nobs[i, j] < min_count:
1262
- if groupby_t is uint64_t:
1263
- runtime_error = True
1264
- break
1265
- else :
1266
- out[i, j] = nan_val
1267
- else :
1268
- out[i, j] = minx[i, j]
1269
-
1270
- if runtime_error:
1271
- # We cannot raise directly above because that is within a nogil
1272
- # block.
1273
- raise RuntimeError (" empty group with uint64_t" )
1238
+ @ cython.wraparound (False )
1239
+ @ cython.boundscheck (False )
1240
+ def group_min (groupby_t[:, ::1] out ,
1241
+ int64_t[::1] counts ,
1242
+ ndarray[groupby_t , ndim = 2 ] values,
1243
+ const int64_t[:] labels ,
1244
+ Py_ssize_t min_count = - 1 ):
1245
+ """ See group_min_max.__doc__"""
1246
+ group_min_max(out, counts, values, labels, min_count = min_count, compute_max = False )
1274
1247
1275
1248
1276
1249
@ cython.boundscheck (False )
0 commit comments