Skip to content

Commit 50fb400

Browse files
mroeschkejreback
authored andcommitted
CLN: Split test_window.py (#27305)
1 parent c0c1c9a commit 50fb400

File tree

6 files changed

+1201
-1169
lines changed

6 files changed

+1201
-1169
lines changed

pandas/tests/window/__init__.py

Whitespace-only changes.

pandas/tests/window/conftest.py

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import pytest
2+
3+
4+
@pytest.fixture(params=[True, False])
5+
def raw(request):
6+
return request.param
7+
8+
9+
@pytest.fixture(
10+
params=[
11+
"triang",
12+
"blackman",
13+
"hamming",
14+
"bartlett",
15+
"bohman",
16+
"blackmanharris",
17+
"nuttall",
18+
"barthann",
19+
]
20+
)
21+
def win_types(request):
22+
return request.param
23+
24+
25+
@pytest.fixture(params=["kaiser", "gaussian", "general_gaussian", "exponential"])
26+
def win_types_special(request):
27+
return request.param
28+
29+
30+
@pytest.fixture(
31+
params=["sum", "mean", "median", "max", "min", "var", "std", "kurt", "skew"]
32+
)
33+
def arithmetic_win_operators(request):
34+
return request.param
35+
36+
37+
@pytest.fixture(params=["right", "left", "both", "neither"])
38+
def closed(request):
39+
return request.param
40+
41+
42+
@pytest.fixture(params=[True, False])
43+
def center(request):
44+
return request.param
45+
46+
47+
@pytest.fixture(params=[None, 1])
48+
def min_periods(request):
49+
return request.param

pandas/tests/window/test_dtypes.py

+228
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
from itertools import product
2+
3+
import numpy as np
4+
import pytest
5+
6+
from pandas import DataFrame, Series
7+
from pandas.core.base import DataError
8+
import pandas.util.testing as tm
9+
10+
# gh-12373 : rolling functions error on float32 data
11+
# make sure rolling functions works for different dtypes
12+
#
13+
# NOTE that these are yielded tests and so _create_data
14+
# is explicitly called.
15+
#
16+
# further note that we are only checking rolling for fully dtype
17+
# compliance (though both expanding and ewm inherit)
18+
19+
20+
class Dtype:
21+
window = 2
22+
23+
funcs = {
24+
"count": lambda v: v.count(),
25+
"max": lambda v: v.max(),
26+
"min": lambda v: v.min(),
27+
"sum": lambda v: v.sum(),
28+
"mean": lambda v: v.mean(),
29+
"std": lambda v: v.std(),
30+
"var": lambda v: v.var(),
31+
"median": lambda v: v.median(),
32+
}
33+
34+
def get_expects(self):
35+
expects = {
36+
"sr1": {
37+
"count": Series([1, 2, 2, 2, 2], dtype="float64"),
38+
"max": Series([np.nan, 1, 2, 3, 4], dtype="float64"),
39+
"min": Series([np.nan, 0, 1, 2, 3], dtype="float64"),
40+
"sum": Series([np.nan, 1, 3, 5, 7], dtype="float64"),
41+
"mean": Series([np.nan, 0.5, 1.5, 2.5, 3.5], dtype="float64"),
42+
"std": Series([np.nan] + [np.sqrt(0.5)] * 4, dtype="float64"),
43+
"var": Series([np.nan, 0.5, 0.5, 0.5, 0.5], dtype="float64"),
44+
"median": Series([np.nan, 0.5, 1.5, 2.5, 3.5], dtype="float64"),
45+
},
46+
"sr2": {
47+
"count": Series([1, 2, 2, 2, 2], dtype="float64"),
48+
"max": Series([np.nan, 10, 8, 6, 4], dtype="float64"),
49+
"min": Series([np.nan, 8, 6, 4, 2], dtype="float64"),
50+
"sum": Series([np.nan, 18, 14, 10, 6], dtype="float64"),
51+
"mean": Series([np.nan, 9, 7, 5, 3], dtype="float64"),
52+
"std": Series([np.nan] + [np.sqrt(2)] * 4, dtype="float64"),
53+
"var": Series([np.nan, 2, 2, 2, 2], dtype="float64"),
54+
"median": Series([np.nan, 9, 7, 5, 3], dtype="float64"),
55+
},
56+
"df": {
57+
"count": DataFrame(
58+
{0: Series([1, 2, 2, 2, 2]), 1: Series([1, 2, 2, 2, 2])},
59+
dtype="float64",
60+
),
61+
"max": DataFrame(
62+
{0: Series([np.nan, 2, 4, 6, 8]), 1: Series([np.nan, 3, 5, 7, 9])},
63+
dtype="float64",
64+
),
65+
"min": DataFrame(
66+
{0: Series([np.nan, 0, 2, 4, 6]), 1: Series([np.nan, 1, 3, 5, 7])},
67+
dtype="float64",
68+
),
69+
"sum": DataFrame(
70+
{
71+
0: Series([np.nan, 2, 6, 10, 14]),
72+
1: Series([np.nan, 4, 8, 12, 16]),
73+
},
74+
dtype="float64",
75+
),
76+
"mean": DataFrame(
77+
{0: Series([np.nan, 1, 3, 5, 7]), 1: Series([np.nan, 2, 4, 6, 8])},
78+
dtype="float64",
79+
),
80+
"std": DataFrame(
81+
{
82+
0: Series([np.nan] + [np.sqrt(2)] * 4),
83+
1: Series([np.nan] + [np.sqrt(2)] * 4),
84+
},
85+
dtype="float64",
86+
),
87+
"var": DataFrame(
88+
{0: Series([np.nan, 2, 2, 2, 2]), 1: Series([np.nan, 2, 2, 2, 2])},
89+
dtype="float64",
90+
),
91+
"median": DataFrame(
92+
{0: Series([np.nan, 1, 3, 5, 7]), 1: Series([np.nan, 2, 4, 6, 8])},
93+
dtype="float64",
94+
),
95+
},
96+
}
97+
return expects
98+
99+
def _create_dtype_data(self, dtype):
100+
sr1 = Series(np.arange(5), dtype=dtype)
101+
sr2 = Series(np.arange(10, 0, -2), dtype=dtype)
102+
df = DataFrame(np.arange(10).reshape((5, 2)), dtype=dtype)
103+
104+
data = {"sr1": sr1, "sr2": sr2, "df": df}
105+
106+
return data
107+
108+
def _create_data(self):
109+
self.data = self._create_dtype_data(self.dtype)
110+
self.expects = self.get_expects()
111+
112+
def test_dtypes(self):
113+
self._create_data()
114+
for f_name, d_name in product(self.funcs.keys(), self.data.keys()):
115+
116+
f = self.funcs[f_name]
117+
d = self.data[d_name]
118+
exp = self.expects[d_name][f_name]
119+
self.check_dtypes(f, f_name, d, d_name, exp)
120+
121+
def check_dtypes(self, f, f_name, d, d_name, exp):
122+
roll = d.rolling(window=self.window)
123+
result = f(roll)
124+
tm.assert_almost_equal(result, exp)
125+
126+
127+
class TestDtype_object(Dtype):
128+
dtype = object
129+
130+
131+
class Dtype_integer(Dtype):
132+
pass
133+
134+
135+
class TestDtype_int8(Dtype_integer):
136+
dtype = np.int8
137+
138+
139+
class TestDtype_int16(Dtype_integer):
140+
dtype = np.int16
141+
142+
143+
class TestDtype_int32(Dtype_integer):
144+
dtype = np.int32
145+
146+
147+
class TestDtype_int64(Dtype_integer):
148+
dtype = np.int64
149+
150+
151+
class Dtype_uinteger(Dtype):
152+
pass
153+
154+
155+
class TestDtype_uint8(Dtype_uinteger):
156+
dtype = np.uint8
157+
158+
159+
class TestDtype_uint16(Dtype_uinteger):
160+
dtype = np.uint16
161+
162+
163+
class TestDtype_uint32(Dtype_uinteger):
164+
dtype = np.uint32
165+
166+
167+
class TestDtype_uint64(Dtype_uinteger):
168+
dtype = np.uint64
169+
170+
171+
class Dtype_float(Dtype):
172+
pass
173+
174+
175+
class TestDtype_float16(Dtype_float):
176+
dtype = np.float16
177+
178+
179+
class TestDtype_float32(Dtype_float):
180+
dtype = np.float32
181+
182+
183+
class TestDtype_float64(Dtype_float):
184+
dtype = np.float64
185+
186+
187+
class TestDtype_category(Dtype):
188+
dtype = "category"
189+
include_df = False
190+
191+
def _create_dtype_data(self, dtype):
192+
sr1 = Series(range(5), dtype=dtype)
193+
sr2 = Series(range(10, 0, -2), dtype=dtype)
194+
195+
data = {"sr1": sr1, "sr2": sr2}
196+
197+
return data
198+
199+
200+
class DatetimeLike(Dtype):
201+
def check_dtypes(self, f, f_name, d, d_name, exp):
202+
203+
roll = d.rolling(window=self.window)
204+
if f_name == "count":
205+
result = f(roll)
206+
tm.assert_almost_equal(result, exp)
207+
208+
else:
209+
with pytest.raises(DataError):
210+
f(roll)
211+
212+
213+
class TestDtype_timedelta(DatetimeLike):
214+
dtype = np.dtype("m8[ns]")
215+
216+
217+
class TestDtype_datetime(DatetimeLike):
218+
dtype = np.dtype("M8[ns]")
219+
220+
221+
class TestDtype_datetime64UTC(DatetimeLike):
222+
dtype = "datetime64[ns, UTC]"
223+
224+
def _create_data(self):
225+
pytest.skip(
226+
"direct creation of extension dtype "
227+
"datetime64[ns, UTC] is not supported ATM"
228+
)

0 commit comments

Comments
 (0)