Skip to content

Commit 5459d10

Browse files
authored
REF: use single-test-class for datetimetz, period tests (#54676)
* REF: use single-test-class for datetimetz, period tests * mypy fixup
1 parent 9268055 commit 5459d10

File tree

3 files changed

+88
-121
lines changed

3 files changed

+88
-121
lines changed

pandas/tests/extension/base/accumulate.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,12 @@ def _supports_accumulation(self, ser: pd.Series, op_name: str) -> bool:
1616
return False
1717

1818
def check_accumulate(self, ser: pd.Series, op_name: str, skipna: bool):
19-
alt = ser.astype("float64")
19+
try:
20+
alt = ser.astype("float64")
21+
except TypeError:
22+
# e.g. Period can't be cast to float64
23+
alt = ser.astype(object)
24+
2025
result = getattr(ser, op_name)(skipna=skipna)
2126

2227
if result.dtype == pd.Float32Dtype() and op_name == "cumprod" and skipna:
@@ -37,5 +42,6 @@ def test_accumulate_series(self, data, all_numeric_accumulations, skipna):
3742
if self._supports_accumulation(ser, op_name):
3843
self.check_accumulate(ser, op_name, skipna)
3944
else:
40-
with pytest.raises(NotImplementedError):
45+
with pytest.raises((NotImplementedError, TypeError)):
46+
# TODO: require TypeError for things that will _never_ work?
4147
getattr(ser, op_name)(skipna=skipna)

pandas/tests/extension/test_datetime.py

+45-60
Original file line numberDiff line numberDiff line change
@@ -82,78 +82,63 @@ def cmp(a, b):
8282

8383

8484
# ----------------------------------------------------------------------------
85-
class BaseDatetimeTests:
86-
pass
85+
class TestDatetimeArray(base.ExtensionTests):
86+
def _get_expected_exception(self, op_name, obj, other):
87+
if op_name in ["__sub__", "__rsub__"]:
88+
return None
89+
return super()._get_expected_exception(op_name, obj, other)
8790

91+
def _supports_accumulation(self, ser, op_name: str) -> bool:
92+
return op_name in ["cummin", "cummax"]
8893

89-
# ----------------------------------------------------------------------------
90-
# Tests
91-
class TestDatetimeDtype(BaseDatetimeTests, base.BaseDtypeTests):
92-
pass
94+
def _supports_reduction(self, obj, op_name: str) -> bool:
95+
return op_name in ["min", "max", "median", "mean", "std", "any", "all"]
9396

97+
@pytest.mark.parametrize("skipna", [True, False])
98+
def test_reduce_series_boolean(self, data, all_boolean_reductions, skipna):
99+
meth = all_boolean_reductions
100+
msg = f"'{meth}' with datetime64 dtypes is deprecated and will raise in"
101+
with tm.assert_produces_warning(
102+
FutureWarning, match=msg, check_stacklevel=False
103+
):
104+
super().test_reduce_series_boolean(data, all_boolean_reductions, skipna)
94105

95-
class TestConstructors(BaseDatetimeTests, base.BaseConstructorsTests):
96106
def test_series_constructor(self, data):
97107
# Series construction drops any .freq attr
98108
data = data._with_freq(None)
99109
super().test_series_constructor(data)
100110

101-
102-
class TestGetitem(BaseDatetimeTests, base.BaseGetitemTests):
103-
pass
104-
105-
106-
class TestIndex(base.BaseIndexTests):
107-
pass
108-
109-
110-
class TestMethods(BaseDatetimeTests, base.BaseMethodsTests):
111111
@pytest.mark.parametrize("na_action", [None, "ignore"])
112112
def test_map(self, data, na_action):
113113
result = data.map(lambda x: x, na_action=na_action)
114114
tm.assert_extension_array_equal(result, data)
115115

116-
117-
class TestInterface(BaseDatetimeTests, base.BaseInterfaceTests):
118-
pass
119-
120-
121-
class TestArithmeticOps(BaseDatetimeTests, base.BaseArithmeticOpsTests):
122-
implements = {"__sub__", "__rsub__"}
123-
124-
def _get_expected_exception(self, op_name, obj, other):
125-
if op_name in self.implements:
126-
return None
127-
return super()._get_expected_exception(op_name, obj, other)
128-
129-
130-
class TestCasting(BaseDatetimeTests, base.BaseCastingTests):
131-
pass
132-
133-
134-
class TestComparisonOps(BaseDatetimeTests, base.BaseComparisonOpsTests):
135-
pass
136-
137-
138-
class TestMissing(BaseDatetimeTests, base.BaseMissingTests):
139-
pass
140-
141-
142-
class TestReshaping(BaseDatetimeTests, base.BaseReshapingTests):
143-
pass
144-
145-
146-
class TestSetitem(BaseDatetimeTests, base.BaseSetitemTests):
147-
pass
148-
149-
150-
class TestGroupby(BaseDatetimeTests, base.BaseGroupbyTests):
151-
pass
152-
153-
154-
class TestPrinting(BaseDatetimeTests, base.BasePrintingTests):
155-
pass
156-
157-
158-
class Test2DCompat(BaseDatetimeTests, base.NDArrayBacked2DTests):
116+
@pytest.mark.parametrize("engine", ["c", "python"])
117+
def test_EA_types(self, engine, data):
118+
expected_msg = r".*must implement _from_sequence_of_strings.*"
119+
with pytest.raises(NotImplementedError, match=expected_msg):
120+
super().test_EA_types(engine, data)
121+
122+
def check_reduce(self, ser: pd.Series, op_name: str, skipna: bool):
123+
if op_name in ["median", "mean", "std"]:
124+
alt = ser.astype("int64")
125+
126+
res_op = getattr(ser, op_name)
127+
exp_op = getattr(alt, op_name)
128+
result = res_op(skipna=skipna)
129+
expected = exp_op(skipna=skipna)
130+
if op_name in ["mean", "median"]:
131+
# error: Item "dtype[Any]" of "dtype[Any] | ExtensionDtype"
132+
# has no attribute "tz"
133+
tz = ser.dtype.tz # type: ignore[union-attr]
134+
expected = pd.Timestamp(expected, tz=tz)
135+
else:
136+
expected = pd.Timedelta(expected)
137+
tm.assert_almost_equal(result, expected)
138+
139+
else:
140+
return super().check_reduce(ser, op_name, skipna)
141+
142+
143+
class Test2DCompat(base.NDArrayBacked2DTests):
159144
pass

pandas/tests/extension/test_period.py

+35-59
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,17 @@
1313
be added to the array-specific tests in `pandas/tests/arrays/`.
1414
1515
"""
16+
from __future__ import annotations
17+
18+
from typing import TYPE_CHECKING
19+
1620
import numpy as np
1721
import pytest
1822

19-
from pandas._libs import iNaT
23+
from pandas._libs import (
24+
Period,
25+
iNaT,
26+
)
2027
from pandas.compat import is_platform_windows
2128
from pandas.compat.numpy import np_version_gte1p24
2229

@@ -26,6 +33,9 @@
2633
from pandas.core.arrays import PeriodArray
2734
from pandas.tests.extension import base
2835

36+
if TYPE_CHECKING:
37+
import pandas as pd
38+
2939

3040
@pytest.fixture(params=["D", "2D"])
3141
def dtype(request):
@@ -61,27 +71,36 @@ def data_for_grouping(dtype):
6171
return PeriodArray([B, B, NA, NA, A, A, B, C], dtype=dtype)
6272

6373

64-
class BasePeriodTests:
65-
pass
66-
67-
68-
class TestPeriodDtype(BasePeriodTests, base.BaseDtypeTests):
69-
pass
74+
class TestPeriodArray(base.ExtensionTests):
75+
def _get_expected_exception(self, op_name, obj, other):
76+
if op_name in ("__sub__", "__rsub__"):
77+
return None
78+
return super()._get_expected_exception(op_name, obj, other)
7079

80+
def _supports_accumulation(self, ser, op_name: str) -> bool:
81+
return op_name in ["cummin", "cummax"]
7182

72-
class TestConstructors(BasePeriodTests, base.BaseConstructorsTests):
73-
pass
83+
def _supports_reduction(self, obj, op_name: str) -> bool:
84+
return op_name in ["min", "max", "median"]
7485

86+
def check_reduce(self, ser: pd.Series, op_name: str, skipna: bool):
87+
if op_name == "median":
88+
res_op = getattr(ser, op_name)
7589

76-
class TestGetitem(BasePeriodTests, base.BaseGetitemTests):
77-
pass
90+
alt = ser.astype("int64")
7891

92+
exp_op = getattr(alt, op_name)
93+
result = res_op(skipna=skipna)
94+
expected = exp_op(skipna=skipna)
95+
# error: Item "dtype[Any]" of "dtype[Any] | ExtensionDtype" has no
96+
# attribute "freq"
97+
freq = ser.dtype.freq # type: ignore[union-attr]
98+
expected = Period._from_ordinal(int(expected), freq=freq)
99+
tm.assert_almost_equal(result, expected)
79100

80-
class TestIndex(base.BaseIndexTests):
81-
pass
82-
101+
else:
102+
return super().check_reduce(ser, op_name, skipna)
83103

84-
class TestMethods(BasePeriodTests, base.BaseMethodsTests):
85104
@pytest.mark.parametrize("periods", [1, -2])
86105
def test_diff(self, data, periods):
87106
if is_platform_windows() and np_version_gte1p24:
@@ -96,48 +115,5 @@ def test_map(self, data, na_action):
96115
tm.assert_extension_array_equal(result, data)
97116

98117

99-
class TestInterface(BasePeriodTests, base.BaseInterfaceTests):
100-
pass
101-
102-
103-
class TestArithmeticOps(BasePeriodTests, base.BaseArithmeticOpsTests):
104-
def _get_expected_exception(self, op_name, obj, other):
105-
if op_name in ("__sub__", "__rsub__"):
106-
return None
107-
return super()._get_expected_exception(op_name, obj, other)
108-
109-
110-
class TestCasting(BasePeriodTests, base.BaseCastingTests):
111-
pass
112-
113-
114-
class TestComparisonOps(BasePeriodTests, base.BaseComparisonOpsTests):
115-
pass
116-
117-
118-
class TestMissing(BasePeriodTests, base.BaseMissingTests):
119-
pass
120-
121-
122-
class TestReshaping(BasePeriodTests, base.BaseReshapingTests):
123-
pass
124-
125-
126-
class TestSetitem(BasePeriodTests, base.BaseSetitemTests):
127-
pass
128-
129-
130-
class TestGroupby(BasePeriodTests, base.BaseGroupbyTests):
131-
pass
132-
133-
134-
class TestPrinting(BasePeriodTests, base.BasePrintingTests):
135-
pass
136-
137-
138-
class TestParsing(BasePeriodTests, base.BaseParsingTests):
139-
pass
140-
141-
142-
class Test2DCompat(BasePeriodTests, base.NDArrayBacked2DTests):
118+
class Test2DCompat(base.NDArrayBacked2DTests):
143119
pass

0 commit comments

Comments
 (0)