Skip to content

Commit b3f54d7

Browse files
TST/REF: Fixturize constant functions in ConsistencyBase (#33943)
1 parent e30971a commit b3f54d7

File tree

4 files changed

+158
-129
lines changed

4 files changed

+158
-129
lines changed

pandas/tests/window/common.py

-34
Original file line numberDiff line numberDiff line change
@@ -25,40 +25,6 @@ def _create_data(self):
2525

2626

2727
class ConsistencyBase(Base):
28-
base_functions = [
29-
(lambda v: Series(v).count(), None, "count"),
30-
(lambda v: Series(v).max(), None, "max"),
31-
(lambda v: Series(v).min(), None, "min"),
32-
(lambda v: Series(v).sum(), None, "sum"),
33-
(lambda v: Series(v).mean(), None, "mean"),
34-
(lambda v: Series(v).std(), 1, "std"),
35-
(lambda v: Series(v).cov(Series(v)), None, "cov"),
36-
(lambda v: Series(v).corr(Series(v)), None, "corr"),
37-
(lambda v: Series(v).var(), 1, "var"),
38-
# restore once GH 8086 is fixed
39-
# lambda v: Series(v).skew(), 3, 'skew'),
40-
# (lambda v: Series(v).kurt(), 4, 'kurt'),
41-
# restore once GH 8084 is fixed
42-
# lambda v: Series(v).quantile(0.3), None, 'quantile'),
43-
(lambda v: Series(v).median(), None, "median"),
44-
(np.nanmax, 1, "max"),
45-
(np.nanmin, 1, "min"),
46-
(np.nansum, 1, "sum"),
47-
(np.nanmean, 1, "mean"),
48-
(lambda v: np.nanstd(v, ddof=1), 1, "std"),
49-
(lambda v: np.nanvar(v, ddof=1), 1, "var"),
50-
(np.nanmedian, 1, "median"),
51-
]
52-
no_nan_functions = [
53-
(np.max, None, "max"),
54-
(np.min, None, "min"),
55-
(np.sum, None, "sum"),
56-
(np.mean, None, "mean"),
57-
(lambda v: np.std(v, ddof=1), 1, "std"),
58-
(lambda v: np.var(v, ddof=1), 1, "var"),
59-
(np.median, None, "median"),
60-
]
61-
6228
def _create_data(self):
6329
super()._create_data()
6430

pandas/tests/window/moments/conftest.py

+58
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,61 @@ def binary_ew_data():
1818
@pytest.fixture(params=[0, 1, 2])
1919
def min_periods(request):
2020
return request.param
21+
22+
23+
base_functions_list = [
24+
(lambda v: Series(v).count(), None, "count"),
25+
(lambda v: Series(v).max(), None, "max"),
26+
(lambda v: Series(v).min(), None, "min"),
27+
(lambda v: Series(v).sum(), None, "sum"),
28+
(lambda v: Series(v).mean(), None, "mean"),
29+
(lambda v: Series(v).std(), 1, "std"),
30+
(lambda v: Series(v).cov(Series(v)), None, "cov"),
31+
(lambda v: Series(v).corr(Series(v)), None, "corr"),
32+
(lambda v: Series(v).var(), 1, "var"),
33+
# restore once GH 8086 is fixed
34+
# lambda v: Series(v).skew(), 3, 'skew'),
35+
# (lambda v: Series(v).kurt(), 4, 'kurt'),
36+
# restore once GH 8084 is fixed
37+
# lambda v: Series(v).quantile(0.3), None, 'quantile'),
38+
(lambda v: Series(v).median(), None, "median"),
39+
(np.nanmax, 1, "max"),
40+
(np.nanmin, 1, "min"),
41+
(np.nansum, 1, "sum"),
42+
(np.nanmean, 1, "mean"),
43+
(lambda v: np.nanstd(v, ddof=1), 1, "std"),
44+
(lambda v: np.nanvar(v, ddof=1), 1, "var"),
45+
(np.nanmedian, 1, "median"),
46+
]
47+
48+
no_nan_functions_list = [
49+
(np.max, None, "max"),
50+
(np.min, None, "min"),
51+
(np.sum, None, "sum"),
52+
(np.mean, None, "mean"),
53+
(lambda v: np.std(v, ddof=1), 1, "std"),
54+
(lambda v: np.var(v, ddof=1), 1, "var"),
55+
(np.median, None, "median"),
56+
]
57+
58+
59+
@pytest.fixture(scope="session")
60+
def base_functions():
61+
"""Fixture for base functions.
62+
63+
Returns
64+
-------
65+
List of tuples: (applied function, require_min_periods, name of applied function)
66+
"""
67+
return base_functions_list
68+
69+
70+
@pytest.fixture(scope="session")
71+
def no_nan_functions():
72+
"""Fixture for no nan functions.
73+
74+
Returns
75+
-------
76+
List of tuples: (applied function, require_min_periods, name of applied function)
77+
"""
78+
return no_nan_functions_list

pandas/tests/window/moments/test_moments_expanding.py

+45-43
Original file line numberDiff line numberDiff line change
@@ -145,50 +145,52 @@ def _check_expanding_has_min_periods(self, func, static_comp, has_min_periods):
145145
result = func(ser)
146146
tm.assert_almost_equal(result.iloc[-1], static_comp(ser[:50]))
147147

148-
@pytest.mark.parametrize("min_periods", [0, 1, 2, 3, 4])
149-
def test_expanding_apply_consistency(self, consistency_data, min_periods):
150-
x, is_constant, no_nans = consistency_data
151-
with warnings.catch_warnings():
152-
warnings.filterwarnings(
153-
"ignore",
154-
message=".*(empty slice|0 for slice).*",
155-
category=RuntimeWarning,
156-
)
157-
# test consistency between expanding_xyz() and either (a)
158-
# expanding_apply of Series.xyz(), or (b) expanding_apply of
159-
# np.nanxyz()
160-
functions = self.base_functions
161-
162-
# GH 8269
163-
if no_nans:
164-
functions = self.base_functions + self.no_nan_functions
165-
for (f, require_min_periods, name) in functions:
166-
expanding_f = getattr(x.expanding(min_periods=min_periods), name)
167-
168-
if (
169-
require_min_periods
170-
and (min_periods is not None)
171-
and (min_periods < require_min_periods)
172-
):
173-
continue
174-
175-
if name == "count":
176-
expanding_f_result = expanding_f()
177-
expanding_apply_f_result = x.expanding(min_periods=0).apply(
178-
func=f, raw=True
179-
)
148+
149+
@pytest.mark.parametrize("min_periods", [0, 1, 2, 3, 4])
150+
def test_expanding_apply_consistency(
151+
consistency_data, base_functions, no_nan_functions, min_periods
152+
):
153+
x, is_constant, no_nans = consistency_data
154+
155+
with warnings.catch_warnings():
156+
warnings.filterwarnings(
157+
"ignore", message=".*(empty slice|0 for slice).*", category=RuntimeWarning,
158+
)
159+
# test consistency between expanding_xyz() and either (a)
160+
# expanding_apply of Series.xyz(), or (b) expanding_apply of
161+
# np.nanxyz()
162+
functions = base_functions
163+
164+
# GH 8269
165+
if no_nans:
166+
functions = base_functions + no_nan_functions
167+
for (f, require_min_periods, name) in functions:
168+
expanding_f = getattr(x.expanding(min_periods=min_periods), name)
169+
170+
if (
171+
require_min_periods
172+
and (min_periods is not None)
173+
and (min_periods < require_min_periods)
174+
):
175+
continue
176+
177+
if name == "count":
178+
expanding_f_result = expanding_f()
179+
expanding_apply_f_result = x.expanding(min_periods=0).apply(
180+
func=f, raw=True
181+
)
182+
else:
183+
if name in ["cov", "corr"]:
184+
expanding_f_result = expanding_f(pairwise=False)
180185
else:
181-
if name in ["cov", "corr"]:
182-
expanding_f_result = expanding_f(pairwise=False)
183-
else:
184-
expanding_f_result = expanding_f()
185-
expanding_apply_f_result = x.expanding(
186-
min_periods=min_periods
187-
).apply(func=f, raw=True)
188-
189-
# GH 9422
190-
if name in ["sum", "prod"]:
191-
tm.assert_equal(expanding_f_result, expanding_apply_f_result)
186+
expanding_f_result = expanding_f()
187+
expanding_apply_f_result = x.expanding(min_periods=min_periods).apply(
188+
func=f, raw=True
189+
)
190+
191+
# GH 9422
192+
if name in ["sum", "prod"]:
193+
tm.assert_equal(expanding_f_result, expanding_apply_f_result)
192194

193195

194196
@pytest.mark.parametrize("min_periods", [0, 1, 2, 3, 4])

pandas/tests/window/moments/test_moments_rolling.py

+55-52
Original file line numberDiff line numberDiff line change
@@ -946,58 +946,6 @@ class TestRollingMomentsConsistency(ConsistencyBase):
946946
def setup_method(self, method):
947947
self._create_data()
948948

949-
@pytest.mark.parametrize(
950-
"window,min_periods,center", list(_rolling_consistency_cases())
951-
)
952-
def test_rolling_apply_consistency(
953-
self, consistency_data, window, min_periods, center
954-
):
955-
x, is_constant, no_nans = consistency_data
956-
with warnings.catch_warnings():
957-
warnings.filterwarnings(
958-
"ignore",
959-
message=".*(empty slice|0 for slice).*",
960-
category=RuntimeWarning,
961-
)
962-
# test consistency between rolling_xyz() and either (a)
963-
# rolling_apply of Series.xyz(), or (b) rolling_apply of
964-
# np.nanxyz()
965-
functions = self.base_functions
966-
967-
# GH 8269
968-
if no_nans:
969-
functions = self.base_functions + self.no_nan_functions
970-
for (f, require_min_periods, name) in functions:
971-
rolling_f = getattr(
972-
x.rolling(window=window, center=center, min_periods=min_periods),
973-
name,
974-
)
975-
976-
if (
977-
require_min_periods
978-
and (min_periods is not None)
979-
and (min_periods < require_min_periods)
980-
):
981-
continue
982-
983-
if name == "count":
984-
rolling_f_result = rolling_f()
985-
rolling_apply_f_result = x.rolling(
986-
window=window, min_periods=min_periods, center=center
987-
).apply(func=f, raw=True)
988-
else:
989-
if name in ["cov", "corr"]:
990-
rolling_f_result = rolling_f(pairwise=False)
991-
else:
992-
rolling_f_result = rolling_f()
993-
rolling_apply_f_result = x.rolling(
994-
window=window, min_periods=min_periods, center=center
995-
).apply(func=f, raw=True)
996-
997-
# GH 9422
998-
if name in ["sum", "prod"]:
999-
tm.assert_equal(rolling_f_result, rolling_apply_f_result)
1000-
1001949
# binary moments
1002950
def test_rolling_cov(self):
1003951
A = self.series
@@ -1052,6 +1000,58 @@ def test_flex_binary_frame(self, method):
10521000
tm.assert_frame_equal(res3, exp)
10531001

10541002

1003+
@pytest.mark.slow
1004+
@pytest.mark.parametrize(
1005+
"window,min_periods,center", list(_rolling_consistency_cases())
1006+
)
1007+
def test_rolling_apply_consistency(
1008+
consistency_data, base_functions, no_nan_functions, window, min_periods, center
1009+
):
1010+
x, is_constant, no_nans = consistency_data
1011+
1012+
with warnings.catch_warnings():
1013+
warnings.filterwarnings(
1014+
"ignore", message=".*(empty slice|0 for slice).*", category=RuntimeWarning,
1015+
)
1016+
# test consistency between rolling_xyz() and either (a)
1017+
# rolling_apply of Series.xyz(), or (b) rolling_apply of
1018+
# np.nanxyz()
1019+
functions = base_functions
1020+
1021+
# GH 8269
1022+
if no_nans:
1023+
functions = no_nan_functions + base_functions
1024+
for (f, require_min_periods, name) in functions:
1025+
rolling_f = getattr(
1026+
x.rolling(window=window, center=center, min_periods=min_periods), name,
1027+
)
1028+
1029+
if (
1030+
require_min_periods
1031+
and (min_periods is not None)
1032+
and (min_periods < require_min_periods)
1033+
):
1034+
continue
1035+
1036+
if name == "count":
1037+
rolling_f_result = rolling_f()
1038+
rolling_apply_f_result = x.rolling(
1039+
window=window, min_periods=min_periods, center=center
1040+
).apply(func=f, raw=True)
1041+
else:
1042+
if name in ["cov", "corr"]:
1043+
rolling_f_result = rolling_f(pairwise=False)
1044+
else:
1045+
rolling_f_result = rolling_f()
1046+
rolling_apply_f_result = x.rolling(
1047+
window=window, min_periods=min_periods, center=center
1048+
).apply(func=f, raw=True)
1049+
1050+
# GH 9422
1051+
if name in ["sum", "prod"]:
1052+
tm.assert_equal(rolling_f_result, rolling_apply_f_result)
1053+
1054+
10551055
@pytest.mark.parametrize("window", range(7))
10561056
def test_rolling_corr_with_zero_variance(window):
10571057
# GH 18430
@@ -1431,6 +1431,7 @@ def test_moment_functions_zero_length_pairwise():
14311431
tm.assert_frame_equal(df2_result, df2_expected)
14321432

14331433

1434+
@pytest.mark.slow
14341435
@pytest.mark.parametrize(
14351436
"window,min_periods,center", list(_rolling_consistency_cases())
14361437
)
@@ -1455,6 +1456,7 @@ def test_rolling_consistency_var(consistency_data, window, min_periods, center):
14551456
)
14561457

14571458

1459+
@pytest.mark.slow
14581460
@pytest.mark.parametrize(
14591461
"window,min_periods,center", list(_rolling_consistency_cases())
14601462
)
@@ -1477,6 +1479,7 @@ def test_rolling_consistency_std(consistency_data, window, min_periods, center):
14771479
)
14781480

14791481

1482+
@pytest.mark.slow
14801483
@pytest.mark.parametrize(
14811484
"window,min_periods,center", list(_rolling_consistency_cases())
14821485
)

0 commit comments

Comments
 (0)