Skip to content

CLN: refactor tests in test_moments_ewm.py #30570

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jan 3, 2020
123 changes: 72 additions & 51 deletions pandas/tests/window/moments/test_moments_ewm.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,9 @@ def _check_ew(self, name=None, preserve_nan=False):
if preserve_nan:
assert result[self._nan_locs].isna().all()

@pytest.mark.parametrize("min_periods", [0, 1])
@pytest.mark.parametrize("name", ["mean", "var", "vol"])
def test_ew_min_periods(self, min_periods, name):
# excluding NaNs correctly
arr = randn(50)
arr[:10] = np.NaN
Expand All @@ -228,31 +231,30 @@ def _check_ew(self, name=None, preserve_nan=False):
assert result[:11].isna().all()
assert not result[11:].isna().any()

for min_periods in (0, 1):
result = getattr(s.ewm(com=50, min_periods=min_periods), name)()
if name == "mean":
assert result[:10].isna().all()
assert not result[10:].isna().any()
else:
# ewm.std, ewm.vol, ewm.var (with bias=False) require at least
# two values
assert result[:11].isna().all()
assert not result[11:].isna().any()

# check series of length 0
result = getattr(
Series(dtype=object).ewm(com=50, min_periods=min_periods), name
)()
tm.assert_series_equal(result, Series(dtype="float64"))

# check series of length 1
result = getattr(Series([1.0]).ewm(50, min_periods=min_periods), name)()
if name == "mean":
tm.assert_series_equal(result, Series([1.0]))
else:
# ewm.std, ewm.vol, ewm.var with bias=False require at least
# two values
tm.assert_series_equal(result, Series([np.NaN]))
result = getattr(s.ewm(com=50, min_periods=min_periods), name)()
if name == "mean":
assert result[:10].isna().all()
assert not result[10:].isna().any()
else:
# ewm.std, ewm.vol, ewm.var (with bias=False) require at least
# two values
assert result[:11].isna().all()
assert not result[11:].isna().any()

# check series of length 0
result = getattr(
Series(dtype=object).ewm(com=50, min_periods=min_periods), name
)()
tm.assert_series_equal(result, Series(dtype="float64"))

# check series of length 1
result = getattr(Series([1.0]).ewm(50, min_periods=min_periods), name)()
if name == "mean":
tm.assert_series_equal(result, Series([1.0]))
else:
# ewm.std, ewm.vol, ewm.var with bias=False require at least
# two values
tm.assert_series_equal(result, Series([np.NaN]))

# pass in ints
result2 = getattr(Series(np.arange(50)).ewm(span=10), name)()
Expand All @@ -263,49 +265,68 @@ class TestEwmMomentsConsistency(ConsistencyBase):
def setup_method(self, method):
self._create_data()

def test_ewmcov(self):
self._check_binary_ew("cov")
def _create_binary_ew_data(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make this a fixture

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

made

A = Series(randn(50), index=np.arange(50))
B = A[2:] + randn(48)

A[:10] = np.NaN
B[-10:] = np.NaN
return A, B

@pytest.mark.parametrize("min_periods", [0, 1, 2])
def test_ewmcov(self, min_periods):
A, B = self._create_binary_ew_data()

self._check_binary_ew(name="cov", A=A, B=B)
self._check_binary_ew_min_periods("cov", min_periods, A, B)

def test_ewmcov_pairwise(self):
self._check_pairwise_moment("ewm", "cov", span=10, min_periods=5)

def test_ewmcorr(self):
self._check_binary_ew("corr")
@pytest.mark.parametrize("min_periods", [0, 1, 2])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can just create a fixture with these min_periods (you can put in a conftest.py in this directory)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, i made a conftest.py under moments directory, and add couple fixtures there

def test_ewmcorr(self, min_periods):
A, B = self._create_binary_ew_data()

self._check_binary_ew(name="corr", A=A, B=B)
self._check_binary_ew_min_periods("corr", min_periods, A, B)

def test_ewmcorr_pairwise(self):
self._check_pairwise_moment("ewm", "corr", span=10, min_periods=5)

def _check_binary_ew(self, name):
def _check_binary_ew(self, name, A, B):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make this a free function, put in common.py

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed

def func(A, B, com, **kwargs):
return getattr(A.ewm(com, **kwargs), name)(B)

A = Series(randn(50), index=np.arange(50))
B = A[2:] + randn(48)

A[:10] = np.NaN
B[-10:] = np.NaN

result = func(A, B, 20, min_periods=5)
assert np.isnan(result.values[:14]).all()
assert not np.isnan(result.values[14:]).any()

def _check_binary_ew_min_periods(self, name, min_periods, A, B):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this does not need to be on the class, it can just be a free function in common.py

also rename these to something like

check_binary_ew_min_periods (de-prevatize)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, i fixed! and also combined two methods with parametrization since they are almost identical!

thanks!

def func(A, B, com, **kwargs):
return getattr(A.ewm(com, **kwargs), name)(B)

# GH 7898
for min_periods in (0, 1, 2):
result = func(A, B, 20, min_periods=min_periods)
# binary functions (ewmcov, ewmcorr) with bias=False require at
# least two values
assert np.isnan(result.values[:11]).all()
assert not np.isnan(result.values[11:]).any()

# check series of length 0
empty = Series([], dtype=np.float64)
result = func(empty, empty, 50, min_periods=min_periods)
tm.assert_series_equal(result, empty)

# check series of length 1
result = func(Series([1.0]), Series([1.0]), 50, min_periods=min_periods)
tm.assert_series_equal(result, Series([np.NaN]))
result = func(A, B, 20, min_periods=min_periods)
# binary functions (ewmcov, ewmcorr) with bias=False require at
# least two values
assert np.isnan(result.values[:11]).all()
assert not np.isnan(result.values[11:]).any()

# check series of length 0
empty = Series([], dtype=np.float64)
result = func(empty, empty, 50, min_periods=min_periods)
tm.assert_series_equal(result, empty)

# check series of length 1
result = func(Series([1.0]), Series([1.0]), 50, min_periods=min_periods)
tm.assert_series_equal(result, Series([np.NaN]))

@pytest.mark.parametrize("name", ["cov", "corr"])
def test_different_input_array_raise_exception(self, name):
def func(A, B, com, **kwargs):
return getattr(A.ewm(com, **kwargs), name)(B)

A, _ = self._create_binary_ew_data()
msg = "Input arrays must be of the same type!"
# exception raised is Exception
with pytest.raises(Exception, match=msg):
Expand Down