Skip to content

Commit 64e5612

Browse files
WillAydjreback
authored andcommitted
Cythonized GroupBy Quantile (#20405)
1 parent e52f063 commit 64e5612

File tree

7 files changed

+258
-19
lines changed

7 files changed

+258
-19
lines changed

asv_bench/benchmarks/groupby.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
method_blacklist = {
1515
'object': {'median', 'prod', 'sem', 'cumsum', 'sum', 'cummin', 'mean',
1616
'max', 'skew', 'cumprod', 'cummax', 'rank', 'pct_change', 'min',
17-
'var', 'mad', 'describe', 'std'},
17+
'var', 'mad', 'describe', 'std', 'quantile'},
1818
'datetime': {'median', 'prod', 'sem', 'cumsum', 'sum', 'mean', 'skew',
1919
'cumprod', 'cummax', 'pct_change', 'var', 'mad', 'describe',
2020
'std'}
@@ -316,8 +316,9 @@ class GroupByMethods(object):
316316
['all', 'any', 'bfill', 'count', 'cumcount', 'cummax', 'cummin',
317317
'cumprod', 'cumsum', 'describe', 'ffill', 'first', 'head',
318318
'last', 'mad', 'max', 'min', 'median', 'mean', 'nunique',
319-
'pct_change', 'prod', 'rank', 'sem', 'shift', 'size', 'skew',
320-
'std', 'sum', 'tail', 'unique', 'value_counts', 'var'],
319+
'pct_change', 'prod', 'quantile', 'rank', 'sem', 'shift',
320+
'size', 'skew', 'std', 'sum', 'tail', 'unique', 'value_counts',
321+
'var'],
321322
['direct', 'transformation']]
322323

323324
def setup(self, dtype, method, application):

doc/source/whatsnew/v0.25.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ Performance Improvements
112112
- `DataFrame.to_stata()` is now faster when outputting data with any string or non-native endian columns (:issue:`25045`)
113113
- Improved performance of :meth:`Series.searchsorted`. The speedup is especially large when the dtype is
114114
int8/int16/int32 and the searched key is within the integer bounds for the dtype (:issue:`22034`)
115+
- Improved performance of :meth:`pandas.core.groupby.GroupBy.quantile` (:issue:`20405`)
115116

116117

117118
.. _whatsnew_0250.bug_fixes:

pandas/_libs/groupby.pxd

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
cdef enum InterpolationEnumType:
2+
INTERPOLATION_LINEAR,
3+
INTERPOLATION_LOWER,
4+
INTERPOLATION_HIGHER,
5+
INTERPOLATION_NEAREST,
6+
INTERPOLATION_MIDPOINT

pandas/_libs/groupby.pyx

+101
Original file line numberDiff line numberDiff line change
@@ -644,5 +644,106 @@ def _group_ohlc(floating[:, :] out,
644644
group_ohlc_float32 = _group_ohlc['float']
645645
group_ohlc_float64 = _group_ohlc['double']
646646

647+
648+
@cython.boundscheck(False)
649+
@cython.wraparound(False)
650+
def group_quantile(ndarray[float64_t] out,
651+
ndarray[int64_t] labels,
652+
numeric[:] values,
653+
ndarray[uint8_t] mask,
654+
float64_t q,
655+
object interpolation):
656+
"""
657+
Calculate the quantile per group.
658+
659+
Parameters
660+
----------
661+
out : ndarray
662+
Array of aggregated values that will be written to.
663+
labels : ndarray
664+
Array containing the unique group labels.
665+
values : ndarray
666+
Array containing the values to apply the function against.
667+
q : float
668+
The quantile value to search for.
669+
670+
Notes
671+
-----
672+
Rather than explicitly returning a value, this function modifies the
673+
provided `out` parameter.
674+
"""
675+
cdef:
676+
Py_ssize_t i, N=len(labels), ngroups, grp_sz, non_na_sz
677+
Py_ssize_t grp_start=0, idx=0
678+
int64_t lab
679+
uint8_t interp
680+
float64_t q_idx, frac, val, next_val
681+
ndarray[int64_t] counts, non_na_counts, sort_arr
682+
683+
assert values.shape[0] == N
684+
inter_methods = {
685+
'linear': INTERPOLATION_LINEAR,
686+
'lower': INTERPOLATION_LOWER,
687+
'higher': INTERPOLATION_HIGHER,
688+
'nearest': INTERPOLATION_NEAREST,
689+
'midpoint': INTERPOLATION_MIDPOINT,
690+
}
691+
interp = inter_methods[interpolation]
692+
693+
counts = np.zeros_like(out, dtype=np.int64)
694+
non_na_counts = np.zeros_like(out, dtype=np.int64)
695+
ngroups = len(counts)
696+
697+
# First figure out the size of every group
698+
with nogil:
699+
for i in range(N):
700+
lab = labels[i]
701+
counts[lab] += 1
702+
if not mask[i]:
703+
non_na_counts[lab] += 1
704+
705+
# Get an index of values sorted by labels and then values
706+
order = (values, labels)
707+
sort_arr = np.lexsort(order).astype(np.int64, copy=False)
708+
709+
with nogil:
710+
for i in range(ngroups):
711+
# Figure out how many group elements there are
712+
grp_sz = counts[i]
713+
non_na_sz = non_na_counts[i]
714+
715+
if non_na_sz == 0:
716+
out[i] = NaN
717+
else:
718+
# Calculate where to retrieve the desired value
719+
# Casting to int will intentionaly truncate result
720+
idx = grp_start + <int64_t>(q * <float64_t>(non_na_sz - 1))
721+
722+
val = values[sort_arr[idx]]
723+
# If requested quantile falls evenly on a particular index
724+
# then write that index's value out. Otherwise interpolate
725+
q_idx = q * (non_na_sz - 1)
726+
frac = q_idx % 1
727+
728+
if frac == 0.0 or interp == INTERPOLATION_LOWER:
729+
out[i] = val
730+
else:
731+
next_val = values[sort_arr[idx + 1]]
732+
if interp == INTERPOLATION_LINEAR:
733+
out[i] = val + (next_val - val) * frac
734+
elif interp == INTERPOLATION_HIGHER:
735+
out[i] = next_val
736+
elif interp == INTERPOLATION_MIDPOINT:
737+
out[i] = (val + next_val) / 2.0
738+
elif interp == INTERPOLATION_NEAREST:
739+
if frac > .5 or (frac == .5 and q > .5): # Always OK?
740+
out[i] = next_val
741+
else:
742+
out[i] = val
743+
744+
# Increment the index reference in sorted_arr for the next group
745+
grp_start += grp_sz
746+
747+
647748
# generated from template
648749
include "groupby_helper.pxi"

pandas/core/groupby/groupby.py

+92-11
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ class providing the base-class of operations.
2929
ensure_float, is_extension_array_dtype, is_numeric_dtype, is_scalar)
3030
from pandas.core.dtypes.missing import isna, notna
3131

32+
from pandas.api.types import (
33+
is_datetime64_dtype, is_integer_dtype, is_object_dtype)
3234
import pandas.core.algorithms as algorithms
3335
from pandas.core.base import (
3436
DataError, GroupByError, PandasObject, SelectionMixin, SpecificationError)
@@ -1024,15 +1026,17 @@ def _bool_agg(self, val_test, skipna):
10241026
"""
10251027

10261028
def objs_to_bool(vals):
1027-
try:
1028-
vals = vals.astype(np.bool)
1029-
except ValueError: # for objects
1029+
# type: np.ndarray -> (np.ndarray, typing.Type)
1030+
if is_object_dtype(vals):
10301031
vals = np.array([bool(x) for x in vals])
1032+
else:
1033+
vals = vals.astype(np.bool)
10311034

1032-
return vals.view(np.uint8)
1035+
return vals.view(np.uint8), np.bool
10331036

1034-
def result_to_bool(result):
1035-
return result.astype(np.bool, copy=False)
1037+
def result_to_bool(result, inference):
1038+
# type: (np.ndarray, typing.Type) -> np.ndarray
1039+
return result.astype(inference, copy=False)
10361040

10371041
return self._get_cythonized_result('group_any_all', self.grouper,
10381042
aggregate=True,
@@ -1688,6 +1692,75 @@ def nth(self, n, dropna=None):
16881692

16891693
return result
16901694

1695+
def quantile(self, q=0.5, interpolation='linear'):
1696+
"""
1697+
Return group values at the given quantile, a la numpy.percentile.
1698+
1699+
Parameters
1700+
----------
1701+
q : float or array-like, default 0.5 (50% quantile)
1702+
Value(s) between 0 and 1 providing the quantile(s) to compute.
1703+
interpolation : {'linear', 'lower', 'higher', 'midpoint', 'nearest'}
1704+
Method to use when the desired quantile falls between two points.
1705+
1706+
Returns
1707+
-------
1708+
Series or DataFrame
1709+
Return type determined by caller of GroupBy object.
1710+
1711+
See Also
1712+
--------
1713+
Series.quantile : Similar method for Series.
1714+
DataFrame.quantile : Similar method for DataFrame.
1715+
numpy.percentile : NumPy method to compute qth percentile.
1716+
1717+
Examples
1718+
--------
1719+
>>> df = pd.DataFrame([
1720+
... ['a', 1], ['a', 2], ['a', 3],
1721+
... ['b', 1], ['b', 3], ['b', 5]
1722+
... ], columns=['key', 'val'])
1723+
>>> df.groupby('key').quantile()
1724+
val
1725+
key
1726+
a 2.0
1727+
b 3.0
1728+
"""
1729+
1730+
def pre_processor(vals):
1731+
# type: np.ndarray -> (np.ndarray, Optional[typing.Type])
1732+
if is_object_dtype(vals):
1733+
raise TypeError("'quantile' cannot be performed against "
1734+
"'object' dtypes!")
1735+
1736+
inference = None
1737+
if is_integer_dtype(vals):
1738+
inference = np.int64
1739+
elif is_datetime64_dtype(vals):
1740+
inference = 'datetime64[ns]'
1741+
vals = vals.astype(np.float)
1742+
1743+
return vals, inference
1744+
1745+
def post_processor(vals, inference):
1746+
# type: (np.ndarray, Optional[typing.Type]) -> np.ndarray
1747+
if inference:
1748+
# Check for edge case
1749+
if not (is_integer_dtype(inference) and
1750+
interpolation in {'linear', 'midpoint'}):
1751+
vals = vals.astype(inference)
1752+
1753+
return vals
1754+
1755+
return self._get_cythonized_result('group_quantile', self.grouper,
1756+
aggregate=True,
1757+
needs_values=True,
1758+
needs_mask=True,
1759+
cython_dtype=np.float64,
1760+
pre_processing=pre_processor,
1761+
post_processing=post_processor,
1762+
q=q, interpolation=interpolation)
1763+
16911764
@Substitution(name='groupby')
16921765
def ngroup(self, ascending=True):
16931766
"""
@@ -1924,10 +1997,16 @@ def _get_cythonized_result(self, how, grouper, aggregate=False,
19241997
Whether the result of the Cython operation is an index of
19251998
values to be retrieved, instead of the actual values themselves
19261999
pre_processing : function, default None
1927-
Function to be applied to `values` prior to passing to Cython
1928-
Raises if `needs_values` is False
2000+
Function to be applied to `values` prior to passing to Cython.
2001+
Function should return a tuple where the first element is the
2002+
values to be passed to Cython and the second element is an optional
2003+
type which the values should be converted to after being returned
2004+
by the Cython operation. Raises if `needs_values` is False.
19292005
post_processing : function, default None
1930-
Function to be applied to result of Cython function
2006+
Function to be applied to result of Cython function. Should accept
2007+
an array of values as the first argument and type inferences as its
2008+
second argument, i.e. the signature should be
2009+
(ndarray, typing.Type).
19312010
**kwargs : dict
19322011
Extra arguments to be passed back to Cython funcs
19332012
@@ -1963,10 +2042,12 @@ def _get_cythonized_result(self, how, grouper, aggregate=False,
19632042

19642043
result = np.zeros(result_sz, dtype=cython_dtype)
19652044
func = partial(base_func, result, labels)
2045+
inferences = None
2046+
19662047
if needs_values:
19672048
vals = obj.values
19682049
if pre_processing:
1969-
vals = pre_processing(vals)
2050+
vals, inferences = pre_processing(vals)
19702051
func = partial(func, vals)
19712052

19722053
if needs_mask:
@@ -1982,7 +2063,7 @@ def _get_cythonized_result(self, how, grouper, aggregate=False,
19822063
result = algorithms.take_nd(obj.values, result)
19832064

19842065
if post_processing:
1985-
result = post_processing(result)
2066+
result = post_processing(result, inferences)
19862067

19872068
output[name] = result
19882069

pandas/tests/groupby/test_function.py

+49
Original file line numberDiff line numberDiff line change
@@ -1069,6 +1069,55 @@ def test_size(df):
10691069
tm.assert_series_equal(df.groupby('A').size(), out)
10701070

10711071

1072+
# quantile
1073+
# --------------------------------
1074+
@pytest.mark.parametrize("interpolation", [
1075+
"linear", "lower", "higher", "nearest", "midpoint"])
1076+
@pytest.mark.parametrize("a_vals,b_vals", [
1077+
# Ints
1078+
([1, 2, 3, 4, 5], [5, 4, 3, 2, 1]),
1079+
([1, 2, 3, 4], [4, 3, 2, 1]),
1080+
([1, 2, 3, 4, 5], [4, 3, 2, 1]),
1081+
# Floats
1082+
([1., 2., 3., 4., 5.], [5., 4., 3., 2., 1.]),
1083+
# Missing data
1084+
([1., np.nan, 3., np.nan, 5.], [5., np.nan, 3., np.nan, 1.]),
1085+
([np.nan, 4., np.nan, 2., np.nan], [np.nan, 4., np.nan, 2., np.nan]),
1086+
# Timestamps
1087+
([x for x in pd.date_range('1/1/18', freq='D', periods=5)],
1088+
[x for x in pd.date_range('1/1/18', freq='D', periods=5)][::-1]),
1089+
# All NA
1090+
([np.nan] * 5, [np.nan] * 5),
1091+
])
1092+
@pytest.mark.parametrize('q', [0, .25, .5, .75, 1])
1093+
def test_quantile(interpolation, a_vals, b_vals, q):
1094+
if interpolation == 'nearest' and q == 0.5 and b_vals == [4, 3, 2, 1]:
1095+
pytest.skip("Unclear numpy expectation for nearest result with "
1096+
"equidistant data")
1097+
1098+
a_expected = pd.Series(a_vals).quantile(q, interpolation=interpolation)
1099+
b_expected = pd.Series(b_vals).quantile(q, interpolation=interpolation)
1100+
1101+
df = DataFrame({
1102+
'key': ['a'] * len(a_vals) + ['b'] * len(b_vals),
1103+
'val': a_vals + b_vals})
1104+
1105+
expected = DataFrame([a_expected, b_expected], columns=['val'],
1106+
index=Index(['a', 'b'], name='key'))
1107+
result = df.groupby('key').quantile(q, interpolation=interpolation)
1108+
1109+
tm.assert_frame_equal(result, expected)
1110+
1111+
1112+
def test_quantile_raises():
1113+
df = pd.DataFrame([
1114+
['foo', 'a'], ['foo', 'b'], ['foo', 'c']], columns=['key', 'val'])
1115+
1116+
with pytest.raises(TypeError, match="cannot be performed against "
1117+
"'object' dtypes"):
1118+
df.groupby('key').quantile()
1119+
1120+
10721121
# pipe
10731122
# --------------------------------
10741123

pandas/tests/groupby/test_groupby.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ def f(x, q=None, axis=0):
208208
trans_expected = ts_grouped.transform(g)
209209

210210
assert_series_equal(apply_result, agg_expected)
211-
assert_series_equal(agg_result, agg_expected, check_names=False)
211+
assert_series_equal(agg_result, agg_expected)
212212
assert_series_equal(trans_result, trans_expected)
213213

214214
agg_result = ts_grouped.agg(f, q=80)
@@ -223,13 +223,13 @@ def f(x, q=None, axis=0):
223223
agg_result = df_grouped.agg(np.percentile, 80, axis=0)
224224
apply_result = df_grouped.apply(DataFrame.quantile, .8)
225225
expected = df_grouped.quantile(.8)
226-
assert_frame_equal(apply_result, expected)
227-
assert_frame_equal(agg_result, expected, check_names=False)
226+
assert_frame_equal(apply_result, expected, check_names=False)
227+
assert_frame_equal(agg_result, expected)
228228

229229
agg_result = df_grouped.agg(f, q=80)
230230
apply_result = df_grouped.apply(DataFrame.quantile, q=.8)
231-
assert_frame_equal(agg_result, expected, check_names=False)
232-
assert_frame_equal(apply_result, expected)
231+
assert_frame_equal(agg_result, expected)
232+
assert_frame_equal(apply_result, expected, check_names=False)
233233

234234

235235
def test_len():

0 commit comments

Comments
 (0)