Skip to content

Commit 31da913

Browse files
authored
TST: use single-class pattern for Arrow, Masked tests (#54573)
* TST: use single-class pattern in test_masked.py * TST: use one-class pattern in arrow extension tests
1 parent 1c5c4ef commit 31da913

File tree

2 files changed

+9
-137
lines changed

2 files changed

+9
-137
lines changed

pandas/tests/extension/test_arrow.py

+4-61
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ def data_for_twos(data):
265265
# TODO: skip otherwise?
266266

267267

268-
class TestBaseCasting(base.BaseCastingTests):
268+
class TestArrowArray(base.ExtensionTests):
269269
def test_astype_str(self, data, request):
270270
pa_dtype = data.dtype.pyarrow_dtype
271271
if pa.types.is_binary(pa_dtype):
@@ -276,8 +276,6 @@ def test_astype_str(self, data, request):
276276
)
277277
super().test_astype_str(data)
278278

279-
280-
class TestConstructors(base.BaseConstructorsTests):
281279
def test_from_dtype(self, data, request):
282280
pa_dtype = data.dtype.pyarrow_dtype
283281
if pa.types.is_string(pa_dtype) or pa.types.is_decimal(pa_dtype):
@@ -338,12 +336,6 @@ def test_from_sequence_of_strings_pa_array(self, data, request):
338336
result = type(data)._from_sequence_of_strings(pa_array, dtype=data.dtype)
339337
tm.assert_extension_array_equal(result, data)
340338

341-
342-
class TestGetitemTests(base.BaseGetitemTests):
343-
pass
344-
345-
346-
class TestBaseAccumulateTests(base.BaseAccumulateTests):
347339
def check_accumulate(self, ser, op_name, skipna):
348340
result = getattr(ser, op_name)(skipna=skipna)
349341

@@ -409,8 +401,6 @@ def test_accumulate_series(self, data, all_numeric_accumulations, skipna, reques
409401

410402
self.check_accumulate(ser, op_name, skipna)
411403

412-
413-
class TestReduce(base.BaseReduceTests):
414404
def _supports_reduction(self, obj, op_name: str) -> bool:
415405
dtype = tm.get_dtype(obj)
416406
# error: Item "dtype[Any]" of "dtype[Any] | ExtensionDtype" has
@@ -561,8 +551,6 @@ def test_median_not_approximate(self, typ):
561551
result = pd.Series([1, 2], dtype=f"{typ}[pyarrow]").median()
562552
assert result == 1.5
563553

564-
565-
class TestBaseGroupby(base.BaseGroupbyTests):
566554
def test_in_numeric_groupby(self, data_for_grouping):
567555
dtype = data_for_grouping.dtype
568556
if is_string_dtype(dtype):
@@ -583,8 +571,6 @@ def test_in_numeric_groupby(self, data_for_grouping):
583571
else:
584572
super().test_in_numeric_groupby(data_for_grouping)
585573

586-
587-
class TestBaseDtype(base.BaseDtypeTests):
588574
def test_construct_from_string_own_name(self, dtype, request):
589575
pa_dtype = dtype.pyarrow_dtype
590576
if pa.types.is_decimal(pa_dtype):
@@ -651,20 +637,12 @@ def test_is_not_string_type(self, dtype):
651637
else:
652638
super().test_is_not_string_type(dtype)
653639

654-
655-
class TestBaseIndex(base.BaseIndexTests):
656-
pass
657-
658-
659-
class TestBaseInterface(base.BaseInterfaceTests):
660640
@pytest.mark.xfail(
661641
reason="GH 45419: pyarrow.ChunkedArray does not support views.", run=False
662642
)
663643
def test_view(self, data):
664644
super().test_view(data)
665645

666-
667-
class TestBaseMissing(base.BaseMissingTests):
668646
def test_fillna_no_op_returns_copy(self, data):
669647
data = data[~data.isna()]
670648

@@ -677,28 +655,18 @@ def test_fillna_no_op_returns_copy(self, data):
677655
assert result is not data
678656
tm.assert_extension_array_equal(result, data)
679657

680-
681-
class TestBasePrinting(base.BasePrintingTests):
682-
pass
683-
684-
685-
class TestBaseReshaping(base.BaseReshapingTests):
686658
@pytest.mark.xfail(
687659
reason="GH 45419: pyarrow.ChunkedArray does not support views", run=False
688660
)
689661
def test_transpose(self, data):
690662
super().test_transpose(data)
691663

692-
693-
class TestBaseSetitem(base.BaseSetitemTests):
694664
@pytest.mark.xfail(
695665
reason="GH 45419: pyarrow.ChunkedArray does not support views", run=False
696666
)
697667
def test_setitem_preserves_views(self, data):
698668
super().test_setitem_preserves_views(data)
699669

700-
701-
class TestBaseParsing(base.BaseParsingTests):
702670
@pytest.mark.parametrize("dtype_backend", ["pyarrow", no_default])
703671
@pytest.mark.parametrize("engine", ["c", "python"])
704672
def test_EA_types(self, engine, data, dtype_backend, request):
@@ -736,8 +704,6 @@ def test_EA_types(self, engine, data, dtype_backend, request):
736704
expected = df
737705
tm.assert_frame_equal(result, expected)
738706

739-
740-
class TestBaseUnaryOps(base.BaseUnaryOpsTests):
741707
def test_invert(self, data, request):
742708
pa_dtype = data.dtype.pyarrow_dtype
743709
if not pa.types.is_boolean(pa_dtype):
@@ -749,8 +715,6 @@ def test_invert(self, data, request):
749715
)
750716
super().test_invert(data)
751717

752-
753-
class TestBaseMethods(base.BaseMethodsTests):
754718
@pytest.mark.parametrize("periods", [1, -2])
755719
def test_diff(self, data, periods, request):
756720
pa_dtype = data.dtype.pyarrow_dtype
@@ -814,8 +778,6 @@ def test_argreduce_series(
814778

815779
_combine_le_expected_dtype = "bool[pyarrow]"
816780

817-
818-
class TestBaseArithmeticOps(base.BaseArithmeticOpsTests):
819781
divmod_exc = NotImplementedError
820782

821783
def get_op_from_name(self, op_name):
@@ -838,6 +800,9 @@ def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result):
838800
# while ArrowExtensionArray maintains original type
839801
expected = pointwise_result
840802

803+
if op_name in ["eq", "ne", "lt", "le", "gt", "ge"]:
804+
return pointwise_result.astype("boolean[pyarrow]")
805+
841806
was_frame = False
842807
if isinstance(expected, pd.DataFrame):
843808
was_frame = True
@@ -1121,28 +1086,6 @@ def test_add_series_with_extension_array(self, data, request):
11211086
)
11221087
super().test_add_series_with_extension_array(data)
11231088

1124-
1125-
class TestBaseComparisonOps(base.BaseComparisonOpsTests):
1126-
def test_compare_array(self, data, comparison_op, na_value):
1127-
ser = pd.Series(data)
1128-
# pd.Series([ser.iloc[0]] * len(ser)) may not return ArrowExtensionArray
1129-
# since ser.iloc[0] is a python scalar
1130-
other = pd.Series(pd.array([ser.iloc[0]] * len(ser), dtype=data.dtype))
1131-
if comparison_op.__name__ in ["eq", "ne"]:
1132-
# comparison should match point-wise comparisons
1133-
result = comparison_op(ser, other)
1134-
# Series.combine does not calculate the NA mask correctly
1135-
# when comparing over an array
1136-
assert result[8] is na_value
1137-
assert result[97] is na_value
1138-
expected = ser.combine(other, comparison_op)
1139-
expected[8] = na_value
1140-
expected[97] = na_value
1141-
tm.assert_series_equal(result, expected)
1142-
1143-
else:
1144-
return super().test_compare_array(data, comparison_op)
1145-
11461089
def test_invalid_other_comp(self, data, comparison_op):
11471090
# GH 48833
11481091
with pytest.raises(

pandas/tests/extension/test_masked.py

+5-76
Original file line numberDiff line numberDiff line change
@@ -159,11 +159,7 @@ def data_for_grouping(dtype):
159159
return pd.array([b, b, na, na, a, a, b, c], dtype=dtype)
160160

161161

162-
class TestDtype(base.BaseDtypeTests):
163-
pass
164-
165-
166-
class TestArithmeticOps(base.BaseArithmeticOpsTests):
162+
class TestMaskedArrays(base.ExtensionTests):
167163
def _get_expected_exception(self, op_name, obj, other):
168164
try:
169165
dtype = tm.get_dtype(obj)
@@ -179,12 +175,15 @@ def _get_expected_exception(self, op_name, obj, other):
179175
# exception message would include "numpy boolean subtract""
180176
return TypeError
181177
return None
182-
return super()._get_expected_exception(op_name, obj, other)
178+
return None
183179

184180
def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result):
185181
sdtype = tm.get_dtype(obj)
186182
expected = pointwise_result
187183

184+
if op_name in ("eq", "ne", "le", "ge", "lt", "gt"):
185+
return expected.astype("boolean")
186+
188187
if sdtype.kind in "iu":
189188
if op_name in ("__rtruediv__", "__truediv__", "__div__"):
190189
expected = expected.fillna(np.nan).astype("Float64")
@@ -219,11 +218,6 @@ def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result):
219218
expected = expected.astype(sdtype)
220219
return expected
221220

222-
series_scalar_exc = None
223-
series_array_exc = None
224-
frame_scalar_exc = None
225-
divmod_exc = None
226-
227221
def test_divmod_series_array(self, data, data_for_twos, request):
228222
if data.dtype.kind == "b":
229223
mark = pytest.mark.xfail(
@@ -234,49 +228,6 @@ def test_divmod_series_array(self, data, data_for_twos, request):
234228
request.node.add_marker(mark)
235229
super().test_divmod_series_array(data, data_for_twos)
236230

237-
238-
class TestComparisonOps(base.BaseComparisonOpsTests):
239-
series_scalar_exc = None
240-
series_array_exc = None
241-
frame_scalar_exc = None
242-
243-
def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result):
244-
return pointwise_result.astype("boolean")
245-
246-
247-
class TestInterface(base.BaseInterfaceTests):
248-
pass
249-
250-
251-
class TestConstructors(base.BaseConstructorsTests):
252-
pass
253-
254-
255-
class TestReshaping(base.BaseReshapingTests):
256-
pass
257-
258-
# for test_concat_mixed_dtypes test
259-
# concat of an Integer and Int coerces to object dtype
260-
# TODO(jreback) once integrated this would
261-
262-
263-
class TestGetitem(base.BaseGetitemTests):
264-
pass
265-
266-
267-
class TestSetitem(base.BaseSetitemTests):
268-
pass
269-
270-
271-
class TestIndex(base.BaseIndexTests):
272-
pass
273-
274-
275-
class TestMissing(base.BaseMissingTests):
276-
pass
277-
278-
279-
class TestMethods(base.BaseMethodsTests):
280231
def test_combine_le(self, data_repeated):
281232
# TODO: patching self is a bad pattern here
282233
orig_data1, orig_data2 = data_repeated(2)
@@ -287,16 +238,6 @@ def test_combine_le(self, data_repeated):
287238
self._combine_le_expected_dtype = object
288239
super().test_combine_le(data_repeated)
289240

290-
291-
class TestCasting(base.BaseCastingTests):
292-
pass
293-
294-
295-
class TestGroupby(base.BaseGroupbyTests):
296-
pass
297-
298-
299-
class TestReduce(base.BaseReduceTests):
300241
def _supports_reduction(self, obj, op_name: str) -> bool:
301242
if op_name in ["any", "all"] and tm.get_dtype(obj).kind != "b":
302243
pytest.skip(reason="Tested in tests/reductions/test_reductions.py")
@@ -351,8 +292,6 @@ def _get_expected_reduction_dtype(self, arr, op_name: str):
351292
raise TypeError("not supposed to reach this")
352293
return cmp_dtype
353294

354-
355-
class TestAccumulation(base.BaseAccumulateTests):
356295
def _supports_accumulation(self, ser: pd.Series, op_name: str) -> bool:
357296
return True
358297

@@ -411,8 +350,6 @@ def check_accumulate(self, ser: pd.Series, op_name: str, skipna: bool):
411350
else:
412351
raise NotImplementedError(f"{op_name} not supported")
413352

414-
415-
class TestUnaryOps(base.BaseUnaryOpsTests):
416353
def test_invert(self, data, request):
417354
if data.dtype.kind == "f":
418355
mark = pytest.mark.xfail(
@@ -423,13 +360,5 @@ def test_invert(self, data, request):
423360
super().test_invert(data)
424361

425362

426-
class TestPrinting(base.BasePrintingTests):
427-
pass
428-
429-
430-
class TestParsing(base.BaseParsingTests):
431-
pass
432-
433-
434363
class Test2DCompat(base.Dim2CompatTests):
435364
pass

0 commit comments

Comments
 (0)