Skip to content

Commit ba742da

Browse files
charlesdong1991jreback
authored andcommitted
CLN: refactor tests in test_moments_ewm.py (#30570)
1 parent c38a416 commit ba742da

File tree

3 files changed

+95
-62
lines changed

3 files changed

+95
-62
lines changed

pandas/tests/window/common.py

+31
Original file line numberDiff line numberDiff line change
@@ -353,3 +353,34 @@ def get_result(obj, obj2=None):
353353
result.index = result.index.droplevel(1)
354354
expected = get_result(self.frame[1], self.frame[5])
355355
tm.assert_series_equal(result, expected, check_names=False)
356+
357+
358+
def ew_func(A, B, com, name, **kwargs):
359+
return getattr(A.ewm(com, **kwargs), name)(B)
360+
361+
362+
def check_binary_ew(name, A, B):
363+
364+
result = ew_func(A=A, B=B, com=20, name=name, min_periods=5)
365+
assert np.isnan(result.values[:14]).all()
366+
assert not np.isnan(result.values[14:]).any()
367+
368+
369+
def check_binary_ew_min_periods(name, min_periods, A, B):
370+
# GH 7898
371+
result = ew_func(A, B, 20, name=name, min_periods=min_periods)
372+
# binary functions (ewmcov, ewmcorr) with bias=False require at
373+
# least two values
374+
assert np.isnan(result.values[:11]).all()
375+
assert not np.isnan(result.values[11:]).any()
376+
377+
# check series of length 0
378+
empty = Series([], dtype=np.float64)
379+
result = ew_func(empty, empty, 50, name=name, min_periods=min_periods)
380+
tm.assert_series_equal(result, empty)
381+
382+
# check series of length 1
383+
result = ew_func(
384+
Series([1.0]), Series([1.0]), 50, name=name, min_periods=min_periods
385+
)
386+
tm.assert_series_equal(result, Series([np.NaN]))
+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import numpy as np
2+
from numpy.random import randn
3+
import pytest
4+
5+
from pandas import Series
6+
7+
8+
@pytest.fixture
9+
def binary_ew_data():
10+
A = Series(randn(50), index=np.arange(50))
11+
B = A[2:] + randn(48)
12+
13+
A[:10] = np.NaN
14+
B[-10:] = np.NaN
15+
return A, B
16+
17+
18+
@pytest.fixture(params=[0, 1, 2])
19+
def min_periods(request):
20+
return request.param

pandas/tests/window/moments/test_moments_ewm.py

+44-62
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,13 @@
44

55
import pandas as pd
66
from pandas import DataFrame, Series, concat
7-
from pandas.tests.window.common import Base, ConsistencyBase
7+
from pandas.tests.window.common import (
8+
Base,
9+
ConsistencyBase,
10+
check_binary_ew,
11+
check_binary_ew_min_periods,
12+
ew_func,
13+
)
814
import pandas.util.testing as tm
915

1016

@@ -216,6 +222,9 @@ def _check_ew(self, name=None, preserve_nan=False):
216222
if preserve_nan:
217223
assert result[self._nan_locs].isna().all()
218224

225+
@pytest.mark.parametrize("min_periods", [0, 1])
226+
@pytest.mark.parametrize("name", ["mean", "var", "vol"])
227+
def test_ew_min_periods(self, min_periods, name):
219228
# excluding NaNs correctly
220229
arr = randn(50)
221230
arr[:10] = np.NaN
@@ -228,31 +237,30 @@ def _check_ew(self, name=None, preserve_nan=False):
228237
assert result[:11].isna().all()
229238
assert not result[11:].isna().any()
230239

231-
for min_periods in (0, 1):
232-
result = getattr(s.ewm(com=50, min_periods=min_periods), name)()
233-
if name == "mean":
234-
assert result[:10].isna().all()
235-
assert not result[10:].isna().any()
236-
else:
237-
# ewm.std, ewm.vol, ewm.var (with bias=False) require at least
238-
# two values
239-
assert result[:11].isna().all()
240-
assert not result[11:].isna().any()
241-
242-
# check series of length 0
243-
result = getattr(
244-
Series(dtype=object).ewm(com=50, min_periods=min_periods), name
245-
)()
246-
tm.assert_series_equal(result, Series(dtype="float64"))
247-
248-
# check series of length 1
249-
result = getattr(Series([1.0]).ewm(50, min_periods=min_periods), name)()
250-
if name == "mean":
251-
tm.assert_series_equal(result, Series([1.0]))
252-
else:
253-
# ewm.std, ewm.vol, ewm.var with bias=False require at least
254-
# two values
255-
tm.assert_series_equal(result, Series([np.NaN]))
240+
result = getattr(s.ewm(com=50, min_periods=min_periods), name)()
241+
if name == "mean":
242+
assert result[:10].isna().all()
243+
assert not result[10:].isna().any()
244+
else:
245+
# ewm.std, ewm.vol, ewm.var (with bias=False) require at least
246+
# two values
247+
assert result[:11].isna().all()
248+
assert not result[11:].isna().any()
249+
250+
# check series of length 0
251+
result = getattr(
252+
Series(dtype=object).ewm(com=50, min_periods=min_periods), name
253+
)()
254+
tm.assert_series_equal(result, Series(dtype="float64"))
255+
256+
# check series of length 1
257+
result = getattr(Series([1.0]).ewm(50, min_periods=min_periods), name)()
258+
if name == "mean":
259+
tm.assert_series_equal(result, Series([1.0]))
260+
else:
261+
# ewm.std, ewm.vol, ewm.var with bias=False require at least
262+
# two values
263+
tm.assert_series_equal(result, Series([np.NaN]))
256264

257265
# pass in ints
258266
result2 = getattr(Series(np.arange(50)).ewm(span=10), name)()
@@ -263,53 +271,27 @@ class TestEwmMomentsConsistency(ConsistencyBase):
263271
def setup_method(self, method):
264272
self._create_data()
265273

266-
def test_ewmcov(self):
267-
self._check_binary_ew("cov")
268-
269274
def test_ewmcov_pairwise(self):
270275
self._check_pairwise_moment("ewm", "cov", span=10, min_periods=5)
271276

272-
def test_ewmcorr(self):
273-
self._check_binary_ew("corr")
277+
@pytest.mark.parametrize("name", ["cov", "corr"])
278+
def test_ewm_corr_cov(self, name, min_periods, binary_ew_data):
279+
A, B = binary_ew_data
280+
281+
check_binary_ew(name="corr", A=A, B=B)
282+
check_binary_ew_min_periods("corr", min_periods, A, B)
274283

275284
def test_ewmcorr_pairwise(self):
276285
self._check_pairwise_moment("ewm", "corr", span=10, min_periods=5)
277286

278-
def _check_binary_ew(self, name):
279-
def func(A, B, com, **kwargs):
280-
return getattr(A.ewm(com, **kwargs), name)(B)
281-
282-
A = Series(randn(50), index=np.arange(50))
283-
B = A[2:] + randn(48)
284-
285-
A[:10] = np.NaN
286-
B[-10:] = np.NaN
287-
288-
result = func(A, B, 20, min_periods=5)
289-
assert np.isnan(result.values[:14]).all()
290-
assert not np.isnan(result.values[14:]).any()
291-
292-
# GH 7898
293-
for min_periods in (0, 1, 2):
294-
result = func(A, B, 20, min_periods=min_periods)
295-
# binary functions (ewmcov, ewmcorr) with bias=False require at
296-
# least two values
297-
assert np.isnan(result.values[:11]).all()
298-
assert not np.isnan(result.values[11:]).any()
299-
300-
# check series of length 0
301-
empty = Series([], dtype=np.float64)
302-
result = func(empty, empty, 50, min_periods=min_periods)
303-
tm.assert_series_equal(result, empty)
304-
305-
# check series of length 1
306-
result = func(Series([1.0]), Series([1.0]), 50, min_periods=min_periods)
307-
tm.assert_series_equal(result, Series([np.NaN]))
287+
@pytest.mark.parametrize("name", ["cov", "corr"])
288+
def test_different_input_array_raise_exception(self, name, binary_ew_data):
308289

290+
A, _ = binary_ew_data
309291
msg = "Input arrays must be of the same type!"
310292
# exception raised is Exception
311293
with pytest.raises(Exception, match=msg):
312-
func(A, randn(50), 20, min_periods=5)
294+
ew_func(A, randn(50), 20, name=name, min_periods=5)
313295

314296
@pytest.mark.slow
315297
@pytest.mark.parametrize("min_periods", [0, 1, 2, 3, 4])

0 commit comments

Comments
 (0)