@@ -770,25 +770,25 @@ def group_ohlc(floating[:, ::1] out,
770
770
771
771
@ cython.boundscheck (False )
772
772
@ cython.wraparound (False )
773
- def group_quantile (ndarray[float64_t] out ,
773
+ def group_quantile (ndarray[float64_t , ndim = 2 ] out,
774
774
ndarray[numeric , ndim = 1 ] values,
775
775
ndarray[intp_t] labels ,
776
776
ndarray[uint8_t] mask ,
777
- float64_t q ,
777
+ const float64_t[:] qs ,
778
778
str interpolation ) -> None:
779
779
"""
780
780
Calculate the quantile per group.
781
781
782
782
Parameters
783
783
----------
784
- out : np.ndarray[np.float64]
784
+ out : np.ndarray[np.float64 , ndim = 2 ]
785
785
Array of aggregated values that will be written to.
786
786
values : np.ndarray
787
787
Array containing the values to apply the function against.
788
788
labels : ndarray[np.intp]
789
789
Array containing the unique group labels.
790
- q : float
791
- The quantile value to search for.
790
+ qs : ndarray[float64_t]
791
+ The quantile values to search for.
792
792
interpolation : {'linear', 'lower', 'highest', 'nearest', 'midpoint'}
793
793
794
794
Notes
@@ -797,17 +797,20 @@ def group_quantile(ndarray[float64_t] out,
797
797
provided `out` parameter.
798
798
"""
799
799
cdef:
800
- Py_ssize_t i , N = len (labels), ngroups , grp_sz , non_na_sz
800
+ Py_ssize_t i , N = len (labels), ngroups , grp_sz , non_na_sz , k , nqs
801
801
Py_ssize_t grp_start = 0 , idx = 0
802
802
intp_t lab
803
803
uint8_t interp
804
- float64_t q_idx , frac , val , next_val
804
+ float64_t q_val , q_idx , frac , val , next_val
805
805
ndarray[int64_t] counts , non_na_counts , sort_arr
806
806
807
807
assert values.shape[0] == N
808
808
809
- if not (0 <= q <= 1):
810
- raise ValueError (f" 'q' must be between 0 and 1. Got '{q}' instead" )
809
+ if any(not (0 <= q <= 1) for q in qs ):
810
+ wrong = [x for x in qs if not (0 <= x <= 1 )][0 ]
811
+ raise ValueError (
812
+ f" Each 'q' must be between 0 and 1. Got '{wrong}' instead"
813
+ )
811
814
812
815
inter_methods = {
813
816
' linear' : INTERPOLATION_LINEAR,
@@ -818,9 +821,10 @@ def group_quantile(ndarray[float64_t] out,
818
821
}
819
822
interp = inter_methods[interpolation]
820
823
821
- counts = np.zeros_like(out, dtype = np.int64)
822
- non_na_counts = np.zeros_like(out, dtype = np.int64)
823
- ngroups = len (counts)
824
+ nqs = len (qs)
825
+ ngroups = len (out)
826
+ counts = np.zeros(ngroups, dtype = np.int64)
827
+ non_na_counts = np.zeros(ngroups, dtype = np.int64)
824
828
825
829
# First figure out the size of every group
826
830
with nogil:
@@ -850,33 +854,37 @@ def group_quantile(ndarray[float64_t] out,
850
854
non_na_sz = non_na_counts[i]
851
855
852
856
if non_na_sz == 0 :
853
- out[i] = NaN
857
+ for k in range (nqs):
858
+ out[i, k] = NaN
854
859
else :
855
- # Calculate where to retrieve the desired value
856
- # Casting to int will intentionally truncate result
857
- idx = grp_start + < int64_t> (q * < float64_t> (non_na_sz - 1 ))
858
-
859
- val = values[sort_arr[idx]]
860
- # If requested quantile falls evenly on a particular index
861
- # then write that index's value out. Otherwise interpolate
862
- q_idx = q * (non_na_sz - 1 )
863
- frac = q_idx % 1
864
-
865
- if frac == 0.0 or interp == INTERPOLATION_LOWER:
866
- out[i] = val
867
- else :
868
- next_val = values[sort_arr[idx + 1 ]]
869
- if interp == INTERPOLATION_LINEAR:
870
- out[i] = val + (next_val - val) * frac
871
- elif interp == INTERPOLATION_HIGHER:
872
- out[i] = next_val
873
- elif interp == INTERPOLATION_MIDPOINT:
874
- out[i] = (val + next_val) / 2.0
875
- elif interp == INTERPOLATION_NEAREST:
876
- if frac > .5 or (frac == .5 and q > .5 ): # Always OK?
877
- out[i] = next_val
878
- else :
879
- out[i] = val
860
+ for k in range (nqs):
861
+ q_val = qs[k]
862
+
863
+ # Calculate where to retrieve the desired value
864
+ # Casting to int will intentionally truncate result
865
+ idx = grp_start + < int64_t> (q_val * < float64_t> (non_na_sz - 1 ))
866
+
867
+ val = values[sort_arr[idx]]
868
+ # If requested quantile falls evenly on a particular index
869
+ # then write that index's value out. Otherwise interpolate
870
+ q_idx = q_val * (non_na_sz - 1 )
871
+ frac = q_idx % 1
872
+
873
+ if frac == 0.0 or interp == INTERPOLATION_LOWER:
874
+ out[i, k] = val
875
+ else :
876
+ next_val = values[sort_arr[idx + 1 ]]
877
+ if interp == INTERPOLATION_LINEAR:
878
+ out[i, k] = val + (next_val - val) * frac
879
+ elif interp == INTERPOLATION_HIGHER:
880
+ out[i, k] = next_val
881
+ elif interp == INTERPOLATION_MIDPOINT:
882
+ out[i, k] = (val + next_val) / 2.0
883
+ elif interp == INTERPOLATION_NEAREST:
884
+ if frac > .5 or (frac == .5 and q_val > .5 ): # Always OK?
885
+ out[i, k] = next_val
886
+ else :
887
+ out[i, k] = val
880
888
881
889
# Increment the index reference in sorted_arr for the next group
882
890
grp_start += grp_sz
0 commit comments