Skip to content

Commit e25aa9d

Browse files
authored
TST/CLN: Use more frame_or_series fixture (pandas-dev#48926)
* TST/CLN: Use more frame_or_series fixture * Revert for base ext tests
1 parent ff9a1dc commit e25aa9d

File tree

12 files changed

+50
-57
lines changed

12 files changed

+50
-57
lines changed

pandas/tests/apply/test_invalid_arg.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -90,12 +90,11 @@ def test_map_datetimetz_na_action():
9090
s.map(lambda x: x, na_action="ignore")
9191

9292

93-
@pytest.mark.parametrize("box", [DataFrame, Series])
9493
@pytest.mark.parametrize("method", ["apply", "agg", "transform"])
9594
@pytest.mark.parametrize("func", [{"A": {"B": "sum"}}, {"A": {"B": ["sum"]}}])
96-
def test_nested_renamer(box, method, func):
95+
def test_nested_renamer(frame_or_series, method, func):
9796
# GH 35964
98-
obj = box({"A": [1]})
97+
obj = frame_or_series({"A": [1]})
9998
match = "nested renamer is not supported"
10099
with pytest.raises(SpecificationError, match=match):
101100
getattr(obj, method)(func)

pandas/tests/extension/base/ops.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -179,13 +179,13 @@ def test_direct_arith_with_ndframe_returns_not_implemented(self, data, box):
179179
result = data.__eq__(other)
180180
assert result is NotImplemented
181181
else:
182-
raise pytest.skip(f"{type(data).__name__} does not implement __eq__")
182+
pytest.skip(f"{type(data).__name__} does not implement __eq__")
183183

184184
if hasattr(data, "__ne__"):
185185
result = data.__ne__(other)
186186
assert result is NotImplemented
187187
else:
188-
raise pytest.skip(f"{type(data).__name__} does not implement __ne__")
188+
pytest.skip(f"{type(data).__name__} does not implement __ne__")
189189

190190

191191
class BaseUnaryOpsTests(BaseOpsUtil):

pandas/tests/extension/test_period.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -142,11 +142,12 @@ def test_add_series_with_extension_array(self, data):
142142
with pytest.raises(TypeError, match=msg):
143143
s + data
144144

145-
@pytest.mark.parametrize("box", [pd.Series, pd.DataFrame])
146-
def test_direct_arith_with_ndframe_returns_not_implemented(self, data, box):
145+
def test_direct_arith_with_ndframe_returns_not_implemented(
146+
self, data, frame_or_series
147+
):
147148
# Override to use __sub__ instead of __add__
148149
other = pd.Series(data)
149-
if box is pd.DataFrame:
150+
if frame_or_series is pd.DataFrame:
150151
other = other.to_frame()
151152

152153
result = data.__sub__(other)

pandas/tests/frame/indexing/test_xs.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -324,8 +324,7 @@ def test_xs_loc_equality(self, multiindex_dataframe_random_data):
324324
expected = df.loc[("bar", "two")]
325325
tm.assert_series_equal(result, expected)
326326

327-
@pytest.mark.parametrize("klass", [DataFrame, Series])
328-
def test_xs_IndexSlice_argument_not_implemented(self, klass):
327+
def test_xs_IndexSlice_argument_not_implemented(self, frame_or_series):
329328
# GH#35301
330329

331330
index = MultiIndex(
@@ -334,7 +333,7 @@ def test_xs_IndexSlice_argument_not_implemented(self, klass):
334333
)
335334

336335
obj = DataFrame(np.random.randn(6, 4), index=index)
337-
if klass is Series:
336+
if frame_or_series is Series:
338337
obj = obj[0]
339338

340339
expected = obj.iloc[-2:].droplevel(0)
@@ -345,10 +344,9 @@ def test_xs_IndexSlice_argument_not_implemented(self, klass):
345344
result = obj.loc[IndexSlice[("foo", "qux", 0), :]]
346345
tm.assert_equal(result, expected)
347346

348-
@pytest.mark.parametrize("klass", [DataFrame, Series])
349-
def test_xs_levels_raises(self, klass):
347+
def test_xs_levels_raises(self, frame_or_series):
350348
obj = DataFrame({"A": [1, 2, 3]})
351-
if klass is Series:
349+
if frame_or_series is Series:
352350
obj = obj["A"]
353351

354352
msg = "Index must be a MultiIndex"

pandas/tests/frame/methods/test_drop.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -422,17 +422,16 @@ def test_drop_level_nonunique_datetime(self):
422422
expected = df.loc[idx != 4]
423423
tm.assert_frame_equal(result, expected)
424424

425-
@pytest.mark.parametrize("box", [Series, DataFrame])
426-
def test_drop_tz_aware_timestamp_across_dst(self, box):
425+
def test_drop_tz_aware_timestamp_across_dst(self, frame_or_series):
427426
# GH#21761
428427
start = Timestamp("2017-10-29", tz="Europe/Berlin")
429428
end = Timestamp("2017-10-29 04:00:00", tz="Europe/Berlin")
430429
index = pd.date_range(start, end, freq="15min")
431-
data = box(data=[1] * len(index), index=index)
430+
data = frame_or_series(data=[1] * len(index), index=index)
432431
result = data.drop(start)
433432
expected_start = Timestamp("2017-10-29 00:15:00", tz="Europe/Berlin")
434433
expected_idx = pd.date_range(expected_start, end, freq="15min")
435-
expected = box(data=[1] * len(expected_idx), index=expected_idx)
434+
expected = frame_or_series(data=[1] * len(expected_idx), index=expected_idx)
436435
tm.assert_equal(result, expected)
437436

438437
def test_drop_preserve_names(self):

pandas/tests/frame/methods/test_pct_change.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,14 @@ class TestDataFramePctChange:
2222
(-1, "bfill", 1, [np.nan, 0, -0.5, -0.5, -0.6, np.nan, np.nan, np.nan]),
2323
],
2424
)
25-
@pytest.mark.parametrize("klass", [DataFrame, Series])
26-
def test_pct_change_with_nas(self, periods, fill_method, limit, exp, klass):
25+
def test_pct_change_with_nas(
26+
self, periods, fill_method, limit, exp, frame_or_series
27+
):
2728
vals = [np.nan, np.nan, 1, 2, 4, 10, np.nan, np.nan]
28-
obj = klass(vals)
29+
obj = frame_or_series(vals)
2930

3031
res = obj.pct_change(periods=periods, fill_method=fill_method, limit=limit)
31-
tm.assert_equal(res, klass(exp))
32+
tm.assert_equal(res, frame_or_series(exp))
3233

3334
def test_pct_change_numeric(self):
3435
# GH#11150

pandas/tests/frame/methods/test_rename.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
DataFrame,
1111
Index,
1212
MultiIndex,
13-
Series,
1413
merge,
1514
)
1615
import pandas._testing as tm
@@ -32,9 +31,8 @@ def test_rename_signature(self):
3231
"errors",
3332
}
3433

35-
@pytest.mark.parametrize("klass", [Series, DataFrame])
36-
def test_rename_mi(self, klass):
37-
obj = klass(
34+
def test_rename_mi(self, frame_or_series):
35+
obj = frame_or_series(
3836
[11, 21, 31],
3937
index=MultiIndex.from_tuples([("A", x) for x in ["a", "B", "c"]]),
4038
)

pandas/tests/frame/methods/test_sample.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,13 @@
1111

1212

1313
class TestSample:
14-
@pytest.fixture(params=[Series, DataFrame])
15-
def obj(self, request):
16-
klass = request.param
17-
if klass is Series:
14+
@pytest.fixture
15+
def obj(self, frame_or_series):
16+
if frame_or_series is Series:
1817
arr = np.random.randn(10)
1918
else:
2019
arr = np.random.randn(10, 10)
21-
return klass(arr, dtype=None)
20+
return frame_or_series(arr, dtype=None)
2221

2322
@pytest.mark.parametrize("test", list(range(10)))
2423
def test_sample(self, test, obj):

pandas/tests/io/formats/test_to_csv.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -385,10 +385,9 @@ def test_to_csv_multi_index(self):
385385
),
386386
],
387387
)
388-
@pytest.mark.parametrize("klass", [DataFrame, pd.Series])
389-
def test_to_csv_single_level_multi_index(self, ind, expected, klass):
388+
def test_to_csv_single_level_multi_index(self, ind, expected, frame_or_series):
390389
# see gh-19589
391-
obj = klass(pd.Series([1], ind, name="data"))
390+
obj = frame_or_series(pd.Series([1], ind, name="data"))
392391

393392
with tm.assert_produces_warning(FutureWarning, match="lineterminator"):
394393
# GH#9568 standardize on lineterminator matching stdlib

pandas/tests/resample/conftest.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -159,13 +159,13 @@ def empty_frame_dti(series):
159159
return DataFrame(index=index)
160160

161161

162-
@pytest.fixture(params=[Series, DataFrame])
163-
def series_and_frame(request, series, frame):
162+
@pytest.fixture
163+
def series_and_frame(frame_or_series, series, frame):
164164
"""
165165
Fixture for parametrization of Series and DataFrame with date_range,
166166
period_range and timedelta_range indexes
167167
"""
168-
if request.param == Series:
168+
if frame_or_series == Series:
169169
return series
170-
if request.param == DataFrame:
170+
if frame_or_series == DataFrame:
171171
return frame

pandas/tests/reshape/concat/test_concat.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -505,16 +505,15 @@ def test_concat_duplicate_indices_raise(self):
505505
concat([df1, df2], axis=1)
506506

507507

508-
@pytest.mark.parametrize("pdt", [Series, DataFrame])
509508
@pytest.mark.parametrize("dt", np.sctypes["float"])
510-
def test_concat_no_unnecessary_upcast(dt, pdt):
509+
def test_concat_no_unnecessary_upcast(dt, frame_or_series):
511510
# GH 13247
512-
dims = pdt(dtype=object).ndim
511+
dims = frame_or_series(dtype=object).ndim
513512

514513
dfs = [
515-
pdt(np.array([1], dtype=dt, ndmin=dims)),
516-
pdt(np.array([np.nan], dtype=dt, ndmin=dims)),
517-
pdt(np.array([5], dtype=dt, ndmin=dims)),
514+
frame_or_series(np.array([1], dtype=dt, ndmin=dims)),
515+
frame_or_series(np.array([np.nan], dtype=dt, ndmin=dims)),
516+
frame_or_series(np.array([5], dtype=dt, ndmin=dims)),
518517
]
519518
x = concat(dfs)
520519
assert x.values.dtype == dt

pandas/tests/window/test_base_indexer.py

+13-13
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,6 @@ def get_window_bounds(self, num_values, min_periods, center, closed, step):
9494
tm.assert_frame_equal(result, expected)
9595

9696

97-
@pytest.mark.parametrize("constructor", [Series, DataFrame])
9897
@pytest.mark.parametrize(
9998
"func,np_func,expected,np_kwargs",
10099
[
@@ -149,7 +148,9 @@ def get_window_bounds(self, num_values, min_periods, center, closed, step):
149148
],
150149
)
151150
@pytest.mark.filterwarnings("ignore:min_periods:FutureWarning")
152-
def test_rolling_forward_window(constructor, func, np_func, expected, np_kwargs, step):
151+
def test_rolling_forward_window(
152+
frame_or_series, func, np_func, expected, np_kwargs, step
153+
):
153154
# GH 32865
154155
values = np.arange(10.0)
155156
values[5] = 100.0
@@ -158,47 +159,46 @@ def test_rolling_forward_window(constructor, func, np_func, expected, np_kwargs,
158159

159160
match = "Forward-looking windows can't have center=True"
160161
with pytest.raises(ValueError, match=match):
161-
rolling = constructor(values).rolling(window=indexer, center=True)
162+
rolling = frame_or_series(values).rolling(window=indexer, center=True)
162163
getattr(rolling, func)()
163164

164165
match = "Forward-looking windows don't support setting the closed argument"
165166
with pytest.raises(ValueError, match=match):
166-
rolling = constructor(values).rolling(window=indexer, closed="right")
167+
rolling = frame_or_series(values).rolling(window=indexer, closed="right")
167168
getattr(rolling, func)()
168169

169-
rolling = constructor(values).rolling(window=indexer, min_periods=2, step=step)
170+
rolling = frame_or_series(values).rolling(window=indexer, min_periods=2, step=step)
170171
result = getattr(rolling, func)()
171172

172173
# Check that the function output matches the explicitly provided array
173-
expected = constructor(expected)[::step]
174+
expected = frame_or_series(expected)[::step]
174175
tm.assert_equal(result, expected)
175176

176177
# Check that the rolling function output matches applying an alternative
177178
# function to the rolling window object
178-
expected2 = constructor(rolling.apply(lambda x: np_func(x, **np_kwargs)))
179+
expected2 = frame_or_series(rolling.apply(lambda x: np_func(x, **np_kwargs)))
179180
tm.assert_equal(result, expected2)
180181

181182
# Check that the function output matches applying an alternative function
182183
# if min_periods isn't specified
183184
# GH 39604: After count-min_periods deprecation, apply(lambda x: len(x))
184185
# is equivalent to count after setting min_periods=0
185186
min_periods = 0 if func == "count" else None
186-
rolling3 = constructor(values).rolling(window=indexer, min_periods=min_periods)
187+
rolling3 = frame_or_series(values).rolling(window=indexer, min_periods=min_periods)
187188
result3 = getattr(rolling3, func)()
188-
expected3 = constructor(rolling3.apply(lambda x: np_func(x, **np_kwargs)))
189+
expected3 = frame_or_series(rolling3.apply(lambda x: np_func(x, **np_kwargs)))
189190
tm.assert_equal(result3, expected3)
190191

191192

192-
@pytest.mark.parametrize("constructor", [Series, DataFrame])
193-
def test_rolling_forward_skewness(constructor, step):
193+
def test_rolling_forward_skewness(frame_or_series, step):
194194
values = np.arange(10.0)
195195
values[5] = 100.0
196196

197197
indexer = FixedForwardWindowIndexer(window_size=5)
198-
rolling = constructor(values).rolling(window=indexer, min_periods=3, step=step)
198+
rolling = frame_or_series(values).rolling(window=indexer, min_periods=3, step=step)
199199
result = rolling.skew()
200200

201-
expected = constructor(
201+
expected = frame_or_series(
202202
[
203203
0.0,
204204
2.232396,

0 commit comments

Comments
 (0)