Skip to content

CLN: refactor tests in test_moments_ewm.py #30570

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jan 3, 2020
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pandas/tests/window/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,3 +348,7 @@ def get_result(obj, obj2=None):
result.index = result.index.droplevel(1)
expected = get_result(self.frame[1], self.frame[5])
tm.assert_series_equal(result, expected, check_names=False)

@staticmethod
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just make this a module level function, ideally we want to get rid of all of the class based stuff eventually here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these also don't need to be private (leading underscore)

Copy link
Member Author

@charlesdong1991 charlesdong1991 Jan 3, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just make this a module level function, ideally we want to get rid of all of the class based stuff eventually here.

ahh! I get what you meant before, sorry for misunderstanding! fixed!

def _ew_func(A, B, com, name, **kwargs):
return getattr(A.ewm(com, **kwargs), name)(B)
20 changes: 20 additions & 0 deletions pandas/tests/window/moments/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import numpy as np
from numpy.random import randn
import pytest

from pandas import Series


@pytest.fixture
def binary_ew_data():
A = Series(randn(50), index=np.arange(50))
B = A[2:] + randn(48)

A[:10] = np.NaN
B[-10:] = np.NaN
return A, B


@pytest.fixture(params=[0, 1, 2])
def min_periods(request):
return request.param
116 changes: 61 additions & 55 deletions pandas/tests/window/moments/test_moments_ewm.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,9 @@ def _check_ew(self, name=None, preserve_nan=False):
if preserve_nan:
assert result[self._nan_locs].isna().all()

@pytest.mark.parametrize("min_periods", [0, 1])
@pytest.mark.parametrize("name", ["mean", "var", "vol"])
def test_ew_min_periods(self, min_periods, name):
# excluding NaNs correctly
arr = randn(50)
arr[:10] = np.NaN
Expand All @@ -228,31 +231,30 @@ def _check_ew(self, name=None, preserve_nan=False):
assert result[:11].isna().all()
assert not result[11:].isna().any()

for min_periods in (0, 1):
result = getattr(s.ewm(com=50, min_periods=min_periods), name)()
if name == "mean":
assert result[:10].isna().all()
assert not result[10:].isna().any()
else:
# ewm.std, ewm.vol, ewm.var (with bias=False) require at least
# two values
assert result[:11].isna().all()
assert not result[11:].isna().any()

# check series of length 0
result = getattr(
Series(dtype=object).ewm(com=50, min_periods=min_periods), name
)()
tm.assert_series_equal(result, Series(dtype="float64"))

# check series of length 1
result = getattr(Series([1.0]).ewm(50, min_periods=min_periods), name)()
if name == "mean":
tm.assert_series_equal(result, Series([1.0]))
else:
# ewm.std, ewm.vol, ewm.var with bias=False require at least
# two values
tm.assert_series_equal(result, Series([np.NaN]))
result = getattr(s.ewm(com=50, min_periods=min_periods), name)()
if name == "mean":
assert result[:10].isna().all()
assert not result[10:].isna().any()
else:
# ewm.std, ewm.vol, ewm.var (with bias=False) require at least
# two values
assert result[:11].isna().all()
assert not result[11:].isna().any()

# check series of length 0
result = getattr(
Series(dtype=object).ewm(com=50, min_periods=min_periods), name
)()
tm.assert_series_equal(result, Series(dtype="float64"))

# check series of length 1
result = getattr(Series([1.0]).ewm(50, min_periods=min_periods), name)()
if name == "mean":
tm.assert_series_equal(result, Series([1.0]))
else:
# ewm.std, ewm.vol, ewm.var with bias=False require at least
# two values
tm.assert_series_equal(result, Series([np.NaN]))

# pass in ints
result2 = getattr(Series(np.arange(50)).ewm(span=10), name)()
Expand All @@ -263,53 +265,57 @@ class TestEwmMomentsConsistency(ConsistencyBase):
def setup_method(self, method):
self._create_data()

def test_ewmcov(self):
self._check_binary_ew("cov")
def test_ewmcov(self, min_periods, binary_ew_data):
A, B = binary_ew_data

self._check_binary_ew(name="cov", A=A, B=B)
self._check_binary_ew_min_periods("cov", min_periods, A, B)

def test_ewmcov_pairwise(self):
self._check_pairwise_moment("ewm", "cov", span=10, min_periods=5)

def test_ewmcorr(self):
self._check_binary_ew("corr")
def test_ewmcorr(self, min_periods, binary_ew_data):
A, B = binary_ew_data

self._check_binary_ew(name="corr", A=A, B=B)
self._check_binary_ew_min_periods("corr", min_periods, A, B)

def test_ewmcorr_pairwise(self):
self._check_pairwise_moment("ewm", "corr", span=10, min_periods=5)

def _check_binary_ew(self, name):
def func(A, B, com, **kwargs):
return getattr(A.ewm(com, **kwargs), name)(B)

A = Series(randn(50), index=np.arange(50))
B = A[2:] + randn(48)

A[:10] = np.NaN
B[-10:] = np.NaN
def _check_binary_ew(self, name, A, B):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make this a free function, put in common.py

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed


result = func(A, B, 20, min_periods=5)
result = self._ew_func(A=A, B=B, com=20, name=name, min_periods=5)
assert np.isnan(result.values[:14]).all()
assert not np.isnan(result.values[14:]).any()

def _check_binary_ew_min_periods(self, name, min_periods, A, B):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this does not need to be on the class, it can just be a free function in common.py

also rename these to something like

check_binary_ew_min_periods (de-prevatize)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, i fixed! and also combined two methods with parametrization since they are almost identical!

thanks!

# GH 7898
for min_periods in (0, 1, 2):
result = func(A, B, 20, min_periods=min_periods)
# binary functions (ewmcov, ewmcorr) with bias=False require at
# least two values
assert np.isnan(result.values[:11]).all()
assert not np.isnan(result.values[11:]).any()

# check series of length 0
empty = Series([], dtype=np.float64)
result = func(empty, empty, 50, min_periods=min_periods)
tm.assert_series_equal(result, empty)

# check series of length 1
result = func(Series([1.0]), Series([1.0]), 50, min_periods=min_periods)
tm.assert_series_equal(result, Series([np.NaN]))
result = self._ew_func(A, B, 20, name=name, min_periods=min_periods)
# binary functions (ewmcov, ewmcorr) with bias=False require at
# least two values
assert np.isnan(result.values[:11]).all()
assert not np.isnan(result.values[11:]).any()

# check series of length 0
empty = Series([], dtype=np.float64)
result = self._ew_func(empty, empty, 50, name=name, min_periods=min_periods)
tm.assert_series_equal(result, empty)

# check series of length 1
result = self._ew_func(
Series([1.0]), Series([1.0]), 50, name=name, min_periods=min_periods
)
tm.assert_series_equal(result, Series([np.NaN]))

@pytest.mark.parametrize("name", ["cov", "corr"])
def test_different_input_array_raise_exception(self, name, binary_ew_data):

A, _ = binary_ew_data
msg = "Input arrays must be of the same type!"
# exception raised is Exception
with pytest.raises(Exception, match=msg):
func(A, randn(50), 20, min_periods=5)
self._ew_func(A, randn(50), 20, name=name, min_periods=5)

@pytest.mark.slow
@pytest.mark.parametrize("min_periods", [0, 1, 2, 3, 4])
Expand Down