Skip to content

Commit 9c40e06

Browse files
charlesdong1991jreback
authored andcommitted
REF: Refactor window/test_moments.py (#30542)
1 parent 11284f5 commit 9c40e06

File tree

4 files changed

+1221
-1138
lines changed

4 files changed

+1221
-1138
lines changed

pandas/tests/window/common.py

+328-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import numpy as np
44
from numpy.random import randn
55

6-
from pandas import DataFrame, Series, bdate_range
6+
from pandas import DataFrame, Series, bdate_range, notna
7+
import pandas.util.testing as tm
78

89
N, K = 100, 10
910

@@ -21,3 +22,329 @@ def _create_data(self):
2122
self.rng = bdate_range(datetime(2009, 1, 1), periods=N)
2223
self.series = Series(arr.copy(), index=self.rng)
2324
self.frame = DataFrame(randn(N, K), index=self.rng, columns=np.arange(K))
25+
26+
27+
# create the data only once as we are not setting it
28+
def _create_consistency_data():
29+
def create_series():
30+
return [
31+
Series(dtype=object),
32+
Series([np.nan]),
33+
Series([np.nan, np.nan]),
34+
Series([3.0]),
35+
Series([np.nan, 3.0]),
36+
Series([3.0, np.nan]),
37+
Series([1.0, 3.0]),
38+
Series([2.0, 2.0]),
39+
Series([3.0, 1.0]),
40+
Series(
41+
[5.0, 5.0, 5.0, 5.0, np.nan, np.nan, np.nan, 5.0, 5.0, np.nan, np.nan]
42+
),
43+
Series(
44+
[
45+
np.nan,
46+
5.0,
47+
5.0,
48+
5.0,
49+
np.nan,
50+
np.nan,
51+
np.nan,
52+
5.0,
53+
5.0,
54+
np.nan,
55+
np.nan,
56+
]
57+
),
58+
Series(
59+
[
60+
np.nan,
61+
np.nan,
62+
5.0,
63+
5.0,
64+
np.nan,
65+
np.nan,
66+
np.nan,
67+
5.0,
68+
5.0,
69+
np.nan,
70+
np.nan,
71+
]
72+
),
73+
Series(
74+
[
75+
np.nan,
76+
3.0,
77+
np.nan,
78+
3.0,
79+
4.0,
80+
5.0,
81+
6.0,
82+
np.nan,
83+
np.nan,
84+
7.0,
85+
12.0,
86+
13.0,
87+
14.0,
88+
15.0,
89+
]
90+
),
91+
Series(
92+
[
93+
np.nan,
94+
5.0,
95+
np.nan,
96+
2.0,
97+
4.0,
98+
0.0,
99+
9.0,
100+
np.nan,
101+
np.nan,
102+
3.0,
103+
12.0,
104+
13.0,
105+
14.0,
106+
15.0,
107+
]
108+
),
109+
Series(
110+
[
111+
2.0,
112+
3.0,
113+
np.nan,
114+
3.0,
115+
4.0,
116+
5.0,
117+
6.0,
118+
np.nan,
119+
np.nan,
120+
7.0,
121+
12.0,
122+
13.0,
123+
14.0,
124+
15.0,
125+
]
126+
),
127+
Series(
128+
[
129+
2.0,
130+
5.0,
131+
np.nan,
132+
2.0,
133+
4.0,
134+
0.0,
135+
9.0,
136+
np.nan,
137+
np.nan,
138+
3.0,
139+
12.0,
140+
13.0,
141+
14.0,
142+
15.0,
143+
]
144+
),
145+
Series(range(10)),
146+
Series(range(20, 0, -2)),
147+
]
148+
149+
def create_dataframes():
150+
return [
151+
DataFrame(),
152+
DataFrame(columns=["a"]),
153+
DataFrame(columns=["a", "a"]),
154+
DataFrame(columns=["a", "b"]),
155+
DataFrame(np.arange(10).reshape((5, 2))),
156+
DataFrame(np.arange(25).reshape((5, 5))),
157+
DataFrame(np.arange(25).reshape((5, 5)), columns=["a", "b", 99, "d", "d"]),
158+
] + [DataFrame(s) for s in create_series()]
159+
160+
def is_constant(x):
161+
values = x.values.ravel()
162+
return len(set(values[notna(values)])) == 1
163+
164+
def no_nans(x):
165+
return x.notna().all().all()
166+
167+
# data is a tuple(object, is_constant, no_nans)
168+
data = create_series() + create_dataframes()
169+
170+
return [(x, is_constant(x), no_nans(x)) for x in data]
171+
172+
173+
_consistency_data = _create_consistency_data()
174+
175+
176+
class ConsistencyBase(Base):
177+
base_functions = [
178+
(lambda v: Series(v).count(), None, "count"),
179+
(lambda v: Series(v).max(), None, "max"),
180+
(lambda v: Series(v).min(), None, "min"),
181+
(lambda v: Series(v).sum(), None, "sum"),
182+
(lambda v: Series(v).mean(), None, "mean"),
183+
(lambda v: Series(v).std(), 1, "std"),
184+
(lambda v: Series(v).cov(Series(v)), None, "cov"),
185+
(lambda v: Series(v).corr(Series(v)), None, "corr"),
186+
(lambda v: Series(v).var(), 1, "var"),
187+
# restore once GH 8086 is fixed
188+
# lambda v: Series(v).skew(), 3, 'skew'),
189+
# (lambda v: Series(v).kurt(), 4, 'kurt'),
190+
# restore once GH 8084 is fixed
191+
# lambda v: Series(v).quantile(0.3), None, 'quantile'),
192+
(lambda v: Series(v).median(), None, "median"),
193+
(np.nanmax, 1, "max"),
194+
(np.nanmin, 1, "min"),
195+
(np.nansum, 1, "sum"),
196+
(np.nanmean, 1, "mean"),
197+
(lambda v: np.nanstd(v, ddof=1), 1, "std"),
198+
(lambda v: np.nanvar(v, ddof=1), 1, "var"),
199+
(np.nanmedian, 1, "median"),
200+
]
201+
no_nan_functions = [
202+
(np.max, None, "max"),
203+
(np.min, None, "min"),
204+
(np.sum, None, "sum"),
205+
(np.mean, None, "mean"),
206+
(lambda v: np.std(v, ddof=1), 1, "std"),
207+
(lambda v: np.var(v, ddof=1), 1, "var"),
208+
(np.median, None, "median"),
209+
]
210+
211+
def _create_data(self):
212+
super()._create_data()
213+
self.data = _consistency_data
214+
215+
def _test_moments_consistency(
216+
self,
217+
min_periods,
218+
count,
219+
mean,
220+
mock_mean,
221+
corr,
222+
var_unbiased=None,
223+
std_unbiased=None,
224+
cov_unbiased=None,
225+
var_biased=None,
226+
std_biased=None,
227+
cov_biased=None,
228+
var_debiasing_factors=None,
229+
):
230+
def _non_null_values(x):
231+
values = x.values.ravel()
232+
return set(values[notna(values)].tolist())
233+
234+
for (x, is_constant, no_nans) in self.data:
235+
count_x = count(x)
236+
mean_x = mean(x)
237+
238+
if mock_mean:
239+
# check that mean equals mock_mean
240+
expected = mock_mean(x)
241+
tm.assert_equal(mean_x, expected.astype("float64"))
242+
243+
# check that correlation of a series with itself is either 1 or NaN
244+
corr_x_x = corr(x, x)
245+
246+
# assert _non_null_values(corr_x_x).issubset(set([1.]))
247+
# restore once rolling_cov(x, x) is identically equal to var(x)
248+
249+
if is_constant:
250+
exp = x.max() if isinstance(x, Series) else x.max().max()
251+
252+
# check mean of constant series
253+
expected = x * np.nan
254+
expected[count_x >= max(min_periods, 1)] = exp
255+
tm.assert_equal(mean_x, expected)
256+
257+
# check correlation of constant series with itself is NaN
258+
expected[:] = np.nan
259+
tm.assert_equal(corr_x_x, expected)
260+
261+
if var_unbiased and var_biased and var_debiasing_factors:
262+
# check variance debiasing factors
263+
var_unbiased_x = var_unbiased(x)
264+
var_biased_x = var_biased(x)
265+
var_debiasing_factors_x = var_debiasing_factors(x)
266+
tm.assert_equal(var_unbiased_x, var_biased_x * var_debiasing_factors_x)
267+
268+
for (std, var, cov) in [
269+
(std_biased, var_biased, cov_biased),
270+
(std_unbiased, var_unbiased, cov_unbiased),
271+
]:
272+
273+
# check that var(x), std(x), and cov(x) are all >= 0
274+
var_x = var(x)
275+
std_x = std(x)
276+
assert not (var_x < 0).any().any()
277+
assert not (std_x < 0).any().any()
278+
if cov:
279+
cov_x_x = cov(x, x)
280+
assert not (cov_x_x < 0).any().any()
281+
282+
# check that var(x) == cov(x, x)
283+
tm.assert_equal(var_x, cov_x_x)
284+
285+
# check that var(x) == std(x)^2
286+
tm.assert_equal(var_x, std_x * std_x)
287+
288+
if var is var_biased:
289+
# check that biased var(x) == mean(x^2) - mean(x)^2
290+
mean_x2 = mean(x * x)
291+
tm.assert_equal(var_x, mean_x2 - (mean_x * mean_x))
292+
293+
if is_constant:
294+
# check that variance of constant series is identically 0
295+
assert not (var_x > 0).any().any()
296+
expected = x * np.nan
297+
expected[count_x >= max(min_periods, 1)] = 0.0
298+
if var is var_unbiased:
299+
expected[count_x < 2] = np.nan
300+
tm.assert_equal(var_x, expected)
301+
302+
if isinstance(x, Series):
303+
for (y, is_constant, no_nans) in self.data:
304+
if not x.isna().equals(y.isna()):
305+
# can only easily test two Series with similar
306+
# structure
307+
continue
308+
309+
# check that cor(x, y) is symmetric
310+
corr_x_y = corr(x, y)
311+
corr_y_x = corr(y, x)
312+
tm.assert_equal(corr_x_y, corr_y_x)
313+
314+
if cov:
315+
# check that cov(x, y) is symmetric
316+
cov_x_y = cov(x, y)
317+
cov_y_x = cov(y, x)
318+
tm.assert_equal(cov_x_y, cov_y_x)
319+
320+
# check that cov(x, y) == (var(x+y) - var(x) -
321+
# var(y)) / 2
322+
var_x_plus_y = var(x + y)
323+
var_y = var(y)
324+
tm.assert_equal(
325+
cov_x_y, 0.5 * (var_x_plus_y - var_x - var_y)
326+
)
327+
328+
# check that corr(x, y) == cov(x, y) / (std(x) *
329+
# std(y))
330+
std_y = std(y)
331+
tm.assert_equal(corr_x_y, cov_x_y / (std_x * std_y))
332+
333+
if cov is cov_biased:
334+
# check that biased cov(x, y) == mean(x*y) -
335+
# mean(x)*mean(y)
336+
mean_y = mean(y)
337+
mean_x_times_y = mean(x * y)
338+
tm.assert_equal(
339+
cov_x_y, mean_x_times_y - (mean_x * mean_y)
340+
)
341+
342+
def _check_pairwise_moment(self, dispatch, name, **kwargs):
343+
def get_result(obj, obj2=None):
344+
return getattr(getattr(obj, dispatch)(**kwargs), name)(obj2)
345+
346+
result = get_result(self.frame)
347+
result = result.loc[(slice(None), 1), 5]
348+
result.index = result.index.droplevel(1)
349+
expected = get_result(self.frame[1], self.frame[5])
350+
tm.assert_series_equal(result, expected, check_names=False)

0 commit comments

Comments
 (0)