Skip to content

Commit b20bb70

Browse files
mroeschkephofl
authored andcommitted
TST: Use fixtures instead of setup_method (pandas-dev#45842)
1 parent eaee093 commit b20bb70

File tree

5 files changed

+169
-149
lines changed

5 files changed

+169
-149
lines changed

pandas/tests/arrays/sparse/test_array.py

+26-18
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,22 @@
1616
)
1717

1818

19-
class TestSparseArray:
20-
def setup_method(self):
21-
self.arr_data = np.array([np.nan, np.nan, 1, 2, 3, np.nan, 4, 5, np.nan, 6])
22-
self.arr = SparseArray(self.arr_data)
23-
self.zarr = SparseArray([0, 0, 1, 2, 3, 0, 4, 5, 0, 6], fill_value=0)
19+
@pytest.fixture
20+
def arr_data():
21+
return np.array([np.nan, np.nan, 1, 2, 3, np.nan, 4, 5, np.nan, 6])
22+
23+
24+
@pytest.fixture
25+
def arr(arr_data):
26+
return SparseArray(arr_data)
27+
2428

29+
@pytest.fixture
30+
def zarr():
31+
return SparseArray([0, 0, 1, 2, 3, 0, 4, 5, 0, 6], fill_value=0)
32+
33+
34+
class TestSparseArray:
2535
@pytest.mark.parametrize("fill_value", [0, None, np.nan])
2636
def test_shift_fill_value(self, fill_value):
2737
# GH #24128
@@ -79,13 +89,13 @@ def test_set_fill_invalid_non_scalar(self, val):
7989
with pytest.raises(ValueError, match=msg):
8090
arr.fill_value = val
8191

82-
def test_copy(self):
83-
arr2 = self.arr.copy()
84-
assert arr2.sp_values is not self.arr.sp_values
85-
assert arr2.sp_index is self.arr.sp_index
92+
def test_copy(self, arr):
93+
arr2 = arr.copy()
94+
assert arr2.sp_values is not arr.sp_values
95+
assert arr2.sp_index is arr.sp_index
8696

87-
def test_values_asarray(self):
88-
tm.assert_almost_equal(self.arr.to_dense(), self.arr_data)
97+
def test_values_asarray(self, arr_data, arr):
98+
tm.assert_almost_equal(arr.to_dense(), arr_data)
8999

90100
@pytest.mark.parametrize(
91101
"data,shape,dtype",
@@ -121,13 +131,11 @@ def test_dense_repr(self, vals, fill_value):
121131

122132
tm.assert_numpy_array_equal(res2, vals)
123133

124-
def test_pickle(self):
125-
def _check_roundtrip(obj):
126-
unpickled = tm.round_trip_pickle(obj)
127-
tm.assert_sp_array_equal(unpickled, obj)
128-
129-
_check_roundtrip(self.arr)
130-
_check_roundtrip(self.zarr)
134+
@pytest.mark.parametrize("fix", ["arr", "zarr"])
135+
def test_pickle(self, fix, request):
136+
obj = request.getfixturevalue(fix)
137+
unpickled = tm.round_trip_pickle(obj)
138+
tm.assert_sp_array_equal(unpickled, obj)
131139

132140
def test_generator_warnings(self):
133141
sp_arr = SparseArray([1, 2, 3])

pandas/tests/base/test_constructors.py

-3
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,6 @@ class Delegate(PandasDelegate, PandasObject):
5454
def __init__(self, obj):
5555
self.obj = obj
5656

57-
def setup_method(self, method):
58-
pass
59-
6057
def test_invalid_delegation(self):
6158
# these show that in order for the delegation to work
6259
# the _delegate_* methods need to be overridden to not raise

pandas/tests/indexes/test_frozen.py

+45-37
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,22 @@
55
from pandas.core.indexes.frozen import FrozenList
66

77

8-
class TestFrozenList:
8+
@pytest.fixture
9+
def lst():
10+
return [1, 2, 3, 4, 5]
11+
12+
13+
@pytest.fixture
14+
def container(lst):
15+
return FrozenList(lst)
916

10-
unicode_container = FrozenList(["\u05d0", "\u05d1", "c"])
1117

12-
def setup_method(self, _):
13-
self.lst = [1, 2, 3, 4, 5]
14-
self.container = FrozenList(self.lst)
18+
@pytest.fixture
19+
def unicode_container():
20+
return FrozenList(["\u05d0", "\u05d1", "c"])
1521

22+
23+
class TestFrozenList:
1624
def check_mutable_error(self, *args, **kwargs):
1725
# Pass whatever function you normally would to pytest.raises
1826
# (after the Exception kind).
@@ -21,75 +29,75 @@ def check_mutable_error(self, *args, **kwargs):
2129
with pytest.raises(TypeError, match=msg):
2230
mutable_regex(*args, **kwargs)
2331

24-
def test_no_mutable_funcs(self):
32+
def test_no_mutable_funcs(self, container):
2533
def setitem():
26-
self.container[0] = 5
34+
container[0] = 5
2735

2836
self.check_mutable_error(setitem)
2937

3038
def setslice():
31-
self.container[1:2] = 3
39+
container[1:2] = 3
3240

3341
self.check_mutable_error(setslice)
3442

3543
def delitem():
36-
del self.container[0]
44+
del container[0]
3745

3846
self.check_mutable_error(delitem)
3947

4048
def delslice():
41-
del self.container[0:3]
49+
del container[0:3]
4250

4351
self.check_mutable_error(delslice)
4452

4553
mutable_methods = ("extend", "pop", "remove", "insert")
4654

4755
for meth in mutable_methods:
48-
self.check_mutable_error(getattr(self.container, meth))
56+
self.check_mutable_error(getattr(container, meth))
4957

50-
def test_slicing_maintains_type(self):
51-
result = self.container[1:2]
52-
expected = self.lst[1:2]
58+
def test_slicing_maintains_type(self, container, lst):
59+
result = container[1:2]
60+
expected = lst[1:2]
5361
self.check_result(result, expected)
5462

5563
def check_result(self, result, expected):
5664
assert isinstance(result, FrozenList)
5765
assert result == expected
5866

59-
def test_string_methods_dont_fail(self):
60-
repr(self.container)
61-
str(self.container)
62-
bytes(self.container)
67+
def test_string_methods_dont_fail(self, container):
68+
repr(container)
69+
str(container)
70+
bytes(container)
6371

64-
def test_tricky_container(self):
65-
repr(self.unicode_container)
66-
str(self.unicode_container)
72+
def test_tricky_container(self, unicode_container):
73+
repr(unicode_container)
74+
str(unicode_container)
6775

68-
def test_add(self):
69-
result = self.container + (1, 2, 3)
70-
expected = FrozenList(self.lst + [1, 2, 3])
76+
def test_add(self, container, lst):
77+
result = container + (1, 2, 3)
78+
expected = FrozenList(lst + [1, 2, 3])
7179
self.check_result(result, expected)
7280

73-
result = (1, 2, 3) + self.container
74-
expected = FrozenList([1, 2, 3] + self.lst)
81+
result = (1, 2, 3) + container
82+
expected = FrozenList([1, 2, 3] + lst)
7583
self.check_result(result, expected)
7684

77-
def test_iadd(self):
78-
q = r = self.container
85+
def test_iadd(self, container, lst):
86+
q = r = container
7987

8088
q += [5]
81-
self.check_result(q, self.lst + [5])
89+
self.check_result(q, lst + [5])
8290

8391
# Other shouldn't be mutated.
84-
self.check_result(r, self.lst)
92+
self.check_result(r, lst)
8593

86-
def test_union(self):
87-
result = self.container.union((1, 2, 3))
88-
expected = FrozenList(self.lst + [1, 2, 3])
94+
def test_union(self, container, lst):
95+
result = container.union((1, 2, 3))
96+
expected = FrozenList(lst + [1, 2, 3])
8997
self.check_result(result, expected)
9098

91-
def test_difference(self):
92-
result = self.container.difference([2])
99+
def test_difference(self, container):
100+
result = container.difference([2])
93101
expected = FrozenList([1, 3, 4, 5])
94102
self.check_result(result, expected)
95103

@@ -98,8 +106,8 @@ def test_difference_dupe(self):
98106
expected = FrozenList([1, 3])
99107
self.check_result(result, expected)
100108

101-
def test_tricky_container_to_bytes_raises(self):
109+
def test_tricky_container_to_bytes_raises(self, unicode_container):
102110
# GH 26447
103111
msg = "^'str' object cannot be interpreted as an integer$"
104112
with pytest.raises(TypeError, match=msg):
105-
bytes(self.unicode_container)
113+
bytes(unicode_container)

pandas/tests/window/test_groupby.py

+36-34
Original file line numberDiff line numberDiff line change
@@ -40,24 +40,26 @@ def times_frame():
4040
)
4141

4242

43-
class TestRolling:
44-
def setup_method(self):
45-
self.frame = DataFrame({"A": [1] * 20 + [2] * 12 + [3] * 8, "B": np.arange(40)})
43+
@pytest.fixture
44+
def roll_frame():
45+
return DataFrame({"A": [1] * 20 + [2] * 12 + [3] * 8, "B": np.arange(40)})
46+
4647

47-
def test_mutated(self):
48+
class TestRolling:
49+
def test_mutated(self, roll_frame):
4850

4951
msg = r"groupby\(\) got an unexpected keyword argument 'foo'"
5052
with pytest.raises(TypeError, match=msg):
51-
self.frame.groupby("A", foo=1)
53+
roll_frame.groupby("A", foo=1)
5254

53-
g = self.frame.groupby("A")
55+
g = roll_frame.groupby("A")
5456
assert not g.mutated
55-
g = get_groupby(self.frame, by="A", mutated=True)
57+
g = get_groupby(roll_frame, by="A", mutated=True)
5658
assert g.mutated
5759

58-
def test_getitem(self):
59-
g = self.frame.groupby("A")
60-
g_mutated = get_groupby(self.frame, by="A", mutated=True)
60+
def test_getitem(self, roll_frame):
61+
g = roll_frame.groupby("A")
62+
g_mutated = get_groupby(roll_frame, by="A", mutated=True)
6163

6264
expected = g_mutated.B.apply(lambda x: x.rolling(2).mean())
6365

@@ -70,15 +72,15 @@ def test_getitem(self):
7072
result = g.B.rolling(2).mean()
7173
tm.assert_series_equal(result, expected)
7274

73-
result = self.frame.B.groupby(self.frame.A).rolling(2).mean()
75+
result = roll_frame.B.groupby(roll_frame.A).rolling(2).mean()
7476
tm.assert_series_equal(result, expected)
7577

76-
def test_getitem_multiple(self):
78+
def test_getitem_multiple(self, roll_frame):
7779

7880
# GH 13174
79-
g = self.frame.groupby("A")
81+
g = roll_frame.groupby("A")
8082
r = g.rolling(2, min_periods=0)
81-
g_mutated = get_groupby(self.frame, by="A", mutated=True)
83+
g_mutated = get_groupby(roll_frame, by="A", mutated=True)
8284
expected = g_mutated.B.apply(lambda x: x.rolling(2, min_periods=0).count())
8385

8486
result = r.B.count()
@@ -102,38 +104,38 @@ def test_getitem_multiple(self):
102104
"skew",
103105
],
104106
)
105-
def test_rolling(self, f):
106-
g = self.frame.groupby("A")
107+
def test_rolling(self, f, roll_frame):
108+
g = roll_frame.groupby("A")
107109
r = g.rolling(window=4)
108110

109111
result = getattr(r, f)()
110112
expected = g.apply(lambda x: getattr(x.rolling(4), f)())
111113
# groupby.apply doesn't drop the grouped-by column
112114
expected = expected.drop("A", axis=1)
113115
# GH 39732
114-
expected_index = MultiIndex.from_arrays([self.frame["A"], range(40)])
116+
expected_index = MultiIndex.from_arrays([roll_frame["A"], range(40)])
115117
expected.index = expected_index
116118
tm.assert_frame_equal(result, expected)
117119

118120
@pytest.mark.parametrize("f", ["std", "var"])
119-
def test_rolling_ddof(self, f):
120-
g = self.frame.groupby("A")
121+
def test_rolling_ddof(self, f, roll_frame):
122+
g = roll_frame.groupby("A")
121123
r = g.rolling(window=4)
122124

123125
result = getattr(r, f)(ddof=1)
124126
expected = g.apply(lambda x: getattr(x.rolling(4), f)(ddof=1))
125127
# groupby.apply doesn't drop the grouped-by column
126128
expected = expected.drop("A", axis=1)
127129
# GH 39732
128-
expected_index = MultiIndex.from_arrays([self.frame["A"], range(40)])
130+
expected_index = MultiIndex.from_arrays([roll_frame["A"], range(40)])
129131
expected.index = expected_index
130132
tm.assert_frame_equal(result, expected)
131133

132134
@pytest.mark.parametrize(
133135
"interpolation", ["linear", "lower", "higher", "midpoint", "nearest"]
134136
)
135-
def test_rolling_quantile(self, interpolation):
136-
g = self.frame.groupby("A")
137+
def test_rolling_quantile(self, interpolation, roll_frame):
138+
g = roll_frame.groupby("A")
137139
r = g.rolling(window=4)
138140

139141
result = r.quantile(0.4, interpolation=interpolation)
@@ -143,7 +145,7 @@ def test_rolling_quantile(self, interpolation):
143145
# groupby.apply doesn't drop the grouped-by column
144146
expected = expected.drop("A", axis=1)
145147
# GH 39732
146-
expected_index = MultiIndex.from_arrays([self.frame["A"], range(40)])
148+
expected_index = MultiIndex.from_arrays([roll_frame["A"], range(40)])
147149
expected.index = expected_index
148150
tm.assert_frame_equal(result, expected)
149151

@@ -173,14 +175,14 @@ def test_rolling_corr_cov_other_same_size_as_groups(self, f, expected_val):
173175
tm.assert_frame_equal(result, expected)
174176

175177
@pytest.mark.parametrize("f", ["corr", "cov"])
176-
def test_rolling_corr_cov_other_diff_size_as_groups(self, f):
177-
g = self.frame.groupby("A")
178+
def test_rolling_corr_cov_other_diff_size_as_groups(self, f, roll_frame):
179+
g = roll_frame.groupby("A")
178180
r = g.rolling(window=4)
179181

180-
result = getattr(r, f)(self.frame)
182+
result = getattr(r, f)(roll_frame)
181183

182184
def func(x):
183-
return getattr(x.rolling(4), f)(self.frame)
185+
return getattr(x.rolling(4), f)(roll_frame)
184186

185187
expected = g.apply(func)
186188
# GH 39591: The grouped column should be all np.nan
@@ -189,8 +191,8 @@ def func(x):
189191
tm.assert_frame_equal(result, expected)
190192

191193
@pytest.mark.parametrize("f", ["corr", "cov"])
192-
def test_rolling_corr_cov_pairwise(self, f):
193-
g = self.frame.groupby("A")
194+
def test_rolling_corr_cov_pairwise(self, f, roll_frame):
195+
g = roll_frame.groupby("A")
194196
r = g.rolling(window=4)
195197

196198
result = getattr(r.B, f)(pairwise=True)
@@ -237,8 +239,8 @@ def test_rolling_corr_cov_unordered(self, func, expected_values):
237239
)
238240
tm.assert_frame_equal(result, expected)
239241

240-
def test_rolling_apply(self, raw):
241-
g = self.frame.groupby("A")
242+
def test_rolling_apply(self, raw, roll_frame):
243+
g = roll_frame.groupby("A")
242244
r = g.rolling(window=4)
243245

244246
# reduction
@@ -247,7 +249,7 @@ def test_rolling_apply(self, raw):
247249
# groupby.apply doesn't drop the grouped-by column
248250
expected = expected.drop("A", axis=1)
249251
# GH 39732
250-
expected_index = MultiIndex.from_arrays([self.frame["A"], range(40)])
252+
expected_index = MultiIndex.from_arrays([roll_frame["A"], range(40)])
251253
expected.index = expected_index
252254
tm.assert_frame_equal(result, expected)
253255

@@ -778,9 +780,9 @@ def test_groupby_rolling_resulting_multiindex3(self):
778780
)
779781
tm.assert_index_equal(result.index, expected_index, exact="equiv")
780782

781-
def test_groupby_rolling_object_doesnt_affect_groupby_apply(self):
783+
def test_groupby_rolling_object_doesnt_affect_groupby_apply(self, roll_frame):
782784
# GH 39732
783-
g = self.frame.groupby("A")
785+
g = roll_frame.groupby("A")
784786
expected = g.apply(lambda x: x.rolling(4).sum()).index
785787
_ = g.rolling(window=4)
786788
result = g.apply(lambda x: x.rolling(4).sum()).index

0 commit comments

Comments
 (0)