Skip to content

REF: use single-test-class for datetimetz, period tests #54676

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions pandas/tests/extension/base/accumulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@ def _supports_accumulation(self, ser: pd.Series, op_name: str) -> bool:
return False

def check_accumulate(self, ser: pd.Series, op_name: str, skipna: bool):
alt = ser.astype("float64")
try:
alt = ser.astype("float64")
except TypeError:
# e.g. Period can't be cast to float64
alt = ser.astype(object)

result = getattr(ser, op_name)(skipna=skipna)

if result.dtype == pd.Float32Dtype() and op_name == "cumprod" and skipna:
Expand All @@ -37,5 +42,6 @@ def test_accumulate_series(self, data, all_numeric_accumulations, skipna):
if self._supports_accumulation(ser, op_name):
self.check_accumulate(ser, op_name, skipna)
else:
with pytest.raises(NotImplementedError):
with pytest.raises((NotImplementedError, TypeError)):
# TODO: require TypeError for things that will _never_ work?
getattr(ser, op_name)(skipna=skipna)
105 changes: 45 additions & 60 deletions pandas/tests/extension/test_datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,78 +82,63 @@ def cmp(a, b):


# ----------------------------------------------------------------------------
class BaseDatetimeTests:
pass
class TestDatetimeArray(base.ExtensionTests):
def _get_expected_exception(self, op_name, obj, other):
if op_name in ["__sub__", "__rsub__"]:
return None
return super()._get_expected_exception(op_name, obj, other)

def _supports_accumulation(self, ser, op_name: str) -> bool:
return op_name in ["cummin", "cummax"]

# ----------------------------------------------------------------------------
# Tests
class TestDatetimeDtype(BaseDatetimeTests, base.BaseDtypeTests):
pass
def _supports_reduction(self, obj, op_name: str) -> bool:
return op_name in ["min", "max", "median", "mean", "std", "any", "all"]

@pytest.mark.parametrize("skipna", [True, False])
def test_reduce_series_boolean(self, data, all_boolean_reductions, skipna):
meth = all_boolean_reductions
msg = f"'{meth}' with datetime64 dtypes is deprecated and will raise in"
with tm.assert_produces_warning(
FutureWarning, match=msg, check_stacklevel=False
):
super().test_reduce_series_boolean(data, all_boolean_reductions, skipna)

class TestConstructors(BaseDatetimeTests, base.BaseConstructorsTests):
def test_series_constructor(self, data):
# Series construction drops any .freq attr
data = data._with_freq(None)
super().test_series_constructor(data)


class TestGetitem(BaseDatetimeTests, base.BaseGetitemTests):
pass


class TestIndex(base.BaseIndexTests):
pass


class TestMethods(BaseDatetimeTests, base.BaseMethodsTests):
@pytest.mark.parametrize("na_action", [None, "ignore"])
def test_map(self, data, na_action):
result = data.map(lambda x: x, na_action=na_action)
tm.assert_extension_array_equal(result, data)


class TestInterface(BaseDatetimeTests, base.BaseInterfaceTests):
pass


class TestArithmeticOps(BaseDatetimeTests, base.BaseArithmeticOpsTests):
implements = {"__sub__", "__rsub__"}

def _get_expected_exception(self, op_name, obj, other):
if op_name in self.implements:
return None
return super()._get_expected_exception(op_name, obj, other)


class TestCasting(BaseDatetimeTests, base.BaseCastingTests):
pass


class TestComparisonOps(BaseDatetimeTests, base.BaseComparisonOpsTests):
pass


class TestMissing(BaseDatetimeTests, base.BaseMissingTests):
pass


class TestReshaping(BaseDatetimeTests, base.BaseReshapingTests):
pass


class TestSetitem(BaseDatetimeTests, base.BaseSetitemTests):
pass


class TestGroupby(BaseDatetimeTests, base.BaseGroupbyTests):
pass


class TestPrinting(BaseDatetimeTests, base.BasePrintingTests):
pass


class Test2DCompat(BaseDatetimeTests, base.NDArrayBacked2DTests):
@pytest.mark.parametrize("engine", ["c", "python"])
def test_EA_types(self, engine, data):
expected_msg = r".*must implement _from_sequence_of_strings.*"
with pytest.raises(NotImplementedError, match=expected_msg):
super().test_EA_types(engine, data)

def check_reduce(self, ser: pd.Series, op_name: str, skipna: bool):
if op_name in ["median", "mean", "std"]:
alt = ser.astype("int64")

res_op = getattr(ser, op_name)
exp_op = getattr(alt, op_name)
result = res_op(skipna=skipna)
expected = exp_op(skipna=skipna)
if op_name in ["mean", "median"]:
# error: Item "dtype[Any]" of "dtype[Any] | ExtensionDtype"
# has no attribute "tz"
tz = ser.dtype.tz # type: ignore[union-attr]
expected = pd.Timestamp(expected, tz=tz)
else:
expected = pd.Timedelta(expected)
tm.assert_almost_equal(result, expected)

else:
return super().check_reduce(ser, op_name, skipna)


class Test2DCompat(base.NDArrayBacked2DTests):
pass
94 changes: 35 additions & 59 deletions pandas/tests/extension/test_period.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,17 @@
be added to the array-specific tests in `pandas/tests/arrays/`.

"""
from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np
import pytest

from pandas._libs import iNaT
from pandas._libs import (
Period,
iNaT,
)
from pandas.compat import is_platform_windows
from pandas.compat.numpy import np_version_gte1p24

Expand All @@ -26,6 +33,9 @@
from pandas.core.arrays import PeriodArray
from pandas.tests.extension import base

if TYPE_CHECKING:
import pandas as pd


@pytest.fixture(params=["D", "2D"])
def dtype(request):
Expand Down Expand Up @@ -61,27 +71,36 @@ def data_for_grouping(dtype):
return PeriodArray([B, B, NA, NA, A, A, B, C], dtype=dtype)


class BasePeriodTests:
pass


class TestPeriodDtype(BasePeriodTests, base.BaseDtypeTests):
pass
class TestPeriodArray(base.ExtensionTests):
def _get_expected_exception(self, op_name, obj, other):
if op_name in ("__sub__", "__rsub__"):
return None
return super()._get_expected_exception(op_name, obj, other)

def _supports_accumulation(self, ser, op_name: str) -> bool:
return op_name in ["cummin", "cummax"]

class TestConstructors(BasePeriodTests, base.BaseConstructorsTests):
pass
def _supports_reduction(self, obj, op_name: str) -> bool:
return op_name in ["min", "max", "median"]

def check_reduce(self, ser: pd.Series, op_name: str, skipna: bool):
if op_name == "median":
res_op = getattr(ser, op_name)

class TestGetitem(BasePeriodTests, base.BaseGetitemTests):
pass
alt = ser.astype("int64")

exp_op = getattr(alt, op_name)
result = res_op(skipna=skipna)
expected = exp_op(skipna=skipna)
# error: Item "dtype[Any]" of "dtype[Any] | ExtensionDtype" has no
# attribute "freq"
freq = ser.dtype.freq # type: ignore[union-attr]
expected = Period._from_ordinal(int(expected), freq=freq)
tm.assert_almost_equal(result, expected)

class TestIndex(base.BaseIndexTests):
pass

else:
return super().check_reduce(ser, op_name, skipna)

class TestMethods(BasePeriodTests, base.BaseMethodsTests):
@pytest.mark.parametrize("periods", [1, -2])
def test_diff(self, data, periods):
if is_platform_windows() and np_version_gte1p24:
Expand All @@ -96,48 +115,5 @@ def test_map(self, data, na_action):
tm.assert_extension_array_equal(result, data)


class TestInterface(BasePeriodTests, base.BaseInterfaceTests):
pass


class TestArithmeticOps(BasePeriodTests, base.BaseArithmeticOpsTests):
def _get_expected_exception(self, op_name, obj, other):
if op_name in ("__sub__", "__rsub__"):
return None
return super()._get_expected_exception(op_name, obj, other)


class TestCasting(BasePeriodTests, base.BaseCastingTests):
pass


class TestComparisonOps(BasePeriodTests, base.BaseComparisonOpsTests):
pass


class TestMissing(BasePeriodTests, base.BaseMissingTests):
pass


class TestReshaping(BasePeriodTests, base.BaseReshapingTests):
pass


class TestSetitem(BasePeriodTests, base.BaseSetitemTests):
pass


class TestGroupby(BasePeriodTests, base.BaseGroupbyTests):
pass


class TestPrinting(BasePeriodTests, base.BasePrintingTests):
pass


class TestParsing(BasePeriodTests, base.BaseParsingTests):
pass


class Test2DCompat(BasePeriodTests, base.NDArrayBacked2DTests):
class Test2DCompat(base.NDArrayBacked2DTests):
pass