Skip to content

Commit 3db38fb

Browse files
authored
CLN: Use more pytest idioms in test_momemts_ewm.py (#36801)
1 parent dc6bdab commit 3db38fb

File tree

1 file changed

+141
-53
lines changed

1 file changed

+141
-53
lines changed

pandas/tests/window/moments/test_moments_ewm.py

+141-53
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,19 @@
77
import pandas._testing as tm
88

99

10-
def check_ew(name=None, preserve_nan=False, series=None, frame=None, nan_locs=None):
10+
@pytest.mark.parametrize("name", ["var", "vol", "mean"])
11+
def test_ewma_series(series, name):
1112
series_result = getattr(series.ewm(com=10), name)()
1213
assert isinstance(series_result, Series)
1314

14-
frame_result = getattr(frame.ewm(com=10), name)()
15-
assert type(frame_result) == DataFrame
16-
17-
result = getattr(series.ewm(com=10), name)()
18-
if preserve_nan:
19-
assert result[nan_locs].isna().all()
2015

16+
@pytest.mark.parametrize("name", ["var", "vol", "mean"])
17+
def test_ewma_frame(frame, name):
18+
frame_result = getattr(frame.ewm(com=10), name)()
19+
assert isinstance(frame_result, DataFrame)
2120

22-
def test_ewma(series, frame, nan_locs):
23-
check_ew(name="mean", frame=frame, series=series, nan_locs=nan_locs)
2421

22+
def test_ewma_adjust():
2523
vals = pd.Series(np.zeros(1000))
2624
vals[5] = 1
2725
result = vals.ewm(span=100, adjust=False).mean().sum()
@@ -53,63 +51,153 @@ def test_ewma_nan_handling():
5351
result = s.ewm(com=5).mean()
5452
tm.assert_series_equal(result, Series([np.nan] * 2 + [1.0] * 4))
5553

56-
# GH 7603
57-
s0 = Series([np.nan, 1.0, 101.0])
58-
s1 = Series([1.0, np.nan, 101.0])
59-
s2 = Series([np.nan, 1.0, np.nan, np.nan, 101.0, np.nan])
60-
s3 = Series([1.0, np.nan, 101.0, 50.0])
61-
com = 2.0
62-
alpha = 1.0 / (1.0 + com)
63-
64-
def simple_wma(s, w):
65-
return (s.multiply(w).cumsum() / w.cumsum()).fillna(method="ffill")
66-
67-
for (s, adjust, ignore_na, w) in [
68-
(s0, True, False, [np.nan, (1.0 - alpha), 1.0]),
69-
(s0, True, True, [np.nan, (1.0 - alpha), 1.0]),
70-
(s0, False, False, [np.nan, (1.0 - alpha), alpha]),
71-
(s0, False, True, [np.nan, (1.0 - alpha), alpha]),
72-
(s1, True, False, [(1.0 - alpha) ** 2, np.nan, 1.0]),
73-
(s1, True, True, [(1.0 - alpha), np.nan, 1.0]),
74-
(s1, False, False, [(1.0 - alpha) ** 2, np.nan, alpha]),
75-
(s1, False, True, [(1.0 - alpha), np.nan, alpha]),
76-
(s2, True, False, [np.nan, (1.0 - alpha) ** 3, np.nan, np.nan, 1.0, np.nan]),
77-
(s2, True, True, [np.nan, (1.0 - alpha), np.nan, np.nan, 1.0, np.nan]),
54+
55+
@pytest.mark.parametrize(
56+
"s, adjust, ignore_na, w",
57+
[
58+
(
59+
Series([np.nan, 1.0, 101.0]),
60+
True,
61+
False,
62+
[np.nan, (1.0 - (1.0 / (1.0 + 2.0))), 1.0],
63+
),
64+
(
65+
Series([np.nan, 1.0, 101.0]),
66+
True,
67+
True,
68+
[np.nan, (1.0 - (1.0 / (1.0 + 2.0))), 1.0],
69+
),
70+
(
71+
Series([np.nan, 1.0, 101.0]),
72+
False,
73+
False,
74+
[np.nan, (1.0 - (1.0 / (1.0 + 2.0))), (1.0 / (1.0 + 2.0))],
75+
),
76+
(
77+
Series([np.nan, 1.0, 101.0]),
78+
False,
79+
True,
80+
[np.nan, (1.0 - (1.0 / (1.0 + 2.0))), (1.0 / (1.0 + 2.0))],
81+
),
82+
(
83+
Series([1.0, np.nan, 101.0]),
84+
True,
85+
False,
86+
[(1.0 - (1.0 / (1.0 + 2.0))) ** 2, np.nan, 1.0],
87+
),
7888
(
79-
s2,
89+
Series([1.0, np.nan, 101.0]),
90+
True,
91+
True,
92+
[(1.0 - (1.0 / (1.0 + 2.0))), np.nan, 1.0],
93+
),
94+
(
95+
Series([1.0, np.nan, 101.0]),
96+
False,
8097
False,
98+
[(1.0 - (1.0 / (1.0 + 2.0))) ** 2, np.nan, (1.0 / (1.0 + 2.0))],
99+
),
100+
(
101+
Series([1.0, np.nan, 101.0]),
102+
False,
103+
True,
104+
[(1.0 - (1.0 / (1.0 + 2.0))), np.nan, (1.0 / (1.0 + 2.0))],
105+
),
106+
(
107+
Series([np.nan, 1.0, np.nan, np.nan, 101.0, np.nan]),
108+
True,
81109
False,
82-
[np.nan, (1.0 - alpha) ** 3, np.nan, np.nan, alpha, np.nan],
110+
[np.nan, (1.0 - (1.0 / (1.0 + 2.0))) ** 3, np.nan, np.nan, 1.0, np.nan],
111+
),
112+
(
113+
Series([np.nan, 1.0, np.nan, np.nan, 101.0, np.nan]),
114+
True,
115+
True,
116+
[np.nan, (1.0 - (1.0 / (1.0 + 2.0))), np.nan, np.nan, 1.0, np.nan],
83117
),
84-
(s2, False, True, [np.nan, (1.0 - alpha), np.nan, np.nan, alpha, np.nan]),
85-
(s3, True, False, [(1.0 - alpha) ** 3, np.nan, (1.0 - alpha), 1.0]),
86-
(s3, True, True, [(1.0 - alpha) ** 2, np.nan, (1.0 - alpha), 1.0]),
87118
(
88-
s3,
119+
Series([np.nan, 1.0, np.nan, np.nan, 101.0, np.nan]),
89120
False,
90121
False,
91122
[
92-
(1.0 - alpha) ** 3,
93123
np.nan,
94-
(1.0 - alpha) * alpha,
95-
alpha * ((1.0 - alpha) ** 2 + alpha),
124+
(1.0 - (1.0 / (1.0 + 2.0))) ** 3,
125+
np.nan,
126+
np.nan,
127+
(1.0 / (1.0 + 2.0)),
128+
np.nan,
96129
],
97130
),
98-
(s3, False, True, [(1.0 - alpha) ** 2, np.nan, (1.0 - alpha) * alpha, alpha]),
99-
]:
100-
expected = simple_wma(s, Series(w))
101-
result = s.ewm(com=com, adjust=adjust, ignore_na=ignore_na).mean()
131+
(
132+
Series([np.nan, 1.0, np.nan, np.nan, 101.0, np.nan]),
133+
False,
134+
True,
135+
[
136+
np.nan,
137+
(1.0 - (1.0 / (1.0 + 2.0))),
138+
np.nan,
139+
np.nan,
140+
(1.0 / (1.0 + 2.0)),
141+
np.nan,
142+
],
143+
),
144+
(
145+
Series([1.0, np.nan, 101.0, 50.0]),
146+
True,
147+
False,
148+
[
149+
(1.0 - (1.0 / (1.0 + 2.0))) ** 3,
150+
np.nan,
151+
(1.0 - (1.0 / (1.0 + 2.0))),
152+
1.0,
153+
],
154+
),
155+
(
156+
Series([1.0, np.nan, 101.0, 50.0]),
157+
True,
158+
True,
159+
[
160+
(1.0 - (1.0 / (1.0 + 2.0))) ** 2,
161+
np.nan,
162+
(1.0 - (1.0 / (1.0 + 2.0))),
163+
1.0,
164+
],
165+
),
166+
(
167+
Series([1.0, np.nan, 101.0, 50.0]),
168+
False,
169+
False,
170+
[
171+
(1.0 - (1.0 / (1.0 + 2.0))) ** 3,
172+
np.nan,
173+
(1.0 - (1.0 / (1.0 + 2.0))) * (1.0 / (1.0 + 2.0)),
174+
(1.0 / (1.0 + 2.0))
175+
* ((1.0 - (1.0 / (1.0 + 2.0))) ** 2 + (1.0 / (1.0 + 2.0))),
176+
],
177+
),
178+
(
179+
Series([1.0, np.nan, 101.0, 50.0]),
180+
False,
181+
True,
182+
[
183+
(1.0 - (1.0 / (1.0 + 2.0))) ** 2,
184+
np.nan,
185+
(1.0 - (1.0 / (1.0 + 2.0))) * (1.0 / (1.0 + 2.0)),
186+
(1.0 / (1.0 + 2.0)),
187+
],
188+
),
189+
],
190+
)
191+
def test_ewma_nan_handling_cases(s, adjust, ignore_na, w):
192+
# GH 7603
193+
expected = (s.multiply(w).cumsum() / Series(w).cumsum()).fillna(method="ffill")
194+
result = s.ewm(com=2.0, adjust=adjust, ignore_na=ignore_na).mean()
102195

196+
tm.assert_series_equal(result, expected)
197+
if ignore_na is False:
198+
# check that ignore_na defaults to False
199+
result = s.ewm(com=2.0, adjust=adjust).mean()
103200
tm.assert_series_equal(result, expected)
104-
if ignore_na is False:
105-
# check that ignore_na defaults to False
106-
result = s.ewm(com=com, adjust=adjust).mean()
107-
tm.assert_series_equal(result, expected)
108-
109-
110-
@pytest.mark.parametrize("name", ["var", "vol"])
111-
def test_ewmvar_ewmvol(series, frame, nan_locs, name):
112-
check_ew(name=name, frame=frame, series=series, nan_locs=nan_locs)
113201

114202

115203
def test_ewma_span_com_args(series):

0 commit comments

Comments
 (0)