forked from pandas-dev/pandas
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest_online.py
112 lines (96 loc) · 3.56 KB
/
test_online.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import numpy as np
import pytest
from pandas.compat import is_platform_arm
from pandas import (
DataFrame,
Series,
)
import pandas._testing as tm
from pandas.util.version import Version
pytestmark = [pytest.mark.single_cpu]
numba = pytest.importorskip("numba")
pytestmark.append(
pytest.mark.skipif(
Version(numba.__version__) == Version("0.61") and is_platform_arm(),
reason=f"Segfaults on ARM platforms with numba {numba.__version__}",
)
)
@pytest.mark.filterwarnings("ignore")
# Filter warnings when parallel=True and the function can't be parallelized by Numba
class TestEWM:
def test_invalid_update(self):
df = DataFrame({"a": range(5), "b": range(5)})
online_ewm = df.head(2).ewm(0.5).online()
with pytest.raises(
ValueError,
match="Must call mean with update=None first before passing update",
):
online_ewm.mean(update=df.head(1))
@pytest.mark.slow
@pytest.mark.parametrize(
"obj", [DataFrame({"a": range(5), "b": range(5)}), Series(range(5), name="foo")]
)
def test_online_vs_non_online_mean(
self, obj, nogil, parallel, nopython, adjust, ignore_na
):
expected = obj.ewm(0.5, adjust=adjust, ignore_na=ignore_na).mean()
engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython}
online_ewm = (
obj.head(2)
.ewm(0.5, adjust=adjust, ignore_na=ignore_na)
.online(engine_kwargs=engine_kwargs)
)
# Test resetting once
for _ in range(2):
result = online_ewm.mean()
tm.assert_equal(result, expected.head(2))
result = online_ewm.mean(update=obj.tail(3))
tm.assert_equal(result, expected.tail(3))
online_ewm.reset()
@pytest.mark.xfail(raises=NotImplementedError)
@pytest.mark.parametrize(
"obj", [DataFrame({"a": range(5), "b": range(5)}), Series(range(5), name="foo")]
)
def test_update_times_mean(
self, obj, nogil, parallel, nopython, adjust, ignore_na, halflife_with_times
):
times = Series(
np.array(
["2020-01-01", "2020-01-05", "2020-01-07", "2020-01-17", "2020-01-21"],
dtype="datetime64[ns]",
)
)
expected = obj.ewm(
0.5,
adjust=adjust,
ignore_na=ignore_na,
times=times,
halflife=halflife_with_times,
).mean()
engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython}
online_ewm = (
obj.head(2)
.ewm(
0.5,
adjust=adjust,
ignore_na=ignore_na,
times=times.head(2),
halflife=halflife_with_times,
)
.online(engine_kwargs=engine_kwargs)
)
# Test resetting once
for _ in range(2):
result = online_ewm.mean()
tm.assert_equal(result, expected.head(2))
result = online_ewm.mean(update=obj.tail(3), update_times=times.tail(3))
tm.assert_equal(result, expected.tail(3))
online_ewm.reset()
@pytest.mark.parametrize("method", ["aggregate", "std", "corr", "cov", "var"])
def test_ewm_notimplementederror_raises(self, method):
ser = Series(range(10))
kwargs = {}
if method == "aggregate":
kwargs["func"] = lambda x: x
with pytest.raises(NotImplementedError, match=".* is not implemented."):
getattr(ser.ewm(1).online(), method)(**kwargs)