@@ -774,6 +774,7 @@ def group_quantile(ndarray[float64_t, ndim=2] out,
774
774
ndarray[numeric , ndim = 1 ] values,
775
775
ndarray[intp_t] labels ,
776
776
ndarray[uint8_t] mask ,
777
+ const intp_t[:] sort_indexer ,
777
778
const float64_t[:] qs ,
778
779
str interpolation ) -> None:
779
780
"""
@@ -787,6 +788,8 @@ def group_quantile(ndarray[float64_t, ndim=2] out,
787
788
Array containing the values to apply the function against.
788
789
labels : ndarray[np.intp]
789
790
Array containing the unique group labels.
791
+ sort_indexer : ndarray[np.intp]
792
+ Indices describing sort order by values and labels.
790
793
qs : ndarray[float64_t]
791
794
The quantile values to search for.
792
795
interpolation : {'linear', 'lower', 'highest', 'nearest', 'midpoint'}
@@ -800,9 +803,9 @@ def group_quantile(ndarray[float64_t, ndim=2] out,
800
803
Py_ssize_t i , N = len (labels), ngroups , grp_sz , non_na_sz , k , nqs
801
804
Py_ssize_t grp_start = 0 , idx = 0
802
805
intp_t lab
803
- uint8_t interp
806
+ InterpolationEnumType interp
804
807
float64_t q_val , q_idx , frac , val , next_val
805
- ndarray[ int64_t] counts , non_na_counts , sort_arr
808
+ int64_t[::1 ] counts , non_na_counts
806
809
807
810
assert values.shape[0] == N
808
811
@@ -837,16 +840,6 @@ def group_quantile(ndarray[float64_t, ndim=2] out,
837
840
if not mask[i]:
838
841
non_na_counts[lab] += 1
839
842
840
- # Get an index of values sorted by labels and then values
841
- if labels.any():
842
- # Put '-1' (NaN) labels as the last group so it does not interfere
843
- # with the calculations.
844
- labels_for_lexsort = np.where(labels == - 1 , labels.max() + 1 , labels)
845
- else :
846
- labels_for_lexsort = labels
847
- order = (values, labels_for_lexsort)
848
- sort_arr = np.lexsort(order).astype(np.int64, copy = False )
849
-
850
843
with nogil:
851
844
for i in range (ngroups):
852
845
# Figure out how many group elements there are
@@ -864,7 +857,7 @@ def group_quantile(ndarray[float64_t, ndim=2] out,
864
857
# Casting to int will intentionally truncate result
865
858
idx = grp_start + < int64_t> (q_val * < float64_t> (non_na_sz - 1 ))
866
859
867
- val = values[sort_arr [idx]]
860
+ val = values[sort_indexer [idx]]
868
861
# If requested quantile falls evenly on a particular index
869
862
# then write that index's value out. Otherwise interpolate
870
863
q_idx = q_val * (non_na_sz - 1 )
@@ -873,7 +866,7 @@ def group_quantile(ndarray[float64_t, ndim=2] out,
873
866
if frac == 0.0 or interp == INTERPOLATION_LOWER:
874
867
out[i, k] = val
875
868
else :
876
- next_val = values[sort_arr [idx + 1 ]]
869
+ next_val = values[sort_indexer [idx + 1 ]]
877
870
if interp == INTERPOLATION_LINEAR:
878
871
out[i, k] = val + (next_val - val) * frac
879
872
elif interp == INTERPOLATION_HIGHER:
0 commit comments