Skip to content

Commit ace46fd

Browse files
authored
PERF: expanding/rolling.min/max with engine='numba' (#45170)
1 parent 0ccb0ef commit ace46fd

File tree

6 files changed

+101
-23
lines changed

6 files changed

+101
-23
lines changed

doc/source/whatsnew/v1.4.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -665,7 +665,7 @@ Performance improvements
665665
- :meth:`SparseArray.min` and :meth:`SparseArray.max` no longer require converting to a dense array (:issue:`43526`)
666666
- Indexing into a :class:`SparseArray` with a ``slice`` with ``step=1`` no longer requires converting to a dense array (:issue:`43777`)
667667
- Performance improvement in :meth:`SparseArray.take` with ``allow_fill=False`` (:issue:`43654`)
668-
- Performance improvement in :meth:`.Rolling.mean`, :meth:`.Expanding.mean`, :meth:`.Rolling.sum`, :meth:`.Expanding.sum` with ``engine="numba"`` (:issue:`43612`, :issue:`44176`)
668+
- Performance improvement in :meth:`.Rolling.mean`, :meth:`.Expanding.mean`, :meth:`.Rolling.sum`, :meth:`.Expanding.sum`, :meth:`.Rolling.max`, :meth:`.Expanding.max`, :meth:`.Rolling.min` and :meth:`.Expanding.min` with ``engine="numba"`` (:issue:`43612`, :issue:`44176`, :issue:`45170`)
669669
- Improved performance of :meth:`pandas.read_csv` with ``memory_map=True`` when file encoding is UTF-8 (:issue:`43787`)
670670
- Performance improvement in :meth:`RangeIndex.sort_values` overriding :meth:`Index.sort_values` (:issue:`43666`)
671671
- Performance improvement in :meth:`RangeIndex.insert` (:issue:`43988`)

pandas/_libs/window/aggregations.pyx

+1-1
Original file line numberDiff line numberDiff line change
@@ -960,7 +960,7 @@ cdef _roll_min_max(ndarray[numeric_t] values,
960960
with nogil:
961961

962962
# This is using a modified version of the C++ code in this
963-
# SO post: http://bit.ly/2nOoHlY
963+
# SO post: https://stackoverflow.com/a/12239580
964964
# The original impl didn't deal with variable window sizes
965965
# So the code was optimized for that
966966

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from pandas.core._numba.kernels.mean_ import sliding_mean
2+
from pandas.core._numba.kernels.min_max_ import sliding_min_max
23
from pandas.core._numba.kernels.sum_ import sliding_sum
34
from pandas.core._numba.kernels.var_ import sliding_var
45

5-
__all__ = ["sliding_mean", "sliding_sum", "sliding_var"]
6+
__all__ = ["sliding_mean", "sliding_sum", "sliding_var", "sliding_min_max"]
+70
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
"""
2+
Numba 1D min/max kernels that can be shared by
3+
* Dataframe / Series
4+
* groupby
5+
* rolling / expanding
6+
7+
Mirrors pandas/_libs/window/aggregation.pyx
8+
"""
9+
from __future__ import annotations
10+
11+
import numba
12+
import numpy as np
13+
14+
15+
@numba.jit(nopython=True, nogil=True, parallel=False)
16+
def sliding_min_max(
17+
values: np.ndarray,
18+
start: np.ndarray,
19+
end: np.ndarray,
20+
min_periods: int,
21+
is_max: bool,
22+
) -> np.ndarray:
23+
N = len(start)
24+
nobs = 0
25+
output = np.empty(N, dtype=np.float64)
26+
# Use deque once numba supports it
27+
# https://github.com/numba/numba/issues/7417
28+
Q: list = []
29+
W: list = []
30+
for i in range(N):
31+
32+
curr_win_size = end[i] - start[i]
33+
if i == 0:
34+
st = start[i]
35+
else:
36+
st = end[i - 1]
37+
38+
for k in range(st, end[i]):
39+
ai = values[k]
40+
if not np.isnan(ai):
41+
nobs += 1
42+
elif is_max:
43+
ai = -np.inf
44+
else:
45+
ai = np.inf
46+
# Discard previous entries if we find new min or max
47+
if is_max:
48+
while Q and ((ai >= values[Q[-1]]) or values[Q[-1]] != values[Q[-1]]):
49+
Q.pop()
50+
else:
51+
while Q and ((ai <= values[Q[-1]]) or values[Q[-1]] != values[Q[-1]]):
52+
Q.pop()
53+
Q.append(k)
54+
W.append(k)
55+
56+
# Discard entries outside and left of current window
57+
while Q and Q[0] <= start[i] - 1:
58+
Q.pop(0)
59+
while W and W[0] <= start[i] - 1:
60+
if not np.isnan(values[W[0]]):
61+
nobs -= 1
62+
W.pop(0)
63+
64+
# Save output based on index in input value array
65+
if Q and curr_win_size > 0 and nobs >= min_periods:
66+
output[i] = values[Q[0]]
67+
else:
68+
output[i] = np.nan
69+
70+
return output

pandas/core/window/rolling.py

+20-14
Original file line numberDiff line numberDiff line change
@@ -1385,15 +1385,18 @@ def max(
13851385
if maybe_use_numba(engine):
13861386
if self.method == "table":
13871387
func = generate_manual_numpy_nan_agg_with_axis(np.nanmax)
1388+
return self.apply(
1389+
func,
1390+
raw=True,
1391+
engine=engine,
1392+
engine_kwargs=engine_kwargs,
1393+
)
13881394
else:
1389-
func = np.nanmax
1395+
from pandas.core._numba.kernels import sliding_min_max
13901396

1391-
return self.apply(
1392-
func,
1393-
raw=True,
1394-
engine=engine,
1395-
engine_kwargs=engine_kwargs,
1396-
)
1397+
return self._numba_apply(
1398+
sliding_min_max, "rolling_max", engine_kwargs, True
1399+
)
13971400
window_func = window_aggregations.roll_max
13981401
return self._apply(window_func, name="max", **kwargs)
13991402

@@ -1408,15 +1411,18 @@ def min(
14081411
if maybe_use_numba(engine):
14091412
if self.method == "table":
14101413
func = generate_manual_numpy_nan_agg_with_axis(np.nanmin)
1414+
return self.apply(
1415+
func,
1416+
raw=True,
1417+
engine=engine,
1418+
engine_kwargs=engine_kwargs,
1419+
)
14111420
else:
1412-
func = np.nanmin
1421+
from pandas.core._numba.kernels import sliding_min_max
14131422

1414-
return self.apply(
1415-
func,
1416-
raw=True,
1417-
engine=engine,
1418-
engine_kwargs=engine_kwargs,
1419-
)
1423+
return self._numba_apply(
1424+
sliding_min_max, "rolling_min", engine_kwargs, False
1425+
)
14201426
window_func = window_aggregations.roll_min
14211427
return self._apply(window_func, name="min", **kwargs)
14221428

pandas/tests/window/test_numba.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def test_numba_vs_cython_rolling_methods(
6161
expected = getattr(roll, method)(engine="cython", **kwargs)
6262

6363
# Check the cache
64-
if method not in ("mean", "sum", "var", "std"):
64+
if method not in ("mean", "sum", "var", "std", "max", "min"):
6565
assert (
6666
getattr(np, f"nan{method}"),
6767
"Rolling_apply_single",
@@ -88,7 +88,7 @@ def test_numba_vs_cython_expanding_methods(
8888
expected = getattr(expand, method)(engine="cython", **kwargs)
8989

9090
# Check the cache
91-
if method not in ("mean", "sum", "var", "std"):
91+
if method not in ("mean", "sum", "var", "std", "max", "min"):
9292
assert (
9393
getattr(np, f"nan{method}"),
9494
"Expanding_apply_single",
@@ -150,15 +150,16 @@ def test_dont_cache_args(
150150
def add(values, x):
151151
return np.sum(values) + x
152152

153+
engine_kwargs = {"nopython": nopython, "nogil": nogil, "parallel": parallel}
153154
df = DataFrame({"value": [0, 0, 0]})
154-
result = getattr(df, window)(**window_kwargs).apply(
155-
add, raw=True, engine="numba", args=(1,)
155+
result = getattr(df, window)(method=method, **window_kwargs).apply(
156+
add, raw=True, engine="numba", engine_kwargs=engine_kwargs, args=(1,)
156157
)
157158
expected = DataFrame({"value": [1.0, 1.0, 1.0]})
158159
tm.assert_frame_equal(result, expected)
159160

160-
result = getattr(df, window)(**window_kwargs).apply(
161-
add, raw=True, engine="numba", args=(2,)
161+
result = getattr(df, window)(method=method, **window_kwargs).apply(
162+
add, raw=True, engine="numba", engine_kwargs=engine_kwargs, args=(2,)
162163
)
163164
expected = DataFrame({"value": [2.0, 2.0, 2.0]})
164165
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)