Skip to content

Commit cc54943

Browse files
authored
REF: window/test_dtypes.py with pytest idioms (#35918)
1 parent 42289d0 commit cc54943

File tree

2 files changed

+118
-228
lines changed

2 files changed

+118
-228
lines changed

pandas/tests/window/conftest.py

+31
Original file line numberDiff line numberDiff line change
@@ -308,3 +308,34 @@ def which(request):
308308
def halflife_with_times(request):
309309
"""Halflife argument for EWM when times is specified."""
310310
return request.param
311+
312+
313+
@pytest.fixture(
314+
params=[
315+
"object",
316+
"category",
317+
"int8",
318+
"int16",
319+
"int32",
320+
"int64",
321+
"uint8",
322+
"uint16",
323+
"uint32",
324+
"uint64",
325+
"float16",
326+
"float32",
327+
"float64",
328+
"m8[ns]",
329+
"M8[ns]",
330+
pytest.param(
331+
"datetime64[ns, UTC]",
332+
marks=pytest.mark.skip(
333+
"direct creation of extension dtype datetime64[ns, UTC] "
334+
"is not supported ATM"
335+
),
336+
),
337+
]
338+
)
339+
def dtypes(request):
340+
"""Dtypes for window tests"""
341+
return request.param

pandas/tests/window/test_dtypes.py

+87-228
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from itertools import product
2-
31
import numpy as np
42
import pytest
53

@@ -10,234 +8,95 @@
108
# gh-12373 : rolling functions error on float32 data
119
# make sure rolling functions works for different dtypes
1210
#
13-
# NOTE that these are yielded tests and so _create_data
14-
# is explicitly called.
15-
#
1611
# further note that we are only checking rolling for fully dtype
1712
# compliance (though both expanding and ewm inherit)
1813

1914

20-
class Dtype:
21-
window = 2
22-
23-
funcs = {
24-
"count": lambda v: v.count(),
25-
"max": lambda v: v.max(),
26-
"min": lambda v: v.min(),
27-
"sum": lambda v: v.sum(),
28-
"mean": lambda v: v.mean(),
29-
"std": lambda v: v.std(),
30-
"var": lambda v: v.var(),
31-
"median": lambda v: v.median(),
32-
}
33-
34-
def get_expects(self):
35-
expects = {
36-
"sr1": {
37-
"count": Series([1, 2, 2, 2, 2], dtype="float64"),
38-
"max": Series([np.nan, 1, 2, 3, 4], dtype="float64"),
39-
"min": Series([np.nan, 0, 1, 2, 3], dtype="float64"),
40-
"sum": Series([np.nan, 1, 3, 5, 7], dtype="float64"),
41-
"mean": Series([np.nan, 0.5, 1.5, 2.5, 3.5], dtype="float64"),
42-
"std": Series([np.nan] + [np.sqrt(0.5)] * 4, dtype="float64"),
43-
"var": Series([np.nan, 0.5, 0.5, 0.5, 0.5], dtype="float64"),
44-
"median": Series([np.nan, 0.5, 1.5, 2.5, 3.5], dtype="float64"),
15+
def get_dtype(dtype, coerce_int=None):
16+
if coerce_int is False and "int" in dtype:
17+
return None
18+
if dtype != "category":
19+
return np.dtype(dtype)
20+
return dtype
21+
22+
23+
@pytest.mark.parametrize(
24+
"method, data, expected_data, coerce_int",
25+
[
26+
("count", np.arange(5), [1, 2, 2, 2, 2], True),
27+
("count", np.arange(10, 0, -2), [1, 2, 2, 2, 2], True),
28+
("count", [0, 1, 2, np.nan, 4], [1, 2, 2, 1, 1], False),
29+
("max", np.arange(5), [np.nan, 1, 2, 3, 4], True),
30+
("max", np.arange(10, 0, -2), [np.nan, 10, 8, 6, 4], True),
31+
("max", [0, 1, 2, np.nan, 4], [np.nan, 1, 2, np.nan, np.nan], False),
32+
("min", np.arange(5), [np.nan, 0, 1, 2, 3], True),
33+
("min", np.arange(10, 0, -2), [np.nan, 8, 6, 4, 2], True),
34+
("min", [0, 1, 2, np.nan, 4], [np.nan, 0, 1, np.nan, np.nan], False),
35+
("sum", np.arange(5), [np.nan, 1, 3, 5, 7], True),
36+
("sum", np.arange(10, 0, -2), [np.nan, 18, 14, 10, 6], True),
37+
("sum", [0, 1, 2, np.nan, 4], [np.nan, 1, 3, np.nan, np.nan], False),
38+
("mean", np.arange(5), [np.nan, 0.5, 1.5, 2.5, 3.5], True),
39+
("mean", np.arange(10, 0, -2), [np.nan, 9, 7, 5, 3], True),
40+
("mean", [0, 1, 2, np.nan, 4], [np.nan, 0.5, 1.5, np.nan, np.nan], False),
41+
("std", np.arange(5), [np.nan] + [np.sqrt(0.5)] * 4, True),
42+
("std", np.arange(10, 0, -2), [np.nan] + [np.sqrt(2)] * 4, True),
43+
(
44+
"std",
45+
[0, 1, 2, np.nan, 4],
46+
[np.nan] + [np.sqrt(0.5)] * 2 + [np.nan] * 2,
47+
False,
48+
),
49+
("var", np.arange(5), [np.nan, 0.5, 0.5, 0.5, 0.5], True),
50+
("var", np.arange(10, 0, -2), [np.nan, 2, 2, 2, 2], True),
51+
("var", [0, 1, 2, np.nan, 4], [np.nan, 0.5, 0.5, np.nan, np.nan], False),
52+
("median", np.arange(5), [np.nan, 0.5, 1.5, 2.5, 3.5], True),
53+
("median", np.arange(10, 0, -2), [np.nan, 9, 7, 5, 3], True),
54+
("median", [0, 1, 2, np.nan, 4], [np.nan, 0.5, 1.5, np.nan, np.nan], False),
55+
],
56+
)
57+
def test_series_dtypes(method, data, expected_data, coerce_int, dtypes):
58+
s = Series(data, dtype=get_dtype(dtypes, coerce_int=coerce_int))
59+
if dtypes in ("m8[ns]", "M8[ns]") and method != "count":
60+
msg = "No numeric types to aggregate"
61+
with pytest.raises(DataError, match=msg):
62+
getattr(s.rolling(2), method)()
63+
else:
64+
result = getattr(s.rolling(2), method)()
65+
expected = Series(expected_data, dtype="float64")
66+
tm.assert_almost_equal(result, expected)
67+
68+
69+
@pytest.mark.parametrize(
70+
"method, expected_data",
71+
[
72+
("count", {0: Series([1, 2, 2, 2, 2]), 1: Series([1, 2, 2, 2, 2])}),
73+
("max", {0: Series([np.nan, 2, 4, 6, 8]), 1: Series([np.nan, 3, 5, 7, 9])}),
74+
("min", {0: Series([np.nan, 0, 2, 4, 6]), 1: Series([np.nan, 1, 3, 5, 7])}),
75+
(
76+
"sum",
77+
{0: Series([np.nan, 2, 6, 10, 14]), 1: Series([np.nan, 4, 8, 12, 16])},
78+
),
79+
("mean", {0: Series([np.nan, 1, 3, 5, 7]), 1: Series([np.nan, 2, 4, 6, 8])}),
80+
(
81+
"std",
82+
{
83+
0: Series([np.nan] + [np.sqrt(2)] * 4),
84+
1: Series([np.nan] + [np.sqrt(2)] * 4),
4585
},
46-
"sr2": {
47-
"count": Series([1, 2, 2, 2, 2], dtype="float64"),
48-
"max": Series([np.nan, 10, 8, 6, 4], dtype="float64"),
49-
"min": Series([np.nan, 8, 6, 4, 2], dtype="float64"),
50-
"sum": Series([np.nan, 18, 14, 10, 6], dtype="float64"),
51-
"mean": Series([np.nan, 9, 7, 5, 3], dtype="float64"),
52-
"std": Series([np.nan] + [np.sqrt(2)] * 4, dtype="float64"),
53-
"var": Series([np.nan, 2, 2, 2, 2], dtype="float64"),
54-
"median": Series([np.nan, 9, 7, 5, 3], dtype="float64"),
55-
},
56-
"sr3": {
57-
"count": Series([1, 2, 2, 1, 1], dtype="float64"),
58-
"max": Series([np.nan, 1, 2, np.nan, np.nan], dtype="float64"),
59-
"min": Series([np.nan, 0, 1, np.nan, np.nan], dtype="float64"),
60-
"sum": Series([np.nan, 1, 3, np.nan, np.nan], dtype="float64"),
61-
"mean": Series([np.nan, 0.5, 1.5, np.nan, np.nan], dtype="float64"),
62-
"std": Series(
63-
[np.nan] + [np.sqrt(0.5)] * 2 + [np.nan] * 2, dtype="float64"
64-
),
65-
"var": Series([np.nan, 0.5, 0.5, np.nan, np.nan], dtype="float64"),
66-
"median": Series([np.nan, 0.5, 1.5, np.nan, np.nan], dtype="float64"),
67-
},
68-
"df": {
69-
"count": DataFrame(
70-
{0: Series([1, 2, 2, 2, 2]), 1: Series([1, 2, 2, 2, 2])},
71-
dtype="float64",
72-
),
73-
"max": DataFrame(
74-
{0: Series([np.nan, 2, 4, 6, 8]), 1: Series([np.nan, 3, 5, 7, 9])},
75-
dtype="float64",
76-
),
77-
"min": DataFrame(
78-
{0: Series([np.nan, 0, 2, 4, 6]), 1: Series([np.nan, 1, 3, 5, 7])},
79-
dtype="float64",
80-
),
81-
"sum": DataFrame(
82-
{
83-
0: Series([np.nan, 2, 6, 10, 14]),
84-
1: Series([np.nan, 4, 8, 12, 16]),
85-
},
86-
dtype="float64",
87-
),
88-
"mean": DataFrame(
89-
{0: Series([np.nan, 1, 3, 5, 7]), 1: Series([np.nan, 2, 4, 6, 8])},
90-
dtype="float64",
91-
),
92-
"std": DataFrame(
93-
{
94-
0: Series([np.nan] + [np.sqrt(2)] * 4),
95-
1: Series([np.nan] + [np.sqrt(2)] * 4),
96-
},
97-
dtype="float64",
98-
),
99-
"var": DataFrame(
100-
{0: Series([np.nan, 2, 2, 2, 2]), 1: Series([np.nan, 2, 2, 2, 2])},
101-
dtype="float64",
102-
),
103-
"median": DataFrame(
104-
{0: Series([np.nan, 1, 3, 5, 7]), 1: Series([np.nan, 2, 4, 6, 8])},
105-
dtype="float64",
106-
),
107-
},
108-
}
109-
return expects
110-
111-
def _create_dtype_data(self, dtype):
112-
sr1 = Series(np.arange(5), dtype=dtype)
113-
sr2 = Series(np.arange(10, 0, -2), dtype=dtype)
114-
sr3 = sr1.copy()
115-
sr3[3] = np.NaN
116-
df = DataFrame(np.arange(10).reshape((5, 2)), dtype=dtype)
117-
118-
data = {"sr1": sr1, "sr2": sr2, "sr3": sr3, "df": df}
119-
120-
return data
121-
122-
def _create_data(self):
123-
self.data = self._create_dtype_data(self.dtype)
124-
self.expects = self.get_expects()
125-
126-
def test_dtypes(self):
127-
self._create_data()
128-
for f_name, d_name in product(self.funcs.keys(), self.data.keys()):
129-
130-
f = self.funcs[f_name]
131-
d = self.data[d_name]
132-
exp = self.expects[d_name][f_name]
133-
self.check_dtypes(f, f_name, d, d_name, exp)
134-
135-
def check_dtypes(self, f, f_name, d, d_name, exp):
136-
roll = d.rolling(window=self.window)
137-
result = f(roll)
138-
tm.assert_almost_equal(result, exp)
139-
140-
141-
class TestDtype_object(Dtype):
142-
dtype = object
143-
144-
145-
class Dtype_integer(Dtype):
146-
pass
147-
148-
149-
class TestDtype_int8(Dtype_integer):
150-
dtype = np.int8
151-
152-
153-
class TestDtype_int16(Dtype_integer):
154-
dtype = np.int16
155-
156-
157-
class TestDtype_int32(Dtype_integer):
158-
dtype = np.int32
159-
160-
161-
class TestDtype_int64(Dtype_integer):
162-
dtype = np.int64
163-
164-
165-
class Dtype_uinteger(Dtype):
166-
pass
167-
168-
169-
class TestDtype_uint8(Dtype_uinteger):
170-
dtype = np.uint8
171-
172-
173-
class TestDtype_uint16(Dtype_uinteger):
174-
dtype = np.uint16
175-
176-
177-
class TestDtype_uint32(Dtype_uinteger):
178-
dtype = np.uint32
179-
180-
181-
class TestDtype_uint64(Dtype_uinteger):
182-
dtype = np.uint64
183-
184-
185-
class Dtype_float(Dtype):
186-
pass
187-
188-
189-
class TestDtype_float16(Dtype_float):
190-
dtype = np.float16
191-
192-
193-
class TestDtype_float32(Dtype_float):
194-
dtype = np.float32
195-
196-
197-
class TestDtype_float64(Dtype_float):
198-
dtype = np.float64
199-
200-
201-
class TestDtype_category(Dtype):
202-
dtype = "category"
203-
include_df = False
204-
205-
def _create_dtype_data(self, dtype):
206-
sr1 = Series(range(5), dtype=dtype)
207-
sr2 = Series(range(10, 0, -2), dtype=dtype)
208-
209-
data = {"sr1": sr1, "sr2": sr2}
210-
211-
return data
212-
213-
214-
class DatetimeLike(Dtype):
215-
def check_dtypes(self, f, f_name, d, d_name, exp):
216-
217-
roll = d.rolling(window=self.window)
218-
if f_name == "count":
219-
result = f(roll)
220-
tm.assert_almost_equal(result, exp)
221-
222-
else:
223-
msg = "No numeric types to aggregate"
224-
with pytest.raises(DataError, match=msg):
225-
f(roll)
226-
227-
228-
class TestDtype_timedelta(DatetimeLike):
229-
dtype = np.dtype("m8[ns]")
230-
231-
232-
class TestDtype_datetime(DatetimeLike):
233-
dtype = np.dtype("M8[ns]")
234-
235-
236-
class TestDtype_datetime64UTC(DatetimeLike):
237-
dtype = "datetime64[ns, UTC]"
238-
239-
def _create_data(self):
240-
pytest.skip(
241-
"direct creation of extension dtype "
242-
"datetime64[ns, UTC] is not supported ATM"
243-
)
86+
),
87+
("var", {0: Series([np.nan, 2, 2, 2, 2]), 1: Series([np.nan, 2, 2, 2, 2])}),
88+
("median", {0: Series([np.nan, 1, 3, 5, 7]), 1: Series([np.nan, 2, 4, 6, 8])}),
89+
],
90+
)
91+
def test_dataframe_dtypes(method, expected_data, dtypes):
92+
if dtypes == "category":
93+
pytest.skip("Category dataframe testing not implemented.")
94+
df = DataFrame(np.arange(10).reshape((5, 2)), dtype=get_dtype(dtypes))
95+
if dtypes in ("m8[ns]", "M8[ns]") and method != "count":
96+
msg = "No numeric types to aggregate"
97+
with pytest.raises(DataError, match=msg):
98+
getattr(df.rolling(2), method)()
99+
else:
100+
result = getattr(df.rolling(2), method)()
101+
expected = DataFrame(expected_data, dtype="float64")
102+
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)