Skip to content

Commit 9f90bd4

Browse files
authored
ENH: Rolling rank (pandas-dev#43338)
1 parent 323595a commit 9f90bd4

File tree

13 files changed

+427
-11
lines changed

13 files changed

+427
-11
lines changed

asv_bench/benchmarks/rolling.py

+27
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,33 @@ def time_quantile(self, constructor, window, dtype, percentile, interpolation):
180180
self.roll.quantile(percentile, interpolation=interpolation)
181181

182182

183+
class Rank:
184+
params = (
185+
["DataFrame", "Series"],
186+
[10, 1000],
187+
["int", "float"],
188+
[True, False],
189+
[True, False],
190+
["min", "max", "average"],
191+
)
192+
param_names = [
193+
"constructor",
194+
"window",
195+
"dtype",
196+
"percentile",
197+
"ascending",
198+
"method",
199+
]
200+
201+
def setup(self, constructor, window, dtype, percentile, ascending, method):
202+
N = 10 ** 5
203+
arr = np.random.random(N).astype(dtype)
204+
self.roll = getattr(pd, constructor)(arr).rolling(window)
205+
206+
def time_rank(self, constructor, window, dtype, percentile, ascending, method):
207+
self.roll.rank(pct=percentile, ascending=ascending, method=method)
208+
209+
183210
class PeakMemFixedWindowMinMax:
184211

185212
params = ["min", "max"]

doc/source/reference/window.rst

+2
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ Rolling window functions
3535
Rolling.aggregate
3636
Rolling.quantile
3737
Rolling.sem
38+
Rolling.rank
3839

3940
.. _api.functions_window:
4041

@@ -75,6 +76,7 @@ Expanding window functions
7576
Expanding.aggregate
7677
Expanding.quantile
7778
Expanding.sem
79+
Expanding.rank
7880

7981
.. _api.functions_ewm:
8082

doc/source/whatsnew/v1.4.0.rst

+15
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,21 @@ Multithreaded CSV reading with a new CSV Engine based on pyarrow
9494
:func:`pandas.read_csv` now accepts ``engine="pyarrow"`` (requires at least ``pyarrow`` 0.17.0) as an argument, allowing for faster csv parsing on multicore machines
9595
with pyarrow installed. See the :doc:`I/O docs </user_guide/io>` for more info. (:issue:`23697`)
9696

97+
.. _whatsnew_140.enhancements.window_rank:
98+
99+
Rank function for rolling and expanding windows
100+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
101+
102+
Added ``rank`` function to :class:`Rolling` and :class:`Expanding`. The new function supports the ``method``, ``ascending``, and ``pct`` flags of :meth:`DataFrame.rank`. The ``method`` argument supports ``min``, ``max``, and ``average`` ranking methods.
103+
Example:
104+
105+
.. ipython:: python
106+
107+
s = pd.Series([1, 4, 2, 3, 5, 3])
108+
s.rolling(3).rank()
109+
110+
s.rolling(3).rank(method="max")
111+
97112
.. _whatsnew_140.enhancements.other:
98113

99114
Other enhancements

pandas/_libs/algos.pxd

+8
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,11 @@ from pandas._libs.util cimport numeric
22

33

44
cdef numeric kth_smallest_c(numeric* arr, Py_ssize_t k, Py_ssize_t n) nogil
5+
6+
cdef enum TiebreakEnumType:
7+
TIEBREAK_AVERAGE
8+
TIEBREAK_MIN,
9+
TIEBREAK_MAX
10+
TIEBREAK_FIRST
11+
TIEBREAK_FIRST_DESCENDING
12+
TIEBREAK_DENSE

pandas/_libs/algos.pyx

-7
Original file line numberDiff line numberDiff line change
@@ -66,13 +66,6 @@ cdef:
6666
float64_t NaN = <float64_t>np.NaN
6767
int64_t NPY_NAT = get_nat()
6868

69-
cdef enum TiebreakEnumType:
70-
TIEBREAK_AVERAGE
71-
TIEBREAK_MIN,
72-
TIEBREAK_MAX
73-
TIEBREAK_FIRST
74-
TIEBREAK_FIRST_DESCENDING
75-
TIEBREAK_DENSE
7669

7770
tiebreakers = {
7871
"average": TIEBREAK_AVERAGE,

pandas/_libs/src/skiplist.h

+23-2
Original file line numberDiff line numberDiff line change
@@ -180,10 +180,30 @@ PANDAS_INLINE double skiplist_get(skiplist_t *skp, int i, int *ret) {
180180
return node->value;
181181
}
182182

183+
// Returns the lowest rank of all elements with value `value`, as opposed to the
184+
// highest rank returned by `skiplist_insert`.
185+
PANDAS_INLINE int skiplist_min_rank(skiplist_t *skp, double value) {
186+
node_t *node;
187+
int level, rank = 0;
188+
189+
node = skp->head;
190+
for (level = skp->maxlevels - 1; level >= 0; --level) {
191+
while (_node_cmp(node->next[level], value) > 0) {
192+
rank += node->width[level];
193+
node = node->next[level];
194+
}
195+
}
196+
197+
return rank + 1;
198+
}
199+
200+
// Returns the rank of the inserted element. When there are duplicates,
201+
// `rank` is the highest of the group, i.e. the 'max' method of
202+
// https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.rank.html
183203
PANDAS_INLINE int skiplist_insert(skiplist_t *skp, double value) {
184204
node_t *node, *prevnode, *newnode, *next_at_level;
185205
int *steps_at_level;
186-
int size, steps, level;
206+
int size, steps, level, rank = 0;
187207
node_t **chain;
188208

189209
chain = skp->tmp_chain;
@@ -197,6 +217,7 @@ PANDAS_INLINE int skiplist_insert(skiplist_t *skp, double value) {
197217
next_at_level = node->next[level];
198218
while (_node_cmp(next_at_level, value) >= 0) {
199219
steps_at_level[level] += node->width[level];
220+
rank += node->width[level];
200221
node = next_at_level;
201222
next_at_level = node->next[level];
202223
}
@@ -230,7 +251,7 @@ PANDAS_INLINE int skiplist_insert(skiplist_t *skp, double value) {
230251

231252
++(skp->size);
232253

233-
return 1;
254+
return rank + 1;
234255
}
235256

236257
PANDAS_INLINE int skiplist_remove(skiplist_t *skp, double value) {

pandas/_libs/window/aggregations.pyi

+11
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ from typing import (
66

77
import numpy as np
88

9+
from pandas._typing import WindowingRankType
10+
911
def roll_sum(
1012
values: np.ndarray, # const float64_t[:]
1113
start: np.ndarray, # np.ndarray[np.int64]
@@ -63,6 +65,15 @@ def roll_quantile(
6365
quantile: float, # float64_t
6466
interpolation: Literal["linear", "lower", "higher", "nearest", "midpoint"],
6567
) -> np.ndarray: ... # np.ndarray[float]
68+
def roll_rank(
69+
values: np.ndarray,
70+
start: np.ndarray,
71+
end: np.ndarray,
72+
minp: int,
73+
percentile: bool,
74+
method: WindowingRankType,
75+
ascending: bool,
76+
) -> np.ndarray: ... # np.ndarray[float]
6677
def roll_apply(
6778
obj: object,
6879
start: np.ndarray, # np.ndarray[np.int64]

pandas/_libs/window/aggregations.pyx

+122-2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ import cython
55
from libc.math cimport round
66
from libcpp.deque cimport deque
77

8+
from pandas._libs.algos cimport TiebreakEnumType
9+
810
import numpy as np
911

1012
cimport numpy as cnp
@@ -50,6 +52,8 @@ cdef extern from "../src/skiplist.h":
5052
double skiplist_get(skiplist_t*, int, int*) nogil
5153
int skiplist_insert(skiplist_t*, double) nogil
5254
int skiplist_remove(skiplist_t*, double) nogil
55+
int skiplist_rank(skiplist_t*, double) nogil
56+
int skiplist_min_rank(skiplist_t*, double) nogil
5357

5458
cdef:
5559
float32_t MINfloat32 = np.NINF
@@ -795,7 +799,7 @@ def roll_median_c(const float64_t[:] values, ndarray[int64_t] start,
795799
val = values[j]
796800
if notnan(val):
797801
nobs += 1
798-
err = skiplist_insert(sl, val) != 1
802+
err = skiplist_insert(sl, val) == -1
799803
if err:
800804
break
801805

@@ -806,7 +810,7 @@ def roll_median_c(const float64_t[:] values, ndarray[int64_t] start,
806810
val = values[j]
807811
if notnan(val):
808812
nobs += 1
809-
err = skiplist_insert(sl, val) != 1
813+
err = skiplist_insert(sl, val) == -1
810814
if err:
811815
break
812816

@@ -1139,6 +1143,122 @@ def roll_quantile(const float64_t[:] values, ndarray[int64_t] start,
11391143
return output
11401144

11411145

1146+
rolling_rank_tiebreakers = {
1147+
"average": TiebreakEnumType.TIEBREAK_AVERAGE,
1148+
"min": TiebreakEnumType.TIEBREAK_MIN,
1149+
"max": TiebreakEnumType.TIEBREAK_MAX,
1150+
}
1151+
1152+
1153+
def roll_rank(const float64_t[:] values, ndarray[int64_t] start,
1154+
ndarray[int64_t] end, int64_t minp, bint percentile,
1155+
str method, bint ascending) -> np.ndarray:
1156+
"""
1157+
O(N log(window)) implementation using skip list
1158+
1159+
derived from roll_quantile
1160+
"""
1161+
cdef:
1162+
Py_ssize_t i, j, s, e, N = len(values), idx
1163+
float64_t rank_min = 0, rank = 0
1164+
int64_t nobs = 0, win
1165+
float64_t val
1166+
skiplist_t *skiplist
1167+
float64_t[::1] output
1168+
TiebreakEnumType rank_type
1169+
1170+
try:
1171+
rank_type = rolling_rank_tiebreakers[method]
1172+
except KeyError:
1173+
raise ValueError(f"Method '{method}' is not supported")
1174+
1175+
is_monotonic_increasing_bounds = is_monotonic_increasing_start_end_bounds(
1176+
start, end
1177+
)
1178+
# we use the Fixed/Variable Indexer here as the
1179+
# actual skiplist ops outweigh any window computation costs
1180+
output = np.empty(N, dtype=np.float64)
1181+
1182+
win = (end - start).max()
1183+
if win == 0:
1184+
output[:] = NaN
1185+
return np.asarray(output)
1186+
skiplist = skiplist_init(<int>win)
1187+
if skiplist == NULL:
1188+
raise MemoryError("skiplist_init failed")
1189+
1190+
with nogil:
1191+
for i in range(N):
1192+
s = start[i]
1193+
e = end[i]
1194+
1195+
if i == 0 or not is_monotonic_increasing_bounds:
1196+
if not is_monotonic_increasing_bounds:
1197+
nobs = 0
1198+
skiplist_destroy(skiplist)
1199+
skiplist = skiplist_init(<int>win)
1200+
1201+
# setup
1202+
for j in range(s, e):
1203+
val = values[j] if ascending else -values[j]
1204+
if notnan(val):
1205+
nobs += 1
1206+
rank = skiplist_insert(skiplist, val)
1207+
if rank == -1:
1208+
raise MemoryError("skiplist_insert failed")
1209+
if rank_type == TiebreakEnumType.TIEBREAK_AVERAGE:
1210+
# The average rank of `val` is the sum of the ranks of all
1211+
# instances of `val` in the skip list divided by the number
1212+
# of instances. The sum of consecutive integers from 1 to N
1213+
# is N * (N + 1) / 2.
1214+
# The sum of the ranks is the sum of integers from the
1215+
# lowest rank to the highest rank, which is the sum of
1216+
# integers from 1 to the highest rank minus the sum of
1217+
# integers from 1 to one less than the lowest rank.
1218+
rank_min = skiplist_min_rank(skiplist, val)
1219+
rank = (((rank * (rank + 1) / 2)
1220+
- ((rank_min - 1) * rank_min / 2))
1221+
/ (rank - rank_min + 1))
1222+
elif rank_type == TiebreakEnumType.TIEBREAK_MIN:
1223+
rank = skiplist_min_rank(skiplist, val)
1224+
else:
1225+
rank = NaN
1226+
1227+
else:
1228+
# calculate deletes
1229+
for j in range(start[i - 1], s):
1230+
val = values[j] if ascending else -values[j]
1231+
if notnan(val):
1232+
skiplist_remove(skiplist, val)
1233+
nobs -= 1
1234+
1235+
# calculate adds
1236+
for j in range(end[i - 1], e):
1237+
val = values[j] if ascending else -values[j]
1238+
if notnan(val):
1239+
nobs += 1
1240+
rank = skiplist_insert(skiplist, val)
1241+
if rank == -1:
1242+
raise MemoryError("skiplist_insert failed")
1243+
if rank_type == TiebreakEnumType.TIEBREAK_AVERAGE:
1244+
rank_min = skiplist_min_rank(skiplist, val)
1245+
rank = (((rank * (rank + 1) / 2)
1246+
- ((rank_min - 1) * rank_min / 2))
1247+
/ (rank - rank_min + 1))
1248+
elif rank_type == TiebreakEnumType.TIEBREAK_MIN:
1249+
rank = skiplist_min_rank(skiplist, val)
1250+
else:
1251+
rank = NaN
1252+
if nobs >= minp:
1253+
output[i] = rank / nobs if percentile else rank
1254+
else:
1255+
output[i] = NaN
1256+
1257+
skiplist_destroy(skiplist)
1258+
1259+
return np.asarray(output)
1260+
1261+
11421262
def roll_apply(object obj,
11431263
ndarray[int64_t] start, ndarray[int64_t] end,
11441264
int64_t minp,

pandas/_typing.py

+3
Original file line numberDiff line numberDiff line change
@@ -219,3 +219,6 @@
219219
PositionalIndexer = Union[ScalarIndexer, SequenceIndexer]
220220
PositionalIndexerTuple = Tuple[PositionalIndexer, PositionalIndexer]
221221
PositionalIndexer2D = Union[PositionalIndexer, PositionalIndexerTuple]
222+
223+
# Windowing rank methods
224+
WindowingRankType = Literal["average", "min", "max"]

0 commit comments

Comments
 (0)