Skip to content

Commit 24c0c1a

Browse files
authored
REF: update to use extension test patterns (#54436)
1 parent 71a6d53 commit 24c0c1a

File tree

3 files changed

+24
-32
lines changed

3 files changed

+24
-32
lines changed

pandas/tests/extension/base/reduce.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,15 @@ def check_reduce(self, s, op_name, skipna):
2323
# that the results match. Override if you need to cast to something
2424
# other than float64.
2525
res_op = getattr(s, op_name)
26-
exp_op = getattr(s.astype("float64"), op_name)
26+
27+
try:
28+
alt = s.astype("float64")
29+
except TypeError:
30+
# e.g. Interval can't cast, so let's cast to object and do
31+
# the reduction pointwise
32+
alt = s.astype(object)
33+
34+
exp_op = getattr(alt, op_name)
2735
if op_name == "count":
2836
result = res_op()
2937
expected = exp_op()

pandas/tests/extension/decimal/test_decimal.py

+11-12
Original file line numberDiff line numberDiff line change
@@ -152,14 +152,8 @@ def _supports_reduction(self, obj, op_name: str) -> bool:
152152
return True
153153

154154
def check_reduce(self, s, op_name, skipna):
155-
if op_name in ["median", "skew", "kurt", "sem"]:
156-
msg = r"decimal does not support the .* operation"
157-
with pytest.raises(NotImplementedError, match=msg):
158-
getattr(s, op_name)(skipna=skipna)
159-
elif op_name == "count":
160-
result = getattr(s, op_name)()
161-
expected = len(s) - s.isna().sum()
162-
tm.assert_almost_equal(result, expected)
155+
if op_name == "count":
156+
return super().check_reduce(s, op_name, skipna)
163157
else:
164158
result = getattr(s, op_name)(skipna=skipna)
165159
expected = getattr(np.asarray(s), op_name)()
@@ -189,12 +183,17 @@ def _reduce(self, name: str, *, skipna: bool = True, **kwargs):
189183

190184

191185
class TestReduce(Reduce, base.BaseReduceTests):
192-
@pytest.mark.parametrize("skipna", [True, False])
193-
def test_reduce_frame(self, data, all_numeric_reductions, skipna):
186+
def test_reduce_series_numeric(self, data, all_numeric_reductions, skipna, request):
187+
if all_numeric_reductions in ["kurt", "skew", "sem", "median"]:
188+
mark = pytest.mark.xfail(raises=NotImplementedError)
189+
request.node.add_marker(mark)
190+
super().test_reduce_series_numeric(data, all_numeric_reductions, skipna)
191+
192+
def test_reduce_frame(self, data, all_numeric_reductions, skipna, request):
194193
op_name = all_numeric_reductions
195194
if op_name in ["skew", "median"]:
196-
assert not hasattr(data, op_name)
197-
pytest.skip(f"{op_name} not an array method")
195+
mark = pytest.mark.xfail(raises=NotImplementedError)
196+
request.node.add_marker(mark)
198197

199198
return super().test_reduce_frame(data, all_numeric_reductions, skipna)
200199

pandas/tests/extension/test_interval.py

+4-19
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,7 @@
1818

1919
from pandas.core.dtypes.dtypes import IntervalDtype
2020

21-
from pandas import (
22-
Interval,
23-
Series,
24-
)
21+
from pandas import Interval
2522
from pandas.core.arrays import IntervalArray
2623
from pandas.tests.extension import base
2724

@@ -106,18 +103,8 @@ class TestInterface(BaseInterval, base.BaseInterfaceTests):
106103

107104

108105
class TestReduce(base.BaseReduceTests):
109-
@pytest.mark.parametrize("skipna", [True, False])
110-
def test_reduce_series_numeric(self, data, all_numeric_reductions, skipna):
111-
op_name = all_numeric_reductions
112-
ser = Series(data)
113-
114-
if op_name in ["min", "max"]:
115-
# IntervalArray *does* implement these
116-
assert getattr(ser, op_name)(skipna=skipna) in data
117-
assert getattr(data, op_name)(skipna=skipna) in data
118-
return
119-
120-
super().test_reduce_series_numeric(data, all_numeric_reductions, skipna)
106+
def _supports_reduction(self, obj, op_name: str) -> bool:
107+
return op_name in ["min", "max"]
121108

122109

123110
class TestMethods(BaseInterval, base.BaseMethodsTests):
@@ -145,9 +132,7 @@ class TestSetitem(BaseInterval, base.BaseSetitemTests):
145132

146133

147134
class TestPrinting(BaseInterval, base.BasePrintingTests):
148-
@pytest.mark.xfail(reason="Interval has custom repr")
149-
def test_array_repr(self, data, size):
150-
super().test_array_repr()
135+
pass
151136

152137

153138
class TestParsing(BaseInterval, base.BaseParsingTests):

0 commit comments

Comments
 (0)