Skip to content

TST: use single-class pattern for Arrow, Masked tests #54573

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 4 additions & 61 deletions pandas/tests/extension/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def data_for_twos(data):
# TODO: skip otherwise?


class TestBaseCasting(base.BaseCastingTests):
class TestArrowArray(base.ExtensionTests):
def test_astype_str(self, data, request):
pa_dtype = data.dtype.pyarrow_dtype
if pa.types.is_binary(pa_dtype):
Expand All @@ -276,8 +276,6 @@ def test_astype_str(self, data, request):
)
super().test_astype_str(data)


class TestConstructors(base.BaseConstructorsTests):
def test_from_dtype(self, data, request):
pa_dtype = data.dtype.pyarrow_dtype
if pa.types.is_string(pa_dtype) or pa.types.is_decimal(pa_dtype):
Expand Down Expand Up @@ -338,12 +336,6 @@ def test_from_sequence_of_strings_pa_array(self, data, request):
result = type(data)._from_sequence_of_strings(pa_array, dtype=data.dtype)
tm.assert_extension_array_equal(result, data)


class TestGetitemTests(base.BaseGetitemTests):
pass


class TestBaseAccumulateTests(base.BaseAccumulateTests):
def check_accumulate(self, ser, op_name, skipna):
result = getattr(ser, op_name)(skipna=skipna)

Expand Down Expand Up @@ -409,8 +401,6 @@ def test_accumulate_series(self, data, all_numeric_accumulations, skipna, reques

self.check_accumulate(ser, op_name, skipna)


class TestReduce(base.BaseReduceTests):
def _supports_reduction(self, obj, op_name: str) -> bool:
dtype = tm.get_dtype(obj)
# error: Item "dtype[Any]" of "dtype[Any] | ExtensionDtype" has
Expand Down Expand Up @@ -561,8 +551,6 @@ def test_median_not_approximate(self, typ):
result = pd.Series([1, 2], dtype=f"{typ}[pyarrow]").median()
assert result == 1.5


class TestBaseGroupby(base.BaseGroupbyTests):
def test_in_numeric_groupby(self, data_for_grouping):
dtype = data_for_grouping.dtype
if is_string_dtype(dtype):
Expand All @@ -583,8 +571,6 @@ def test_in_numeric_groupby(self, data_for_grouping):
else:
super().test_in_numeric_groupby(data_for_grouping)


class TestBaseDtype(base.BaseDtypeTests):
def test_construct_from_string_own_name(self, dtype, request):
pa_dtype = dtype.pyarrow_dtype
if pa.types.is_decimal(pa_dtype):
Expand Down Expand Up @@ -651,20 +637,12 @@ def test_is_not_string_type(self, dtype):
else:
super().test_is_not_string_type(dtype)


class TestBaseIndex(base.BaseIndexTests):
pass


class TestBaseInterface(base.BaseInterfaceTests):
@pytest.mark.xfail(
reason="GH 45419: pyarrow.ChunkedArray does not support views.", run=False
)
def test_view(self, data):
super().test_view(data)


class TestBaseMissing(base.BaseMissingTests):
def test_fillna_no_op_returns_copy(self, data):
data = data[~data.isna()]

Expand All @@ -677,28 +655,18 @@ def test_fillna_no_op_returns_copy(self, data):
assert result is not data
tm.assert_extension_array_equal(result, data)


class TestBasePrinting(base.BasePrintingTests):
pass


class TestBaseReshaping(base.BaseReshapingTests):
@pytest.mark.xfail(
reason="GH 45419: pyarrow.ChunkedArray does not support views", run=False
)
def test_transpose(self, data):
super().test_transpose(data)


class TestBaseSetitem(base.BaseSetitemTests):
@pytest.mark.xfail(
reason="GH 45419: pyarrow.ChunkedArray does not support views", run=False
)
def test_setitem_preserves_views(self, data):
super().test_setitem_preserves_views(data)


class TestBaseParsing(base.BaseParsingTests):
@pytest.mark.parametrize("dtype_backend", ["pyarrow", no_default])
@pytest.mark.parametrize("engine", ["c", "python"])
def test_EA_types(self, engine, data, dtype_backend, request):
Expand Down Expand Up @@ -736,8 +704,6 @@ def test_EA_types(self, engine, data, dtype_backend, request):
expected = df
tm.assert_frame_equal(result, expected)


class TestBaseUnaryOps(base.BaseUnaryOpsTests):
def test_invert(self, data, request):
pa_dtype = data.dtype.pyarrow_dtype
if not pa.types.is_boolean(pa_dtype):
Expand All @@ -749,8 +715,6 @@ def test_invert(self, data, request):
)
super().test_invert(data)


class TestBaseMethods(base.BaseMethodsTests):
@pytest.mark.parametrize("periods", [1, -2])
def test_diff(self, data, periods, request):
pa_dtype = data.dtype.pyarrow_dtype
Expand Down Expand Up @@ -814,8 +778,6 @@ def test_argreduce_series(

_combine_le_expected_dtype = "bool[pyarrow]"


class TestBaseArithmeticOps(base.BaseArithmeticOpsTests):
divmod_exc = NotImplementedError

def get_op_from_name(self, op_name):
Expand All @@ -838,6 +800,9 @@ def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result):
# while ArrowExtensionArray maintains original type
expected = pointwise_result

if op_name in ["eq", "ne", "lt", "le", "gt", "ge"]:
return pointwise_result.astype("boolean[pyarrow]")

was_frame = False
if isinstance(expected, pd.DataFrame):
was_frame = True
Expand Down Expand Up @@ -1121,28 +1086,6 @@ def test_add_series_with_extension_array(self, data, request):
)
super().test_add_series_with_extension_array(data)


class TestBaseComparisonOps(base.BaseComparisonOpsTests):
def test_compare_array(self, data, comparison_op, na_value):
ser = pd.Series(data)
# pd.Series([ser.iloc[0]] * len(ser)) may not return ArrowExtensionArray
# since ser.iloc[0] is a python scalar
other = pd.Series(pd.array([ser.iloc[0]] * len(ser), dtype=data.dtype))
if comparison_op.__name__ in ["eq", "ne"]:
# comparison should match point-wise comparisons
result = comparison_op(ser, other)
# Series.combine does not calculate the NA mask correctly
# when comparing over an array
assert result[8] is na_value
assert result[97] is na_value
expected = ser.combine(other, comparison_op)
expected[8] = na_value
expected[97] = na_value
tm.assert_series_equal(result, expected)

else:
return super().test_compare_array(data, comparison_op)

def test_invalid_other_comp(self, data, comparison_op):
# GH 48833
with pytest.raises(
Expand Down
81 changes: 5 additions & 76 deletions pandas/tests/extension/test_masked.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,7 @@ def data_for_grouping(dtype):
return pd.array([b, b, na, na, a, a, b, c], dtype=dtype)


class TestDtype(base.BaseDtypeTests):
pass


class TestArithmeticOps(base.BaseArithmeticOpsTests):
class TestMaskedArrays(base.ExtensionTests):
def _get_expected_exception(self, op_name, obj, other):
try:
dtype = tm.get_dtype(obj)
Expand All @@ -179,12 +175,15 @@ def _get_expected_exception(self, op_name, obj, other):
# exception message would include "numpy boolean subtract""
return TypeError
return None
return super()._get_expected_exception(op_name, obj, other)
return None

def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result):
sdtype = tm.get_dtype(obj)
expected = pointwise_result

if op_name in ("eq", "ne", "le", "ge", "lt", "gt"):
return expected.astype("boolean")

if sdtype.kind in "iu":
if op_name in ("__rtruediv__", "__truediv__", "__div__"):
expected = expected.fillna(np.nan).astype("Float64")
Expand Down Expand Up @@ -219,11 +218,6 @@ def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result):
expected = expected.astype(sdtype)
return expected

series_scalar_exc = None
series_array_exc = None
frame_scalar_exc = None
divmod_exc = None

def test_divmod_series_array(self, data, data_for_twos, request):
if data.dtype.kind == "b":
mark = pytest.mark.xfail(
Expand All @@ -234,49 +228,6 @@ def test_divmod_series_array(self, data, data_for_twos, request):
request.node.add_marker(mark)
super().test_divmod_series_array(data, data_for_twos)


class TestComparisonOps(base.BaseComparisonOpsTests):
series_scalar_exc = None
series_array_exc = None
frame_scalar_exc = None

def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result):
return pointwise_result.astype("boolean")


class TestInterface(base.BaseInterfaceTests):
pass


class TestConstructors(base.BaseConstructorsTests):
pass


class TestReshaping(base.BaseReshapingTests):
pass

# for test_concat_mixed_dtypes test
# concat of an Integer and Int coerces to object dtype
# TODO(jreback) once integrated this would


class TestGetitem(base.BaseGetitemTests):
pass


class TestSetitem(base.BaseSetitemTests):
pass


class TestIndex(base.BaseIndexTests):
pass


class TestMissing(base.BaseMissingTests):
pass


class TestMethods(base.BaseMethodsTests):
def test_combine_le(self, data_repeated):
# TODO: patching self is a bad pattern here
orig_data1, orig_data2 = data_repeated(2)
Expand All @@ -287,16 +238,6 @@ def test_combine_le(self, data_repeated):
self._combine_le_expected_dtype = object
super().test_combine_le(data_repeated)


class TestCasting(base.BaseCastingTests):
pass


class TestGroupby(base.BaseGroupbyTests):
pass


class TestReduce(base.BaseReduceTests):
def _supports_reduction(self, obj, op_name: str) -> bool:
if op_name in ["any", "all"] and tm.get_dtype(obj).kind != "b":
pytest.skip(reason="Tested in tests/reductions/test_reductions.py")
Expand Down Expand Up @@ -351,8 +292,6 @@ def _get_expected_reduction_dtype(self, arr, op_name: str):
raise TypeError("not supposed to reach this")
return cmp_dtype


class TestAccumulation(base.BaseAccumulateTests):
def _supports_accumulation(self, ser: pd.Series, op_name: str) -> bool:
return True

Expand Down Expand Up @@ -411,8 +350,6 @@ def check_accumulate(self, ser: pd.Series, op_name: str, skipna: bool):
else:
raise NotImplementedError(f"{op_name} not supported")


class TestUnaryOps(base.BaseUnaryOpsTests):
def test_invert(self, data, request):
if data.dtype.kind == "f":
mark = pytest.mark.xfail(
Expand All @@ -423,13 +360,5 @@ def test_invert(self, data, request):
super().test_invert(data)


class TestPrinting(base.BasePrintingTests):
pass


class TestParsing(base.BaseParsingTests):
pass


class Test2DCompat(base.Dim2CompatTests):
pass