Skip to content

Commit ace2bdd

Browse files
authored
CLN: test_moments_consistency_*.py (#36810)
1 parent dc1bf59 commit ace2bdd

File tree

3 files changed

+40
-54
lines changed

3 files changed

+40
-54
lines changed

pandas/tests/window/common.py

-43
Original file line numberDiff line numberDiff line change
@@ -4,49 +4,6 @@
44
import pandas._testing as tm
55

66

7-
def check_pairwise_moment(frame, dispatch, name, **kwargs):
8-
def get_result(obj, obj2=None):
9-
return getattr(getattr(obj, dispatch)(**kwargs), name)(obj2)
10-
11-
result = get_result(frame)
12-
result = result.loc[(slice(None), 1), 5]
13-
result.index = result.index.droplevel(1)
14-
expected = get_result(frame[1], frame[5])
15-
expected.index = expected.index._with_freq(None)
16-
tm.assert_series_equal(result, expected, check_names=False)
17-
18-
19-
def ew_func(A, B, com, name, **kwargs):
20-
return getattr(A.ewm(com, **kwargs), name)(B)
21-
22-
23-
def check_binary_ew(name, A, B):
24-
25-
result = ew_func(A=A, B=B, com=20, name=name, min_periods=5)
26-
assert np.isnan(result.values[:14]).all()
27-
assert not np.isnan(result.values[14:]).any()
28-
29-
30-
def check_binary_ew_min_periods(name, min_periods, A, B):
31-
# GH 7898
32-
result = ew_func(A, B, 20, name=name, min_periods=min_periods)
33-
# binary functions (ewmcov, ewmcorr) with bias=False require at
34-
# least two values
35-
assert np.isnan(result.values[:11]).all()
36-
assert not np.isnan(result.values[11:]).any()
37-
38-
# check series of length 0
39-
empty = Series([], dtype=np.float64)
40-
result = ew_func(empty, empty, 50, name=name, min_periods=min_periods)
41-
tm.assert_series_equal(result, empty)
42-
43-
# check series of length 1
44-
result = ew_func(
45-
Series([1.0]), Series([1.0]), 50, name=name, min_periods=min_periods
46-
)
47-
tm.assert_series_equal(result, Series([np.NaN]))
48-
49-
507
def moments_consistency_mock_mean(x, mean, mock_mean):
518
mean_x = mean(x)
529
# check that correlation of a series with itself is either 1 or NaN

pandas/tests/window/moments/test_moments_consistency_ewm.py

+34-9
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,8 @@
33
import pytest
44

55
from pandas import DataFrame, Series, concat
6+
import pandas._testing as tm
67
from pandas.tests.window.common import (
7-
check_binary_ew,
8-
check_binary_ew_min_periods,
9-
check_pairwise_moment,
10-
ew_func,
118
moments_consistency_cov_data,
129
moments_consistency_is_constant,
1310
moments_consistency_mock_mean,
@@ -20,15 +17,43 @@
2017

2118
@pytest.mark.parametrize("func", ["cov", "corr"])
2219
def test_ewm_pairwise_cov_corr(func, frame):
23-
check_pairwise_moment(frame, "ewm", func, span=10, min_periods=5)
20+
result = getattr(frame.ewm(span=10, min_periods=5), func)()
21+
result = result.loc[(slice(None), 1), 5]
22+
result.index = result.index.droplevel(1)
23+
expected = getattr(frame[1].ewm(span=10, min_periods=5), func)(frame[5])
24+
expected.index = expected.index._with_freq(None)
25+
tm.assert_series_equal(result, expected, check_names=False)
2426

2527

2628
@pytest.mark.parametrize("name", ["cov", "corr"])
27-
def test_ewm_corr_cov(name, min_periods, binary_ew_data):
29+
def test_ewm_corr_cov(name, binary_ew_data):
2830
A, B = binary_ew_data
2931

30-
check_binary_ew(name="corr", A=A, B=B)
31-
check_binary_ew_min_periods("corr", min_periods, A, B)
32+
result = getattr(A.ewm(com=20, min_periods=5), name)(B)
33+
assert np.isnan(result.values[:14]).all()
34+
assert not np.isnan(result.values[14:]).any()
35+
36+
37+
@pytest.mark.parametrize("name", ["cov", "corr"])
38+
def test_ewm_corr_cov_min_periods(name, min_periods, binary_ew_data):
39+
# GH 7898
40+
A, B = binary_ew_data
41+
result = getattr(A.ewm(com=20, min_periods=min_periods), name)(B)
42+
# binary functions (ewmcov, ewmcorr) with bias=False require at
43+
# least two values
44+
assert np.isnan(result.values[:11]).all()
45+
assert not np.isnan(result.values[11:]).any()
46+
47+
# check series of length 0
48+
empty = Series([], dtype=np.float64)
49+
result = getattr(empty.ewm(com=50, min_periods=min_periods), name)(empty)
50+
tm.assert_series_equal(result, empty)
51+
52+
# check series of length 1
53+
result = getattr(Series([1.0]).ewm(com=50, min_periods=min_periods), name)(
54+
Series([1.0])
55+
)
56+
tm.assert_series_equal(result, Series([np.NaN]))
3257

3358

3459
@pytest.mark.parametrize("name", ["cov", "corr"])
@@ -38,7 +63,7 @@ def test_different_input_array_raise_exception(name, binary_ew_data):
3863
msg = "Input arrays must be of the same type!"
3964
# exception raised is Exception
4065
with pytest.raises(Exception, match=msg):
41-
ew_func(A, randn(50), 20, name=name, min_periods=5)
66+
getattr(A.ewm(com=20, min_periods=5), name)(randn(50))
4267

4368

4469
@pytest.mark.slow

pandas/tests/window/moments/test_moments_consistency_rolling.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import pandas._testing as tm
1313
from pandas.core.window.common import flex_binary_moment
1414
from pandas.tests.window.common import (
15-
check_pairwise_moment,
1615
moments_consistency_cov_data,
1716
moments_consistency_is_constant,
1817
moments_consistency_mock_mean,
@@ -60,7 +59,12 @@ def test_rolling_corr(series):
6059

6160
@pytest.mark.parametrize("func", ["cov", "corr"])
6261
def test_rolling_pairwise_cov_corr(func, frame):
63-
check_pairwise_moment(frame, "rolling", func, window=10, min_periods=5)
62+
result = getattr(frame.rolling(window=10, min_periods=5), func)()
63+
result = result.loc[(slice(None), 1), 5]
64+
result.index = result.index.droplevel(1)
65+
expected = getattr(frame[1].rolling(window=10, min_periods=5), func)(frame[5])
66+
expected.index = expected.index._with_freq(None)
67+
tm.assert_series_equal(result, expected, check_names=False)
6468

6569

6670
@pytest.mark.parametrize("method", ["corr", "cov"])

0 commit comments

Comments
 (0)