Skip to content

Commit 40e3c7b

Browse files
ihsansecerjreback
authored andcommitted
CLN: Split test_window.py further (#27348)
1 parent 423ca86 commit 40e3c7b

File tree

8 files changed

+3643
-3579
lines changed

8 files changed

+3643
-3579
lines changed

pandas/tests/window/common.py

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from datetime import datetime
2+
3+
import numpy as np
4+
from numpy.random import randn
5+
6+
from pandas import DataFrame, Series, bdate_range
7+
8+
N, K = 100, 10
9+
10+
11+
class Base:
12+
13+
_nan_locs = np.arange(20, 40)
14+
_inf_locs = np.array([])
15+
16+
def _create_data(self):
17+
arr = randn(N)
18+
arr[self._nan_locs] = np.NaN
19+
20+
self.arr = arr
21+
self.rng = bdate_range(datetime(2009, 1, 1), periods=N)
22+
self.series = Series(arr.copy(), index=self.rng)
23+
self.frame = DataFrame(randn(N, K), index=self.rng, columns=np.arange(K))

pandas/tests/window/test_api.py

+367
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,367 @@
1+
from collections import OrderedDict
2+
import warnings
3+
from warnings import catch_warnings
4+
5+
import numpy as np
6+
import pytest
7+
8+
import pandas.util._test_decorators as td
9+
10+
import pandas as pd
11+
from pandas import DataFrame, Index, Series, Timestamp, concat
12+
from pandas.core.base import SpecificationError
13+
from pandas.tests.window.common import Base
14+
import pandas.util.testing as tm
15+
16+
17+
class TestApi(Base):
18+
def setup_method(self, method):
19+
self._create_data()
20+
21+
def test_getitem(self):
22+
23+
r = self.frame.rolling(window=5)
24+
tm.assert_index_equal(r._selected_obj.columns, self.frame.columns)
25+
26+
r = self.frame.rolling(window=5)[1]
27+
assert r._selected_obj.name == self.frame.columns[1]
28+
29+
# technically this is allowed
30+
r = self.frame.rolling(window=5)[1, 3]
31+
tm.assert_index_equal(r._selected_obj.columns, self.frame.columns[[1, 3]])
32+
33+
r = self.frame.rolling(window=5)[[1, 3]]
34+
tm.assert_index_equal(r._selected_obj.columns, self.frame.columns[[1, 3]])
35+
36+
def test_select_bad_cols(self):
37+
df = DataFrame([[1, 2]], columns=["A", "B"])
38+
g = df.rolling(window=5)
39+
with pytest.raises(KeyError, match="Columns not found: 'C'"):
40+
g[["C"]]
41+
with pytest.raises(KeyError, match="^[^A]+$"):
42+
# A should not be referenced as a bad column...
43+
# will have to rethink regex if you change message!
44+
g[["A", "C"]]
45+
46+
def test_attribute_access(self):
47+
48+
df = DataFrame([[1, 2]], columns=["A", "B"])
49+
r = df.rolling(window=5)
50+
tm.assert_series_equal(r.A.sum(), r["A"].sum())
51+
msg = "'Rolling' object has no attribute 'F'"
52+
with pytest.raises(AttributeError, match=msg):
53+
r.F
54+
55+
def tests_skip_nuisance(self):
56+
57+
df = DataFrame({"A": range(5), "B": range(5, 10), "C": "foo"})
58+
r = df.rolling(window=3)
59+
result = r[["A", "B"]].sum()
60+
expected = DataFrame(
61+
{"A": [np.nan, np.nan, 3, 6, 9], "B": [np.nan, np.nan, 18, 21, 24]},
62+
columns=list("AB"),
63+
)
64+
tm.assert_frame_equal(result, expected)
65+
66+
def test_skip_sum_object_raises(self):
67+
df = DataFrame({"A": range(5), "B": range(5, 10), "C": "foo"})
68+
r = df.rolling(window=3)
69+
result = r.sum()
70+
expected = DataFrame(
71+
{"A": [np.nan, np.nan, 3, 6, 9], "B": [np.nan, np.nan, 18, 21, 24]},
72+
columns=list("AB"),
73+
)
74+
tm.assert_frame_equal(result, expected)
75+
76+
def test_agg(self):
77+
df = DataFrame({"A": range(5), "B": range(0, 10, 2)})
78+
79+
r = df.rolling(window=3)
80+
a_mean = r["A"].mean()
81+
a_std = r["A"].std()
82+
a_sum = r["A"].sum()
83+
b_mean = r["B"].mean()
84+
b_std = r["B"].std()
85+
b_sum = r["B"].sum()
86+
87+
result = r.aggregate([np.mean, np.std])
88+
expected = concat([a_mean, a_std, b_mean, b_std], axis=1)
89+
expected.columns = pd.MultiIndex.from_product([["A", "B"], ["mean", "std"]])
90+
tm.assert_frame_equal(result, expected)
91+
92+
result = r.aggregate({"A": np.mean, "B": np.std})
93+
94+
expected = concat([a_mean, b_std], axis=1)
95+
tm.assert_frame_equal(result, expected, check_like=True)
96+
97+
result = r.aggregate({"A": ["mean", "std"]})
98+
expected = concat([a_mean, a_std], axis=1)
99+
expected.columns = pd.MultiIndex.from_tuples([("A", "mean"), ("A", "std")])
100+
tm.assert_frame_equal(result, expected)
101+
102+
result = r["A"].aggregate(["mean", "sum"])
103+
expected = concat([a_mean, a_sum], axis=1)
104+
expected.columns = ["mean", "sum"]
105+
tm.assert_frame_equal(result, expected)
106+
107+
with catch_warnings(record=True):
108+
# using a dict with renaming
109+
warnings.simplefilter("ignore", FutureWarning)
110+
result = r.aggregate({"A": {"mean": "mean", "sum": "sum"}})
111+
expected = concat([a_mean, a_sum], axis=1)
112+
expected.columns = pd.MultiIndex.from_tuples([("A", "mean"), ("A", "sum")])
113+
tm.assert_frame_equal(result, expected, check_like=True)
114+
115+
with catch_warnings(record=True):
116+
warnings.simplefilter("ignore", FutureWarning)
117+
result = r.aggregate(
118+
{
119+
"A": {"mean": "mean", "sum": "sum"},
120+
"B": {"mean2": "mean", "sum2": "sum"},
121+
}
122+
)
123+
expected = concat([a_mean, a_sum, b_mean, b_sum], axis=1)
124+
exp_cols = [("A", "mean"), ("A", "sum"), ("B", "mean2"), ("B", "sum2")]
125+
expected.columns = pd.MultiIndex.from_tuples(exp_cols)
126+
tm.assert_frame_equal(result, expected, check_like=True)
127+
128+
result = r.aggregate({"A": ["mean", "std"], "B": ["mean", "std"]})
129+
expected = concat([a_mean, a_std, b_mean, b_std], axis=1)
130+
131+
exp_cols = [("A", "mean"), ("A", "std"), ("B", "mean"), ("B", "std")]
132+
expected.columns = pd.MultiIndex.from_tuples(exp_cols)
133+
tm.assert_frame_equal(result, expected, check_like=True)
134+
135+
def test_agg_apply(self, raw):
136+
137+
# passed lambda
138+
df = DataFrame({"A": range(5), "B": range(0, 10, 2)})
139+
140+
r = df.rolling(window=3)
141+
a_sum = r["A"].sum()
142+
143+
result = r.agg({"A": np.sum, "B": lambda x: np.std(x, ddof=1)})
144+
rcustom = r["B"].apply(lambda x: np.std(x, ddof=1), raw=raw)
145+
expected = concat([a_sum, rcustom], axis=1)
146+
tm.assert_frame_equal(result, expected, check_like=True)
147+
148+
def test_agg_consistency(self):
149+
150+
df = DataFrame({"A": range(5), "B": range(0, 10, 2)})
151+
r = df.rolling(window=3)
152+
153+
result = r.agg([np.sum, np.mean]).columns
154+
expected = pd.MultiIndex.from_product([list("AB"), ["sum", "mean"]])
155+
tm.assert_index_equal(result, expected)
156+
157+
result = r["A"].agg([np.sum, np.mean]).columns
158+
expected = Index(["sum", "mean"])
159+
tm.assert_index_equal(result, expected)
160+
161+
result = r.agg({"A": [np.sum, np.mean]}).columns
162+
expected = pd.MultiIndex.from_tuples([("A", "sum"), ("A", "mean")])
163+
tm.assert_index_equal(result, expected)
164+
165+
def test_agg_nested_dicts(self):
166+
167+
# API change for disallowing these types of nested dicts
168+
df = DataFrame({"A": range(5), "B": range(0, 10, 2)})
169+
r = df.rolling(window=3)
170+
171+
msg = r"cannot perform renaming for (r1|r2) with a nested dictionary"
172+
with pytest.raises(SpecificationError, match=msg):
173+
r.aggregate({"r1": {"A": ["mean", "sum"]}, "r2": {"B": ["mean", "sum"]}})
174+
175+
expected = concat(
176+
[r["A"].mean(), r["A"].std(), r["B"].mean(), r["B"].std()], axis=1
177+
)
178+
expected.columns = pd.MultiIndex.from_tuples(
179+
[("ra", "mean"), ("ra", "std"), ("rb", "mean"), ("rb", "std")]
180+
)
181+
with catch_warnings(record=True):
182+
warnings.simplefilter("ignore", FutureWarning)
183+
result = r[["A", "B"]].agg(
184+
{"A": {"ra": ["mean", "std"]}, "B": {"rb": ["mean", "std"]}}
185+
)
186+
tm.assert_frame_equal(result, expected, check_like=True)
187+
188+
with catch_warnings(record=True):
189+
warnings.simplefilter("ignore", FutureWarning)
190+
result = r.agg({"A": {"ra": ["mean", "std"]}, "B": {"rb": ["mean", "std"]}})
191+
expected.columns = pd.MultiIndex.from_tuples(
192+
[
193+
("A", "ra", "mean"),
194+
("A", "ra", "std"),
195+
("B", "rb", "mean"),
196+
("B", "rb", "std"),
197+
]
198+
)
199+
tm.assert_frame_equal(result, expected, check_like=True)
200+
201+
def test_count_nonnumeric_types(self):
202+
# GH12541
203+
cols = [
204+
"int",
205+
"float",
206+
"string",
207+
"datetime",
208+
"timedelta",
209+
"periods",
210+
"fl_inf",
211+
"fl_nan",
212+
"str_nan",
213+
"dt_nat",
214+
"periods_nat",
215+
]
216+
217+
df = DataFrame(
218+
{
219+
"int": [1, 2, 3],
220+
"float": [4.0, 5.0, 6.0],
221+
"string": list("abc"),
222+
"datetime": pd.date_range("20170101", periods=3),
223+
"timedelta": pd.timedelta_range("1 s", periods=3, freq="s"),
224+
"periods": [
225+
pd.Period("2012-01"),
226+
pd.Period("2012-02"),
227+
pd.Period("2012-03"),
228+
],
229+
"fl_inf": [1.0, 2.0, np.Inf],
230+
"fl_nan": [1.0, 2.0, np.NaN],
231+
"str_nan": ["aa", "bb", np.NaN],
232+
"dt_nat": [
233+
Timestamp("20170101"),
234+
Timestamp("20170203"),
235+
Timestamp(None),
236+
],
237+
"periods_nat": [
238+
pd.Period("2012-01"),
239+
pd.Period("2012-02"),
240+
pd.Period(None),
241+
],
242+
},
243+
columns=cols,
244+
)
245+
246+
expected = DataFrame(
247+
{
248+
"int": [1.0, 2.0, 2.0],
249+
"float": [1.0, 2.0, 2.0],
250+
"string": [1.0, 2.0, 2.0],
251+
"datetime": [1.0, 2.0, 2.0],
252+
"timedelta": [1.0, 2.0, 2.0],
253+
"periods": [1.0, 2.0, 2.0],
254+
"fl_inf": [1.0, 2.0, 2.0],
255+
"fl_nan": [1.0, 2.0, 1.0],
256+
"str_nan": [1.0, 2.0, 1.0],
257+
"dt_nat": [1.0, 2.0, 1.0],
258+
"periods_nat": [1.0, 2.0, 1.0],
259+
},
260+
columns=cols,
261+
)
262+
263+
result = df.rolling(window=2).count()
264+
tm.assert_frame_equal(result, expected)
265+
266+
result = df.rolling(1).count()
267+
expected = df.notna().astype(float)
268+
tm.assert_frame_equal(result, expected)
269+
270+
@td.skip_if_no_scipy
271+
@pytest.mark.filterwarnings("ignore:can't resolve:ImportWarning")
272+
def test_window_with_args(self):
273+
# make sure that we are aggregating window functions correctly with arg
274+
r = Series(np.random.randn(100)).rolling(
275+
window=10, min_periods=1, win_type="gaussian"
276+
)
277+
expected = concat([r.mean(std=10), r.mean(std=0.01)], axis=1)
278+
expected.columns = ["<lambda>", "<lambda>"]
279+
result = r.aggregate([lambda x: x.mean(std=10), lambda x: x.mean(std=0.01)])
280+
tm.assert_frame_equal(result, expected)
281+
282+
def a(x):
283+
return x.mean(std=10)
284+
285+
def b(x):
286+
return x.mean(std=0.01)
287+
288+
expected = concat([r.mean(std=10), r.mean(std=0.01)], axis=1)
289+
expected.columns = ["a", "b"]
290+
result = r.aggregate([a, b])
291+
tm.assert_frame_equal(result, expected)
292+
293+
def test_preserve_metadata(self):
294+
# GH 10565
295+
s = Series(np.arange(100), name="foo")
296+
297+
s2 = s.rolling(30).sum()
298+
s3 = s.rolling(20).sum()
299+
assert s2.name == "foo"
300+
assert s3.name == "foo"
301+
302+
@pytest.mark.parametrize(
303+
"func,window_size,expected_vals",
304+
[
305+
(
306+
"rolling",
307+
2,
308+
[
309+
[np.nan, np.nan, np.nan, np.nan],
310+
[15.0, 20.0, 25.0, 20.0],
311+
[25.0, 30.0, 35.0, 30.0],
312+
[np.nan, np.nan, np.nan, np.nan],
313+
[20.0, 30.0, 35.0, 30.0],
314+
[35.0, 40.0, 60.0, 40.0],
315+
[60.0, 80.0, 85.0, 80],
316+
],
317+
),
318+
(
319+
"expanding",
320+
None,
321+
[
322+
[10.0, 10.0, 20.0, 20.0],
323+
[15.0, 20.0, 25.0, 20.0],
324+
[20.0, 30.0, 30.0, 20.0],
325+
[10.0, 10.0, 30.0, 30.0],
326+
[20.0, 30.0, 35.0, 30.0],
327+
[26.666667, 40.0, 50.0, 30.0],
328+
[40.0, 80.0, 60.0, 30.0],
329+
],
330+
),
331+
],
332+
)
333+
def test_multiple_agg_funcs(self, func, window_size, expected_vals):
334+
# GH 15072
335+
df = pd.DataFrame(
336+
[
337+
["A", 10, 20],
338+
["A", 20, 30],
339+
["A", 30, 40],
340+
["B", 10, 30],
341+
["B", 30, 40],
342+
["B", 40, 80],
343+
["B", 80, 90],
344+
],
345+
columns=["stock", "low", "high"],
346+
)
347+
348+
f = getattr(df.groupby("stock"), func)
349+
if window_size:
350+
window = f(window_size)
351+
else:
352+
window = f()
353+
354+
index = pd.MultiIndex.from_tuples(
355+
[("A", 0), ("A", 1), ("A", 2), ("B", 3), ("B", 4), ("B", 5), ("B", 6)],
356+
names=["stock", None],
357+
)
358+
columns = pd.MultiIndex.from_tuples(
359+
[("low", "mean"), ("low", "max"), ("high", "mean"), ("high", "min")]
360+
)
361+
expected = pd.DataFrame(expected_vals, index=index, columns=columns)
362+
363+
result = window.agg(
364+
OrderedDict((("low", ["mean", "max"]), ("high", ["mean", "min"])))
365+
)
366+
367+
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)