Skip to content

TST/CLN: Use more frame_or_series fixture #48926

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 4, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions pandas/tests/apply/test_invalid_arg.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,11 @@ def test_map_datetimetz_na_action():
s.map(lambda x: x, na_action="ignore")


@pytest.mark.parametrize("box", [DataFrame, Series])
@pytest.mark.parametrize("method", ["apply", "agg", "transform"])
@pytest.mark.parametrize("func", [{"A": {"B": "sum"}}, {"A": {"B": ["sum"]}}])
def test_nested_renamer(box, method, func):
def test_nested_renamer(frame_or_series, method, func):
# GH 35964
obj = box({"A": [1]})
obj = frame_or_series({"A": [1]})
match = "nested renamer is not supported"
with pytest.raises(SpecificationError, match=match):
getattr(obj, method)(func)
Expand Down
16 changes: 8 additions & 8 deletions pandas/tests/extension/base/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,14 +113,13 @@ def test_add_series_with_extension_array(self, data):
expected = pd.Series(data + data)
self.assert_series_equal(result, expected)

@pytest.mark.parametrize("box", [pd.Series, pd.DataFrame])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

im on board with this, but in the past @jorisvandenbossche has objected bc it would require library authors to re-implement the fixture

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah good point. I didn't realize this was a base extension test so I'll exclude this one.

def test_direct_arith_with_ndframe_returns_not_implemented(
self, request, data, box
self, request, data, frame_or_series
):
# EAs should return NotImplemented for ops with Series/DataFrame
# Pandas takes care of unboxing the series and calling the EA's op.
other = pd.Series(data)
if box is pd.DataFrame:
if frame_or_series is pd.DataFrame:
other = other.to_frame()
if not hasattr(data, "__add__"):
request.node.add_marker(
Expand Down Expand Up @@ -167,25 +166,26 @@ def test_compare_array(self, data, comparison_op):
other = pd.Series([data[0]] * len(data))
self._compare_other(ser, data, comparison_op, other)

@pytest.mark.parametrize("box", [pd.Series, pd.DataFrame])
def test_direct_arith_with_ndframe_returns_not_implemented(self, data, box):
def test_direct_arith_with_ndframe_returns_not_implemented(
self, data, frame_or_series
):
# EAs should return NotImplemented for ops with Series/DataFrame
# Pandas takes care of unboxing the series and calling the EA's op.
other = pd.Series(data)
if box is pd.DataFrame:
if frame_or_series is pd.DataFrame:
other = other.to_frame()

if hasattr(data, "__eq__"):
result = data.__eq__(other)
assert result is NotImplemented
else:
raise pytest.skip(f"{type(data).__name__} does not implement __eq__")
pytest.skip(f"{type(data).__name__} does not implement __eq__")

if hasattr(data, "__ne__"):
result = data.__ne__(other)
assert result is NotImplemented
else:
raise pytest.skip(f"{type(data).__name__} does not implement __ne__")
pytest.skip(f"{type(data).__name__} does not implement __ne__")


class BaseUnaryOpsTests(BaseOpsUtil):
Expand Down
7 changes: 4 additions & 3 deletions pandas/tests/extension/test_period.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,12 @@ def test_add_series_with_extension_array(self, data):
with pytest.raises(TypeError, match=msg):
s + data

@pytest.mark.parametrize("box", [pd.Series, pd.DataFrame])
def test_direct_arith_with_ndframe_returns_not_implemented(self, data, box):
def test_direct_arith_with_ndframe_returns_not_implemented(
self, data, frame_or_series
):
# Override to use __sub__ instead of __add__
other = pd.Series(data)
if box is pd.DataFrame:
if frame_or_series is pd.DataFrame:
other = other.to_frame()

result = data.__sub__(other)
Expand Down
10 changes: 4 additions & 6 deletions pandas/tests/frame/indexing/test_xs.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,8 +324,7 @@ def test_xs_loc_equality(self, multiindex_dataframe_random_data):
expected = df.loc[("bar", "two")]
tm.assert_series_equal(result, expected)

@pytest.mark.parametrize("klass", [DataFrame, Series])
def test_xs_IndexSlice_argument_not_implemented(self, klass):
def test_xs_IndexSlice_argument_not_implemented(self, frame_or_series):
# GH#35301

index = MultiIndex(
Expand All @@ -334,7 +333,7 @@ def test_xs_IndexSlice_argument_not_implemented(self, klass):
)

obj = DataFrame(np.random.randn(6, 4), index=index)
if klass is Series:
if frame_or_series is Series:
obj = obj[0]

expected = obj.iloc[-2:].droplevel(0)
Expand All @@ -345,10 +344,9 @@ def test_xs_IndexSlice_argument_not_implemented(self, klass):
result = obj.loc[IndexSlice[("foo", "qux", 0), :]]
tm.assert_equal(result, expected)

@pytest.mark.parametrize("klass", [DataFrame, Series])
def test_xs_levels_raises(self, klass):
def test_xs_levels_raises(self, frame_or_series):
obj = DataFrame({"A": [1, 2, 3]})
if klass is Series:
if frame_or_series is Series:
obj = obj["A"]

msg = "Index must be a MultiIndex"
Expand Down
7 changes: 3 additions & 4 deletions pandas/tests/frame/methods/test_drop.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,17 +422,16 @@ def test_drop_level_nonunique_datetime(self):
expected = df.loc[idx != 4]
tm.assert_frame_equal(result, expected)

@pytest.mark.parametrize("box", [Series, DataFrame])
def test_drop_tz_aware_timestamp_across_dst(self, box):
def test_drop_tz_aware_timestamp_across_dst(self, frame_or_series):
# GH#21761
start = Timestamp("2017-10-29", tz="Europe/Berlin")
end = Timestamp("2017-10-29 04:00:00", tz="Europe/Berlin")
index = pd.date_range(start, end, freq="15min")
data = box(data=[1] * len(index), index=index)
data = frame_or_series(data=[1] * len(index), index=index)
result = data.drop(start)
expected_start = Timestamp("2017-10-29 00:15:00", tz="Europe/Berlin")
expected_idx = pd.date_range(expected_start, end, freq="15min")
expected = box(data=[1] * len(expected_idx), index=expected_idx)
expected = frame_or_series(data=[1] * len(expected_idx), index=expected_idx)
tm.assert_equal(result, expected)

def test_drop_preserve_names(self):
Expand Down
9 changes: 5 additions & 4 deletions pandas/tests/frame/methods/test_pct_change.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,14 @@ class TestDataFramePctChange:
(-1, "bfill", 1, [np.nan, 0, -0.5, -0.5, -0.6, np.nan, np.nan, np.nan]),
],
)
@pytest.mark.parametrize("klass", [DataFrame, Series])
def test_pct_change_with_nas(self, periods, fill_method, limit, exp, klass):
def test_pct_change_with_nas(
self, periods, fill_method, limit, exp, frame_or_series
):
vals = [np.nan, np.nan, 1, 2, 4, 10, np.nan, np.nan]
obj = klass(vals)
obj = frame_or_series(vals)

res = obj.pct_change(periods=periods, fill_method=fill_method, limit=limit)
tm.assert_equal(res, klass(exp))
tm.assert_equal(res, frame_or_series(exp))

def test_pct_change_numeric(self):
# GH#11150
Expand Down
6 changes: 2 additions & 4 deletions pandas/tests/frame/methods/test_rename.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
DataFrame,
Index,
MultiIndex,
Series,
merge,
)
import pandas._testing as tm
Expand All @@ -32,9 +31,8 @@ def test_rename_signature(self):
"errors",
}

@pytest.mark.parametrize("klass", [Series, DataFrame])
def test_rename_mi(self, klass):
obj = klass(
def test_rename_mi(self, frame_or_series):
obj = frame_or_series(
[11, 21, 31],
index=MultiIndex.from_tuples([("A", x) for x in ["a", "B", "c"]]),
)
Expand Down
9 changes: 4 additions & 5 deletions pandas/tests/frame/methods/test_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,13 @@


class TestSample:
@pytest.fixture(params=[Series, DataFrame])
def obj(self, request):
klass = request.param
if klass is Series:
@pytest.fixture
def obj(self, frame_or_series):
if frame_or_series is Series:
arr = np.random.randn(10)
else:
arr = np.random.randn(10, 10)
return klass(arr, dtype=None)
return frame_or_series(arr, dtype=None)

@pytest.mark.parametrize("test", list(range(10)))
def test_sample(self, test, obj):
Expand Down
5 changes: 2 additions & 3 deletions pandas/tests/io/formats/test_to_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,10 +385,9 @@ def test_to_csv_multi_index(self):
),
],
)
@pytest.mark.parametrize("klass", [DataFrame, pd.Series])
def test_to_csv_single_level_multi_index(self, ind, expected, klass):
def test_to_csv_single_level_multi_index(self, ind, expected, frame_or_series):
# see gh-19589
obj = klass(pd.Series([1], ind, name="data"))
obj = frame_or_series(pd.Series([1], ind, name="data"))

with tm.assert_produces_warning(FutureWarning, match="lineterminator"):
# GH#9568 standardize on lineterminator matching stdlib
Expand Down
8 changes: 4 additions & 4 deletions pandas/tests/resample/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,13 +159,13 @@ def empty_frame_dti(series):
return DataFrame(index=index)


@pytest.fixture(params=[Series, DataFrame])
def series_and_frame(request, series, frame):
@pytest.fixture
def series_and_frame(frame_or_series, series, frame):
"""
Fixture for parametrization of Series and DataFrame with date_range,
period_range and timedelta_range indexes
"""
if request.param == Series:
if frame_or_series == Series:
return series
if request.param == DataFrame:
if frame_or_series == DataFrame:
return frame
11 changes: 5 additions & 6 deletions pandas/tests/reshape/concat/test_concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,16 +505,15 @@ def test_concat_duplicate_indices_raise(self):
concat([df1, df2], axis=1)


@pytest.mark.parametrize("pdt", [Series, DataFrame])
@pytest.mark.parametrize("dt", np.sctypes["float"])
def test_concat_no_unnecessary_upcast(dt, pdt):
def test_concat_no_unnecessary_upcast(dt, frame_or_series):
# GH 13247
dims = pdt(dtype=object).ndim
dims = frame_or_series(dtype=object).ndim

dfs = [
pdt(np.array([1], dtype=dt, ndmin=dims)),
pdt(np.array([np.nan], dtype=dt, ndmin=dims)),
pdt(np.array([5], dtype=dt, ndmin=dims)),
frame_or_series(np.array([1], dtype=dt, ndmin=dims)),
frame_or_series(np.array([np.nan], dtype=dt, ndmin=dims)),
frame_or_series(np.array([5], dtype=dt, ndmin=dims)),
]
x = concat(dfs)
assert x.values.dtype == dt
Expand Down
26 changes: 13 additions & 13 deletions pandas/tests/window/test_base_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ def get_window_bounds(self, num_values, min_periods, center, closed, step):
tm.assert_frame_equal(result, expected)


@pytest.mark.parametrize("constructor", [Series, DataFrame])
@pytest.mark.parametrize(
"func,np_func,expected,np_kwargs",
[
Expand Down Expand Up @@ -149,7 +148,9 @@ def get_window_bounds(self, num_values, min_periods, center, closed, step):
],
)
@pytest.mark.filterwarnings("ignore:min_periods:FutureWarning")
def test_rolling_forward_window(constructor, func, np_func, expected, np_kwargs, step):
def test_rolling_forward_window(
frame_or_series, func, np_func, expected, np_kwargs, step
):
# GH 32865
values = np.arange(10.0)
values[5] = 100.0
Expand All @@ -158,47 +159,46 @@ def test_rolling_forward_window(constructor, func, np_func, expected, np_kwargs,

match = "Forward-looking windows can't have center=True"
with pytest.raises(ValueError, match=match):
rolling = constructor(values).rolling(window=indexer, center=True)
rolling = frame_or_series(values).rolling(window=indexer, center=True)
getattr(rolling, func)()

match = "Forward-looking windows don't support setting the closed argument"
with pytest.raises(ValueError, match=match):
rolling = constructor(values).rolling(window=indexer, closed="right")
rolling = frame_or_series(values).rolling(window=indexer, closed="right")
getattr(rolling, func)()

rolling = constructor(values).rolling(window=indexer, min_periods=2, step=step)
rolling = frame_or_series(values).rolling(window=indexer, min_periods=2, step=step)
result = getattr(rolling, func)()

# Check that the function output matches the explicitly provided array
expected = constructor(expected)[::step]
expected = frame_or_series(expected)[::step]
tm.assert_equal(result, expected)

# Check that the rolling function output matches applying an alternative
# function to the rolling window object
expected2 = constructor(rolling.apply(lambda x: np_func(x, **np_kwargs)))
expected2 = frame_or_series(rolling.apply(lambda x: np_func(x, **np_kwargs)))
tm.assert_equal(result, expected2)

# Check that the function output matches applying an alternative function
# if min_periods isn't specified
# GH 39604: After count-min_periods deprecation, apply(lambda x: len(x))
# is equivalent to count after setting min_periods=0
min_periods = 0 if func == "count" else None
rolling3 = constructor(values).rolling(window=indexer, min_periods=min_periods)
rolling3 = frame_or_series(values).rolling(window=indexer, min_periods=min_periods)
result3 = getattr(rolling3, func)()
expected3 = constructor(rolling3.apply(lambda x: np_func(x, **np_kwargs)))
expected3 = frame_or_series(rolling3.apply(lambda x: np_func(x, **np_kwargs)))
tm.assert_equal(result3, expected3)


@pytest.mark.parametrize("constructor", [Series, DataFrame])
def test_rolling_forward_skewness(constructor, step):
def test_rolling_forward_skewness(frame_or_series, step):
values = np.arange(10.0)
values[5] = 100.0

indexer = FixedForwardWindowIndexer(window_size=5)
rolling = constructor(values).rolling(window=indexer, min_periods=3, step=step)
rolling = frame_or_series(values).rolling(window=indexer, min_periods=3, step=step)
result = rolling.skew()

expected = constructor(
expected = frame_or_series(
[
0.0,
2.232396,
Expand Down