Skip to content

Commit e8961f1

Browse files
authored
PERF: GroupBy.quantile (#51722)
1 parent 2bb3557 commit e8961f1

File tree

4 files changed

+114
-96
lines changed

4 files changed

+114
-96
lines changed

doc/source/whatsnew/v2.1.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ Other enhancements
176176
- Performance improvement in :func:`concat` with homogeneous ``np.float64`` or ``np.float32`` dtypes (:issue:`52685`)
177177
- Performance improvement in :meth:`DataFrame.filter` when ``items`` is given (:issue:`52941`)
178178
- Reductions :meth:`Series.argmax`, :meth:`Series.argmin`, :meth:`Series.idxmax`, :meth:`Series.idxmin`, :meth:`Index.argmax`, :meth:`Index.argmin`, :meth:`DataFrame.idxmax`, :meth:`DataFrame.idxmin` are now supported for object-dtype objects (:issue:`4279`, :issue:`18021`, :issue:`40685`, :issue:`43697`)
179+
- Performance improvement in :meth:`GroupBy.quantile` (:issue:`51722`)
179180
-
180181

181182
.. ---------------------------------------------------------------------------

pandas/_libs/groupby.pyi

+4-2
Original file line numberDiff line numberDiff line change
@@ -121,10 +121,12 @@ def group_quantile(
121121
values: np.ndarray, # ndarray[numeric, ndim=1]
122122
labels: npt.NDArray[np.intp],
123123
mask: npt.NDArray[np.uint8],
124-
sort_indexer: npt.NDArray[np.intp], # const
125124
qs: npt.NDArray[np.float64], # const
125+
starts: npt.NDArray[np.int64],
126+
ends: npt.NDArray[np.int64],
126127
interpolation: Literal["linear", "lower", "higher", "nearest", "midpoint"],
127-
result_mask: np.ndarray | None = ...,
128+
result_mask: np.ndarray | None,
129+
is_datetimelike: bool,
128130
) -> None: ...
129131
def group_last(
130132
out: np.ndarray, # rank_t[:, ::1]

pandas/_libs/groupby.pyx

+75-61
Original file line numberDiff line numberDiff line change
@@ -1177,11 +1177,13 @@ def group_quantile(
11771177
ndarray[float64_t, ndim=2] out,
11781178
ndarray[numeric_t, ndim=1] values,
11791179
ndarray[intp_t] labels,
1180-
ndarray[uint8_t] mask,
1181-
const intp_t[:] sort_indexer,
1180+
const uint8_t[:] mask,
11821181
const float64_t[:] qs,
1182+
ndarray[int64_t] starts,
1183+
ndarray[int64_t] ends,
11831184
str interpolation,
1184-
uint8_t[:, ::1] result_mask=None,
1185+
uint8_t[:, ::1] result_mask,
1186+
bint is_datetimelike,
11851187
) -> None:
11861188
"""
11871189
Calculate the quantile per group.
@@ -1194,27 +1196,38 @@ def group_quantile(
11941196
Array containing the values to apply the function against.
11951197
labels : ndarray[np.intp]
11961198
Array containing the unique group labels.
1197-
sort_indexer : ndarray[np.intp]
1198-
Indices describing sort order by values and labels.
11991199
qs : ndarray[float64_t]
12001200
The quantile values to search for.
1201+
starts : ndarray[int64]
1202+
Positions at which each group begins.
1203+
ends : ndarray[int64]
1204+
Positions at which each group ends.
12011205
interpolation : {'linear', 'lower', 'highest', 'nearest', 'midpoint'}
1206+
result_mask : ndarray[bool, ndim=2] or None
1207+
is_datetimelike : bool
1208+
Whether int64 values represent datetime64-like values.
12021209

12031210
Notes
12041211
-----
12051212
Rather than explicitly returning a value, this function modifies the
12061213
provided `out` parameter.
12071214
"""
12081215
cdef:
1209-
Py_ssize_t i, N=len(labels), ngroups, grp_sz, non_na_sz, k, nqs
1210-
Py_ssize_t grp_start=0, idx=0
1211-
intp_t lab
1216+
Py_ssize_t i, N=len(labels), ngroups, non_na_sz, k, nqs
1217+
Py_ssize_t idx=0
1218+
Py_ssize_t grp_size
12121219
InterpolationEnumType interp
12131220
float64_t q_val, q_idx, frac, val, next_val
1214-
int64_t[::1] counts, non_na_counts
12151221
bint uses_result_mask = result_mask is not None
1222+
Py_ssize_t start, end
1223+
ndarray[numeric_t] grp
1224+
intp_t[::1] sort_indexer
1225+
const uint8_t[:] sub_mask
12161226

12171227
assert values.shape[0] == N
1228+
assert starts is not None
1229+
assert ends is not None
1230+
assert len(starts) == len(ends)
12181231

12191232
if any(not (0 <= q <= 1) for q in qs):
12201233
wrong = [x for x in qs if not (0 <= x <= 1)][0]
@@ -1233,64 +1246,65 @@ def group_quantile(
12331246

12341247
nqs = len(qs)
12351248
ngroups = len(out)
1236-
counts = np.zeros(ngroups, dtype=np.int64)
1237-
non_na_counts = np.zeros(ngroups, dtype=np.int64)
1238-
1239-
# First figure out the size of every group
1240-
with nogil:
1241-
for i in range(N):
1242-
lab = labels[i]
1243-
if lab == -1: # NA group label
1244-
continue
12451249

1246-
counts[lab] += 1
1247-
if not mask[i]:
1248-
non_na_counts[lab] += 1
1250+
# TODO: get cnp.PyArray_ArgSort to work with nogil so we can restore the rest
1251+
# of this function as being `with nogil:`
1252+
for i in range(ngroups):
1253+
start = starts[i]
1254+
end = ends[i]
1255+
1256+
grp = values[start:end]
1257+
1258+
# Figure out how many group elements there are
1259+
sub_mask = mask[start:end]
1260+
grp_size = sub_mask.size
1261+
non_na_sz = 0
1262+
for k in range(grp_size):
1263+
if sub_mask[k] == 0:
1264+
non_na_sz += 1
1265+
1266+
# equiv: sort_indexer = grp.argsort()
1267+
if is_datetimelike:
1268+
# We need the argsort to put NaTs at the end, not the beginning
1269+
sort_indexer = cnp.PyArray_ArgSort(grp.view("M8[ns]"), 0, cnp.NPY_QUICKSORT)
1270+
else:
1271+
sort_indexer = cnp.PyArray_ArgSort(grp, 0, cnp.NPY_QUICKSORT)
12491272

1250-
with nogil:
1251-
for i in range(ngroups):
1252-
# Figure out how many group elements there are
1253-
grp_sz = counts[i]
1254-
non_na_sz = non_na_counts[i]
1255-
1256-
if non_na_sz == 0:
1257-
for k in range(nqs):
1258-
if uses_result_mask:
1259-
result_mask[i, k] = 1
1260-
else:
1261-
out[i, k] = NaN
1262-
else:
1263-
for k in range(nqs):
1264-
q_val = qs[k]
1273+
if non_na_sz == 0:
1274+
for k in range(nqs):
1275+
if uses_result_mask:
1276+
result_mask[i, k] = 1
1277+
else:
1278+
out[i, k] = NaN
1279+
else:
1280+
for k in range(nqs):
1281+
q_val = qs[k]
12651282

1266-
# Calculate where to retrieve the desired value
1267-
# Casting to int will intentionally truncate result
1268-
idx = grp_start + <int64_t>(q_val * <float64_t>(non_na_sz - 1))
1283+
# Calculate where to retrieve the desired value
1284+
# Casting to int will intentionally truncate result
1285+
idx = <int64_t>(q_val * <float64_t>(non_na_sz - 1))
12691286

1270-
val = values[sort_indexer[idx]]
1271-
# If requested quantile falls evenly on a particular index
1272-
# then write that index's value out. Otherwise interpolate
1273-
q_idx = q_val * (non_na_sz - 1)
1274-
frac = q_idx % 1
1287+
val = grp[sort_indexer[idx]]
1288+
# If requested quantile falls evenly on a particular index
1289+
# then write that index's value out. Otherwise interpolate
1290+
q_idx = q_val * (non_na_sz - 1)
1291+
frac = q_idx % 1
12751292

1276-
if frac == 0.0 or interp == INTERPOLATION_LOWER:
1277-
out[i, k] = val
1278-
else:
1279-
next_val = values[sort_indexer[idx + 1]]
1280-
if interp == INTERPOLATION_LINEAR:
1281-
out[i, k] = val + (next_val - val) * frac
1282-
elif interp == INTERPOLATION_HIGHER:
1293+
if frac == 0.0 or interp == INTERPOLATION_LOWER:
1294+
out[i, k] = val
1295+
else:
1296+
next_val = grp[sort_indexer[idx + 1]]
1297+
if interp == INTERPOLATION_LINEAR:
1298+
out[i, k] = val + (next_val - val) * frac
1299+
elif interp == INTERPOLATION_HIGHER:
1300+
out[i, k] = next_val
1301+
elif interp == INTERPOLATION_MIDPOINT:
1302+
out[i, k] = (val + next_val) / 2.0
1303+
elif interp == INTERPOLATION_NEAREST:
1304+
if frac > .5 or (frac == .5 and q_val > .5): # Always OK?
12831305
out[i, k] = next_val
1284-
elif interp == INTERPOLATION_MIDPOINT:
1285-
out[i, k] = (val + next_val) / 2.0
1286-
elif interp == INTERPOLATION_NEAREST:
1287-
if frac > .5 or (frac == .5 and q_val > .5): # Always OK?
1288-
out[i, k] = next_val
1289-
else:
1290-
out[i, k] = val
1291-
1292-
# Increment the index reference in sorted_arr for the next group
1293-
grp_start += grp_sz
1306+
else:
1307+
out[i, k] = val
12941308

12951309

12961310
# ----------------------------------------------------------------------

pandas/core/groupby/groupby.py

+34-33
Original file line numberDiff line numberDiff line change
@@ -4226,6 +4226,16 @@ def quantile(
42264226
a 2.0
42274227
b 3.0
42284228
"""
4229+
mgr = self._get_data_to_aggregate(numeric_only=numeric_only, name="quantile")
4230+
obj = self._wrap_agged_manager(mgr)
4231+
if self.axis == 1:
4232+
splitter = self.grouper._get_splitter(obj.T, axis=self.axis)
4233+
sdata = splitter._sorted_data.T
4234+
else:
4235+
splitter = self.grouper._get_splitter(obj, axis=self.axis)
4236+
sdata = splitter._sorted_data
4237+
4238+
starts, ends = lib.generate_slices(splitter._slabels, splitter.ngroups)
42294239

42304240
def pre_processor(vals: ArrayLike) -> tuple[np.ndarray, DtypeObj | None]:
42314241
if is_object_dtype(vals.dtype):
@@ -4323,24 +4333,24 @@ def post_processor(
43234333

43244334
return vals
43254335

4326-
orig_scalar = is_scalar(q)
4327-
if orig_scalar:
4336+
qs = np.array(q, dtype=np.float64)
4337+
pass_qs: np.ndarray | None = qs
4338+
if is_scalar(q):
43284339
qs = np.array([q], dtype=np.float64)
4329-
else:
4330-
qs = np.array(q, dtype=np.float64)
4340+
pass_qs = None
4341+
43314342
ids, _, ngroups = self.grouper.group_info
43324343
nqs = len(qs)
43334344

43344345
func = partial(
4335-
libgroupby.group_quantile, labels=ids, qs=qs, interpolation=interpolation
4346+
libgroupby.group_quantile,
4347+
labels=ids,
4348+
qs=qs,
4349+
interpolation=interpolation,
4350+
starts=starts,
4351+
ends=ends,
43364352
)
43374353

4338-
# Put '-1' (NaN) labels as the last group so it does not interfere
4339-
# with the calculations. Note: length check avoids failure on empty
4340-
# labels. In that case, the value doesn't matter
4341-
na_label_for_sorting = ids.max() + 1 if len(ids) > 0 else 0
4342-
labels_for_lexsort = np.where(ids == -1, na_label_for_sorting, ids)
4343-
43444354
def blk_func(values: ArrayLike) -> ArrayLike:
43454355
orig_vals = values
43464356
if isinstance(values, BaseMaskedArray):
@@ -4357,53 +4367,44 @@ def blk_func(values: ArrayLike) -> ArrayLike:
43574367
ncols = 1
43584368
if vals.ndim == 2:
43594369
ncols = vals.shape[0]
4360-
shaped_labels = np.broadcast_to(
4361-
labels_for_lexsort, (ncols, len(labels_for_lexsort))
4362-
)
4363-
else:
4364-
shaped_labels = labels_for_lexsort
43654370

43664371
out = np.empty((ncols, ngroups, nqs), dtype=np.float64)
43674372

4368-
# Get an index of values sorted by values and then labels
4369-
order = (vals, shaped_labels)
4370-
sort_arr = np.lexsort(order).astype(np.intp, copy=False)
4371-
43724373
if is_datetimelike:
4373-
# This casting needs to happen after the lexsort in order
4374-
# to ensure that NaTs are placed at the end and not the front
4375-
vals = vals.view("i8").astype(np.float64)
4374+
vals = vals.view("i8")
43764375

43774376
if vals.ndim == 1:
4378-
# Ea is always 1d
4377+
# EA is always 1d
43794378
func(
43804379
out[0],
43814380
values=vals,
43824381
mask=mask,
4383-
sort_indexer=sort_arr,
43844382
result_mask=result_mask,
4383+
is_datetimelike=is_datetimelike,
43854384
)
43864385
else:
43874386
for i in range(ncols):
4388-
func(out[i], values=vals[i], mask=mask[i], sort_indexer=sort_arr[i])
4387+
func(
4388+
out[i],
4389+
values=vals[i],
4390+
mask=mask[i],
4391+
result_mask=None,
4392+
is_datetimelike=is_datetimelike,
4393+
)
43894394

43904395
if vals.ndim == 1:
43914396
out = out.ravel("K")
43924397
if result_mask is not None:
43934398
result_mask = result_mask.ravel("K")
43944399
else:
43954400
out = out.reshape(ncols, ngroups * nqs)
4401+
43964402
return post_processor(out, inference, result_mask, orig_vals)
43974403

4398-
data = self._get_data_to_aggregate(numeric_only=numeric_only, name="quantile")
4399-
res_mgr = data.grouped_reduce(blk_func)
4404+
res_mgr = sdata._mgr.grouped_reduce(blk_func)
44004405

44014406
res = self._wrap_agged_manager(res_mgr)
4402-
4403-
if orig_scalar:
4404-
# Avoid expensive MultiIndex construction
4405-
return self._wrap_aggregated_output(res)
4406-
return self._wrap_aggregated_output(res, qs=qs)
4407+
return self._wrap_aggregated_output(res, qs=pass_qs)
44074408

44084409
@final
44094410
@Substitution(name="groupby")

0 commit comments

Comments
 (0)