Skip to content

Commit 22daf77

Browse files
authored
REF: de-duplicate some test code (#52228)
* REF: de-duplicate some test code * mypy fixup * mypy fixup
1 parent 5045a99 commit 22daf77

File tree

5 files changed

+270
-380
lines changed

5 files changed

+270
-380
lines changed
+121
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
"""
2+
Shared test code for IntegerArray/FloatingArray/BooleanArray.
3+
"""
4+
import pytest
5+
6+
from pandas.compat import (
7+
IS64,
8+
is_platform_windows,
9+
)
10+
11+
import pandas as pd
12+
import pandas._testing as tm
13+
from pandas.tests.extension import base
14+
15+
16+
class Arithmetic(base.BaseArithmeticOpsTests):
17+
def check_opname(self, ser: pd.Series, op_name: str, other, exc=None):
18+
# overwriting to indicate ops don't raise an error
19+
super().check_opname(ser, op_name, other, exc=None)
20+
21+
def _check_divmod_op(self, ser: pd.Series, op, other, exc=None):
22+
super()._check_divmod_op(ser, op, other, None)
23+
24+
25+
class Comparison(base.BaseComparisonOpsTests):
26+
def _check_op(
27+
self, ser: pd.Series, op, other, op_name: str, exc=NotImplementedError
28+
):
29+
if exc is None:
30+
result = op(ser, other)
31+
# Override to do the astype to boolean
32+
expected = ser.combine(other, op).astype("boolean")
33+
self.assert_series_equal(result, expected)
34+
else:
35+
with pytest.raises(exc):
36+
op(ser, other)
37+
38+
def check_opname(self, ser: pd.Series, op_name: str, other, exc=None):
39+
super().check_opname(ser, op_name, other, exc=None)
40+
41+
def _compare_other(self, ser: pd.Series, data, op, other):
42+
op_name = f"__{op.__name__}__"
43+
self.check_opname(ser, op_name, other)
44+
45+
46+
class NumericReduce(base.BaseNumericReduceTests):
47+
def check_reduce(self, ser: pd.Series, op_name: str, skipna: bool):
48+
# overwrite to ensure pd.NA is tested instead of np.nan
49+
# https://github.com/pandas-dev/pandas/issues/30958
50+
51+
cmp_dtype = "int64"
52+
if ser.dtype.kind == "f":
53+
# Item "dtype[Any]" of "Union[dtype[Any], ExtensionDtype]" has
54+
# no attribute "numpy_dtype"
55+
cmp_dtype = ser.dtype.numpy_dtype # type: ignore[union-attr]
56+
57+
if op_name == "count":
58+
result = getattr(ser, op_name)()
59+
expected = getattr(ser.dropna().astype(cmp_dtype), op_name)()
60+
else:
61+
result = getattr(ser, op_name)(skipna=skipna)
62+
expected = getattr(ser.dropna().astype(cmp_dtype), op_name)(skipna=skipna)
63+
if not skipna and ser.isna().any():
64+
expected = pd.NA
65+
tm.assert_almost_equal(result, expected)
66+
67+
68+
class Accumulation(base.BaseAccumulateTests):
69+
@pytest.mark.parametrize("skipna", [True, False])
70+
def test_accumulate_series_raises(self, data, all_numeric_accumulations, skipna):
71+
pass
72+
73+
def check_accumulate(self, ser: pd.Series, op_name: str, skipna: bool):
74+
# overwrite to ensure pd.NA is tested instead of np.nan
75+
# https://github.com/pandas-dev/pandas/issues/30958
76+
length = 64
77+
if not IS64 or is_platform_windows():
78+
# Item "ExtensionDtype" of "Union[dtype[Any], ExtensionDtype]" has
79+
# no attribute "itemsize"
80+
if not ser.dtype.itemsize == 8: # type: ignore[union-attr]
81+
length = 32
82+
83+
if ser.dtype.name.startswith("U"):
84+
expected_dtype = f"UInt{length}"
85+
elif ser.dtype.name.startswith("I"):
86+
expected_dtype = f"Int{length}"
87+
elif ser.dtype.name.startswith("F"):
88+
# Incompatible types in assignment (expression has type
89+
# "Union[dtype[Any], ExtensionDtype]", variable has type "str")
90+
expected_dtype = ser.dtype # type: ignore[assignment]
91+
92+
if op_name == "cumsum":
93+
result = getattr(ser, op_name)(skipna=skipna)
94+
expected = pd.Series(
95+
pd.array(
96+
getattr(ser.astype("float64"), op_name)(skipna=skipna),
97+
dtype=expected_dtype,
98+
)
99+
)
100+
tm.assert_series_equal(result, expected)
101+
elif op_name in ["cummax", "cummin"]:
102+
result = getattr(ser, op_name)(skipna=skipna)
103+
expected = pd.Series(
104+
pd.array(
105+
getattr(ser.astype("float64"), op_name)(skipna=skipna),
106+
dtype=ser.dtype,
107+
)
108+
)
109+
tm.assert_series_equal(result, expected)
110+
elif op_name == "cumprod":
111+
result = getattr(ser[:12], op_name)(skipna=skipna)
112+
expected = pd.Series(
113+
pd.array(
114+
getattr(ser[:12].astype("float64"), op_name)(skipna=skipna),
115+
dtype=expected_dtype,
116+
)
117+
)
118+
tm.assert_series_equal(result, expected)
119+
120+
else:
121+
raise NotImplementedError(f"{op_name} not supported")

pandas/tests/extension/test_floating.py

+11-46
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,10 @@
2525
Float32Dtype,
2626
Float64Dtype,
2727
)
28-
from pandas.tests.extension import base
28+
from pandas.tests.extension import (
29+
base,
30+
masked_shared,
31+
)
2932

3033

3134
def make_data():
@@ -92,11 +95,7 @@ class TestDtype(base.BaseDtypeTests):
9295
pass
9396

9497

95-
class TestArithmeticOps(base.BaseArithmeticOpsTests):
96-
def check_opname(self, s, op_name, other, exc=None):
97-
# overwriting to indicate ops don't raise an error
98-
super().check_opname(s, op_name, other, exc=None)
99-
98+
class TestArithmeticOps(masked_shared.Arithmetic):
10099
def _check_op(self, s, op, other, op_name, exc=NotImplementedError):
101100
if exc is None:
102101
sdtype = tm.get_dtype(s)
@@ -120,28 +119,9 @@ def _check_op(self, s, op, other, op_name, exc=NotImplementedError):
120119
with pytest.raises(exc):
121120
op(s, other)
122121

123-
def _check_divmod_op(self, s, op, other, exc=None):
124-
super()._check_divmod_op(s, op, other, None)
125122

126-
127-
class TestComparisonOps(base.BaseComparisonOpsTests):
128-
# TODO: share with IntegerArray?
129-
def _check_op(self, s, op, other, op_name, exc=NotImplementedError):
130-
if exc is None:
131-
result = op(s, other)
132-
# Override to do the astype to boolean
133-
expected = s.combine(other, op).astype("boolean")
134-
self.assert_series_equal(result, expected)
135-
else:
136-
with pytest.raises(exc):
137-
op(s, other)
138-
139-
def check_opname(self, s, op_name, other, exc=None):
140-
super().check_opname(s, op_name, other, exc=None)
141-
142-
def _compare_other(self, s, data, op, other):
143-
op_name = f"__{op.__name__}__"
144-
self.check_opname(s, op_name, other)
123+
class TestComparisonOps(masked_shared.Comparison):
124+
pass
145125

146126

147127
class TestInterface(base.BaseInterfaceTests):
@@ -184,21 +164,8 @@ class TestGroupby(base.BaseGroupbyTests):
184164
pass
185165

186166

187-
class TestNumericReduce(base.BaseNumericReduceTests):
188-
def check_reduce(self, s, op_name, skipna):
189-
# overwrite to ensure pd.NA is tested instead of np.nan
190-
# https://github.com/pandas-dev/pandas/issues/30958
191-
if op_name == "count":
192-
result = getattr(s, op_name)()
193-
expected = getattr(s.dropna().astype(s.dtype.numpy_dtype), op_name)()
194-
else:
195-
result = getattr(s, op_name)(skipna=skipna)
196-
expected = getattr(s.dropna().astype(s.dtype.numpy_dtype), op_name)(
197-
skipna=skipna
198-
)
199-
if not skipna and s.isna().any():
200-
expected = pd.NA
201-
tm.assert_almost_equal(result, expected)
167+
class TestNumericReduce(masked_shared.NumericReduce):
168+
pass
202169

203170

204171
@pytest.mark.skip(reason="Tested in tests/reductions/test_reductions.py")
@@ -219,7 +186,5 @@ class Test2DCompat(base.Dim2CompatTests):
219186
pass
220187

221188

222-
class TestAccumulation(base.BaseAccumulateTests):
223-
@pytest.mark.parametrize("skipna", [True, False])
224-
def test_accumulate_series_raises(self, data, all_numeric_accumulations, skipna):
225-
pass
189+
class TestAccumulation(masked_shared.Accumulation):
190+
pass

pandas/tests/extension/test_integer.py

+11-92
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,6 @@
1616
import numpy as np
1717
import pytest
1818

19-
from pandas.compat import (
20-
IS64,
21-
is_platform_windows,
22-
)
23-
2419
import pandas as pd
2520
import pandas._testing as tm
2621
from pandas.api.types import (
@@ -37,7 +32,10 @@
3732
UInt32Dtype,
3833
UInt64Dtype,
3934
)
40-
from pandas.tests.extension import base
35+
from pandas.tests.extension import (
36+
base,
37+
masked_shared,
38+
)
4139

4240

4341
def make_data():
@@ -109,11 +107,7 @@ class TestDtype(base.BaseDtypeTests):
109107
pass
110108

111109

112-
class TestArithmeticOps(base.BaseArithmeticOpsTests):
113-
def check_opname(self, s, op_name, other, exc=None):
114-
# overwriting to indicate ops don't raise an error
115-
super().check_opname(s, op_name, other, exc=None)
116-
110+
class TestArithmeticOps(masked_shared.Arithmetic):
117111
def _check_op(self, s, op, other, op_name, exc=NotImplementedError):
118112
if exc is None:
119113
sdtype = tm.get_dtype(s)
@@ -145,27 +139,9 @@ def _check_op(self, s, op, other, op_name, exc=NotImplementedError):
145139
with pytest.raises(exc):
146140
op(s, other)
147141

148-
def _check_divmod_op(self, s, op, other, exc=None):
149-
super()._check_divmod_op(s, op, other, None)
150-
151-
152-
class TestComparisonOps(base.BaseComparisonOpsTests):
153-
def _check_op(self, s, op, other, op_name, exc=NotImplementedError):
154-
if exc is None:
155-
result = op(s, other)
156-
# Override to do the astype to boolean
157-
expected = s.combine(other, op).astype("boolean")
158-
self.assert_series_equal(result, expected)
159-
else:
160-
with pytest.raises(exc):
161-
op(s, other)
162-
163-
def check_opname(self, s, op_name, other, exc=None):
164-
super().check_opname(s, op_name, other, exc=None)
165142

166-
def _compare_other(self, s, data, op, other):
167-
op_name = f"__{op.__name__}__"
168-
self.check_opname(s, op_name, other)
143+
class TestComparisonOps(masked_shared.Comparison):
144+
pass
169145

170146

171147
class TestInterface(base.BaseInterfaceTests):
@@ -212,74 +188,17 @@ class TestGroupby(base.BaseGroupbyTests):
212188
pass
213189

214190

215-
class TestNumericReduce(base.BaseNumericReduceTests):
216-
def check_reduce(self, s, op_name, skipna):
217-
# overwrite to ensure pd.NA is tested instead of np.nan
218-
# https://github.com/pandas-dev/pandas/issues/30958
219-
if op_name == "count":
220-
result = getattr(s, op_name)()
221-
expected = getattr(s.dropna().astype("int64"), op_name)()
222-
else:
223-
result = getattr(s, op_name)(skipna=skipna)
224-
expected = getattr(s.dropna().astype("int64"), op_name)(skipna=skipna)
225-
if not skipna and s.isna().any():
226-
expected = pd.NA
227-
tm.assert_almost_equal(result, expected)
191+
class TestNumericReduce(masked_shared.NumericReduce):
192+
pass
228193

229194

230195
@pytest.mark.skip(reason="Tested in tests/reductions/test_reductions.py")
231196
class TestBooleanReduce(base.BaseBooleanReduceTests):
232197
pass
233198

234199

235-
class TestAccumulation(base.BaseAccumulateTests):
236-
def check_accumulate(self, s, op_name, skipna):
237-
# overwrite to ensure pd.NA is tested instead of np.nan
238-
# https://github.com/pandas-dev/pandas/issues/30958
239-
length = 64
240-
if not IS64 or is_platform_windows():
241-
if not s.dtype.itemsize == 8:
242-
length = 32
243-
244-
if s.dtype.name.startswith("U"):
245-
expected_dtype = f"UInt{length}"
246-
else:
247-
expected_dtype = f"Int{length}"
248-
249-
if op_name == "cumsum":
250-
result = getattr(s, op_name)(skipna=skipna)
251-
expected = pd.Series(
252-
pd.array(
253-
getattr(s.astype("float64"), op_name)(skipna=skipna),
254-
dtype=expected_dtype,
255-
)
256-
)
257-
tm.assert_series_equal(result, expected)
258-
elif op_name in ["cummax", "cummin"]:
259-
result = getattr(s, op_name)(skipna=skipna)
260-
expected = pd.Series(
261-
pd.array(
262-
getattr(s.astype("float64"), op_name)(skipna=skipna),
263-
dtype=s.dtype,
264-
)
265-
)
266-
tm.assert_series_equal(result, expected)
267-
elif op_name == "cumprod":
268-
result = getattr(s[:12], op_name)(skipna=skipna)
269-
expected = pd.Series(
270-
pd.array(
271-
getattr(s[:12].astype("float64"), op_name)(skipna=skipna),
272-
dtype=expected_dtype,
273-
)
274-
)
275-
tm.assert_series_equal(result, expected)
276-
277-
else:
278-
raise NotImplementedError(f"{op_name} not supported")
279-
280-
@pytest.mark.parametrize("skipna", [True, False])
281-
def test_accumulate_series_raises(self, data, all_numeric_accumulations, skipna):
282-
pass
200+
class TestAccumulation(masked_shared.Accumulation):
201+
pass
283202

284203

285204
class TestPrinting(base.BasePrintingTests):

0 commit comments

Comments
 (0)