Skip to content

REF: de-duplicate some test code #52228

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 121 additions & 0 deletions pandas/tests/extension/masked_shared.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
"""
Shared test code for IntegerArray/FloatingArray/BooleanArray.
"""
import pytest

from pandas.compat import (
IS64,
is_platform_windows,
)

import pandas as pd
import pandas._testing as tm
from pandas.tests.extension import base


class Arithmetic(base.BaseArithmeticOpsTests):
def check_opname(self, ser: pd.Series, op_name: str, other, exc=None):
# overwriting to indicate ops don't raise an error
super().check_opname(ser, op_name, other, exc=None)

def _check_divmod_op(self, ser: pd.Series, op, other, exc=None):
super()._check_divmod_op(ser, op, other, None)


class Comparison(base.BaseComparisonOpsTests):
def _check_op(
self, ser: pd.Series, op, other, op_name: str, exc=NotImplementedError
):
if exc is None:
result = op(ser, other)
# Override to do the astype to boolean
expected = ser.combine(other, op).astype("boolean")
self.assert_series_equal(result, expected)
else:
with pytest.raises(exc):
op(ser, other)

def check_opname(self, ser: pd.Series, op_name: str, other, exc=None):
super().check_opname(ser, op_name, other, exc=None)

def _compare_other(self, ser: pd.Series, data, op, other):
op_name = f"__{op.__name__}__"
self.check_opname(ser, op_name, other)


class NumericReduce(base.BaseNumericReduceTests):
def check_reduce(self, ser: pd.Series, op_name: str, skipna: bool):
# overwrite to ensure pd.NA is tested instead of np.nan
# https://github.com/pandas-dev/pandas/issues/30958

cmp_dtype = "int64"
if ser.dtype.kind == "f":
# Item "dtype[Any]" of "Union[dtype[Any], ExtensionDtype]" has
# no attribute "numpy_dtype"
cmp_dtype = ser.dtype.numpy_dtype # type: ignore[union-attr]

if op_name == "count":
result = getattr(ser, op_name)()
expected = getattr(ser.dropna().astype(cmp_dtype), op_name)()
else:
result = getattr(ser, op_name)(skipna=skipna)
expected = getattr(ser.dropna().astype(cmp_dtype), op_name)(skipna=skipna)
if not skipna and ser.isna().any():
expected = pd.NA
tm.assert_almost_equal(result, expected)


class Accumulation(base.BaseAccumulateTests):
@pytest.mark.parametrize("skipna", [True, False])
def test_accumulate_series_raises(self, data, all_numeric_accumulations, skipna):
pass

def check_accumulate(self, ser: pd.Series, op_name: str, skipna: bool):
# overwrite to ensure pd.NA is tested instead of np.nan
# https://github.com/pandas-dev/pandas/issues/30958
length = 64
if not IS64 or is_platform_windows():
# Item "ExtensionDtype" of "Union[dtype[Any], ExtensionDtype]" has
# no attribute "itemsize"
if not ser.dtype.itemsize == 8: # type: ignore[union-attr]
length = 32

if ser.dtype.name.startswith("U"):
expected_dtype = f"UInt{length}"
elif ser.dtype.name.startswith("I"):
expected_dtype = f"Int{length}"
elif ser.dtype.name.startswith("F"):
# Incompatible types in assignment (expression has type
# "Union[dtype[Any], ExtensionDtype]", variable has type "str")
expected_dtype = ser.dtype # type: ignore[assignment]

if op_name == "cumsum":
result = getattr(ser, op_name)(skipna=skipna)
expected = pd.Series(
pd.array(
getattr(ser.astype("float64"), op_name)(skipna=skipna),
dtype=expected_dtype,
)
)
tm.assert_series_equal(result, expected)
elif op_name in ["cummax", "cummin"]:
result = getattr(ser, op_name)(skipna=skipna)
expected = pd.Series(
pd.array(
getattr(ser.astype("float64"), op_name)(skipna=skipna),
dtype=ser.dtype,
)
)
tm.assert_series_equal(result, expected)
elif op_name == "cumprod":
result = getattr(ser[:12], op_name)(skipna=skipna)
expected = pd.Series(
pd.array(
getattr(ser[:12].astype("float64"), op_name)(skipna=skipna),
dtype=expected_dtype,
)
)
tm.assert_series_equal(result, expected)

else:
raise NotImplementedError(f"{op_name} not supported")
57 changes: 11 additions & 46 deletions pandas/tests/extension/test_floating.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@
Float32Dtype,
Float64Dtype,
)
from pandas.tests.extension import base
from pandas.tests.extension import (
base,
masked_shared,
)


def make_data():
Expand Down Expand Up @@ -92,11 +95,7 @@ class TestDtype(base.BaseDtypeTests):
pass


class TestArithmeticOps(base.BaseArithmeticOpsTests):
def check_opname(self, s, op_name, other, exc=None):
# overwriting to indicate ops don't raise an error
super().check_opname(s, op_name, other, exc=None)

class TestArithmeticOps(masked_shared.Arithmetic):
def _check_op(self, s, op, other, op_name, exc=NotImplementedError):
if exc is None:
sdtype = tm.get_dtype(s)
Expand All @@ -120,28 +119,9 @@ def _check_op(self, s, op, other, op_name, exc=NotImplementedError):
with pytest.raises(exc):
op(s, other)

def _check_divmod_op(self, s, op, other, exc=None):
super()._check_divmod_op(s, op, other, None)


class TestComparisonOps(base.BaseComparisonOpsTests):
# TODO: share with IntegerArray?
def _check_op(self, s, op, other, op_name, exc=NotImplementedError):
if exc is None:
result = op(s, other)
# Override to do the astype to boolean
expected = s.combine(other, op).astype("boolean")
self.assert_series_equal(result, expected)
else:
with pytest.raises(exc):
op(s, other)

def check_opname(self, s, op_name, other, exc=None):
super().check_opname(s, op_name, other, exc=None)

def _compare_other(self, s, data, op, other):
op_name = f"__{op.__name__}__"
self.check_opname(s, op_name, other)
class TestComparisonOps(masked_shared.Comparison):
pass


class TestInterface(base.BaseInterfaceTests):
Expand Down Expand Up @@ -184,21 +164,8 @@ class TestGroupby(base.BaseGroupbyTests):
pass


class TestNumericReduce(base.BaseNumericReduceTests):
def check_reduce(self, s, op_name, skipna):
# overwrite to ensure pd.NA is tested instead of np.nan
# https://github.com/pandas-dev/pandas/issues/30958
if op_name == "count":
result = getattr(s, op_name)()
expected = getattr(s.dropna().astype(s.dtype.numpy_dtype), op_name)()
else:
result = getattr(s, op_name)(skipna=skipna)
expected = getattr(s.dropna().astype(s.dtype.numpy_dtype), op_name)(
skipna=skipna
)
if not skipna and s.isna().any():
expected = pd.NA
tm.assert_almost_equal(result, expected)
class TestNumericReduce(masked_shared.NumericReduce):
pass


@pytest.mark.skip(reason="Tested in tests/reductions/test_reductions.py")
Expand All @@ -219,7 +186,5 @@ class Test2DCompat(base.Dim2CompatTests):
pass


class TestAccumulation(base.BaseAccumulateTests):
@pytest.mark.parametrize("skipna", [True, False])
def test_accumulate_series_raises(self, data, all_numeric_accumulations, skipna):
pass
class TestAccumulation(masked_shared.Accumulation):
pass
103 changes: 11 additions & 92 deletions pandas/tests/extension/test_integer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,6 @@
import numpy as np
import pytest

from pandas.compat import (
IS64,
is_platform_windows,
)

import pandas as pd
import pandas._testing as tm
from pandas.api.types import (
Expand All @@ -37,7 +32,10 @@
UInt32Dtype,
UInt64Dtype,
)
from pandas.tests.extension import base
from pandas.tests.extension import (
base,
masked_shared,
)


def make_data():
Expand Down Expand Up @@ -109,11 +107,7 @@ class TestDtype(base.BaseDtypeTests):
pass


class TestArithmeticOps(base.BaseArithmeticOpsTests):
def check_opname(self, s, op_name, other, exc=None):
# overwriting to indicate ops don't raise an error
super().check_opname(s, op_name, other, exc=None)

class TestArithmeticOps(masked_shared.Arithmetic):
def _check_op(self, s, op, other, op_name, exc=NotImplementedError):
if exc is None:
sdtype = tm.get_dtype(s)
Expand Down Expand Up @@ -145,27 +139,9 @@ def _check_op(self, s, op, other, op_name, exc=NotImplementedError):
with pytest.raises(exc):
op(s, other)

def _check_divmod_op(self, s, op, other, exc=None):
super()._check_divmod_op(s, op, other, None)


class TestComparisonOps(base.BaseComparisonOpsTests):
def _check_op(self, s, op, other, op_name, exc=NotImplementedError):
if exc is None:
result = op(s, other)
# Override to do the astype to boolean
expected = s.combine(other, op).astype("boolean")
self.assert_series_equal(result, expected)
else:
with pytest.raises(exc):
op(s, other)

def check_opname(self, s, op_name, other, exc=None):
super().check_opname(s, op_name, other, exc=None)

def _compare_other(self, s, data, op, other):
op_name = f"__{op.__name__}__"
self.check_opname(s, op_name, other)
class TestComparisonOps(masked_shared.Comparison):
pass


class TestInterface(base.BaseInterfaceTests):
Expand Down Expand Up @@ -212,74 +188,17 @@ class TestGroupby(base.BaseGroupbyTests):
pass


class TestNumericReduce(base.BaseNumericReduceTests):
def check_reduce(self, s, op_name, skipna):
# overwrite to ensure pd.NA is tested instead of np.nan
# https://github.com/pandas-dev/pandas/issues/30958
if op_name == "count":
result = getattr(s, op_name)()
expected = getattr(s.dropna().astype("int64"), op_name)()
else:
result = getattr(s, op_name)(skipna=skipna)
expected = getattr(s.dropna().astype("int64"), op_name)(skipna=skipna)
if not skipna and s.isna().any():
expected = pd.NA
tm.assert_almost_equal(result, expected)
class TestNumericReduce(masked_shared.NumericReduce):
pass


@pytest.mark.skip(reason="Tested in tests/reductions/test_reductions.py")
class TestBooleanReduce(base.BaseBooleanReduceTests):
pass


class TestAccumulation(base.BaseAccumulateTests):
def check_accumulate(self, s, op_name, skipna):
# overwrite to ensure pd.NA is tested instead of np.nan
# https://github.com/pandas-dev/pandas/issues/30958
length = 64
if not IS64 or is_platform_windows():
if not s.dtype.itemsize == 8:
length = 32

if s.dtype.name.startswith("U"):
expected_dtype = f"UInt{length}"
else:
expected_dtype = f"Int{length}"

if op_name == "cumsum":
result = getattr(s, op_name)(skipna=skipna)
expected = pd.Series(
pd.array(
getattr(s.astype("float64"), op_name)(skipna=skipna),
dtype=expected_dtype,
)
)
tm.assert_series_equal(result, expected)
elif op_name in ["cummax", "cummin"]:
result = getattr(s, op_name)(skipna=skipna)
expected = pd.Series(
pd.array(
getattr(s.astype("float64"), op_name)(skipna=skipna),
dtype=s.dtype,
)
)
tm.assert_series_equal(result, expected)
elif op_name == "cumprod":
result = getattr(s[:12], op_name)(skipna=skipna)
expected = pd.Series(
pd.array(
getattr(s[:12].astype("float64"), op_name)(skipna=skipna),
dtype=expected_dtype,
)
)
tm.assert_series_equal(result, expected)

else:
raise NotImplementedError(f"{op_name} not supported")

@pytest.mark.parametrize("skipna", [True, False])
def test_accumulate_series_raises(self, data, all_numeric_accumulations, skipna):
pass
class TestAccumulation(masked_shared.Accumulation):
pass


class TestPrinting(base.BasePrintingTests):
Expand Down
Loading