Skip to content

TST/CLN: Use more shared fixtures #56708

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions pandas/tests/arithmetic/test_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -1116,11 +1116,10 @@ def test_ufunc_compat(self, holder, dtype):
tm.assert_equal(result, expected)

# TODO: add more dtypes
@pytest.mark.parametrize("holder", [Index, Series])
@pytest.mark.parametrize("dtype", [np.int64, np.uint64, np.float64])
def test_ufunc_coercions(self, holder, dtype):
idx = holder([1, 2, 3, 4, 5], dtype=dtype, name="x")
box = Series if holder is Series else Index
def test_ufunc_coercions(self, index_or_series, dtype):
idx = index_or_series([1, 2, 3, 4, 5], dtype=dtype, name="x")
box = index_or_series

result = np.sqrt(idx)
assert result.dtype == "f8" and isinstance(result, box)
Expand Down
12 changes: 6 additions & 6 deletions pandas/tests/copy_view/test_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,14 +283,13 @@ def test_dataframe_from_dict_of_series_with_reindex(dtype):
assert np.shares_memory(arr_before, arr_after)


@pytest.mark.parametrize("cons", [Series, Index])
@pytest.mark.parametrize(
"data, dtype", [([1, 2], None), ([1, 2], "int64"), (["a", "b"], None)]
)
def test_dataframe_from_series_or_index(
using_copy_on_write, warn_copy_on_write, data, dtype, cons
using_copy_on_write, warn_copy_on_write, data, dtype, index_or_series
):
obj = cons(data, dtype=dtype)
obj = index_or_series(data, dtype=dtype)
obj_orig = obj.copy()
df = DataFrame(obj, dtype=dtype)
assert np.shares_memory(get_array(obj), get_array(df, 0))
Expand All @@ -303,9 +302,10 @@ def test_dataframe_from_series_or_index(
tm.assert_equal(obj, obj_orig)


@pytest.mark.parametrize("cons", [Series, Index])
def test_dataframe_from_series_or_index_different_dtype(using_copy_on_write, cons):
obj = cons([1, 2], dtype="int64")
def test_dataframe_from_series_or_index_different_dtype(
using_copy_on_write, index_or_series
):
obj = index_or_series([1, 2], dtype="int64")
df = DataFrame(obj, dtype="int32")
assert not np.shares_memory(get_array(obj), get_array(df, 0))
if using_copy_on_write:
Expand Down
12 changes: 6 additions & 6 deletions pandas/tests/dtypes/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -1701,19 +1701,19 @@ def test_interval_mismatched_subtype(self):
arr = np.array([first, flt_interval], dtype=object)
assert lib.infer_dtype(arr, skipna=False) == "interval"

@pytest.mark.parametrize("klass", [pd.array, Series])
@pytest.mark.parametrize("data", [["a", "b", "c"], ["a", "b", pd.NA]])
def test_string_dtype(self, data, skipna, klass, nullable_string_dtype):
def test_string_dtype(
self, data, skipna, index_or_series_or_array, nullable_string_dtype
):
# StringArray
val = klass(data, dtype=nullable_string_dtype)
val = index_or_series_or_array(data, dtype=nullable_string_dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice on improved coverage

inferred = lib.infer_dtype(val, skipna=skipna)
assert inferred == "string"

@pytest.mark.parametrize("klass", [pd.array, Series])
@pytest.mark.parametrize("data", [[True, False, True], [True, False, pd.NA]])
def test_boolean_dtype(self, data, skipna, klass):
def test_boolean_dtype(self, data, skipna, index_or_series_or_array):
# BooleanArray
val = klass(data, dtype="boolean")
val = index_or_series_or_array(data, dtype="boolean")
inferred = lib.infer_dtype(val, skipna=skipna)
assert inferred == "boolean"

Expand Down
8 changes: 3 additions & 5 deletions pandas/tests/extension/decimal/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,8 +302,7 @@ def test_dataframe_constructor_with_dtype():
tm.assert_frame_equal(result, expected)


@pytest.mark.parametrize("frame", [True, False])
def test_astype_dispatches(frame):
def test_astype_dispatches(frame_or_series):
# This is a dtype-specific test that ensures Series[decimal].astype
# gets all the way through to ExtensionArray.astype
# Designing a reliable smoke test that works for arbitrary data types
Expand All @@ -312,12 +311,11 @@ def test_astype_dispatches(frame):
ctx = decimal.Context()
ctx.prec = 5

if frame:
data = data.to_frame()
data = frame_or_series(data)

result = data.astype(DecimalDtype(ctx))

if frame:
if frame_or_series is pd.DataFrame:
result = result["a"]

assert result.dtype.context.prec == ctx.prec
Expand Down
4 changes: 2 additions & 2 deletions pandas/tests/frame/methods/test_set_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,8 +577,8 @@ def test_set_index_raise_keys(self, frame_of_index_cols, drop, append):

@pytest.mark.parametrize("append", [True, False])
@pytest.mark.parametrize("drop", [True, False])
@pytest.mark.parametrize("box", [set], ids=["set"])
def test_set_index_raise_on_type(self, frame_of_index_cols, box, drop, append):
def test_set_index_raise_on_type(self, frame_of_index_cols, drop, append):
box = set
df = frame_of_index_cols

msg = 'The parameter "keys" may be a column key, .*'
Expand Down
10 changes: 4 additions & 6 deletions pandas/tests/groupby/aggregate/test_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ def incorrect_function(values, index):
@pytest.mark.filterwarnings("ignore")
# Filter warnings when parallel=True and the function can't be parallelized by Numba
@pytest.mark.parametrize("jit", [True, False])
@pytest.mark.parametrize("pandas_obj", ["Series", "DataFrame"])
def test_numba_vs_cython(jit, pandas_obj, nogil, parallel, nopython, as_index):
def test_numba_vs_cython(jit, frame_or_series, nogil, parallel, nopython, as_index):
pytest.importorskip("numba")

def func_numba(values, index):
Expand All @@ -70,7 +69,7 @@ def func_numba(values, index):
)
engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython}
grouped = data.groupby(0, as_index=as_index)
if pandas_obj == "Series":
if frame_or_series is Series:
grouped = grouped[1]

result = grouped.agg(func_numba, engine="numba", engine_kwargs=engine_kwargs)
Expand All @@ -82,8 +81,7 @@ def func_numba(values, index):
@pytest.mark.filterwarnings("ignore")
# Filter warnings when parallel=True and the function can't be parallelized by Numba
@pytest.mark.parametrize("jit", [True, False])
@pytest.mark.parametrize("pandas_obj", ["Series", "DataFrame"])
def test_cache(jit, pandas_obj, nogil, parallel, nopython):
def test_cache(jit, frame_or_series, nogil, parallel, nopython):
# Test that the functions are cached correctly if we switch functions
pytest.importorskip("numba")

Expand All @@ -104,7 +102,7 @@ def func_2(values, index):
)
engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython}
grouped = data.groupby(0)
if pandas_obj == "Series":
if frame_or_series is Series:
grouped = grouped[1]

result = grouped.agg(func_1, engine="numba", engine_kwargs=engine_kwargs)
Expand Down
10 changes: 4 additions & 6 deletions pandas/tests/groupby/transform/test_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,7 @@ def incorrect_function(values, index):
@pytest.mark.filterwarnings("ignore")
# Filter warnings when parallel=True and the function can't be parallelized by Numba
@pytest.mark.parametrize("jit", [True, False])
@pytest.mark.parametrize("pandas_obj", ["Series", "DataFrame"])
def test_numba_vs_cython(jit, pandas_obj, nogil, parallel, nopython, as_index):
def test_numba_vs_cython(jit, frame_or_series, nogil, parallel, nopython, as_index):
pytest.importorskip("numba")

def func(values, index):
Expand All @@ -68,7 +67,7 @@ def func(values, index):
)
engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython}
grouped = data.groupby(0, as_index=as_index)
if pandas_obj == "Series":
if frame_or_series is Series:
grouped = grouped[1]

result = grouped.transform(func, engine="numba", engine_kwargs=engine_kwargs)
Expand All @@ -80,8 +79,7 @@ def func(values, index):
@pytest.mark.filterwarnings("ignore")
# Filter warnings when parallel=True and the function can't be parallelized by Numba
@pytest.mark.parametrize("jit", [True, False])
@pytest.mark.parametrize("pandas_obj", ["Series", "DataFrame"])
def test_cache(jit, pandas_obj, nogil, parallel, nopython):
def test_cache(jit, frame_or_series, nogil, parallel, nopython):
# Test that the functions are cached correctly if we switch functions
pytest.importorskip("numba")

Expand All @@ -102,7 +100,7 @@ def func_2(values, index):
)
engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython}
grouped = data.groupby(0)
if pandas_obj == "Series":
if frame_or_series is Series:
grouped = grouped[1]

result = grouped.transform(func_1, engine="numba", engine_kwargs=engine_kwargs)
Expand Down
9 changes: 5 additions & 4 deletions pandas/tests/indexing/test_iloc.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,9 @@ def test_iloc_setitem_fullcol_categorical(self, indexer, key):
expected = DataFrame({0: Series(cat.astype(object), dtype=object), 1: range(3)})
tm.assert_frame_equal(df, expected)

@pytest.mark.parametrize("box", [array, Series])
def test_iloc_setitem_ea_inplace(self, frame_or_series, box, using_copy_on_write):
def test_iloc_setitem_ea_inplace(
self, frame_or_series, index_or_series_or_array, using_copy_on_write
):
# GH#38952 Case with not setting a full column
# IntegerArray without NAs
arr = array([1, 2, 3, 4])
Expand All @@ -119,9 +120,9 @@ def test_iloc_setitem_ea_inplace(self, frame_or_series, box, using_copy_on_write
values = obj._mgr.arrays[0]

if frame_or_series is Series:
obj.iloc[:2] = box(arr[2:])
obj.iloc[:2] = index_or_series_or_array(arr[2:])
else:
obj.iloc[:2, 0] = box(arr[2:])
obj.iloc[:2, 0] = index_or_series_or_array(arr[2:])

expected = frame_or_series(np.array([3, 4, 3, 4], dtype="i8"))
tm.assert_equal(obj, expected)
Expand Down
17 changes: 7 additions & 10 deletions pandas/tests/reductions/test_stat_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,29 +16,27 @@


class TestDatetimeLikeStatReductions:
@pytest.mark.parametrize("box", [Series, pd.Index, pd.array])
def test_dt64_mean(self, tz_naive_fixture, box):
def test_dt64_mean(self, tz_naive_fixture, index_or_series_or_array):
tz = tz_naive_fixture

dti = date_range("2001-01-01", periods=11, tz=tz)
# shuffle so that we are not just working with monotone-increasing
dti = dti.take([4, 1, 3, 10, 9, 7, 8, 5, 0, 2, 6])
dtarr = dti._data

obj = box(dtarr)
obj = index_or_series_or_array(dtarr)
assert obj.mean() == pd.Timestamp("2001-01-06", tz=tz)
assert obj.mean(skipna=False) == pd.Timestamp("2001-01-06", tz=tz)

# dtarr[-2] will be the first date 2001-01-1
dtarr[-2] = pd.NaT

obj = box(dtarr)
obj = index_or_series_or_array(dtarr)
assert obj.mean() == pd.Timestamp("2001-01-06 07:12:00", tz=tz)
assert obj.mean(skipna=False) is pd.NaT

@pytest.mark.parametrize("box", [Series, pd.Index, pd.array])
@pytest.mark.parametrize("freq", ["s", "h", "D", "W", "B"])
def test_period_mean(self, box, freq):
def test_period_mean(self, index_or_series_or_array, freq):
# GH#24757
dti = date_range("2001-01-01", periods=11)
# shuffle so that we are not just working with monotone-increasing
Expand All @@ -48,7 +46,7 @@ def test_period_mean(self, box, freq):
msg = r"PeriodDtype\[B\] is deprecated"
with tm.assert_produces_warning(warn, match=msg):
parr = dti._data.to_period(freq)
obj = box(parr)
obj = index_or_series_or_array(parr)
with pytest.raises(TypeError, match="ambiguous"):
obj.mean()
with pytest.raises(TypeError, match="ambiguous"):
Expand All @@ -62,13 +60,12 @@ def test_period_mean(self, box, freq):
with pytest.raises(TypeError, match="ambiguous"):
obj.mean(skipna=True)

@pytest.mark.parametrize("box", [Series, pd.Index, pd.array])
def test_td64_mean(self, box):
def test_td64_mean(self, index_or_series_or_array):
m8values = np.array([0, 3, -2, -7, 1, 2, -1, 3, 5, -2, 4], "m8[D]")
tdi = pd.TimedeltaIndex(m8values).as_unit("ns")

tdarr = tdi._data
obj = box(tdarr, copy=False)
obj = index_or_series_or_array(tdarr, copy=False)

result = obj.mean()
expected = np.array(tdarr).mean()
Expand Down
5 changes: 2 additions & 3 deletions pandas/tests/resample/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,8 @@
date_range(datetime(2005, 1, 1), datetime(2005, 1, 10), freq="D"),
],
)
@pytest.mark.parametrize("klass", [DataFrame, Series])
def test_asfreq(klass, index, freq):
obj = klass(range(len(index)), index=index)
def test_asfreq(frame_or_series, index, freq):
obj = frame_or_series(range(len(index)), index=index)
idx_range = date_range if isinstance(index, DatetimeIndex) else timedelta_range

result = obj.resample(freq).asfreq()
Expand Down
20 changes: 8 additions & 12 deletions pandas/tests/resample/test_period_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,11 @@ def _simple_period_range_series(start, end, freq="D"):
class TestPeriodIndex:
@pytest.mark.parametrize("freq", ["2D", "1h", "2h"])
@pytest.mark.parametrize("kind", ["period", None, "timestamp"])
@pytest.mark.parametrize("klass", [DataFrame, Series])
def test_asfreq(self, klass, freq, kind):
def test_asfreq(self, frame_or_series, freq, kind):
# GH 12884, 15944
# make sure .asfreq() returns PeriodIndex (except kind='timestamp')

obj = klass(range(5), index=period_range("2020-01-01", periods=5))
obj = frame_or_series(range(5), index=period_range("2020-01-01", periods=5))
if kind == "timestamp":
expected = obj.to_timestamp().resample(freq).asfreq()
else:
Expand Down Expand Up @@ -1007,12 +1006,11 @@ def test_resample_t_l_deprecated(self):
offsets.BusinessHour(2),
],
)
@pytest.mark.parametrize("klass", [DataFrame, Series])
def test_asfreq_invalid_period_freq(self, offset, klass):
def test_asfreq_invalid_period_freq(self, offset, frame_or_series):
# GH#9586
msg = f"Invalid offset: '{offset.base}' for converting time series "

obj = klass(range(5), index=period_range("2020-01-01", periods=5))
obj = frame_or_series(range(5), index=period_range("2020-01-01", periods=5))
with pytest.raises(ValueError, match=msg):
obj.asfreq(freq=offset)

Expand All @@ -1027,12 +1025,11 @@ def test_asfreq_invalid_period_freq(self, offset, klass):
("2Y-MAR", "2YE-MAR"),
],
)
@pytest.mark.parametrize("klass", [DataFrame, Series])
def test_resample_frequency_ME_QE_YE_error_message(klass, freq, freq_depr):
def test_resample_frequency_ME_QE_YE_error_message(frame_or_series, freq, freq_depr):
# GH#9586
msg = f"for Period, please use '{freq[1:]}' instead of '{freq_depr[1:]}'"

obj = klass(range(5), index=period_range("2020-01-01", periods=5))
obj = frame_or_series(range(5), index=period_range("2020-01-01", periods=5))
with pytest.raises(ValueError, match=msg):
obj.resample(freq_depr)

Expand All @@ -1057,11 +1054,10 @@ def test_corner_cases_period(simple_period_range_series):
"2BYE-MAR",
],
)
@pytest.mark.parametrize("klass", [DataFrame, Series])
def test_resample_frequency_invalid_freq(klass, freq_depr):
def test_resample_frequency_invalid_freq(frame_or_series, freq_depr):
# GH#9586
msg = f"Invalid frequency: {freq_depr[1:]}"

obj = klass(range(5), index=period_range("2020-01-01", periods=5))
obj = frame_or_series(range(5), index=period_range("2020-01-01", periods=5))
with pytest.raises(ValueError, match=msg):
obj.resample(freq_depr)
11 changes: 5 additions & 6 deletions pandas/tests/reshape/concat/test_concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,14 +545,13 @@ def test_concat_no_unnecessary_upcast(float_numpy_dtype, frame_or_series):
assert x.values.dtype == dt


@pytest.mark.parametrize("pdt", [Series, DataFrame])
def test_concat_will_upcast(pdt, any_signed_int_numpy_dtype):
def test_concat_will_upcast(frame_or_series, any_signed_int_numpy_dtype):
dt = any_signed_int_numpy_dtype
dims = pdt().ndim
dims = frame_or_series().ndim
dfs = [
pdt(np.array([1], dtype=dt, ndmin=dims)),
pdt(np.array([np.nan], ndmin=dims)),
pdt(np.array([5], dtype=dt, ndmin=dims)),
frame_or_series(np.array([1], dtype=dt, ndmin=dims)),
frame_or_series(np.array([np.nan], ndmin=dims)),
frame_or_series(np.array([5], dtype=dt, ndmin=dims)),
]
x = concat(dfs)
assert x.values.dtype == "float64"
Expand Down
7 changes: 2 additions & 5 deletions pandas/tests/series/methods/test_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@
import pytest

from pandas import (
Index,
Series,
array,
date_range,
)
import pandas._testing as tm
Expand Down Expand Up @@ -47,11 +45,10 @@ def test_view_tz(self):
@pytest.mark.parametrize(
"second", ["m8[ns]", "M8[ns]", "M8[ns, US/Central]", "period[D]"]
)
@pytest.mark.parametrize("box", [Series, Index, array])
def test_view_between_datetimelike(self, first, second, box):
def test_view_between_datetimelike(self, first, second, index_or_series_or_array):
dti = date_range("2016-01-01", periods=3)

orig = box(dti)
orig = index_or_series_or_array(dti)
obj = orig.view(first)
assert obj.dtype == first
tm.assert_numpy_array_equal(np.asarray(obj.view("i8")), dti.asi8)
Expand Down
Loading