Skip to content

CLN: refactor tests in test_moments_ewm.py #30570

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jan 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions pandas/tests/window/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,3 +348,34 @@ def get_result(obj, obj2=None):
result.index = result.index.droplevel(1)
expected = get_result(self.frame[1], self.frame[5])
tm.assert_series_equal(result, expected, check_names=False)


def ew_func(A, B, com, name, **kwargs):
return getattr(A.ewm(com, **kwargs), name)(B)


def check_binary_ew(name, A, B):

result = ew_func(A=A, B=B, com=20, name=name, min_periods=5)
assert np.isnan(result.values[:14]).all()
assert not np.isnan(result.values[14:]).any()


def check_binary_ew_min_periods(name, min_periods, A, B):
# GH 7898
result = ew_func(A, B, 20, name=name, min_periods=min_periods)
# binary functions (ewmcov, ewmcorr) with bias=False require at
# least two values
assert np.isnan(result.values[:11]).all()
assert not np.isnan(result.values[11:]).any()

# check series of length 0
empty = Series([], dtype=np.float64)
result = ew_func(empty, empty, 50, name=name, min_periods=min_periods)
tm.assert_series_equal(result, empty)

# check series of length 1
result = ew_func(
Series([1.0]), Series([1.0]), 50, name=name, min_periods=min_periods
)
tm.assert_series_equal(result, Series([np.NaN]))
20 changes: 20 additions & 0 deletions pandas/tests/window/moments/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import numpy as np
from numpy.random import randn
import pytest

from pandas import Series


@pytest.fixture
def binary_ew_data():
A = Series(randn(50), index=np.arange(50))
B = A[2:] + randn(48)

A[:10] = np.NaN
B[-10:] = np.NaN
return A, B


@pytest.fixture(params=[0, 1, 2])
def min_periods(request):
return request.param
106 changes: 44 additions & 62 deletions pandas/tests/window/moments/test_moments_ewm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@

import pandas as pd
from pandas import DataFrame, Series, concat
from pandas.tests.window.common import Base, ConsistencyBase
from pandas.tests.window.common import (
Base,
ConsistencyBase,
check_binary_ew,
check_binary_ew_min_periods,
ew_func,
)
import pandas.util.testing as tm


Expand Down Expand Up @@ -216,6 +222,9 @@ def _check_ew(self, name=None, preserve_nan=False):
if preserve_nan:
assert result[self._nan_locs].isna().all()

@pytest.mark.parametrize("min_periods", [0, 1])
@pytest.mark.parametrize("name", ["mean", "var", "vol"])
def test_ew_min_periods(self, min_periods, name):
# excluding NaNs correctly
arr = randn(50)
arr[:10] = np.NaN
Expand All @@ -228,31 +237,30 @@ def _check_ew(self, name=None, preserve_nan=False):
assert result[:11].isna().all()
assert not result[11:].isna().any()

for min_periods in (0, 1):
result = getattr(s.ewm(com=50, min_periods=min_periods), name)()
if name == "mean":
assert result[:10].isna().all()
assert not result[10:].isna().any()
else:
# ewm.std, ewm.vol, ewm.var (with bias=False) require at least
# two values
assert result[:11].isna().all()
assert not result[11:].isna().any()

# check series of length 0
result = getattr(
Series(dtype=object).ewm(com=50, min_periods=min_periods), name
)()
tm.assert_series_equal(result, Series(dtype="float64"))

# check series of length 1
result = getattr(Series([1.0]).ewm(50, min_periods=min_periods), name)()
if name == "mean":
tm.assert_series_equal(result, Series([1.0]))
else:
# ewm.std, ewm.vol, ewm.var with bias=False require at least
# two values
tm.assert_series_equal(result, Series([np.NaN]))
result = getattr(s.ewm(com=50, min_periods=min_periods), name)()
if name == "mean":
assert result[:10].isna().all()
assert not result[10:].isna().any()
else:
# ewm.std, ewm.vol, ewm.var (with bias=False) require at least
# two values
assert result[:11].isna().all()
assert not result[11:].isna().any()

# check series of length 0
result = getattr(
Series(dtype=object).ewm(com=50, min_periods=min_periods), name
)()
tm.assert_series_equal(result, Series(dtype="float64"))

# check series of length 1
result = getattr(Series([1.0]).ewm(50, min_periods=min_periods), name)()
if name == "mean":
tm.assert_series_equal(result, Series([1.0]))
else:
# ewm.std, ewm.vol, ewm.var with bias=False require at least
# two values
tm.assert_series_equal(result, Series([np.NaN]))

# pass in ints
result2 = getattr(Series(np.arange(50)).ewm(span=10), name)()
Expand All @@ -263,53 +271,27 @@ class TestEwmMomentsConsistency(ConsistencyBase):
def setup_method(self, method):
self._create_data()

def test_ewmcov(self):
self._check_binary_ew("cov")

def test_ewmcov_pairwise(self):
self._check_pairwise_moment("ewm", "cov", span=10, min_periods=5)

def test_ewmcorr(self):
self._check_binary_ew("corr")
@pytest.mark.parametrize("name", ["cov", "corr"])
def test_ewm_corr_cov(self, name, min_periods, binary_ew_data):
A, B = binary_ew_data

check_binary_ew(name="corr", A=A, B=B)
check_binary_ew_min_periods("corr", min_periods, A, B)

def test_ewmcorr_pairwise(self):
self._check_pairwise_moment("ewm", "corr", span=10, min_periods=5)

def _check_binary_ew(self, name):
def func(A, B, com, **kwargs):
return getattr(A.ewm(com, **kwargs), name)(B)

A = Series(randn(50), index=np.arange(50))
B = A[2:] + randn(48)

A[:10] = np.NaN
B[-10:] = np.NaN

result = func(A, B, 20, min_periods=5)
assert np.isnan(result.values[:14]).all()
assert not np.isnan(result.values[14:]).any()

# GH 7898
for min_periods in (0, 1, 2):
result = func(A, B, 20, min_periods=min_periods)
# binary functions (ewmcov, ewmcorr) with bias=False require at
# least two values
assert np.isnan(result.values[:11]).all()
assert not np.isnan(result.values[11:]).any()

# check series of length 0
empty = Series([], dtype=np.float64)
result = func(empty, empty, 50, min_periods=min_periods)
tm.assert_series_equal(result, empty)

# check series of length 1
result = func(Series([1.0]), Series([1.0]), 50, min_periods=min_periods)
tm.assert_series_equal(result, Series([np.NaN]))
@pytest.mark.parametrize("name", ["cov", "corr"])
def test_different_input_array_raise_exception(self, name, binary_ew_data):

A, _ = binary_ew_data
msg = "Input arrays must be of the same type!"
# exception raised is Exception
with pytest.raises(Exception, match=msg):
func(A, randn(50), 20, min_periods=5)
ew_func(A, randn(50), 20, name=name, min_periods=5)

@pytest.mark.slow
@pytest.mark.parametrize("min_periods", [0, 1, 2, 3, 4])
Expand Down