Skip to content

Commit 73a5213

Browse files
mroeschkeMatt Roeschke
and
Matt Roeschke
authored
CLN: More tests/window/* (#37959)
* Move tests to appripriate locations, use less pd. * Consolidate fixtures and use frame_or_series * Rename test * Add TestRolling and TestExpanding to groupby tests Co-authored-by: Matt Roeschke <[email protected]>
1 parent b32febd commit 73a5213

9 files changed

+160
-184
lines changed

pandas/tests/window/conftest.py

+8-24
Original file line numberDiff line numberDiff line change
@@ -275,40 +275,24 @@ def consistency_data(request):
275275
return request.param
276276

277277

278-
def _create_series():
279-
"""Internal function to mock Series."""
280-
arr = np.random.randn(100)
281-
locs = np.arange(20, 40)
282-
arr[locs] = np.NaN
283-
series = Series(arr, index=bdate_range(datetime(2009, 1, 1), periods=100))
284-
return series
285-
286-
287-
def _create_frame():
288-
"""Internal function to mock DataFrame."""
278+
@pytest.fixture
279+
def frame():
280+
"""Make mocked frame as fixture."""
289281
return DataFrame(
290282
np.random.randn(100, 10),
291283
index=bdate_range(datetime(2009, 1, 1), periods=100),
292284
columns=np.arange(10),
293285
)
294286

295287

296-
@pytest.fixture
297-
def frame():
298-
"""Make mocked frame as fixture."""
299-
return _create_frame()
300-
301-
302288
@pytest.fixture
303289
def series():
304290
"""Make mocked series as fixture."""
305-
return _create_series()
306-
307-
308-
@pytest.fixture(params=[_create_series(), _create_frame()])
309-
def which(request):
310-
"""Turn parametrized which as fixture for series and frame"""
311-
return request.param
291+
arr = np.random.randn(100)
292+
locs = np.arange(20, 40)
293+
arr[locs] = np.NaN
294+
series = Series(arr, index=bdate_range(datetime(2009, 1, 1), periods=100))
295+
return series
312296

313297

314298
@pytest.fixture(params=["1 day", timedelta(days=1)])

pandas/tests/window/test_api.py

-26
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import numpy as np
22
import pytest
33

4-
import pandas.util._test_decorators as td
5-
64
import pandas as pd
75
from pandas import DataFrame, Index, Series, Timestamp, concat
86
import pandas._testing as tm
@@ -238,30 +236,6 @@ def test_count_nonnumeric_types():
238236
tm.assert_frame_equal(result, expected)
239237

240238

241-
@td.skip_if_no_scipy
242-
@pytest.mark.filterwarnings("ignore:can't resolve:ImportWarning")
243-
def test_window_with_args():
244-
# make sure that we are aggregating window functions correctly with arg
245-
r = Series(np.random.randn(100)).rolling(
246-
window=10, min_periods=1, win_type="gaussian"
247-
)
248-
expected = concat([r.mean(std=10), r.mean(std=0.01)], axis=1)
249-
expected.columns = ["<lambda>", "<lambda>"]
250-
result = r.aggregate([lambda x: x.mean(std=10), lambda x: x.mean(std=0.01)])
251-
tm.assert_frame_equal(result, expected)
252-
253-
def a(x):
254-
return x.mean(std=10)
255-
256-
def b(x):
257-
return x.mean(std=0.01)
258-
259-
expected = concat([r.mean(std=10), r.mean(std=0.01)], axis=1)
260-
expected.columns = ["a", "b"]
261-
result = r.aggregate([a, b])
262-
tm.assert_frame_equal(result, expected)
263-
264-
265239
def test_preserve_metadata():
266240
# GH 10565
267241
s = Series(np.arange(100), name="foo")

pandas/tests/window/test_apply.py

-11
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
import numpy as np
22
import pytest
33

4-
from pandas.errors import NumbaUtilError
5-
import pandas.util._test_decorators as td
6-
74
from pandas import DataFrame, Index, MultiIndex, Series, Timestamp, date_range
85
import pandas._testing as tm
96

@@ -133,14 +130,6 @@ def test_invalid_raw_numba():
133130
Series(range(1)).rolling(1).apply(lambda x: x, raw=False, engine="numba")
134131

135132

136-
@td.skip_if_no("numba")
137-
def test_invalid_kwargs_nopython():
138-
with pytest.raises(NumbaUtilError, match="numba does not support kwargs with"):
139-
Series(range(1)).rolling(1).apply(
140-
lambda x: x, kwargs={"a": 1}, engine="numba", raw=True
141-
)
142-
143-
144133
@pytest.mark.parametrize("args_kwargs", [[None, {"par": 10}], [(10,), None]])
145134
def test_rolling_apply_args_kwargs(args_kwargs):
146135
# GH 33433

pandas/tests/window/test_ewm.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ def test_doc_string():
1515
df.ewm(com=0.5).mean()
1616

1717

18-
def test_constructor(which):
18+
def test_constructor(frame_or_series):
1919

20-
c = which.ewm
20+
c = frame_or_series(range(5)).ewm
2121

2222
# valid
2323
c(com=0.5)

pandas/tests/window/test_expanding.py

+15-19
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@ def test_doc_string():
1919
@pytest.mark.filterwarnings(
2020
"ignore:The `center` argument on `expanding` will be removed in the future"
2121
)
22-
def test_constructor(which):
22+
def test_constructor(frame_or_series):
2323
# GH 12669
2424

25-
c = which.expanding
25+
c = frame_or_series(range(5)).expanding
2626

2727
# valid
2828
c(min_periods=1)
@@ -34,10 +34,10 @@ def test_constructor(which):
3434
@pytest.mark.filterwarnings(
3535
"ignore:The `center` argument on `expanding` will be removed in the future"
3636
)
37-
def test_constructor_invalid(which, w):
37+
def test_constructor_invalid(frame_or_series, w):
3838
# not valid
3939

40-
c = which.expanding
40+
c = frame_or_series(range(5)).expanding
4141
msg = "min_periods must be an integer"
4242
with pytest.raises(ValueError, match=msg):
4343
c(min_periods=w)
@@ -118,30 +118,27 @@ def test_expanding_axis(axis_frame):
118118
tm.assert_frame_equal(result, expected)
119119

120120

121-
@pytest.mark.parametrize("constructor", [Series, DataFrame])
122-
def test_expanding_count_with_min_periods(constructor):
121+
def test_expanding_count_with_min_periods(frame_or_series):
123122
# GH 26996
124-
result = constructor(range(5)).expanding(min_periods=3).count()
125-
expected = constructor([np.nan, np.nan, 3.0, 4.0, 5.0])
123+
result = frame_or_series(range(5)).expanding(min_periods=3).count()
124+
expected = frame_or_series([np.nan, np.nan, 3.0, 4.0, 5.0])
126125
tm.assert_equal(result, expected)
127126

128127

129-
@pytest.mark.parametrize("constructor", [Series, DataFrame])
130-
def test_expanding_count_default_min_periods_with_null_values(constructor):
128+
def test_expanding_count_default_min_periods_with_null_values(frame_or_series):
131129
# GH 26996
132130
values = [1, 2, 3, np.nan, 4, 5, 6]
133131
expected_counts = [1.0, 2.0, 3.0, 3.0, 4.0, 5.0, 6.0]
134132

135-
result = constructor(values).expanding().count()
136-
expected = constructor(expected_counts)
133+
result = frame_or_series(values).expanding().count()
134+
expected = frame_or_series(expected_counts)
137135
tm.assert_equal(result, expected)
138136

139137

140-
@pytest.mark.parametrize("constructor", [Series, DataFrame])
141-
def test_expanding_count_with_min_periods_exceeding_series_length(constructor):
138+
def test_expanding_count_with_min_periods_exceeding_series_length(frame_or_series):
142139
# GH 25857
143-
result = constructor(range(5)).expanding(min_periods=6).count()
144-
expected = constructor([np.nan, np.nan, np.nan, np.nan, np.nan])
140+
result = frame_or_series(range(5)).expanding(min_periods=6).count()
141+
expected = frame_or_series([np.nan, np.nan, np.nan, np.nan, np.nan])
145142
tm.assert_equal(result, expected)
146143

147144

@@ -246,10 +243,9 @@ def test_center_deprecate_warning():
246243
df.expanding()
247244

248245

249-
@pytest.mark.parametrize("constructor", ["DataFrame", "Series"])
250-
def test_expanding_sem(constructor):
246+
def test_expanding_sem(frame_or_series):
251247
# GH: 26476
252-
obj = getattr(pd, constructor)([0, 1, 2])
248+
obj = frame_or_series([0, 1, 2])
253249
result = obj.expanding().sem()
254250
if isinstance(result, DataFrame):
255251
result = Series(result[0].values)

pandas/tests/window/test_groupby.py

+68-64
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,8 @@
77
from pandas.core.groupby.groupby import get_groupby
88

99

10-
class TestGrouperGrouping:
10+
class TestRolling:
1111
def setup_method(self):
12-
self.series = Series(np.arange(10))
1312
self.frame = DataFrame({"A": [1] * 20 + [2] * 12 + [3] * 8, "B": np.arange(40)})
1413

1514
def test_mutated(self):
@@ -152,68 +151,6 @@ def test_rolling_apply_mutability(self):
152151
result = g.rolling(window=2).sum()
153152
tm.assert_frame_equal(result, expected)
154153

155-
@pytest.mark.parametrize(
156-
"f", ["sum", "mean", "min", "max", "count", "kurt", "skew"]
157-
)
158-
def test_expanding(self, f):
159-
g = self.frame.groupby("A")
160-
r = g.expanding()
161-
162-
result = getattr(r, f)()
163-
expected = g.apply(lambda x: getattr(x.expanding(), f)())
164-
tm.assert_frame_equal(result, expected)
165-
166-
@pytest.mark.parametrize("f", ["std", "var"])
167-
def test_expanding_ddof(self, f):
168-
g = self.frame.groupby("A")
169-
r = g.expanding()
170-
171-
result = getattr(r, f)(ddof=0)
172-
expected = g.apply(lambda x: getattr(x.expanding(), f)(ddof=0))
173-
tm.assert_frame_equal(result, expected)
174-
175-
@pytest.mark.parametrize(
176-
"interpolation", ["linear", "lower", "higher", "midpoint", "nearest"]
177-
)
178-
def test_expanding_quantile(self, interpolation):
179-
g = self.frame.groupby("A")
180-
r = g.expanding()
181-
result = r.quantile(0.4, interpolation=interpolation)
182-
expected = g.apply(
183-
lambda x: x.expanding().quantile(0.4, interpolation=interpolation)
184-
)
185-
tm.assert_frame_equal(result, expected)
186-
187-
@pytest.mark.parametrize("f", ["corr", "cov"])
188-
def test_expanding_corr_cov(self, f):
189-
g = self.frame.groupby("A")
190-
r = g.expanding()
191-
192-
result = getattr(r, f)(self.frame)
193-
194-
def func(x):
195-
return getattr(x.expanding(), f)(self.frame)
196-
197-
expected = g.apply(func)
198-
tm.assert_frame_equal(result, expected)
199-
200-
result = getattr(r.B, f)(pairwise=True)
201-
202-
def func(x):
203-
return getattr(x.B.expanding(), f)(pairwise=True)
204-
205-
expected = g.apply(func)
206-
tm.assert_series_equal(result, expected)
207-
208-
def test_expanding_apply(self, raw):
209-
g = self.frame.groupby("A")
210-
r = g.expanding()
211-
212-
# reduction
213-
result = r.apply(lambda x: x.sum(), raw=raw)
214-
expected = g.apply(lambda x: x.expanding().apply(lambda y: y.sum(), raw=raw))
215-
tm.assert_frame_equal(result, expected)
216-
217154
@pytest.mark.parametrize("expected_value,raw_value", [[1.0, True], [0.0, False]])
218155
def test_groupby_rolling(self, expected_value, raw_value):
219156
# GH 31754
@@ -633,6 +570,73 @@ def test_groupby_rolling_index_level_and_column_label(self):
633570
tm.assert_frame_equal(result, expected)
634571

635572

573+
class TestExpanding:
574+
def setup_method(self):
575+
self.frame = DataFrame({"A": [1] * 20 + [2] * 12 + [3] * 8, "B": np.arange(40)})
576+
577+
@pytest.mark.parametrize(
578+
"f", ["sum", "mean", "min", "max", "count", "kurt", "skew"]
579+
)
580+
def test_expanding(self, f):
581+
g = self.frame.groupby("A")
582+
r = g.expanding()
583+
584+
result = getattr(r, f)()
585+
expected = g.apply(lambda x: getattr(x.expanding(), f)())
586+
tm.assert_frame_equal(result, expected)
587+
588+
@pytest.mark.parametrize("f", ["std", "var"])
589+
def test_expanding_ddof(self, f):
590+
g = self.frame.groupby("A")
591+
r = g.expanding()
592+
593+
result = getattr(r, f)(ddof=0)
594+
expected = g.apply(lambda x: getattr(x.expanding(), f)(ddof=0))
595+
tm.assert_frame_equal(result, expected)
596+
597+
@pytest.mark.parametrize(
598+
"interpolation", ["linear", "lower", "higher", "midpoint", "nearest"]
599+
)
600+
def test_expanding_quantile(self, interpolation):
601+
g = self.frame.groupby("A")
602+
r = g.expanding()
603+
result = r.quantile(0.4, interpolation=interpolation)
604+
expected = g.apply(
605+
lambda x: x.expanding().quantile(0.4, interpolation=interpolation)
606+
)
607+
tm.assert_frame_equal(result, expected)
608+
609+
@pytest.mark.parametrize("f", ["corr", "cov"])
610+
def test_expanding_corr_cov(self, f):
611+
g = self.frame.groupby("A")
612+
r = g.expanding()
613+
614+
result = getattr(r, f)(self.frame)
615+
616+
def func(x):
617+
return getattr(x.expanding(), f)(self.frame)
618+
619+
expected = g.apply(func)
620+
tm.assert_frame_equal(result, expected)
621+
622+
result = getattr(r.B, f)(pairwise=True)
623+
624+
def func(x):
625+
return getattr(x.B.expanding(), f)(pairwise=True)
626+
627+
expected = g.apply(func)
628+
tm.assert_series_equal(result, expected)
629+
630+
def test_expanding_apply(self, raw):
631+
g = self.frame.groupby("A")
632+
r = g.expanding()
633+
634+
# reduction
635+
result = r.apply(lambda x: x.sum(), raw=raw)
636+
expected = g.apply(lambda x: x.expanding().apply(lambda y: y.sum(), raw=raw))
637+
tm.assert_frame_equal(result, expected)
638+
639+
636640
class TestEWM:
637641
@pytest.mark.parametrize(
638642
"method, expected_data",

pandas/tests/window/test_numba.py

+9
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22
import pytest
33

4+
from pandas.errors import NumbaUtilError
45
import pandas.util._test_decorators as td
56

67
from pandas import DataFrame, Series, option_context
@@ -112,3 +113,11 @@ def f(x):
112113
result = s.rolling(2).apply(f, engine=None, raw=True)
113114
expected = s.rolling(2).apply(f, engine="numba", raw=True)
114115
tm.assert_series_equal(expected, result)
116+
117+
118+
@td.skip_if_no("numba", "0.46.0")
119+
def test_invalid_kwargs_nopython():
120+
with pytest.raises(NumbaUtilError, match="numba does not support kwargs with"):
121+
Series(range(1)).rolling(1).apply(
122+
lambda x: x, kwargs={"a": 1}, engine="numba", raw=True
123+
)

0 commit comments

Comments
 (0)