Skip to content

Commit ce8f6e8

Browse files
kornilova203jreback
authored andcommitted
Add interpolation options to rolling quantile (#20497)
1 parent 0ae19a1 commit ce8f6e8

File tree

7 files changed

+198
-60
lines changed

7 files changed

+198
-60
lines changed

asv_bench/benchmarks/rolling.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def setup(self, constructor, window, dtype, method):
2222
def time_rolling(self, constructor, window, dtype, method):
2323
getattr(self.roll, method)()
2424

25+
2526
class VariableWindowMethods(Methods):
2627
sample_time = 0.2
2728
params = (['DataFrame', 'Series'],
@@ -37,6 +38,7 @@ def setup(self, constructor, window, dtype, method):
3738
index = pd.date_range('2017-01-01', periods=N, freq='5s')
3839
self.roll = getattr(pd, constructor)(arr, index=index).rolling(window)
3940

41+
4042
class Pairwise(object):
4143

4244
sample_time = 0.2
@@ -59,18 +61,19 @@ def time_pairwise(self, window, method, pairwise):
5961

6062

6163
class Quantile(object):
62-
6364
sample_time = 0.2
6465
params = (['DataFrame', 'Series'],
6566
[10, 1000],
6667
['int', 'float'],
67-
[0, 0.5, 1])
68+
[0, 0.5, 1],
69+
['linear', 'nearest', 'lower', 'higher', 'midpoint'])
6870
param_names = ['constructor', 'window', 'dtype', 'percentile']
6971

70-
def setup(self, constructor, window, dtype, percentile):
71-
N = 10**5
72+
def setup(self, constructor, window, dtype, percentile, interpolation):
73+
N = 10 ** 5
7274
arr = np.random.random(N).astype(dtype)
7375
self.roll = getattr(pd, constructor)(arr).rolling(window)
7476

75-
def time_quantile(self, constructor, window, dtype, percentile):
76-
self.roll.quantile(percentile)
77+
def time_quantile(self, constructor, window, dtype, percentile,
78+
interpolation):
79+
self.roll.quantile(percentile, interpolation=interpolation)

doc/source/whatsnew/v0.23.0.txt

+1
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,7 @@ Other Enhancements
443443
- :meth:`DataFrame.to_sql` now performs a multivalue insert if the underlying connection supports itk rather than inserting row by row.
444444
``SQLAlchemy`` dialects supporting multivalue inserts include: ``mysql``, ``postgresql``, ``sqlite`` and any dialect with ``supports_multivalues_insert``. (:issue:`14315`, :issue:`8953`)
445445
- :func:`read_html` now accepts a ``displayed_only`` keyword argument to controls whether or not hidden elements are parsed (``True`` by default) (:issue:`20027`)
446+
- :meth:`Rolling.quantile` and :meth:`Expanding.quantile` now accept the ``interpolation`` keyword, ``linear`` by default (:issue:`20497`)
446447
- zip compression is supported via ``compression=zip`` in :func:`DataFrame.to_pickle`, :func:`Series.to_pickle`, :func:`DataFrame.to_csv`, :func:`Series.to_csv`, :func:`DataFrame.to_json`, :func:`Series.to_json`. (:issue:`17778`)
447448
- :class:`WeekOfMonth` constructor now supports ``n=0`` (:issue:`20517`).
448449
- :class:`DataFrame` and :class:`Series` now support matrix multiplication (```@```) operator (:issue:`10259`) for Python>=3.5

pandas/_libs/window.pyx

+91-39
Original file line numberDiff line numberDiff line change
@@ -1357,77 +1357,129 @@ cdef _roll_min_max(ndarray[numeric] input, int64_t win, int64_t minp,
13571357
return output
13581358

13591359

1360+
cdef enum InterpolationType:
1361+
LINEAR,
1362+
LOWER,
1363+
HIGHER,
1364+
NEAREST,
1365+
MIDPOINT
1366+
1367+
1368+
interpolation_types = {
1369+
'linear': LINEAR,
1370+
'lower': LOWER,
1371+
'higher': HIGHER,
1372+
'nearest': NEAREST,
1373+
'midpoint': MIDPOINT,
1374+
}
1375+
1376+
13601377
def roll_quantile(ndarray[float64_t, cast=True] input, int64_t win,
13611378
int64_t minp, object index, object closed,
1362-
double quantile):
1379+
double quantile, str interpolation):
13631380
"""
13641381
O(N log(window)) implementation using skip list
13651382
"""
13661383
cdef:
1367-
double val, prev, midpoint
1368-
IndexableSkiplist skiplist
1384+
double val, prev, midpoint, idx_with_fraction
1385+
skiplist_t *skiplist
13691386
int64_t nobs = 0, i, j, s, e, N
13701387
Py_ssize_t idx
13711388
bint is_variable
13721389
ndarray[int64_t] start, end
13731390
ndarray[double_t] output
13741391
double vlow, vhigh
1392+
InterpolationType interpolation_type
1393+
int ret = 0
13751394

13761395
if quantile <= 0.0 or quantile >= 1.0:
13771396
raise ValueError("quantile value {0} not in [0, 1]".format(quantile))
13781397

1398+
try:
1399+
interpolation_type = interpolation_types[interpolation]
1400+
except KeyError:
1401+
raise ValueError("Interpolation '{}' is not supported"
1402+
.format(interpolation))
1403+
13791404
# we use the Fixed/Variable Indexer here as the
13801405
# actual skiplist ops outweigh any window computation costs
13811406
start, end, N, win, minp, is_variable = get_window_indexer(
13821407
input, win,
13831408
minp, index, closed,
13841409
use_mock=False)
13851410
output = np.empty(N, dtype=float)
1386-
skiplist = IndexableSkiplist(win)
1387-
1388-
for i in range(0, N):
1389-
s = start[i]
1390-
e = end[i]
1391-
1392-
if i == 0:
1393-
1394-
# setup
1395-
val = input[i]
1396-
if val == val:
1397-
nobs += 1
1398-
skiplist.insert(val)
1411+
skiplist = skiplist_init(<int>win)
1412+
if skiplist == NULL:
1413+
raise MemoryError("skiplist_init failed")
13991414

1400-
else:
1415+
with nogil:
1416+
for i in range(0, N):
1417+
s = start[i]
1418+
e = end[i]
14011419

1402-
# calculate deletes
1403-
for j in range(start[i - 1], s):
1404-
val = input[j]
1405-
if val == val:
1406-
skiplist.remove(val)
1407-
nobs -= 1
1420+
if i == 0:
14081421

1409-
# calculate adds
1410-
for j in range(end[i - 1], e):
1411-
val = input[j]
1422+
# setup
1423+
val = input[i]
14121424
if val == val:
14131425
nobs += 1
1414-
skiplist.insert(val)
1426+
skiplist_insert(skiplist, val)
14151427

1416-
if nobs >= minp:
1417-
idx = int(quantile * <double>(nobs - 1))
1428+
else:
14181429

1419-
# Single value in skip list
1420-
if nobs == 1:
1421-
output[i] = skiplist.get(0)
1430+
# calculate deletes
1431+
for j in range(start[i - 1], s):
1432+
val = input[j]
1433+
if val == val:
1434+
skiplist_remove(skiplist, val)
1435+
nobs -= 1
14221436

1423-
# Interpolated quantile
1437+
# calculate adds
1438+
for j in range(end[i - 1], e):
1439+
val = input[j]
1440+
if val == val:
1441+
nobs += 1
1442+
skiplist_insert(skiplist, val)
1443+
1444+
if nobs >= minp:
1445+
if nobs == 1:
1446+
# Single value in skip list
1447+
output[i] = skiplist_get(skiplist, 0, &ret)
1448+
else:
1449+
idx_with_fraction = quantile * (nobs - 1)
1450+
idx = <int> idx_with_fraction
1451+
1452+
if idx_with_fraction == idx:
1453+
# no need to interpolate
1454+
output[i] = skiplist_get(skiplist, idx, &ret)
1455+
continue
1456+
1457+
if interpolation_type == LINEAR:
1458+
vlow = skiplist_get(skiplist, idx, &ret)
1459+
vhigh = skiplist_get(skiplist, idx + 1, &ret)
1460+
output[i] = ((vlow + (vhigh - vlow) *
1461+
(idx_with_fraction - idx)))
1462+
elif interpolation_type == LOWER:
1463+
output[i] = skiplist_get(skiplist, idx, &ret)
1464+
elif interpolation_type == HIGHER:
1465+
output[i] = skiplist_get(skiplist, idx + 1, &ret)
1466+
elif interpolation_type == NEAREST:
1467+
# the same behaviour as round()
1468+
if idx_with_fraction - idx == 0.5:
1469+
if idx % 2 == 0:
1470+
output[i] = skiplist_get(skiplist, idx, &ret)
1471+
else:
1472+
output[i] = skiplist_get(skiplist, idx + 1, &ret)
1473+
elif idx_with_fraction - idx < 0.5:
1474+
output[i] = skiplist_get(skiplist, idx, &ret)
1475+
else:
1476+
output[i] = skiplist_get(skiplist, idx + 1, &ret)
1477+
elif interpolation_type == MIDPOINT:
1478+
vlow = skiplist_get(skiplist, idx, &ret)
1479+
vhigh = skiplist_get(skiplist, idx + 1, &ret)
1480+
output[i] = <double> (vlow + vhigh) / 2
14241481
else:
1425-
vlow = skiplist.get(idx)
1426-
vhigh = skiplist.get(idx + 1)
1427-
output[i] = ((vlow + (vhigh - vlow) *
1428-
(quantile * (nobs - 1) - idx)))
1429-
else:
1430-
output[i] = NaN
1482+
output[i] = NaN
14311483

14321484
return output
14331485

pandas/core/frame.py

+4
Original file line numberDiff line numberDiff line change
@@ -7079,6 +7079,10 @@ def quantile(self, q=0.5, axis=0, numeric_only=True,
70797079
a b
70807080
0.1 1.3 3.7
70817081
0.5 2.5 55.0
7082+
7083+
See Also
7084+
--------
7085+
pandas.core.window.Rolling.quantile
70827086
"""
70837087
self._check_percentile(q)
70847088

pandas/core/series.py

+3
Original file line numberDiff line numberDiff line change
@@ -1855,6 +1855,9 @@ def quantile(self, q=0.5, interpolation='linear'):
18551855
0.75 3.25
18561856
dtype: float64
18571857
1858+
See Also
1859+
--------
1860+
pandas.core.window.Rolling.quantile
18581861
"""
18591862

18601863
self._check_percentile(q)

pandas/core/window.py

+56-7
Original file line numberDiff line numberDiff line change
@@ -1276,9 +1276,53 @@ def kurt(self, **kwargs):
12761276
Parameters
12771277
----------
12781278
quantile : float
1279-
0 <= quantile <= 1""")
1279+
0 <= quantile <= 1
1280+
interpolation : {'linear', 'lower', 'higher', 'midpoint', 'nearest'}
1281+
.. versionadded:: 0.23.0
1282+
1283+
This optional parameter specifies the interpolation method to use,
1284+
when the desired quantile lies between two data points `i` and `j`:
1285+
1286+
* linear: `i + (j - i) * fraction`, where `fraction` is the
1287+
fractional part of the index surrounded by `i` and `j`.
1288+
* lower: `i`.
1289+
* higher: `j`.
1290+
* nearest: `i` or `j` whichever is nearest.
1291+
* midpoint: (`i` + `j`) / 2.
1292+
1293+
Returns
1294+
-------
1295+
Series or DataFrame
1296+
Returned object type is determined by the caller of the %(name)s
1297+
calculation.
1298+
1299+
Examples
1300+
--------
1301+
>>> s = Series([1, 2, 3, 4])
1302+
>>> s.rolling(2).quantile(.4, interpolation='lower')
1303+
0 NaN
1304+
1 1.0
1305+
2 2.0
1306+
3 3.0
1307+
dtype: float64
1308+
1309+
>>> s.rolling(2).quantile(.4, interpolation='midpoint')
1310+
0 NaN
1311+
1 1.5
1312+
2 2.5
1313+
3 3.5
1314+
dtype: float64
1315+
1316+
See Also
1317+
--------
1318+
pandas.Series.quantile : Computes value at the given quantile over all data
1319+
in Series.
1320+
pandas.DataFrame.quantile : Computes values at the given quantile over
1321+
requested axis in DataFrame.
1322+
1323+
""")
12801324

1281-
def quantile(self, quantile, **kwargs):
1325+
def quantile(self, quantile, interpolation='linear', **kwargs):
12821326
window = self._get_window()
12831327
index, indexi = self._get_index()
12841328

@@ -1292,7 +1336,8 @@ def f(arg, *args, **kwargs):
12921336
self.closed)
12931337
else:
12941338
return _window.roll_quantile(arg, window, minp, indexi,
1295-
self.closed, quantile)
1339+
self.closed, quantile,
1340+
interpolation)
12961341

12971342
return self._apply(f, 'quantile', quantile=quantile,
12981343
**kwargs)
@@ -1613,8 +1658,10 @@ def kurt(self, **kwargs):
16131658
@Substitution(name='rolling')
16141659
@Appender(_doc_template)
16151660
@Appender(_shared_docs['quantile'])
1616-
def quantile(self, quantile, **kwargs):
1617-
return super(Rolling, self).quantile(quantile=quantile, **kwargs)
1661+
def quantile(self, quantile, interpolation='linear', **kwargs):
1662+
return super(Rolling, self).quantile(quantile=quantile,
1663+
interpolation=interpolation,
1664+
**kwargs)
16181665

16191666
@Substitution(name='rolling')
16201667
@Appender(_doc_template)
@@ -1872,8 +1919,10 @@ def kurt(self, **kwargs):
18721919
@Substitution(name='expanding')
18731920
@Appender(_doc_template)
18741921
@Appender(_shared_docs['quantile'])
1875-
def quantile(self, quantile, **kwargs):
1876-
return super(Expanding, self).quantile(quantile=quantile, **kwargs)
1922+
def quantile(self, quantile, interpolation='linear', **kwargs):
1923+
return super(Expanding, self).quantile(quantile=quantile,
1924+
interpolation=interpolation,
1925+
**kwargs)
18771926

18781927
@Substitution(name='expanding')
18791928
@Appender(_doc_template)

pandas/tests/test_window.py

+34-8
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from datetime import datetime, timedelta
77
from numpy.random import randn
88
import numpy as np
9+
from pandas import _np_version_under1p12
910

1011
import pandas as pd
1112
from pandas import (Series, DataFrame, bdate_range,
@@ -1166,15 +1167,40 @@ def test_rolling_quantile_np_percentile(self):
11661167

11671168
tm.assert_almost_equal(df_quantile.values, np.array(np_percentile))
11681169

1169-
def test_rolling_quantile_series(self):
1170-
# #16211: Tests that rolling window's quantile default behavior
1171-
# is analogus to Series' quantile
1172-
arr = np.arange(100)
1173-
s = Series(arr)
1174-
q1 = s.quantile(0.1)
1175-
q2 = s.rolling(100).quantile(0.1).iloc[-1]
1170+
@pytest.mark.skipif(_np_version_under1p12,
1171+
reason='numpy midpoint interpolation is broken')
1172+
@pytest.mark.parametrize('quantile', [0.0, 0.1, 0.45, 0.5, 1])
1173+
@pytest.mark.parametrize('interpolation', ['linear', 'lower', 'higher',
1174+
'nearest', 'midpoint'])
1175+
@pytest.mark.parametrize('data', [[1., 2., 3., 4., 5., 6., 7.],
1176+
[8., 1., 3., 4., 5., 2., 6., 7.],
1177+
[0., np.nan, 0.2, np.nan, 0.4],
1178+
[np.nan, np.nan, np.nan, np.nan],
1179+
[np.nan, 0.1, np.nan, 0.3, 0.4, 0.5],
1180+
[0.5], [np.nan, 0.7, 0.6]])
1181+
def test_rolling_quantile_interpolation_options(self, quantile,
1182+
interpolation, data):
1183+
# Tests that rolling window's quantile behavior is analogous to
1184+
# Series' quantile for each interpolation option
1185+
s = Series(data)
1186+
1187+
q1 = s.quantile(quantile, interpolation)
1188+
q2 = s.expanding(min_periods=1).quantile(
1189+
quantile, interpolation).iloc[-1]
1190+
1191+
if np.isnan(q1):
1192+
assert np.isnan(q2)
1193+
else:
1194+
assert q1 == q2
1195+
1196+
def test_invalid_quantile_value(self):
1197+
data = np.arange(5)
1198+
s = Series(data)
11761199

1177-
tm.assert_almost_equal(q1, q2)
1200+
with pytest.raises(ValueError, match="Interpolation 'invalid'"
1201+
" is not supported"):
1202+
s.rolling(len(data), min_periods=1).quantile(
1203+
0.5, interpolation='invalid')
11781204

11791205
def test_rolling_quantile_param(self):
11801206
ser = Series([0.0, .1, .5, .9, 1.0])

0 commit comments

Comments
 (0)