Skip to content

TST: clean ups in extension tests #54719

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions pandas/tests/extension/base/accumulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,6 @@ def check_accumulate(self, ser: pd.Series, op_name: str, skipna: bool):
alt = ser.astype(object)

result = getattr(ser, op_name)(skipna=skipna)

if result.dtype == pd.Float32Dtype() and op_name == "cumprod" and skipna:
# TODO: avoid special-casing here
pytest.skip(
f"Float32 precision lead to large differences with op {op_name} "
f"and skipna={skipna}"
)

expected = getattr(alt, op_name)(skipna=skipna)
tm.assert_series_equal(result, expected, check_dtype=False)

Expand Down
6 changes: 4 additions & 2 deletions pandas/tests/extension/base/constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ def test_series_constructor(self, data):
if hasattr(result._mgr, "blocks"):
assert isinstance(result2._mgr.blocks[0], EABackedBlock)

def test_series_constructor_no_data_with_index(self, dtype, na_value):
def test_series_constructor_no_data_with_index(self, dtype):
na_value = dtype.na_value
result = pd.Series(index=[1, 2, 3], dtype=dtype)
expected = pd.Series([na_value] * 3, index=[1, 2, 3], dtype=dtype)
tm.assert_series_equal(result, expected)
Expand All @@ -45,7 +46,8 @@ def test_series_constructor_no_data_with_index(self, dtype, na_value):
expected = pd.Series([], index=pd.Index([], dtype="object"), dtype=dtype)
tm.assert_series_equal(result, expected)

def test_series_constructor_scalar_na_with_index(self, dtype, na_value):
def test_series_constructor_scalar_na_with_index(self, dtype):
na_value = dtype.na_value
result = pd.Series(na_value, index=[1, 2, 3], dtype=dtype)
expected = pd.Series([na_value] * 3, index=[1, 2, 3], dtype=dtype)
tm.assert_series_equal(result, expected)
Expand Down
15 changes: 10 additions & 5 deletions pandas/tests/extension/base/getitem.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,8 @@ def test_getitem_invalid(self, data):
with pytest.raises(IndexError, match=msg):
data[-ub - 1]

def test_getitem_scalar_na(self, data_missing, na_cmp, na_value):
def test_getitem_scalar_na(self, data_missing, na_cmp):
na_value = data_missing.dtype.na_value
result = data_missing[0]
assert na_cmp(result, na_value)

Expand Down Expand Up @@ -348,7 +349,8 @@ def test_take_sequence(self, data):
assert result.iloc[1] == data[1]
assert result.iloc[2] == data[3]

def test_take(self, data, na_value, na_cmp):
def test_take(self, data, na_cmp):
na_value = data.dtype.na_value
result = data.take([0, -1])
assert result.dtype == data.dtype
assert result[0] == data[0]
Expand All @@ -361,7 +363,8 @@ def test_take(self, data, na_value, na_cmp):
with pytest.raises(IndexError, match="out of bounds"):
data.take([len(data) + 1])

def test_take_empty(self, data, na_value, na_cmp):
def test_take_empty(self, data, na_cmp):
na_value = data.dtype.na_value
empty = data[:0]

result = empty.take([-1], allow_fill=True)
Expand Down Expand Up @@ -393,7 +396,8 @@ def test_take_non_na_fill_value(self, data_missing):
expected = arr.take([1, 1])
tm.assert_extension_array_equal(result, expected)

def test_take_pandas_style_negative_raises(self, data, na_value):
def test_take_pandas_style_negative_raises(self, data):
na_value = data.dtype.na_value
with pytest.raises(ValueError, match=""):
data.take([0, -2], fill_value=na_value, allow_fill=True)

Expand All @@ -413,7 +417,8 @@ def test_take_series(self, data):
)
tm.assert_series_equal(result, expected)

def test_reindex(self, data, na_value):
def test_reindex(self, data):
na_value = data.dtype.na_value
s = pd.Series(data)
result = s.reindex([0, 1, 3])
expected = pd.Series(data.take([0, 1, 3]), index=[0, 1, 3])
Expand Down
11 changes: 7 additions & 4 deletions pandas/tests/extension/base/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def test_argsort_missing(self, data_missing_for_sorting):
expected = pd.Series(np.array([1, -1, 0], dtype=np.intp))
tm.assert_series_equal(result, expected)

def test_argmin_argmax(self, data_for_sorting, data_missing_for_sorting, na_value):
def test_argmin_argmax(self, data_for_sorting, data_missing_for_sorting):
# GH 24382
is_bool = data_for_sorting.dtype._is_boolean

Expand Down Expand Up @@ -154,9 +154,10 @@ def test_argmin_argmax_empty_array(self, method, data):
getattr(data[:0], method)()

@pytest.mark.parametrize("method", ["argmax", "argmin"])
def test_argmin_argmax_all_na(self, method, data, na_value):
def test_argmin_argmax_all_na(self, method, data):
# all missing with skipna=True is the same as empty
err_msg = "attempt to get"
na_value = data.dtype.na_value
data_na = type(data)._from_sequence([na_value, na_value], dtype=data.dtype)
with pytest.raises(ValueError, match=err_msg):
getattr(data_na, method)()
Expand Down Expand Up @@ -543,7 +544,8 @@ def _test_searchsorted_bool_dtypes(self, data_for_sorting, as_series):
sorter = np.array([1, 0])
assert data_for_sorting.searchsorted(a, sorter=sorter) == 0

def test_where_series(self, data, na_value, as_frame):
def test_where_series(self, data, as_frame):
na_value = data.dtype.na_value
assert data[0] != data[1]
cls = type(data)
a, b = data[:2]
Expand Down Expand Up @@ -670,7 +672,8 @@ def test_insert_invalid_loc(self, data):
data.insert(1.5, data[0])

@pytest.mark.parametrize("box", [pd.array, pd.Series, pd.DataFrame])
def test_equals(self, data, na_value, as_series, box):
def test_equals(self, data, as_series, box):
na_value = data.dtype.na_value
data2 = type(data)._from_sequence([data[0]] * len(data), dtype=data.dtype)
data_na = type(data)._from_sequence([na_value] * len(data), dtype=data.dtype)

Expand Down
8 changes: 8 additions & 0 deletions pandas/tests/extension/base/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import numpy as np
import pytest

from pandas.core.dtypes.common import is_string_dtype

import pandas as pd
import pandas._testing as tm
from pandas.core import ops
Expand Down Expand Up @@ -128,12 +130,18 @@ class BaseArithmeticOpsTests(BaseOpsUtil):

def test_arith_series_with_scalar(self, data, all_arithmetic_operators):
# series & scalar
if all_arithmetic_operators == "__rmod__" and is_string_dtype(data.dtype):
pytest.skip("Skip testing Python string formatting")

op_name = all_arithmetic_operators
ser = pd.Series(data)
self.check_opname(ser, op_name, ser.iloc[0])

def test_arith_frame_with_scalar(self, data, all_arithmetic_operators):
# frame & scalar
if all_arithmetic_operators == "__rmod__" and is_string_dtype(data.dtype):
pytest.skip("Skip testing Python string formatting")

op_name = all_arithmetic_operators
df = pd.DataFrame({"A": data})
self.check_opname(df, op_name, data[0])
Expand Down
20 changes: 13 additions & 7 deletions pandas/tests/extension/base/reshaping.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ def test_concat_mixed_dtypes(self, data):
expected = pd.concat([df1["A"].astype("object"), df2["A"].astype("object")])
tm.assert_series_equal(result, expected)

def test_concat_columns(self, data, na_value):
def test_concat_columns(self, data):
na_value = data.dtype.na_value
df1 = pd.DataFrame({"A": data[:3]})
df2 = pd.DataFrame({"B": [1, 2, 3]})

Expand All @@ -96,8 +97,9 @@ def test_concat_columns(self, data, na_value):
result = pd.concat([df1["A"], df2["B"]], axis=1)
tm.assert_frame_equal(result, expected)

def test_concat_extension_arrays_copy_false(self, data, na_value):
def test_concat_extension_arrays_copy_false(self, data):
# GH 20756
na_value = data.dtype.na_value
df1 = pd.DataFrame({"A": data[:3]})
df2 = pd.DataFrame({"B": data[3:7]})
expected = pd.DataFrame(
Expand All @@ -122,7 +124,8 @@ def test_concat_with_reindex(self, data):
)
tm.assert_frame_equal(result, expected)

def test_align(self, data, na_value):
def test_align(self, data):
na_value = data.dtype.na_value
a = data[:3]
b = data[2:5]
r1, r2 = pd.Series(a).align(pd.Series(b, index=[1, 2, 3]))
Expand All @@ -133,7 +136,8 @@ def test_align(self, data, na_value):
tm.assert_series_equal(r1, e1)
tm.assert_series_equal(r2, e2)

def test_align_frame(self, data, na_value):
def test_align_frame(self, data):
na_value = data.dtype.na_value
a = data[:3]
b = data[2:5]
r1, r2 = pd.DataFrame({"A": a}).align(pd.DataFrame({"A": b}, index=[1, 2, 3]))
Expand All @@ -148,8 +152,9 @@ def test_align_frame(self, data, na_value):
tm.assert_frame_equal(r1, e1)
tm.assert_frame_equal(r2, e2)

def test_align_series_frame(self, data, na_value):
def test_align_series_frame(self, data):
# https://github.com/pandas-dev/pandas/issues/20576
na_value = data.dtype.na_value
ser = pd.Series(data, name="a")
df = pd.DataFrame({"col": np.arange(len(ser) + 1)})
r1, r2 = ser.align(df)
Expand Down Expand Up @@ -180,7 +185,7 @@ def test_set_frame_overwrite_object(self, data):
df["A"] = data
assert df.dtypes["A"] == data.dtype

def test_merge(self, data, na_value):
def test_merge(self, data):
# GH-20743
df1 = pd.DataFrame({"ext": data[:3], "int1": [1, 2, 3], "key": [0, 1, 2]})
df2 = pd.DataFrame({"int2": [1, 2, 3, 4], "key": [0, 0, 1, 3]})
Expand All @@ -205,7 +210,8 @@ def test_merge(self, data, na_value):
"int2": [1, 2, 3, np.nan, 4],
"key": [0, 0, 1, 2, 3],
"ext": data._from_sequence(
[data[0], data[0], data[1], data[2], na_value], dtype=data.dtype
[data[0], data[0], data[1], data[2], data.dtype.na_value],
dtype=data.dtype,
),
}
)
Expand Down
3 changes: 2 additions & 1 deletion pandas/tests/extension/base/setitem.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,8 @@ def test_setitem_with_expansion_dataframe_column(self, data, full_indexer):

tm.assert_frame_equal(result, expected)

def test_setitem_with_expansion_row(self, data, na_value):
def test_setitem_with_expansion_row(self, data):
na_value = data.dtype.na_value
df = pd.DataFrame({"data": data[:1]})

df.loc[1, "data"] = data[1]
Expand Down
6 changes: 0 additions & 6 deletions pandas/tests/extension/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,6 @@ def na_cmp():
return operator.is_


@pytest.fixture
def na_value(dtype):
"""The scalar missing value for this type. Default dtype.na_value"""
return dtype.na_value


@pytest.fixture
def data_for_grouping():
"""
Expand Down
16 changes: 8 additions & 8 deletions pandas/tests/extension/json/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,24 +97,24 @@ def test_from_dtype(self, data):
super().test_from_dtype(data)

@pytest.mark.xfail(reason="RecursionError, GH-33900")
def test_series_constructor_no_data_with_index(self, dtype, na_value):
def test_series_constructor_no_data_with_index(self, dtype):
# RecursionError: maximum recursion depth exceeded in comparison
rec_limit = sys.getrecursionlimit()
try:
# Limit to avoid stack overflow on Windows CI
sys.setrecursionlimit(100)
super().test_series_constructor_no_data_with_index(dtype, na_value)
super().test_series_constructor_no_data_with_index(dtype)
finally:
sys.setrecursionlimit(rec_limit)

@pytest.mark.xfail(reason="RecursionError, GH-33900")
def test_series_constructor_scalar_na_with_index(self, dtype, na_value):
def test_series_constructor_scalar_na_with_index(self, dtype):
# RecursionError: maximum recursion depth exceeded in comparison
rec_limit = sys.getrecursionlimit()
try:
# Limit to avoid stack overflow on Windows CI
sys.setrecursionlimit(100)
super().test_series_constructor_scalar_na_with_index(dtype, na_value)
super().test_series_constructor_scalar_na_with_index(dtype)
finally:
sys.setrecursionlimit(rec_limit)

Expand Down Expand Up @@ -214,19 +214,19 @@ def test_combine_first(self, data):
super().test_combine_first(data)

@pytest.mark.xfail(reason="broadcasting error")
def test_where_series(self, data, na_value):
def test_where_series(self, data):
# Fails with
# *** ValueError: operands could not be broadcast together
# with shapes (4,) (4,) (0,)
super().test_where_series(data, na_value)
super().test_where_series(data)

@pytest.mark.xfail(reason="Can't compare dicts.")
def test_searchsorted(self, data_for_sorting):
super().test_searchsorted(data_for_sorting)

@pytest.mark.xfail(reason="Can't compare dicts.")
def test_equals(self, data, na_value, as_series):
super().test_equals(data, na_value, as_series)
def test_equals(self, data, as_series):
super().test_equals(data, as_series)

@pytest.mark.skip("fill-value is interpreted as a dict of values")
def test_fillna_copy_frame(self, data_missing):
Expand Down
14 changes: 4 additions & 10 deletions pandas/tests/extension/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,9 +518,7 @@ def test_reduce_series_numeric(self, data, all_numeric_reductions, skipna, reque
super().test_reduce_series_numeric(data, all_numeric_reductions, skipna)

@pytest.mark.parametrize("skipna", [True, False])
def test_reduce_series_boolean(
self, data, all_boolean_reductions, skipna, na_value, request
):
def test_reduce_series_boolean(self, data, all_boolean_reductions, skipna, request):
pa_dtype = data.dtype.pyarrow_dtype
xfail_mark = pytest.mark.xfail(
raises=TypeError,
Expand Down Expand Up @@ -753,9 +751,7 @@ def test_value_counts_returns_pyarrow_int64(self, data):
result = data.value_counts()
assert result.dtype == ArrowDtype(pa.int64())

def test_argmin_argmax(
self, data_for_sorting, data_missing_for_sorting, na_value, request
):
def test_argmin_argmax(self, data_for_sorting, data_missing_for_sorting, request):
pa_dtype = data_for_sorting.dtype.pyarrow_dtype
if pa.types.is_decimal(pa_dtype) and pa_version_under7p0:
request.node.add_marker(
Expand All @@ -764,7 +760,7 @@ def test_argmin_argmax(
raises=pa.ArrowNotImplementedError,
)
)
super().test_argmin_argmax(data_for_sorting, data_missing_for_sorting, na_value)
super().test_argmin_argmax(data_for_sorting, data_missing_for_sorting)

@pytest.mark.parametrize(
"op_name, skipna, expected",
Expand Down Expand Up @@ -1033,9 +1029,7 @@ def _get_arith_xfail_marker(self, opname, pa_dtype):
def test_arith_series_with_scalar(self, data, all_arithmetic_operators, request):
pa_dtype = data.dtype.pyarrow_dtype

if all_arithmetic_operators == "__rmod__" and (
pa.types.is_string(pa_dtype) or pa.types.is_binary(pa_dtype)
):
if all_arithmetic_operators == "__rmod__" and (pa.types.is_binary(pa_dtype)):
pytest.skip("Skip testing Python string formatting")

mark = self._get_arith_xfail_marker(all_arithmetic_operators, pa_dtype)
Expand Down
4 changes: 2 additions & 2 deletions pandas/tests/extension/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,14 +165,14 @@ def test_arith_series_with_scalar(self, data, all_arithmetic_operators, request)
)
super().test_arith_series_with_scalar(data, op_name)

def _compare_other(self, s, data, op, other):
def _compare_other(self, ser: pd.Series, data, op, other):
op_name = f"__{op.__name__}__"
if op_name not in ["__eq__", "__ne__"]:
msg = "Unordered Categoricals can only compare equality or not"
with pytest.raises(TypeError, match=msg):
op(data, other)
else:
return super()._compare_other(s, data, op, other)
return super()._compare_other(ser, data, op, other)

@pytest.mark.xfail(reason="Categorical overrides __repr__")
@pytest.mark.parametrize("size", ["big", "small"])
Expand Down
7 changes: 7 additions & 0 deletions pandas/tests/extension/test_masked.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,13 @@ def check_accumulate(self, ser: pd.Series, op_name: str, skipna: bool):
else:
expected_dtype = f"Int{length}"

if expected_dtype == "Float32" and op_name == "cumprod" and skipna:
# TODO: xfail?
pytest.skip(
f"Float32 precision lead to large differences with op {op_name} "
f"and skipna={skipna}"
)

if op_name == "cumsum":
result = getattr(ser, op_name)(skipna=skipna)
expected = pd.Series(
Expand Down
Loading