Skip to content

Commit 7ecd9af

Browse files
charlesdong1991jreback
authored andcommitted
CLN: Clean test moments for expanding (pandas-dev#30566)
1 parent ac3715b commit 7ecd9af

File tree

1 file changed

+45
-38
lines changed

1 file changed

+45
-38
lines changed

pandas/tests/window/moments/test_moments_expanding.py

+45-38
Original file line numberDiff line numberDiff line change
@@ -173,19 +173,24 @@ def test_expanding_corr_pairwise_diff_length(self):
173173
tm.assert_frame_equal(result3, expected)
174174
tm.assert_frame_equal(result4, expected)
175175

176+
@pytest.mark.parametrize("has_min_periods", [True, False])
176177
@pytest.mark.parametrize(
177178
"func,static_comp",
178179
[("sum", np.sum), ("mean", np.mean), ("max", np.max), ("min", np.min)],
179180
ids=["sum", "mean", "max", "min"],
180181
)
181-
def test_expanding_func(self, func, static_comp):
182+
def test_expanding_func(self, func, static_comp, has_min_periods):
182183
def expanding_func(x, min_periods=1, center=False, axis=0):
183184
exp = x.expanding(min_periods=min_periods, center=center, axis=axis)
184185
return getattr(exp, func)()
185186

186187
self._check_expanding(expanding_func, static_comp, preserve_nan=False)
188+
self._check_expanding_has_min_periods(
189+
expanding_func, static_comp, has_min_periods
190+
)
187191

188-
def test_expanding_apply(self, raw):
192+
@pytest.mark.parametrize("has_min_periods", [True, False])
193+
def test_expanding_apply(self, raw, has_min_periods):
189194
def expanding_mean(x, min_periods=1):
190195

191196
exp = x.expanding(min_periods=min_periods)
@@ -195,19 +200,20 @@ def expanding_mean(x, min_periods=1):
195200
# TODO(jreback), needed to add preserve_nan=False
196201
# here to make this pass
197202
self._check_expanding(expanding_mean, np.mean, preserve_nan=False)
203+
self._check_expanding_has_min_periods(expanding_mean, np.mean, has_min_periods)
198204

205+
def test_expanding_apply_empty_series(self, raw):
199206
ser = Series([], dtype=np.float64)
200207
tm.assert_series_equal(ser, ser.expanding().apply(lambda x: x.mean(), raw=raw))
201208

209+
def test_expanding_apply_min_periods_0(self, raw):
202210
# GH 8080
203211
s = Series([None, None, None])
204212
result = s.expanding(min_periods=0).apply(lambda x: len(x), raw=raw)
205213
expected = Series([1.0, 2.0, 3.0])
206214
tm.assert_series_equal(result, expected)
207215

208-
def _check_expanding(
209-
self, func, static_comp, has_min_periods=True, preserve_nan=True
210-
):
216+
def _check_expanding(self, func, static_comp, preserve_nan=True):
211217

212218
series_result = func(self.series)
213219
assert isinstance(series_result, Series)
@@ -220,6 +226,7 @@ def _check_expanding(
220226
if preserve_nan:
221227
assert result.iloc[self._nan_locs].isna().all()
222228

229+
def _check_expanding_has_min_periods(self, func, static_comp, has_min_periods):
223230
ser = Series(randn(50))
224231

225232
if has_min_periods:
@@ -245,17 +252,9 @@ def _check_expanding(
245252
result = func(ser)
246253
tm.assert_almost_equal(result.iloc[-1], static_comp(ser[:50]))
247254

248-
def test_moment_functions_zero_length(self):
249-
# GH 8056
250-
s = Series(dtype=np.float64)
251-
s_expected = s
252-
df1 = DataFrame()
253-
df1_expected = df1
254-
df2 = DataFrame(columns=["a"])
255-
df2["a"] = df2["a"].astype("float64")
256-
df2_expected = df2
257-
258-
functions = [
255+
@pytest.mark.parametrize(
256+
"f",
257+
[
259258
lambda x: x.expanding().count(),
260259
lambda x: x.expanding(min_periods=5).cov(x, pairwise=False),
261260
lambda x: x.expanding(min_periods=5).corr(x, pairwise=False),
@@ -271,23 +270,35 @@ def test_moment_functions_zero_length(self):
271270
lambda x: x.expanding(min_periods=5).median(),
272271
lambda x: x.expanding(min_periods=5).apply(sum, raw=False),
273272
lambda x: x.expanding(min_periods=5).apply(sum, raw=True),
274-
]
275-
for f in functions:
276-
try:
277-
s_result = f(s)
278-
tm.assert_series_equal(s_result, s_expected)
273+
],
274+
)
275+
def test_moment_functions_zero_length(self, f):
276+
# GH 8056
277+
s = Series(dtype=np.float64)
278+
s_expected = s
279+
df1 = DataFrame()
280+
df1_expected = df1
281+
df2 = DataFrame(columns=["a"])
282+
df2["a"] = df2["a"].astype("float64")
283+
df2_expected = df2
279284

280-
df1_result = f(df1)
281-
tm.assert_frame_equal(df1_result, df1_expected)
285+
s_result = f(s)
286+
tm.assert_series_equal(s_result, s_expected)
282287

283-
df2_result = f(df2)
284-
tm.assert_frame_equal(df2_result, df2_expected)
285-
except (ImportError):
288+
df1_result = f(df1)
289+
tm.assert_frame_equal(df1_result, df1_expected)
286290

287-
# scipy needed for rolling_window
288-
continue
291+
df2_result = f(df2)
292+
tm.assert_frame_equal(df2_result, df2_expected)
289293

290-
def test_moment_functions_zero_length_pairwise(self):
294+
@pytest.mark.parametrize(
295+
"f",
296+
[
297+
lambda x: (x.expanding(min_periods=5).cov(x, pairwise=True)),
298+
lambda x: (x.expanding(min_periods=5).corr(x, pairwise=True)),
299+
],
300+
)
301+
def test_moment_functions_zero_length_pairwise(self, f):
291302

292303
df1 = DataFrame()
293304
df2 = DataFrame(columns=Index(["a"], name="foo"), index=Index([], name="bar"))
@@ -303,16 +314,12 @@ def test_moment_functions_zero_length_pairwise(self):
303314
columns=Index(["a"], name="foo"),
304315
dtype="float64",
305316
)
306-
functions = [
307-
lambda x: (x.expanding(min_periods=5).cov(x, pairwise=True)),
308-
lambda x: (x.expanding(min_periods=5).corr(x, pairwise=True)),
309-
]
310-
for f in functions:
311-
df1_result = f(df1)
312-
tm.assert_frame_equal(df1_result, df1_expected)
313317

314-
df2_result = f(df2)
315-
tm.assert_frame_equal(df2_result, df2_expected)
318+
df1_result = f(df1)
319+
tm.assert_frame_equal(df1_result, df1_expected)
320+
321+
df2_result = f(df2)
322+
tm.assert_frame_equal(df2_result, df2_expected)
316323

317324
@pytest.mark.slow
318325
@pytest.mark.parametrize("min_periods", [0, 1, 2, 3, 4])

0 commit comments

Comments
 (0)