diff --git a/asv_bench/benchmarks/rolling.py b/asv_bench/benchmarks/rolling.py index 3d2273b6d7324..406b27dd37ea5 100644 --- a/asv_bench/benchmarks/rolling.py +++ b/asv_bench/benchmarks/rolling.py @@ -1,3 +1,5 @@ +import warnings + import numpy as np import pandas as pd @@ -44,29 +46,56 @@ def time_rolling(self, constructor, window, dtype, function, raw): self.roll.apply(function, raw=raw) -class Engine: +class NumbaEngine: params = ( ["DataFrame", "Series"], ["int", "float"], [np.sum, lambda x: np.sum(x) + 5], - ["cython", "numba"], ["sum", "max", "min", "median", "mean"], + [True, False], + [None, 100], ) - param_names = ["constructor", "dtype", "function", "engine", "method"] + param_names = ["constructor", "dtype", "function", "method", "parallel", "cols"] - def setup(self, constructor, dtype, function, engine, method): + def setup(self, constructor, dtype, function, method, parallel, cols): N = 10 ** 3 - arr = (100 * np.random.random(N)).astype(dtype) - self.data = getattr(pd, constructor)(arr) - - def time_rolling_apply(self, constructor, dtype, function, engine, method): - self.data.rolling(10).apply(function, raw=True, engine=engine) - - def time_expanding_apply(self, constructor, dtype, function, engine, method): - self.data.expanding().apply(function, raw=True, engine=engine) - - def time_rolling_methods(self, constructor, dtype, function, engine, method): - getattr(self.data.rolling(10), method)(engine=engine) + shape = (N, cols) if cols is not None and constructor != "Series" else N + arr = (100 * np.random.random(shape)).astype(dtype) + data = getattr(pd, constructor)(arr) + + # Warm the cache + with warnings.catch_warnings(record=True): + # Catch parallel=True not being applicable e.g. 1D data + self.roll = data.rolling(10) + self.roll.apply( + function, raw=True, engine="numba", engine_kwargs={"parallel": parallel} + ) + getattr(self.roll, method)( + engine="numba", engine_kwargs={"parallel": parallel} + ) + + self.expand = data.expanding() + self.expand.apply( + function, raw=True, engine="numba", engine_kwargs={"parallel": parallel} + ) + + def time_rolling_apply(self, constructor, dtype, function, method, parallel, col): + with warnings.catch_warnings(record=True): + self.roll.apply( + function, raw=True, engine="numba", engine_kwargs={"parallel": parallel} + ) + + def time_expanding_apply(self, constructor, dtype, function, method, parallel, col): + with warnings.catch_warnings(record=True): + self.expand.apply( + function, raw=True, engine="numba", engine_kwargs={"parallel": parallel} + ) + + def time_rolling_methods(self, constructor, dtype, function, method, parallel, col): + with warnings.catch_warnings(record=True): + getattr(self.roll, method)( + engine="numba", engine_kwargs={"parallel": parallel} + ) class ExpandingMethods: