@@ -1177,11 +1177,13 @@ def group_quantile(
1177
1177
ndarray[float64_t , ndim = 2 ] out,
1178
1178
ndarray[numeric_t , ndim = 1 ] values,
1179
1179
ndarray[intp_t] labels ,
1180
- ndarray[uint8_t] mask ,
1181
- const intp_t[:] sort_indexer ,
1180
+ const uint8_t[:] mask ,
1182
1181
const float64_t[:] qs ,
1182
+ ndarray[int64_t] starts ,
1183
+ ndarray[int64_t] ends ,
1183
1184
str interpolation ,
1184
- uint8_t[:, ::1] result_mask = None ,
1185
+ uint8_t[:, ::1] result_mask ,
1186
+ bint is_datetimelike ,
1185
1187
) -> None:
1186
1188
"""
1187
1189
Calculate the quantile per group.
@@ -1194,27 +1196,38 @@ def group_quantile(
1194
1196
Array containing the values to apply the function against.
1195
1197
labels : ndarray[np.intp]
1196
1198
Array containing the unique group labels.
1197
- sort_indexer : ndarray[np.intp]
1198
- Indices describing sort order by values and labels.
1199
1199
qs : ndarray[float64_t]
1200
1200
The quantile values to search for.
1201
+ starts : ndarray[int64]
1202
+ Positions at which each group begins.
1203
+ ends : ndarray[int64]
1204
+ Positions at which each group ends.
1201
1205
interpolation : {'linear', 'lower', 'highest', 'nearest', 'midpoint'}
1206
+ result_mask : ndarray[bool , ndim = 2 ] or None
1207
+ is_datetimelike : bool
1208
+ Whether int64 values represent datetime64-like values.
1202
1209
1203
1210
Notes
1204
1211
-----
1205
1212
Rather than explicitly returning a value , this function modifies the
1206
1213
provided `out` parameter.
1207
1214
"""
1208
1215
cdef:
1209
- Py_ssize_t i , N = len (labels), ngroups , grp_sz , non_na_sz , k , nqs
1210
- Py_ssize_t grp_start = 0 , idx = 0
1211
- intp_t lab
1216
+ Py_ssize_t i , N = len (labels), ngroups , non_na_sz , k , nqs
1217
+ Py_ssize_t idx = 0
1218
+ Py_ssize_t grp_size
1212
1219
InterpolationEnumType interp
1213
1220
float64_t q_val , q_idx , frac , val , next_val
1214
- int64_t[::1] counts , non_na_counts
1215
1221
bint uses_result_mask = result_mask is not None
1222
+ Py_ssize_t start , end
1223
+ ndarray[numeric_t] grp
1224
+ intp_t[::1] sort_indexer
1225
+ const uint8_t[:] sub_mask
1216
1226
1217
1227
assert values.shape[0] == N
1228
+ assert starts is not None
1229
+ assert ends is not None
1230
+ assert len(starts ) == len(ends )
1218
1231
1219
1232
if any(not (0 <= q <= 1) for q in qs ):
1220
1233
wrong = [x for x in qs if not (0 <= x <= 1 )][0 ]
@@ -1233,64 +1246,65 @@ def group_quantile(
1233
1246
1234
1247
nqs = len (qs)
1235
1248
ngroups = len (out)
1236
- counts = np.zeros(ngroups, dtype = np.int64)
1237
- non_na_counts = np.zeros(ngroups, dtype = np.int64)
1238
-
1239
- # First figure out the size of every group
1240
- with nogil:
1241
- for i in range (N):
1242
- lab = labels[i]
1243
- if lab == - 1 : # NA group label
1244
- continue
1245
1249
1246
- counts[lab] += 1
1247
- if not mask[i]:
1248
- non_na_counts[lab] += 1
1250
+ # TODO: get cnp.PyArray_ArgSort to work with nogil so we can restore the rest
1251
+ # of this function as being `with nogil:`
1252
+ for i in range (ngroups):
1253
+ start = starts[i]
1254
+ end = ends[i]
1255
+
1256
+ grp = values[start:end]
1257
+
1258
+ # Figure out how many group elements there are
1259
+ sub_mask = mask[start:end]
1260
+ grp_size = sub_mask.size
1261
+ non_na_sz = 0
1262
+ for k in range (grp_size):
1263
+ if sub_mask[k] == 0 :
1264
+ non_na_sz += 1
1265
+
1266
+ # equiv: sort_indexer = grp.argsort()
1267
+ if is_datetimelike:
1268
+ # We need the argsort to put NaTs at the end, not the beginning
1269
+ sort_indexer = cnp.PyArray_ArgSort(grp.view(" M8[ns]" ), 0 , cnp.NPY_QUICKSORT)
1270
+ else :
1271
+ sort_indexer = cnp.PyArray_ArgSort(grp, 0 , cnp.NPY_QUICKSORT)
1249
1272
1250
- with nogil:
1251
- for i in range (ngroups):
1252
- # Figure out how many group elements there are
1253
- grp_sz = counts[i]
1254
- non_na_sz = non_na_counts[i]
1255
-
1256
- if non_na_sz == 0 :
1257
- for k in range (nqs):
1258
- if uses_result_mask:
1259
- result_mask[i, k] = 1
1260
- else :
1261
- out[i, k] = NaN
1262
- else :
1263
- for k in range (nqs):
1264
- q_val = qs[k]
1273
+ if non_na_sz == 0 :
1274
+ for k in range (nqs):
1275
+ if uses_result_mask:
1276
+ result_mask[i, k] = 1
1277
+ else :
1278
+ out[i, k] = NaN
1279
+ else :
1280
+ for k in range (nqs):
1281
+ q_val = qs[k]
1265
1282
1266
- # Calculate where to retrieve the desired value
1267
- # Casting to int will intentionally truncate result
1268
- idx = grp_start + < int64_t> (q_val * < float64_t> (non_na_sz - 1 ))
1283
+ # Calculate where to retrieve the desired value
1284
+ # Casting to int will intentionally truncate result
1285
+ idx = < int64_t> (q_val * < float64_t> (non_na_sz - 1 ))
1269
1286
1270
- val = values [sort_indexer[idx]]
1271
- # If requested quantile falls evenly on a particular index
1272
- # then write that index's value out. Otherwise interpolate
1273
- q_idx = q_val * (non_na_sz - 1 )
1274
- frac = q_idx % 1
1287
+ val = grp [sort_indexer[idx]]
1288
+ # If requested quantile falls evenly on a particular index
1289
+ # then write that index's value out. Otherwise interpolate
1290
+ q_idx = q_val * (non_na_sz - 1 )
1291
+ frac = q_idx % 1
1275
1292
1276
- if frac == 0.0 or interp == INTERPOLATION_LOWER:
1277
- out[i, k] = val
1278
- else :
1279
- next_val = values[sort_indexer[idx + 1 ]]
1280
- if interp == INTERPOLATION_LINEAR:
1281
- out[i, k] = val + (next_val - val) * frac
1282
- elif interp == INTERPOLATION_HIGHER:
1293
+ if frac == 0.0 or interp == INTERPOLATION_LOWER:
1294
+ out[i, k] = val
1295
+ else :
1296
+ next_val = grp[sort_indexer[idx + 1 ]]
1297
+ if interp == INTERPOLATION_LINEAR:
1298
+ out[i, k] = val + (next_val - val) * frac
1299
+ elif interp == INTERPOLATION_HIGHER:
1300
+ out[i, k] = next_val
1301
+ elif interp == INTERPOLATION_MIDPOINT:
1302
+ out[i, k] = (val + next_val) / 2.0
1303
+ elif interp == INTERPOLATION_NEAREST:
1304
+ if frac > .5 or (frac == .5 and q_val > .5 ): # Always OK?
1283
1305
out[i, k] = next_val
1284
- elif interp == INTERPOLATION_MIDPOINT:
1285
- out[i, k] = (val + next_val) / 2.0
1286
- elif interp == INTERPOLATION_NEAREST:
1287
- if frac > .5 or (frac == .5 and q_val > .5 ): # Always OK?
1288
- out[i, k] = next_val
1289
- else :
1290
- out[i, k] = val
1291
-
1292
- # Increment the index reference in sorted_arr for the next group
1293
- grp_start += grp_sz
1306
+ else :
1307
+ out[i, k] = val
1294
1308
1295
1309
1296
1310
# ----------------------------------------------------------------------
0 commit comments