diff --git a/pandas/tests/extension/base/accumulate.py b/pandas/tests/extension/base/accumulate.py index 4648f66112e80..776ff80cd6e17 100644 --- a/pandas/tests/extension/base/accumulate.py +++ b/pandas/tests/extension/base/accumulate.py @@ -16,7 +16,12 @@ def _supports_accumulation(self, ser: pd.Series, op_name: str) -> bool: return False def check_accumulate(self, ser: pd.Series, op_name: str, skipna: bool): - alt = ser.astype("float64") + try: + alt = ser.astype("float64") + except TypeError: + # e.g. Period can't be cast to float64 + alt = ser.astype(object) + result = getattr(ser, op_name)(skipna=skipna) if result.dtype == pd.Float32Dtype() and op_name == "cumprod" and skipna: @@ -37,5 +42,6 @@ def test_accumulate_series(self, data, all_numeric_accumulations, skipna): if self._supports_accumulation(ser, op_name): self.check_accumulate(ser, op_name, skipna) else: - with pytest.raises(NotImplementedError): + with pytest.raises((NotImplementedError, TypeError)): + # TODO: require TypeError for things that will _never_ work? getattr(ser, op_name)(skipna=skipna) diff --git a/pandas/tests/extension/test_datetime.py b/pandas/tests/extension/test_datetime.py index 97773d0d40a57..5a7b15ddb01ce 100644 --- a/pandas/tests/extension/test_datetime.py +++ b/pandas/tests/extension/test_datetime.py @@ -82,78 +82,63 @@ def cmp(a, b): # ---------------------------------------------------------------------------- -class BaseDatetimeTests: - pass +class TestDatetimeArray(base.ExtensionTests): + def _get_expected_exception(self, op_name, obj, other): + if op_name in ["__sub__", "__rsub__"]: + return None + return super()._get_expected_exception(op_name, obj, other) + def _supports_accumulation(self, ser, op_name: str) -> bool: + return op_name in ["cummin", "cummax"] -# ---------------------------------------------------------------------------- -# Tests -class TestDatetimeDtype(BaseDatetimeTests, base.BaseDtypeTests): - pass + def _supports_reduction(self, obj, op_name: str) -> bool: + return op_name in ["min", "max", "median", "mean", "std", "any", "all"] + @pytest.mark.parametrize("skipna", [True, False]) + def test_reduce_series_boolean(self, data, all_boolean_reductions, skipna): + meth = all_boolean_reductions + msg = f"'{meth}' with datetime64 dtypes is deprecated and will raise in" + with tm.assert_produces_warning( + FutureWarning, match=msg, check_stacklevel=False + ): + super().test_reduce_series_boolean(data, all_boolean_reductions, skipna) -class TestConstructors(BaseDatetimeTests, base.BaseConstructorsTests): def test_series_constructor(self, data): # Series construction drops any .freq attr data = data._with_freq(None) super().test_series_constructor(data) - -class TestGetitem(BaseDatetimeTests, base.BaseGetitemTests): - pass - - -class TestIndex(base.BaseIndexTests): - pass - - -class TestMethods(BaseDatetimeTests, base.BaseMethodsTests): @pytest.mark.parametrize("na_action", [None, "ignore"]) def test_map(self, data, na_action): result = data.map(lambda x: x, na_action=na_action) tm.assert_extension_array_equal(result, data) - -class TestInterface(BaseDatetimeTests, base.BaseInterfaceTests): - pass - - -class TestArithmeticOps(BaseDatetimeTests, base.BaseArithmeticOpsTests): - implements = {"__sub__", "__rsub__"} - - def _get_expected_exception(self, op_name, obj, other): - if op_name in self.implements: - return None - return super()._get_expected_exception(op_name, obj, other) - - -class TestCasting(BaseDatetimeTests, base.BaseCastingTests): - pass - - -class TestComparisonOps(BaseDatetimeTests, base.BaseComparisonOpsTests): - pass - - -class TestMissing(BaseDatetimeTests, base.BaseMissingTests): - pass - - -class TestReshaping(BaseDatetimeTests, base.BaseReshapingTests): - pass - - -class TestSetitem(BaseDatetimeTests, base.BaseSetitemTests): - pass - - -class TestGroupby(BaseDatetimeTests, base.BaseGroupbyTests): - pass - - -class TestPrinting(BaseDatetimeTests, base.BasePrintingTests): - pass - - -class Test2DCompat(BaseDatetimeTests, base.NDArrayBacked2DTests): + @pytest.mark.parametrize("engine", ["c", "python"]) + def test_EA_types(self, engine, data): + expected_msg = r".*must implement _from_sequence_of_strings.*" + with pytest.raises(NotImplementedError, match=expected_msg): + super().test_EA_types(engine, data) + + def check_reduce(self, ser: pd.Series, op_name: str, skipna: bool): + if op_name in ["median", "mean", "std"]: + alt = ser.astype("int64") + + res_op = getattr(ser, op_name) + exp_op = getattr(alt, op_name) + result = res_op(skipna=skipna) + expected = exp_op(skipna=skipna) + if op_name in ["mean", "median"]: + # error: Item "dtype[Any]" of "dtype[Any] | ExtensionDtype" + # has no attribute "tz" + tz = ser.dtype.tz # type: ignore[union-attr] + expected = pd.Timestamp(expected, tz=tz) + else: + expected = pd.Timedelta(expected) + tm.assert_almost_equal(result, expected) + + else: + return super().check_reduce(ser, op_name, skipna) + + +class Test2DCompat(base.NDArrayBacked2DTests): pass diff --git a/pandas/tests/extension/test_period.py b/pandas/tests/extension/test_period.py index 63297c20daa97..2d1d213322bac 100644 --- a/pandas/tests/extension/test_period.py +++ b/pandas/tests/extension/test_period.py @@ -13,10 +13,17 @@ be added to the array-specific tests in `pandas/tests/arrays/`. """ +from __future__ import annotations + +from typing import TYPE_CHECKING + import numpy as np import pytest -from pandas._libs import iNaT +from pandas._libs import ( + Period, + iNaT, +) from pandas.compat import is_platform_windows from pandas.compat.numpy import np_version_gte1p24 @@ -26,6 +33,9 @@ from pandas.core.arrays import PeriodArray from pandas.tests.extension import base +if TYPE_CHECKING: + import pandas as pd + @pytest.fixture(params=["D", "2D"]) def dtype(request): @@ -61,27 +71,36 @@ def data_for_grouping(dtype): return PeriodArray([B, B, NA, NA, A, A, B, C], dtype=dtype) -class BasePeriodTests: - pass - - -class TestPeriodDtype(BasePeriodTests, base.BaseDtypeTests): - pass +class TestPeriodArray(base.ExtensionTests): + def _get_expected_exception(self, op_name, obj, other): + if op_name in ("__sub__", "__rsub__"): + return None + return super()._get_expected_exception(op_name, obj, other) + def _supports_accumulation(self, ser, op_name: str) -> bool: + return op_name in ["cummin", "cummax"] -class TestConstructors(BasePeriodTests, base.BaseConstructorsTests): - pass + def _supports_reduction(self, obj, op_name: str) -> bool: + return op_name in ["min", "max", "median"] + def check_reduce(self, ser: pd.Series, op_name: str, skipna: bool): + if op_name == "median": + res_op = getattr(ser, op_name) -class TestGetitem(BasePeriodTests, base.BaseGetitemTests): - pass + alt = ser.astype("int64") + exp_op = getattr(alt, op_name) + result = res_op(skipna=skipna) + expected = exp_op(skipna=skipna) + # error: Item "dtype[Any]" of "dtype[Any] | ExtensionDtype" has no + # attribute "freq" + freq = ser.dtype.freq # type: ignore[union-attr] + expected = Period._from_ordinal(int(expected), freq=freq) + tm.assert_almost_equal(result, expected) -class TestIndex(base.BaseIndexTests): - pass - + else: + return super().check_reduce(ser, op_name, skipna) -class TestMethods(BasePeriodTests, base.BaseMethodsTests): @pytest.mark.parametrize("periods", [1, -2]) def test_diff(self, data, periods): if is_platform_windows() and np_version_gte1p24: @@ -96,48 +115,5 @@ def test_map(self, data, na_action): tm.assert_extension_array_equal(result, data) -class TestInterface(BasePeriodTests, base.BaseInterfaceTests): - pass - - -class TestArithmeticOps(BasePeriodTests, base.BaseArithmeticOpsTests): - def _get_expected_exception(self, op_name, obj, other): - if op_name in ("__sub__", "__rsub__"): - return None - return super()._get_expected_exception(op_name, obj, other) - - -class TestCasting(BasePeriodTests, base.BaseCastingTests): - pass - - -class TestComparisonOps(BasePeriodTests, base.BaseComparisonOpsTests): - pass - - -class TestMissing(BasePeriodTests, base.BaseMissingTests): - pass - - -class TestReshaping(BasePeriodTests, base.BaseReshapingTests): - pass - - -class TestSetitem(BasePeriodTests, base.BaseSetitemTests): - pass - - -class TestGroupby(BasePeriodTests, base.BaseGroupbyTests): - pass - - -class TestPrinting(BasePeriodTests, base.BasePrintingTests): - pass - - -class TestParsing(BasePeriodTests, base.BaseParsingTests): - pass - - -class Test2DCompat(BasePeriodTests, base.NDArrayBacked2DTests): +class Test2DCompat(base.NDArrayBacked2DTests): pass