-
-
Notifications
You must be signed in to change notification settings - Fork 18.4k
CLN: refactor tests in test_moments_ewm.py #30570
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
7e461a1
1314059
8bcb313
fd92136
bffc9c8
dae493b
2624900
ad959ee
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import numpy as np | ||
from numpy.random import randn | ||
import pytest | ||
|
||
from pandas import Series | ||
|
||
|
||
@pytest.fixture | ||
def binary_ew_data(): | ||
A = Series(randn(50), index=np.arange(50)) | ||
B = A[2:] + randn(48) | ||
|
||
A[:10] = np.NaN | ||
B[-10:] = np.NaN | ||
return A, B | ||
|
||
|
||
@pytest.fixture(params=[0, 1, 2]) | ||
def min_periods(request): | ||
return request.param |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -216,6 +216,9 @@ def _check_ew(self, name=None, preserve_nan=False): | |
if preserve_nan: | ||
assert result[self._nan_locs].isna().all() | ||
|
||
@pytest.mark.parametrize("min_periods", [0, 1]) | ||
jreback marked this conversation as resolved.
Show resolved
Hide resolved
|
||
@pytest.mark.parametrize("name", ["mean", "var", "vol"]) | ||
def test_ew_min_periods(self, min_periods, name): | ||
# excluding NaNs correctly | ||
arr = randn(50) | ||
arr[:10] = np.NaN | ||
|
@@ -228,31 +231,30 @@ def _check_ew(self, name=None, preserve_nan=False): | |
assert result[:11].isna().all() | ||
assert not result[11:].isna().any() | ||
|
||
for min_periods in (0, 1): | ||
result = getattr(s.ewm(com=50, min_periods=min_periods), name)() | ||
if name == "mean": | ||
assert result[:10].isna().all() | ||
assert not result[10:].isna().any() | ||
else: | ||
# ewm.std, ewm.vol, ewm.var (with bias=False) require at least | ||
# two values | ||
assert result[:11].isna().all() | ||
assert not result[11:].isna().any() | ||
|
||
# check series of length 0 | ||
result = getattr( | ||
Series(dtype=object).ewm(com=50, min_periods=min_periods), name | ||
)() | ||
tm.assert_series_equal(result, Series(dtype="float64")) | ||
|
||
# check series of length 1 | ||
result = getattr(Series([1.0]).ewm(50, min_periods=min_periods), name)() | ||
if name == "mean": | ||
tm.assert_series_equal(result, Series([1.0])) | ||
else: | ||
# ewm.std, ewm.vol, ewm.var with bias=False require at least | ||
# two values | ||
tm.assert_series_equal(result, Series([np.NaN])) | ||
result = getattr(s.ewm(com=50, min_periods=min_periods), name)() | ||
if name == "mean": | ||
assert result[:10].isna().all() | ||
assert not result[10:].isna().any() | ||
else: | ||
# ewm.std, ewm.vol, ewm.var (with bias=False) require at least | ||
# two values | ||
assert result[:11].isna().all() | ||
assert not result[11:].isna().any() | ||
|
||
# check series of length 0 | ||
result = getattr( | ||
Series(dtype=object).ewm(com=50, min_periods=min_periods), name | ||
)() | ||
tm.assert_series_equal(result, Series(dtype="float64")) | ||
|
||
# check series of length 1 | ||
result = getattr(Series([1.0]).ewm(50, min_periods=min_periods), name)() | ||
if name == "mean": | ||
tm.assert_series_equal(result, Series([1.0])) | ||
else: | ||
# ewm.std, ewm.vol, ewm.var with bias=False require at least | ||
# two values | ||
tm.assert_series_equal(result, Series([np.NaN])) | ||
|
||
# pass in ints | ||
result2 = getattr(Series(np.arange(50)).ewm(span=10), name)() | ||
|
@@ -263,53 +265,57 @@ class TestEwmMomentsConsistency(ConsistencyBase): | |
def setup_method(self, method): | ||
self._create_data() | ||
|
||
def test_ewmcov(self): | ||
self._check_binary_ew("cov") | ||
def test_ewmcov(self, min_periods, binary_ew_data): | ||
A, B = binary_ew_data | ||
|
||
self._check_binary_ew(name="cov", A=A, B=B) | ||
self._check_binary_ew_min_periods("cov", min_periods, A, B) | ||
|
||
def test_ewmcov_pairwise(self): | ||
self._check_pairwise_moment("ewm", "cov", span=10, min_periods=5) | ||
|
||
def test_ewmcorr(self): | ||
self._check_binary_ew("corr") | ||
def test_ewmcorr(self, min_periods, binary_ew_data): | ||
A, B = binary_ew_data | ||
|
||
self._check_binary_ew(name="corr", A=A, B=B) | ||
self._check_binary_ew_min_periods("corr", min_periods, A, B) | ||
|
||
def test_ewmcorr_pairwise(self): | ||
self._check_pairwise_moment("ewm", "corr", span=10, min_periods=5) | ||
|
||
def _check_binary_ew(self, name): | ||
def func(A, B, com, **kwargs): | ||
return getattr(A.ewm(com, **kwargs), name)(B) | ||
|
||
A = Series(randn(50), index=np.arange(50)) | ||
B = A[2:] + randn(48) | ||
|
||
A[:10] = np.NaN | ||
B[-10:] = np.NaN | ||
def _check_binary_ew(self, name, A, B): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. make this a free function, put in common.py There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. changed |
||
|
||
result = func(A, B, 20, min_periods=5) | ||
result = self._ew_func(A=A, B=B, com=20, name=name, min_periods=5) | ||
assert np.isnan(result.values[:14]).all() | ||
assert not np.isnan(result.values[14:]).any() | ||
|
||
def _check_binary_ew_min_periods(self, name, min_periods, A, B): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. changed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this does not need to be on the class, it can just be a free function in common.py also rename these to something like check_binary_ew_min_periods (de-prevatize) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah, i fixed! and also combined two methods with parametrization since they are almost identical! thanks! |
||
# GH 7898 | ||
for min_periods in (0, 1, 2): | ||
result = func(A, B, 20, min_periods=min_periods) | ||
# binary functions (ewmcov, ewmcorr) with bias=False require at | ||
# least two values | ||
assert np.isnan(result.values[:11]).all() | ||
assert not np.isnan(result.values[11:]).any() | ||
|
||
# check series of length 0 | ||
empty = Series([], dtype=np.float64) | ||
result = func(empty, empty, 50, min_periods=min_periods) | ||
tm.assert_series_equal(result, empty) | ||
|
||
# check series of length 1 | ||
result = func(Series([1.0]), Series([1.0]), 50, min_periods=min_periods) | ||
tm.assert_series_equal(result, Series([np.NaN])) | ||
result = self._ew_func(A, B, 20, name=name, min_periods=min_periods) | ||
# binary functions (ewmcov, ewmcorr) with bias=False require at | ||
# least two values | ||
assert np.isnan(result.values[:11]).all() | ||
assert not np.isnan(result.values[11:]).any() | ||
|
||
# check series of length 0 | ||
empty = Series([], dtype=np.float64) | ||
result = self._ew_func(empty, empty, 50, name=name, min_periods=min_periods) | ||
tm.assert_series_equal(result, empty) | ||
|
||
# check series of length 1 | ||
result = self._ew_func( | ||
Series([1.0]), Series([1.0]), 50, name=name, min_periods=min_periods | ||
) | ||
tm.assert_series_equal(result, Series([np.NaN])) | ||
|
||
@pytest.mark.parametrize("name", ["cov", "corr"]) | ||
def test_different_input_array_raise_exception(self, name, binary_ew_data): | ||
|
||
A, _ = binary_ew_data | ||
msg = "Input arrays must be of the same type!" | ||
# exception raised is Exception | ||
with pytest.raises(Exception, match=msg): | ||
func(A, randn(50), 20, min_periods=5) | ||
self._ew_func(A, randn(50), 20, name=name, min_periods=5) | ||
|
||
@pytest.mark.slow | ||
@pytest.mark.parametrize("min_periods", [0, 1, 2, 3, 4]) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just make this a module level function, ideally we want to get rid of all of the class based stuff eventually here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these also don't need to be private (leading underscore)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ahh! I get what you meant before, sorry for misunderstanding! fixed!