Skip to content

TST: use one-class pattern in test_numpy #56512

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pandas/tests/extension/base/interface.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import pytest

from pandas.core.dtypes.cast import construct_1d_object_array_from_listlike
from pandas.core.dtypes.common import is_extension_array_dtype
from pandas.core.dtypes.dtypes import ExtensionDtype

Expand Down Expand Up @@ -65,6 +66,9 @@ def test_array_interface(self, data):

result = np.array(data, dtype=object)
expected = np.array(list(data), dtype=object)
if expected.ndim > 1:
# nested data, explicitly construct as 1D
expected = construct_1d_object_array_from_listlike(list(data))
tm.assert_numpy_array_equal(result, expected)

def test_is_extension_array_dtype(self, data):
Expand Down
125 changes: 60 additions & 65 deletions pandas/tests/extension/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,14 @@
import numpy as np
import pytest

from pandas.core.dtypes.cast import can_hold_element
from pandas.core.dtypes.dtypes import NumpyEADtype

import pandas as pd
import pandas._testing as tm
from pandas.api.types import is_object_dtype
from pandas.core.arrays.numpy_ import NumpyExtensionArray
from pandas.core.internals import blocks
from pandas.tests.extension import base


def _can_hold_element_patched(obj, element) -> bool:
if isinstance(element, NumpyExtensionArray):
element = element.to_numpy()
return can_hold_element(obj, element)


orig_assert_attr_equal = tm.assert_attr_equal


Expand Down Expand Up @@ -78,7 +69,6 @@ def allow_in_pandas(monkeypatch):
"""
with monkeypatch.context() as m:
m.setattr(NumpyExtensionArray, "_typ", "extension")
m.setattr(blocks, "can_hold_element", _can_hold_element_patched)
m.setattr(tm.asserters, "assert_attr_equal", _assert_attr_equal)
yield

Expand Down Expand Up @@ -175,15 +165,7 @@ def skip_numpy_object(dtype, request):
skip_nested = pytest.mark.usefixtures("skip_numpy_object")


class BaseNumPyTests:
pass


class TestCasting(BaseNumPyTests, base.BaseCastingTests):
pass


class TestConstructors(BaseNumPyTests, base.BaseConstructorsTests):
class TestNumpyExtensionArray(base.ExtensionTests):
@pytest.mark.skip(reason="We don't register our dtype")
# We don't want to register. This test should probably be split in two.
def test_from_dtype(self, data):
Expand All @@ -194,8 +176,6 @@ def test_series_constructor_scalar_with_index(self, data, dtype):
# ValueError: Length of passed values is 1, index implies 3.
super().test_series_constructor_scalar_with_index(data, dtype)


class TestDtype(BaseNumPyTests, base.BaseDtypeTests):
def test_check_dtype(self, data, request, using_infer_string):
if data.dtype.numpy_dtype == "object":
request.applymarker(
Expand All @@ -214,26 +194,11 @@ def test_is_not_object_type(self, dtype, request):
else:
super().test_is_not_object_type(dtype)


class TestGetitem(BaseNumPyTests, base.BaseGetitemTests):
@skip_nested
def test_getitem_scalar(self, data):
# AssertionError
super().test_getitem_scalar(data)


class TestGroupby(BaseNumPyTests, base.BaseGroupbyTests):
pass


class TestInterface(BaseNumPyTests, base.BaseInterfaceTests):
@skip_nested
def test_array_interface(self, data):
# NumPy array shape inference
super().test_array_interface(data)


class TestMethods(BaseNumPyTests, base.BaseMethodsTests):
@skip_nested
def test_shift_fill_value(self, data):
# np.array shape inference. Shift implementation fails.
Expand All @@ -251,7 +216,9 @@ def test_fillna_copy_series(self, data_missing):

@skip_nested
def test_searchsorted(self, data_for_sorting, as_series):
# Test setup fails.
# TODO: NumpyExtensionArray.searchsorted calls ndarray.searchsorted which
# isn't quite what we want in nested data cases. Instead we need to
# adapt something like libindex._bin_search.
super().test_searchsorted(data_for_sorting, as_series)

@pytest.mark.xfail(reason="NumpyExtensionArray.diff may fail on dtype")
Expand All @@ -270,38 +237,60 @@ def test_insert_invalid(self, data, invalid_scalar):
# NumpyExtensionArray[object] can hold anything, so skip
super().test_insert_invalid(data, invalid_scalar)


class TestArithmetics(BaseNumPyTests, base.BaseArithmeticOpsTests):
divmod_exc = None
series_scalar_exc = None
frame_scalar_exc = None
series_array_exc = None

@skip_nested
def test_divmod(self, data):
divmod_exc = None
if data.dtype.kind == "O":
divmod_exc = TypeError
self.divmod_exc = divmod_exc
super().test_divmod(data)

@skip_nested
def test_arith_series_with_scalar(self, data, all_arithmetic_operators):
def test_divmod_series_array(self, data):
ser = pd.Series(data)
exc = None
if data.dtype.kind == "O":
exc = TypeError
self.divmod_exc = exc
self._check_divmod_op(ser, divmod, data)

def test_arith_series_with_scalar(self, data, all_arithmetic_operators, request):
opname = all_arithmetic_operators
series_scalar_exc = None
if data.dtype.numpy_dtype == object:
if opname in ["__mul__", "__rmul__"]:
mark = pytest.mark.xfail(
reason="the Series.combine step raises but not the Series method."
)
request.node.add_marker(mark)
series_scalar_exc = TypeError
self.series_scalar_exc = series_scalar_exc
super().test_arith_series_with_scalar(data, all_arithmetic_operators)

def test_arith_series_with_array(self, data, all_arithmetic_operators, request):
def test_arith_series_with_array(self, data, all_arithmetic_operators):
opname = all_arithmetic_operators
series_array_exc = None
if data.dtype.numpy_dtype == object and opname not in ["__add__", "__radd__"]:
mark = pytest.mark.xfail(reason="Fails for object dtype")
request.applymarker(mark)
series_array_exc = TypeError
self.series_array_exc = series_array_exc
super().test_arith_series_with_array(data, all_arithmetic_operators)

@skip_nested
def test_arith_frame_with_scalar(self, data, all_arithmetic_operators):
def test_arith_frame_with_scalar(self, data, all_arithmetic_operators, request):
opname = all_arithmetic_operators
frame_scalar_exc = None
if data.dtype.numpy_dtype == object:
if opname in ["__mul__", "__rmul__"]:
mark = pytest.mark.xfail(
reason="the Series.combine step raises but not the Series method."
)
request.node.add_marker(mark)
frame_scalar_exc = TypeError
self.frame_scalar_exc = frame_scalar_exc
super().test_arith_frame_with_scalar(data, all_arithmetic_operators)


class TestPrinting(BaseNumPyTests, base.BasePrintingTests):
pass


class TestReduce(BaseNumPyTests, base.BaseReduceTests):
def _supports_reduction(self, ser: pd.Series, op_name: str) -> bool:
if ser.dtype.kind == "O":
return op_name in ["sum", "min", "max", "any", "all"]
Expand All @@ -328,8 +317,6 @@ def check_reduce(self, ser: pd.Series, op_name: str, skipna: bool):
def test_reduce_frame(self, data, all_numeric_reductions, skipna):
pass


class TestMissing(BaseNumPyTests, base.BaseMissingTests):
@skip_nested
def test_fillna_series(self, data_missing):
# Non-scalar "scalar" values.
Expand All @@ -340,12 +327,6 @@ def test_fillna_frame(self, data_missing):
# Non-scalar "scalar" values.
super().test_fillna_frame(data_missing)


class TestReshaping(BaseNumPyTests, base.BaseReshapingTests):
pass


class TestSetitem(BaseNumPyTests, base.BaseSetitemTests):
@skip_nested
def test_setitem_invalid(self, data, invalid_scalar):
# object dtype can hold anything, so doesn't raise
Expand Down Expand Up @@ -431,11 +412,25 @@ def test_setitem_with_expansion_dataframe_column(self, data, full_indexer):
expected = pd.DataFrame({"data": data.to_numpy()})
tm.assert_frame_equal(result, expected, check_column_type=False)

@pytest.mark.xfail(reason="NumpyEADtype is unpacked")
def test_index_from_listlike_with_dtype(self, data):
super().test_index_from_listlike_with_dtype(data)

@skip_nested
class TestParsing(BaseNumPyTests, base.BaseParsingTests):
pass
@skip_nested
@pytest.mark.parametrize("engine", ["c", "python"])
def test_EA_types(self, engine, data, request):
super().test_EA_types(engine, data, request)

@pytest.mark.xfail(reason="Expect NumpyEA, get np.ndarray")
def test_compare_array(self, data, comparison_op):
super().test_compare_array(data, comparison_op)

def test_compare_scalar(self, data, comparison_op, request):
if data.dtype.kind == "f" or comparison_op.__name__ in ["eq", "ne"]:
mark = pytest.mark.xfail(reason="Expect NumpyEA, get np.ndarray")
request.applymarker(mark)
super().test_compare_scalar(data, comparison_op)


class Test2DCompat(BaseNumPyTests, base.NDArrayBacked2DTests):
class Test2DCompat(base.NDArrayBacked2DTests):
pass