Skip to content

Commit ae3fc3c

Browse files
authored
TST: clean ups in extension tests (#54719)
* REF: misplaced test_accumulate special-casing * REF: skip rmod test with string dtype * typo fixup * REF: remove unnecessary na_value fixture
1 parent 2bf1df9 commit ae3fc3c

13 files changed

+85
-72
lines changed

pandas/tests/extension/base/accumulate.py

-8
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,6 @@ def check_accumulate(self, ser: pd.Series, op_name: str, skipna: bool):
2323
alt = ser.astype(object)
2424

2525
result = getattr(ser, op_name)(skipna=skipna)
26-
27-
if result.dtype == pd.Float32Dtype() and op_name == "cumprod" and skipna:
28-
# TODO: avoid special-casing here
29-
pytest.skip(
30-
f"Float32 precision lead to large differences with op {op_name} "
31-
f"and skipna={skipna}"
32-
)
33-
3426
expected = getattr(alt, op_name)(skipna=skipna)
3527
tm.assert_series_equal(result, expected, check_dtype=False)
3628

pandas/tests/extension/base/constructors.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ def test_series_constructor(self, data):
3535
if hasattr(result._mgr, "blocks"):
3636
assert isinstance(result2._mgr.blocks[0], EABackedBlock)
3737

38-
def test_series_constructor_no_data_with_index(self, dtype, na_value):
38+
def test_series_constructor_no_data_with_index(self, dtype):
39+
na_value = dtype.na_value
3940
result = pd.Series(index=[1, 2, 3], dtype=dtype)
4041
expected = pd.Series([na_value] * 3, index=[1, 2, 3], dtype=dtype)
4142
tm.assert_series_equal(result, expected)
@@ -45,7 +46,8 @@ def test_series_constructor_no_data_with_index(self, dtype, na_value):
4546
expected = pd.Series([], index=pd.Index([], dtype="object"), dtype=dtype)
4647
tm.assert_series_equal(result, expected)
4748

48-
def test_series_constructor_scalar_na_with_index(self, dtype, na_value):
49+
def test_series_constructor_scalar_na_with_index(self, dtype):
50+
na_value = dtype.na_value
4951
result = pd.Series(na_value, index=[1, 2, 3], dtype=dtype)
5052
expected = pd.Series([na_value] * 3, index=[1, 2, 3], dtype=dtype)
5153
tm.assert_series_equal(result, expected)

pandas/tests/extension/base/getitem.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,8 @@ def test_getitem_invalid(self, data):
148148
with pytest.raises(IndexError, match=msg):
149149
data[-ub - 1]
150150

151-
def test_getitem_scalar_na(self, data_missing, na_cmp, na_value):
151+
def test_getitem_scalar_na(self, data_missing, na_cmp):
152+
na_value = data_missing.dtype.na_value
152153
result = data_missing[0]
153154
assert na_cmp(result, na_value)
154155

@@ -348,7 +349,8 @@ def test_take_sequence(self, data):
348349
assert result.iloc[1] == data[1]
349350
assert result.iloc[2] == data[3]
350351

351-
def test_take(self, data, na_value, na_cmp):
352+
def test_take(self, data, na_cmp):
353+
na_value = data.dtype.na_value
352354
result = data.take([0, -1])
353355
assert result.dtype == data.dtype
354356
assert result[0] == data[0]
@@ -361,7 +363,8 @@ def test_take(self, data, na_value, na_cmp):
361363
with pytest.raises(IndexError, match="out of bounds"):
362364
data.take([len(data) + 1])
363365

364-
def test_take_empty(self, data, na_value, na_cmp):
366+
def test_take_empty(self, data, na_cmp):
367+
na_value = data.dtype.na_value
365368
empty = data[:0]
366369

367370
result = empty.take([-1], allow_fill=True)
@@ -393,7 +396,8 @@ def test_take_non_na_fill_value(self, data_missing):
393396
expected = arr.take([1, 1])
394397
tm.assert_extension_array_equal(result, expected)
395398

396-
def test_take_pandas_style_negative_raises(self, data, na_value):
399+
def test_take_pandas_style_negative_raises(self, data):
400+
na_value = data.dtype.na_value
397401
with pytest.raises(ValueError, match=""):
398402
data.take([0, -2], fill_value=na_value, allow_fill=True)
399403

@@ -413,7 +417,8 @@ def test_take_series(self, data):
413417
)
414418
tm.assert_series_equal(result, expected)
415419

416-
def test_reindex(self, data, na_value):
420+
def test_reindex(self, data):
421+
na_value = data.dtype.na_value
417422
s = pd.Series(data)
418423
result = s.reindex([0, 1, 3])
419424
expected = pd.Series(data.take([0, 1, 3]), index=[0, 1, 3])

pandas/tests/extension/base/methods.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def test_argsort_missing(self, data_missing_for_sorting):
121121
expected = pd.Series(np.array([1, -1, 0], dtype=np.intp))
122122
tm.assert_series_equal(result, expected)
123123

124-
def test_argmin_argmax(self, data_for_sorting, data_missing_for_sorting, na_value):
124+
def test_argmin_argmax(self, data_for_sorting, data_missing_for_sorting):
125125
# GH 24382
126126
is_bool = data_for_sorting.dtype._is_boolean
127127

@@ -154,9 +154,10 @@ def test_argmin_argmax_empty_array(self, method, data):
154154
getattr(data[:0], method)()
155155

156156
@pytest.mark.parametrize("method", ["argmax", "argmin"])
157-
def test_argmin_argmax_all_na(self, method, data, na_value):
157+
def test_argmin_argmax_all_na(self, method, data):
158158
# all missing with skipna=True is the same as empty
159159
err_msg = "attempt to get"
160+
na_value = data.dtype.na_value
160161
data_na = type(data)._from_sequence([na_value, na_value], dtype=data.dtype)
161162
with pytest.raises(ValueError, match=err_msg):
162163
getattr(data_na, method)()
@@ -543,7 +544,8 @@ def _test_searchsorted_bool_dtypes(self, data_for_sorting, as_series):
543544
sorter = np.array([1, 0])
544545
assert data_for_sorting.searchsorted(a, sorter=sorter) == 0
545546

546-
def test_where_series(self, data, na_value, as_frame):
547+
def test_where_series(self, data, as_frame):
548+
na_value = data.dtype.na_value
547549
assert data[0] != data[1]
548550
cls = type(data)
549551
a, b = data[:2]
@@ -670,7 +672,8 @@ def test_insert_invalid_loc(self, data):
670672
data.insert(1.5, data[0])
671673

672674
@pytest.mark.parametrize("box", [pd.array, pd.Series, pd.DataFrame])
673-
def test_equals(self, data, na_value, as_series, box):
675+
def test_equals(self, data, as_series, box):
676+
na_value = data.dtype.na_value
674677
data2 = type(data)._from_sequence([data[0]] * len(data), dtype=data.dtype)
675678
data_na = type(data)._from_sequence([na_value] * len(data), dtype=data.dtype)
676679

pandas/tests/extension/base/ops.py

+8
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import numpy as np
66
import pytest
77

8+
from pandas.core.dtypes.common import is_string_dtype
9+
810
import pandas as pd
911
import pandas._testing as tm
1012
from pandas.core import ops
@@ -128,12 +130,18 @@ class BaseArithmeticOpsTests(BaseOpsUtil):
128130

129131
def test_arith_series_with_scalar(self, data, all_arithmetic_operators):
130132
# series & scalar
133+
if all_arithmetic_operators == "__rmod__" and is_string_dtype(data.dtype):
134+
pytest.skip("Skip testing Python string formatting")
135+
131136
op_name = all_arithmetic_operators
132137
ser = pd.Series(data)
133138
self.check_opname(ser, op_name, ser.iloc[0])
134139

135140
def test_arith_frame_with_scalar(self, data, all_arithmetic_operators):
136141
# frame & scalar
142+
if all_arithmetic_operators == "__rmod__" and is_string_dtype(data.dtype):
143+
pytest.skip("Skip testing Python string formatting")
144+
137145
op_name = all_arithmetic_operators
138146
df = pd.DataFrame({"A": data})
139147
self.check_opname(df, op_name, data[0])

pandas/tests/extension/base/reshaping.py

+13-7
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,8 @@ def test_concat_mixed_dtypes(self, data):
7272
expected = pd.concat([df1["A"].astype("object"), df2["A"].astype("object")])
7373
tm.assert_series_equal(result, expected)
7474

75-
def test_concat_columns(self, data, na_value):
75+
def test_concat_columns(self, data):
76+
na_value = data.dtype.na_value
7677
df1 = pd.DataFrame({"A": data[:3]})
7778
df2 = pd.DataFrame({"B": [1, 2, 3]})
7879

@@ -96,8 +97,9 @@ def test_concat_columns(self, data, na_value):
9697
result = pd.concat([df1["A"], df2["B"]], axis=1)
9798
tm.assert_frame_equal(result, expected)
9899

99-
def test_concat_extension_arrays_copy_false(self, data, na_value):
100+
def test_concat_extension_arrays_copy_false(self, data):
100101
# GH 20756
102+
na_value = data.dtype.na_value
101103
df1 = pd.DataFrame({"A": data[:3]})
102104
df2 = pd.DataFrame({"B": data[3:7]})
103105
expected = pd.DataFrame(
@@ -122,7 +124,8 @@ def test_concat_with_reindex(self, data):
122124
)
123125
tm.assert_frame_equal(result, expected)
124126

125-
def test_align(self, data, na_value):
127+
def test_align(self, data):
128+
na_value = data.dtype.na_value
126129
a = data[:3]
127130
b = data[2:5]
128131
r1, r2 = pd.Series(a).align(pd.Series(b, index=[1, 2, 3]))
@@ -133,7 +136,8 @@ def test_align(self, data, na_value):
133136
tm.assert_series_equal(r1, e1)
134137
tm.assert_series_equal(r2, e2)
135138

136-
def test_align_frame(self, data, na_value):
139+
def test_align_frame(self, data):
140+
na_value = data.dtype.na_value
137141
a = data[:3]
138142
b = data[2:5]
139143
r1, r2 = pd.DataFrame({"A": a}).align(pd.DataFrame({"A": b}, index=[1, 2, 3]))
@@ -148,8 +152,9 @@ def test_align_frame(self, data, na_value):
148152
tm.assert_frame_equal(r1, e1)
149153
tm.assert_frame_equal(r2, e2)
150154

151-
def test_align_series_frame(self, data, na_value):
155+
def test_align_series_frame(self, data):
152156
# https://github.com/pandas-dev/pandas/issues/20576
157+
na_value = data.dtype.na_value
153158
ser = pd.Series(data, name="a")
154159
df = pd.DataFrame({"col": np.arange(len(ser) + 1)})
155160
r1, r2 = ser.align(df)
@@ -180,7 +185,7 @@ def test_set_frame_overwrite_object(self, data):
180185
df["A"] = data
181186
assert df.dtypes["A"] == data.dtype
182187

183-
def test_merge(self, data, na_value):
188+
def test_merge(self, data):
184189
# GH-20743
185190
df1 = pd.DataFrame({"ext": data[:3], "int1": [1, 2, 3], "key": [0, 1, 2]})
186191
df2 = pd.DataFrame({"int2": [1, 2, 3, 4], "key": [0, 0, 1, 3]})
@@ -205,7 +210,8 @@ def test_merge(self, data, na_value):
205210
"int2": [1, 2, 3, np.nan, 4],
206211
"key": [0, 0, 1, 2, 3],
207212
"ext": data._from_sequence(
208-
[data[0], data[0], data[1], data[2], na_value], dtype=data.dtype
213+
[data[0], data[0], data[1], data[2], data.dtype.na_value],
214+
dtype=data.dtype,
209215
),
210216
}
211217
)

pandas/tests/extension/base/setitem.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,8 @@ def test_setitem_with_expansion_dataframe_column(self, data, full_indexer):
359359

360360
tm.assert_frame_equal(result, expected)
361361

362-
def test_setitem_with_expansion_row(self, data, na_value):
362+
def test_setitem_with_expansion_row(self, data):
363+
na_value = data.dtype.na_value
363364
df = pd.DataFrame({"data": data[:1]})
364365

365366
df.loc[1, "data"] = data[1]

pandas/tests/extension/conftest.py

-6
Original file line numberDiff line numberDiff line change
@@ -116,12 +116,6 @@ def na_cmp():
116116
return operator.is_
117117

118118

119-
@pytest.fixture
120-
def na_value(dtype):
121-
"""The scalar missing value for this type. Default dtype.na_value"""
122-
return dtype.na_value
123-
124-
125119
@pytest.fixture
126120
def data_for_grouping():
127121
"""

pandas/tests/extension/json/test_json.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -97,24 +97,24 @@ def test_from_dtype(self, data):
9797
super().test_from_dtype(data)
9898

9999
@pytest.mark.xfail(reason="RecursionError, GH-33900")
100-
def test_series_constructor_no_data_with_index(self, dtype, na_value):
100+
def test_series_constructor_no_data_with_index(self, dtype):
101101
# RecursionError: maximum recursion depth exceeded in comparison
102102
rec_limit = sys.getrecursionlimit()
103103
try:
104104
# Limit to avoid stack overflow on Windows CI
105105
sys.setrecursionlimit(100)
106-
super().test_series_constructor_no_data_with_index(dtype, na_value)
106+
super().test_series_constructor_no_data_with_index(dtype)
107107
finally:
108108
sys.setrecursionlimit(rec_limit)
109109

110110
@pytest.mark.xfail(reason="RecursionError, GH-33900")
111-
def test_series_constructor_scalar_na_with_index(self, dtype, na_value):
111+
def test_series_constructor_scalar_na_with_index(self, dtype):
112112
# RecursionError: maximum recursion depth exceeded in comparison
113113
rec_limit = sys.getrecursionlimit()
114114
try:
115115
# Limit to avoid stack overflow on Windows CI
116116
sys.setrecursionlimit(100)
117-
super().test_series_constructor_scalar_na_with_index(dtype, na_value)
117+
super().test_series_constructor_scalar_na_with_index(dtype)
118118
finally:
119119
sys.setrecursionlimit(rec_limit)
120120

@@ -214,19 +214,19 @@ def test_combine_first(self, data):
214214
super().test_combine_first(data)
215215

216216
@pytest.mark.xfail(reason="broadcasting error")
217-
def test_where_series(self, data, na_value):
217+
def test_where_series(self, data):
218218
# Fails with
219219
# *** ValueError: operands could not be broadcast together
220220
# with shapes (4,) (4,) (0,)
221-
super().test_where_series(data, na_value)
221+
super().test_where_series(data)
222222

223223
@pytest.mark.xfail(reason="Can't compare dicts.")
224224
def test_searchsorted(self, data_for_sorting):
225225
super().test_searchsorted(data_for_sorting)
226226

227227
@pytest.mark.xfail(reason="Can't compare dicts.")
228-
def test_equals(self, data, na_value, as_series):
229-
super().test_equals(data, na_value, as_series)
228+
def test_equals(self, data, as_series):
229+
super().test_equals(data, as_series)
230230

231231
@pytest.mark.skip("fill-value is interpreted as a dict of values")
232232
def test_fillna_copy_frame(self, data_missing):

pandas/tests/extension/test_arrow.py

+4-10
Original file line numberDiff line numberDiff line change
@@ -518,9 +518,7 @@ def test_reduce_series_numeric(self, data, all_numeric_reductions, skipna, reque
518518
super().test_reduce_series_numeric(data, all_numeric_reductions, skipna)
519519

520520
@pytest.mark.parametrize("skipna", [True, False])
521-
def test_reduce_series_boolean(
522-
self, data, all_boolean_reductions, skipna, na_value, request
523-
):
521+
def test_reduce_series_boolean(self, data, all_boolean_reductions, skipna, request):
524522
pa_dtype = data.dtype.pyarrow_dtype
525523
xfail_mark = pytest.mark.xfail(
526524
raises=TypeError,
@@ -753,9 +751,7 @@ def test_value_counts_returns_pyarrow_int64(self, data):
753751
result = data.value_counts()
754752
assert result.dtype == ArrowDtype(pa.int64())
755753

756-
def test_argmin_argmax(
757-
self, data_for_sorting, data_missing_for_sorting, na_value, request
758-
):
754+
def test_argmin_argmax(self, data_for_sorting, data_missing_for_sorting, request):
759755
pa_dtype = data_for_sorting.dtype.pyarrow_dtype
760756
if pa.types.is_decimal(pa_dtype) and pa_version_under7p0:
761757
request.node.add_marker(
@@ -764,7 +760,7 @@ def test_argmin_argmax(
764760
raises=pa.ArrowNotImplementedError,
765761
)
766762
)
767-
super().test_argmin_argmax(data_for_sorting, data_missing_for_sorting, na_value)
763+
super().test_argmin_argmax(data_for_sorting, data_missing_for_sorting)
768764

769765
@pytest.mark.parametrize(
770766
"op_name, skipna, expected",
@@ -1033,9 +1029,7 @@ def _get_arith_xfail_marker(self, opname, pa_dtype):
10331029
def test_arith_series_with_scalar(self, data, all_arithmetic_operators, request):
10341030
pa_dtype = data.dtype.pyarrow_dtype
10351031

1036-
if all_arithmetic_operators == "__rmod__" and (
1037-
pa.types.is_string(pa_dtype) or pa.types.is_binary(pa_dtype)
1038-
):
1032+
if all_arithmetic_operators == "__rmod__" and (pa.types.is_binary(pa_dtype)):
10391033
pytest.skip("Skip testing Python string formatting")
10401034

10411035
mark = self._get_arith_xfail_marker(all_arithmetic_operators, pa_dtype)

pandas/tests/extension/test_categorical.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -165,14 +165,14 @@ def test_arith_series_with_scalar(self, data, all_arithmetic_operators, request)
165165
)
166166
super().test_arith_series_with_scalar(data, op_name)
167167

168-
def _compare_other(self, s, data, op, other):
168+
def _compare_other(self, ser: pd.Series, data, op, other):
169169
op_name = f"__{op.__name__}__"
170170
if op_name not in ["__eq__", "__ne__"]:
171171
msg = "Unordered Categoricals can only compare equality or not"
172172
with pytest.raises(TypeError, match=msg):
173173
op(data, other)
174174
else:
175-
return super()._compare_other(s, data, op, other)
175+
return super()._compare_other(ser, data, op, other)
176176

177177
@pytest.mark.xfail(reason="Categorical overrides __repr__")
178178
@pytest.mark.parametrize("size", ["big", "small"])

pandas/tests/extension/test_masked.py

+7
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,13 @@ def check_accumulate(self, ser: pd.Series, op_name: str, skipna: bool):
323323
else:
324324
expected_dtype = f"Int{length}"
325325

326+
if expected_dtype == "Float32" and op_name == "cumprod" and skipna:
327+
# TODO: xfail?
328+
pytest.skip(
329+
f"Float32 precision lead to large differences with op {op_name} "
330+
f"and skipna={skipna}"
331+
)
332+
326333
if op_name == "cumsum":
327334
result = getattr(ser, op_name)(skipna=skipna)
328335
expected = pd.Series(

0 commit comments

Comments
 (0)