Skip to content

Commit e0ef148

Browse files
authored
CLN/PERF: group quantile (#43510)
1 parent e1c7151 commit e0ef148

File tree

3 files changed

+32
-21
lines changed

3 files changed

+32
-21
lines changed

pandas/_libs/groupby.pyi

+7-4
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ from typing import Literal
22

33
import numpy as np
44

5+
from pandas._typing import npt
6+
57
def group_median_float64(
68
out: np.ndarray, # ndarray[float64_t, ndim=2]
79
counts: np.ndarray, # ndarray[int64_t]
@@ -84,11 +86,12 @@ def group_ohlc(
8486
min_count: int = ...,
8587
) -> None: ...
8688
def group_quantile(
87-
out: np.ndarray, # ndarray[float64_t, ndim=2]
89+
out: npt.NDArray[np.float64],
8890
values: np.ndarray, # ndarray[numeric, ndim=1]
89-
labels: np.ndarray, # ndarray[int64_t]
90-
mask: np.ndarray, # ndarray[uint8_t]
91-
qs: np.ndarray, # const float64_t[:]
91+
labels: npt.NDArray[np.intp],
92+
mask: npt.NDArray[np.uint8],
93+
sort_indexer: npt.NDArray[np.intp], # const
94+
qs: npt.NDArray[np.float64], # const
9295
interpolation: Literal["linear", "lower", "higher", "nearest", "midpoint"],
9396
) -> None: ...
9497
def group_last(

pandas/_libs/groupby.pyx

+7-14
Original file line numberDiff line numberDiff line change
@@ -774,6 +774,7 @@ def group_quantile(ndarray[float64_t, ndim=2] out,
774774
ndarray[numeric, ndim=1] values,
775775
ndarray[intp_t] labels,
776776
ndarray[uint8_t] mask,
777+
const intp_t[:] sort_indexer,
777778
const float64_t[:] qs,
778779
str interpolation) -> None:
779780
"""
@@ -787,6 +788,8 @@ def group_quantile(ndarray[float64_t, ndim=2] out,
787788
Array containing the values to apply the function against.
788789
labels : ndarray[np.intp]
789790
Array containing the unique group labels.
791+
sort_indexer : ndarray[np.intp]
792+
Indices describing sort order by values and labels.
790793
qs : ndarray[float64_t]
791794
The quantile values to search for.
792795
interpolation : {'linear', 'lower', 'highest', 'nearest', 'midpoint'}
@@ -800,9 +803,9 @@ def group_quantile(ndarray[float64_t, ndim=2] out,
800803
Py_ssize_t i, N=len(labels), ngroups, grp_sz, non_na_sz, k, nqs
801804
Py_ssize_t grp_start=0, idx=0
802805
intp_t lab
803-
uint8_t interp
806+
InterpolationEnumType interp
804807
float64_t q_val, q_idx, frac, val, next_val
805-
ndarray[int64_t] counts, non_na_counts, sort_arr
808+
int64_t[::1] counts, non_na_counts
806809

807810
assert values.shape[0] == N
808811

@@ -837,16 +840,6 @@ def group_quantile(ndarray[float64_t, ndim=2] out,
837840
if not mask[i]:
838841
non_na_counts[lab] += 1
839842

840-
# Get an index of values sorted by labels and then values
841-
if labels.any():
842-
# Put '-1' (NaN) labels as the last group so it does not interfere
843-
# with the calculations.
844-
labels_for_lexsort = np.where(labels == -1, labels.max() + 1, labels)
845-
else:
846-
labels_for_lexsort = labels
847-
order = (values, labels_for_lexsort)
848-
sort_arr = np.lexsort(order).astype(np.int64, copy=False)
849-
850843
with nogil:
851844
for i in range(ngroups):
852845
# Figure out how many group elements there are
@@ -864,7 +857,7 @@ def group_quantile(ndarray[float64_t, ndim=2] out,
864857
# Casting to int will intentionally truncate result
865858
idx = grp_start + <int64_t>(q_val * <float64_t>(non_na_sz - 1))
866859

867-
val = values[sort_arr[idx]]
860+
val = values[sort_indexer[idx]]
868861
# If requested quantile falls evenly on a particular index
869862
# then write that index's value out. Otherwise interpolate
870863
q_idx = q_val * (non_na_sz - 1)
@@ -873,7 +866,7 @@ def group_quantile(ndarray[float64_t, ndim=2] out,
873866
if frac == 0.0 or interp == INTERPOLATION_LOWER:
874867
out[i, k] = val
875868
else:
876-
next_val = values[sort_arr[idx + 1]]
869+
next_val = values[sort_indexer[idx + 1]]
877870
if interp == INTERPOLATION_LINEAR:
878871
out[i, k] = val + (next_val - val) * frac
879872
elif interp == INTERPOLATION_HIGHER:

pandas/core/groupby/groupby.py

+18-3
Original file line numberDiff line numberDiff line change
@@ -2648,24 +2648,39 @@ def post_processor(vals: np.ndarray, inference: np.dtype | None) -> np.ndarray:
26482648
libgroupby.group_quantile, labels=ids, qs=qs, interpolation=interpolation
26492649
)
26502650

2651+
# Put '-1' (NaN) labels as the last group so it does not interfere
2652+
# with the calculations. Note: length check avoids failure on empty
2653+
# labels. In that case, the value doesn't matter
2654+
na_label_for_sorting = ids.max() + 1 if len(ids) > 0 else 0
2655+
labels_for_lexsort = np.where(ids == -1, na_label_for_sorting, ids)
2656+
26512657
def blk_func(values: ArrayLike) -> ArrayLike:
26522658
mask = isna(values)
26532659
vals, inference = pre_processor(values)
26542660

26552661
ncols = 1
26562662
if vals.ndim == 2:
26572663
ncols = vals.shape[0]
2664+
shaped_labels = np.broadcast_to(
2665+
labels_for_lexsort, (ncols, len(labels_for_lexsort))
2666+
)
2667+
else:
2668+
shaped_labels = labels_for_lexsort
26582669

26592670
out = np.empty((ncols, ngroups, nqs), dtype=np.float64)
26602671

2672+
# Get an index of values sorted by values and then labels
2673+
order = (vals, shaped_labels)
2674+
sort_arr = np.lexsort(order).astype(np.intp, copy=False)
2675+
26612676
if vals.ndim == 1:
2662-
func(out[0], values=vals, mask=mask)
2677+
func(out[0], values=vals, mask=mask, sort_indexer=sort_arr)
26632678
else:
26642679
for i in range(ncols):
2665-
func(out[i], values=vals[i], mask=mask[i])
2680+
func(out[i], values=vals[i], mask=mask[i], sort_indexer=sort_arr[i])
26662681

26672682
if vals.ndim == 1:
2668-
out = out[0].ravel("K")
2683+
out = out.ravel("K")
26692684
else:
26702685
out = out.reshape(ncols, ngroups * nqs)
26712686
return post_processor(out, inference)

0 commit comments

Comments
 (0)