Skip to content

Commit c03f49c

Browse files
authored
PERF: groupby.quantile (#43469)
1 parent 805f0d9 commit c03f49c

File tree

6 files changed

+174
-105
lines changed

6 files changed

+174
-105
lines changed

doc/source/whatsnew/v1.4.0.rst

+2
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,8 @@ Performance improvements
294294
- Performance improvement in :meth:`to_datetime` with ``uint`` dtypes (:issue:`42606`)
295295
- Performance improvement in :meth:`Series.sparse.to_coo` (:issue:`42880`)
296296
- Performance improvement in indexing with a :class:`MultiIndex` indexer on another :class:`MultiIndex` (:issue:43370`)
297+
- Performance improvement in :meth:`GroupBy.quantile` (:issue:`43469`)
298+
-
297299

298300
.. ---------------------------------------------------------------------------
299301

pandas/_libs/groupby.pyi

+2-2
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,11 @@ def group_ohlc(
8484
min_count: int = ...,
8585
) -> None: ...
8686
def group_quantile(
87-
out: np.ndarray, # ndarray[float64_t]
87+
out: np.ndarray, # ndarray[float64_t, ndim=2]
8888
values: np.ndarray, # ndarray[numeric, ndim=1]
8989
labels: np.ndarray, # ndarray[int64_t]
9090
mask: np.ndarray, # ndarray[uint8_t]
91-
q: float, # float64_t
91+
qs: np.ndarray, # const float64_t[:]
9292
interpolation: Literal["linear", "lower", "higher", "nearest", "midpoint"],
9393
) -> None: ...
9494
def group_last(

pandas/_libs/groupby.pyx

+46-38
Original file line numberDiff line numberDiff line change
@@ -770,25 +770,25 @@ def group_ohlc(floating[:, ::1] out,
770770

771771
@cython.boundscheck(False)
772772
@cython.wraparound(False)
773-
def group_quantile(ndarray[float64_t] out,
773+
def group_quantile(ndarray[float64_t, ndim=2] out,
774774
ndarray[numeric, ndim=1] values,
775775
ndarray[intp_t] labels,
776776
ndarray[uint8_t] mask,
777-
float64_t q,
777+
const float64_t[:] qs,
778778
str interpolation) -> None:
779779
"""
780780
Calculate the quantile per group.
781781

782782
Parameters
783783
----------
784-
out : np.ndarray[np.float64]
784+
out : np.ndarray[np.float64, ndim=2]
785785
Array of aggregated values that will be written to.
786786
values : np.ndarray
787787
Array containing the values to apply the function against.
788788
labels : ndarray[np.intp]
789789
Array containing the unique group labels.
790-
q : float
791-
The quantile value to search for.
790+
qs : ndarray[float64_t]
791+
The quantile values to search for.
792792
interpolation : {'linear', 'lower', 'highest', 'nearest', 'midpoint'}
793793

794794
Notes
@@ -797,17 +797,20 @@ def group_quantile(ndarray[float64_t] out,
797797
provided `out` parameter.
798798
"""
799799
cdef:
800-
Py_ssize_t i, N=len(labels), ngroups, grp_sz, non_na_sz
800+
Py_ssize_t i, N=len(labels), ngroups, grp_sz, non_na_sz, k, nqs
801801
Py_ssize_t grp_start=0, idx=0
802802
intp_t lab
803803
uint8_t interp
804-
float64_t q_idx, frac, val, next_val
804+
float64_t q_val, q_idx, frac, val, next_val
805805
ndarray[int64_t] counts, non_na_counts, sort_arr
806806

807807
assert values.shape[0] == N
808808

809-
if not (0 <= q <= 1):
810-
raise ValueError(f"'q' must be between 0 and 1. Got '{q}' instead")
809+
if any(not (0 <= q <= 1) for q in qs):
810+
wrong = [x for x in qs if not (0 <= x <= 1)][0]
811+
raise ValueError(
812+
f"Each 'q' must be between 0 and 1. Got '{wrong}' instead"
813+
)
811814

812815
inter_methods = {
813816
'linear': INTERPOLATION_LINEAR,
@@ -818,9 +821,10 @@ def group_quantile(ndarray[float64_t] out,
818821
}
819822
interp = inter_methods[interpolation]
820823

821-
counts = np.zeros_like(out, dtype=np.int64)
822-
non_na_counts = np.zeros_like(out, dtype=np.int64)
823-
ngroups = len(counts)
824+
nqs = len(qs)
825+
ngroups = len(out)
826+
counts = np.zeros(ngroups, dtype=np.int64)
827+
non_na_counts = np.zeros(ngroups, dtype=np.int64)
824828

825829
# First figure out the size of every group
826830
with nogil:
@@ -850,33 +854,37 @@ def group_quantile(ndarray[float64_t] out,
850854
non_na_sz = non_na_counts[i]
851855

852856
if non_na_sz == 0:
853-
out[i] = NaN
857+
for k in range(nqs):
858+
out[i, k] = NaN
854859
else:
855-
# Calculate where to retrieve the desired value
856-
# Casting to int will intentionally truncate result
857-
idx = grp_start + <int64_t>(q * <float64_t>(non_na_sz - 1))
858-
859-
val = values[sort_arr[idx]]
860-
# If requested quantile falls evenly on a particular index
861-
# then write that index's value out. Otherwise interpolate
862-
q_idx = q * (non_na_sz - 1)
863-
frac = q_idx % 1
864-
865-
if frac == 0.0 or interp == INTERPOLATION_LOWER:
866-
out[i] = val
867-
else:
868-
next_val = values[sort_arr[idx + 1]]
869-
if interp == INTERPOLATION_LINEAR:
870-
out[i] = val + (next_val - val) * frac
871-
elif interp == INTERPOLATION_HIGHER:
872-
out[i] = next_val
873-
elif interp == INTERPOLATION_MIDPOINT:
874-
out[i] = (val + next_val) / 2.0
875-
elif interp == INTERPOLATION_NEAREST:
876-
if frac > .5 or (frac == .5 and q > .5): # Always OK?
877-
out[i] = next_val
878-
else:
879-
out[i] = val
860+
for k in range(nqs):
861+
q_val = qs[k]
862+
863+
# Calculate where to retrieve the desired value
864+
# Casting to int will intentionally truncate result
865+
idx = grp_start + <int64_t>(q_val * <float64_t>(non_na_sz - 1))
866+
867+
val = values[sort_arr[idx]]
868+
# If requested quantile falls evenly on a particular index
869+
# then write that index's value out. Otherwise interpolate
870+
q_idx = q_val * (non_na_sz - 1)
871+
frac = q_idx % 1
872+
873+
if frac == 0.0 or interp == INTERPOLATION_LOWER:
874+
out[i, k] = val
875+
else:
876+
next_val = values[sort_arr[idx + 1]]
877+
if interp == INTERPOLATION_LINEAR:
878+
out[i, k] = val + (next_val - val) * frac
879+
elif interp == INTERPOLATION_HIGHER:
880+
out[i, k] = next_val
881+
elif interp == INTERPOLATION_MIDPOINT:
882+
out[i, k] = (val + next_val) / 2.0
883+
elif interp == INTERPOLATION_NEAREST:
884+
if frac > .5 or (frac == .5 and q_val > .5): # Always OK?
885+
out[i, k] = next_val
886+
else:
887+
out[i, k] = val
880888

881889
# Increment the index reference in sorted_arr for the next group
882890
grp_start += grp_sz

0 commit comments

Comments
 (0)