Skip to content

Commit 66865fc

Browse files
authored
BENCH: Add more numba rolling benchmarks (#44283)
1 parent f97915f commit 66865fc

File tree

1 file changed

+44
-15
lines changed

1 file changed

+44
-15
lines changed

asv_bench/benchmarks/rolling.py

+44-15
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import warnings
2+
13
import numpy as np
24

35
import pandas as pd
@@ -44,29 +46,56 @@ def time_rolling(self, constructor, window, dtype, function, raw):
4446
self.roll.apply(function, raw=raw)
4547

4648

47-
class Engine:
49+
class NumbaEngine:
4850
params = (
4951
["DataFrame", "Series"],
5052
["int", "float"],
5153
[np.sum, lambda x: np.sum(x) + 5],
52-
["cython", "numba"],
5354
["sum", "max", "min", "median", "mean"],
55+
[True, False],
56+
[None, 100],
5457
)
55-
param_names = ["constructor", "dtype", "function", "engine", "method"]
58+
param_names = ["constructor", "dtype", "function", "method", "parallel", "cols"]
5659

57-
def setup(self, constructor, dtype, function, engine, method):
60+
def setup(self, constructor, dtype, function, method, parallel, cols):
5861
N = 10 ** 3
59-
arr = (100 * np.random.random(N)).astype(dtype)
60-
self.data = getattr(pd, constructor)(arr)
61-
62-
def time_rolling_apply(self, constructor, dtype, function, engine, method):
63-
self.data.rolling(10).apply(function, raw=True, engine=engine)
64-
65-
def time_expanding_apply(self, constructor, dtype, function, engine, method):
66-
self.data.expanding().apply(function, raw=True, engine=engine)
67-
68-
def time_rolling_methods(self, constructor, dtype, function, engine, method):
69-
getattr(self.data.rolling(10), method)(engine=engine)
62+
shape = (N, cols) if cols is not None and constructor != "Series" else N
63+
arr = (100 * np.random.random(shape)).astype(dtype)
64+
data = getattr(pd, constructor)(arr)
65+
66+
# Warm the cache
67+
with warnings.catch_warnings(record=True):
68+
# Catch parallel=True not being applicable e.g. 1D data
69+
self.roll = data.rolling(10)
70+
self.roll.apply(
71+
function, raw=True, engine="numba", engine_kwargs={"parallel": parallel}
72+
)
73+
getattr(self.roll, method)(
74+
engine="numba", engine_kwargs={"parallel": parallel}
75+
)
76+
77+
self.expand = data.expanding()
78+
self.expand.apply(
79+
function, raw=True, engine="numba", engine_kwargs={"parallel": parallel}
80+
)
81+
82+
def time_rolling_apply(self, constructor, dtype, function, method, parallel, col):
83+
with warnings.catch_warnings(record=True):
84+
self.roll.apply(
85+
function, raw=True, engine="numba", engine_kwargs={"parallel": parallel}
86+
)
87+
88+
def time_expanding_apply(self, constructor, dtype, function, method, parallel, col):
89+
with warnings.catch_warnings(record=True):
90+
self.expand.apply(
91+
function, raw=True, engine="numba", engine_kwargs={"parallel": parallel}
92+
)
93+
94+
def time_rolling_methods(self, constructor, dtype, function, method, parallel, col):
95+
with warnings.catch_warnings(record=True):
96+
getattr(self.roll, method)(
97+
engine="numba", engine_kwargs={"parallel": parallel}
98+
)
7099

71100

72101
class ExpandingMethods:

0 commit comments

Comments
 (0)