|
| 1 | +import warnings |
| 2 | + |
1 | 3 | import numpy as np
|
2 | 4 |
|
3 | 5 | import pandas as pd
|
@@ -44,29 +46,56 @@ def time_rolling(self, constructor, window, dtype, function, raw):
|
44 | 46 | self.roll.apply(function, raw=raw)
|
45 | 47 |
|
46 | 48 |
|
47 |
| -class Engine: |
| 49 | +class NumbaEngine: |
48 | 50 | params = (
|
49 | 51 | ["DataFrame", "Series"],
|
50 | 52 | ["int", "float"],
|
51 | 53 | [np.sum, lambda x: np.sum(x) + 5],
|
52 |
| - ["cython", "numba"], |
53 | 54 | ["sum", "max", "min", "median", "mean"],
|
| 55 | + [True, False], |
| 56 | + [None, 100], |
54 | 57 | )
|
55 |
| - param_names = ["constructor", "dtype", "function", "engine", "method"] |
| 58 | + param_names = ["constructor", "dtype", "function", "method", "parallel", "cols"] |
56 | 59 |
|
57 |
| - def setup(self, constructor, dtype, function, engine, method): |
| 60 | + def setup(self, constructor, dtype, function, method, parallel, cols): |
58 | 61 | N = 10 ** 3
|
59 |
| - arr = (100 * np.random.random(N)).astype(dtype) |
60 |
| - self.data = getattr(pd, constructor)(arr) |
61 |
| - |
62 |
| - def time_rolling_apply(self, constructor, dtype, function, engine, method): |
63 |
| - self.data.rolling(10).apply(function, raw=True, engine=engine) |
64 |
| - |
65 |
| - def time_expanding_apply(self, constructor, dtype, function, engine, method): |
66 |
| - self.data.expanding().apply(function, raw=True, engine=engine) |
67 |
| - |
68 |
| - def time_rolling_methods(self, constructor, dtype, function, engine, method): |
69 |
| - getattr(self.data.rolling(10), method)(engine=engine) |
| 62 | + shape = (N, cols) if cols is not None and constructor != "Series" else N |
| 63 | + arr = (100 * np.random.random(shape)).astype(dtype) |
| 64 | + data = getattr(pd, constructor)(arr) |
| 65 | + |
| 66 | + # Warm the cache |
| 67 | + with warnings.catch_warnings(record=True): |
| 68 | + # Catch parallel=True not being applicable e.g. 1D data |
| 69 | + self.roll = data.rolling(10) |
| 70 | + self.roll.apply( |
| 71 | + function, raw=True, engine="numba", engine_kwargs={"parallel": parallel} |
| 72 | + ) |
| 73 | + getattr(self.roll, method)( |
| 74 | + engine="numba", engine_kwargs={"parallel": parallel} |
| 75 | + ) |
| 76 | + |
| 77 | + self.expand = data.expanding() |
| 78 | + self.expand.apply( |
| 79 | + function, raw=True, engine="numba", engine_kwargs={"parallel": parallel} |
| 80 | + ) |
| 81 | + |
| 82 | + def time_rolling_apply(self, constructor, dtype, function, method, parallel, col): |
| 83 | + with warnings.catch_warnings(record=True): |
| 84 | + self.roll.apply( |
| 85 | + function, raw=True, engine="numba", engine_kwargs={"parallel": parallel} |
| 86 | + ) |
| 87 | + |
| 88 | + def time_expanding_apply(self, constructor, dtype, function, method, parallel, col): |
| 89 | + with warnings.catch_warnings(record=True): |
| 90 | + self.expand.apply( |
| 91 | + function, raw=True, engine="numba", engine_kwargs={"parallel": parallel} |
| 92 | + ) |
| 93 | + |
| 94 | + def time_rolling_methods(self, constructor, dtype, function, method, parallel, col): |
| 95 | + with warnings.catch_warnings(record=True): |
| 96 | + getattr(self.roll, method)( |
| 97 | + engine="numba", engine_kwargs={"parallel": parallel} |
| 98 | + ) |
70 | 99 |
|
71 | 100 |
|
72 | 101 | class ExpandingMethods:
|
|
0 commit comments