Skip to content

Commit b6333e6

Browse files
authored
TST: use single-class pattern for Categorical tests (#54570)
1 parent 203f483 commit b6333e6

File tree

2 files changed

+19
-46
lines changed

2 files changed

+19
-46
lines changed

pandas/tests/extension/base/groupby.py

+3
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
import pandas._testing as tm
1414

1515

16+
@pytest.mark.filterwarnings(
17+
"ignore:The default of observed=False is deprecated:FutureWarning"
18+
)
1619
class BaseGroupbyTests:
1720
"""Groupby-specific tests."""
1821

pandas/tests/extension/test_categorical.py

+16-46
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,7 @@ def data_for_grouping():
7272
return Categorical(["a", "a", None, None, "b", "b", "a", "c"])
7373

7474

75-
class TestDtype(base.BaseDtypeTests):
76-
pass
77-
78-
79-
class TestInterface(base.BaseInterfaceTests):
75+
class TestCategorical(base.ExtensionTests):
8076
@pytest.mark.xfail(reason="Memory usage doesn't match")
8177
def test_memory_usage(self, data):
8278
# TODO: Is this deliberate?
@@ -106,8 +102,6 @@ def test_contains(self, data, data_missing):
106102
assert na_value_obj not in data
107103
assert na_value_obj in data_missing # this line differs from super method
108104

109-
110-
class TestConstructors(base.BaseConstructorsTests):
111105
def test_empty(self, dtype):
112106
cls = dtype.construct_array_type()
113107
result = cls._empty((4,), dtype=dtype)
@@ -117,41 +111,13 @@ def test_empty(self, dtype):
117111
# dtype on our result.
118112
assert result.dtype == CategoricalDtype([])
119113

120-
121-
class TestReshaping(base.BaseReshapingTests):
122-
pass
123-
124-
125-
class TestGetitem(base.BaseGetitemTests):
126114
@pytest.mark.skip(reason="Backwards compatibility")
127115
def test_getitem_scalar(self, data):
128116
# CategoricalDtype.type isn't "correct" since it should
129117
# be a parent of the elements (object). But don't want
130118
# to break things by changing.
131119
super().test_getitem_scalar(data)
132120

133-
134-
class TestSetitem(base.BaseSetitemTests):
135-
pass
136-
137-
138-
class TestIndex(base.BaseIndexTests):
139-
pass
140-
141-
142-
class TestMissing(base.BaseMissingTests):
143-
pass
144-
145-
146-
class TestReduce(base.BaseReduceTests):
147-
pass
148-
149-
150-
class TestAccumulate(base.BaseAccumulateTests):
151-
pass
152-
153-
154-
class TestMethods(base.BaseMethodsTests):
155121
@pytest.mark.xfail(reason="Unobserved categories included")
156122
def test_value_counts(self, all_data, dropna):
157123
return super().test_value_counts(all_data, dropna)
@@ -178,12 +144,6 @@ def test_map(self, data, na_action):
178144
result = data.map(lambda x: x, na_action=na_action)
179145
tm.assert_extension_array_equal(result, data)
180146

181-
182-
class TestCasting(base.BaseCastingTests):
183-
pass
184-
185-
186-
class TestArithmeticOps(base.BaseArithmeticOpsTests):
187147
def test_arith_frame_with_scalar(self, data, all_arithmetic_operators, request):
188148
# frame & scalar
189149
op_name = all_arithmetic_operators
@@ -205,8 +165,6 @@ def test_arith_series_with_scalar(self, data, all_arithmetic_operators, request)
205165
)
206166
super().test_arith_series_with_scalar(data, op_name)
207167

208-
209-
class TestComparisonOps(base.BaseComparisonOpsTests):
210168
def _compare_other(self, s, data, op, other):
211169
op_name = f"__{op.__name__}__"
212170
if op_name not in ["__eq__", "__ne__"]:
@@ -216,9 +174,21 @@ def _compare_other(self, s, data, op, other):
216174
else:
217175
return super()._compare_other(s, data, op, other)
218176

219-
220-
class TestParsing(base.BaseParsingTests):
221-
pass
177+
@pytest.mark.xfail(reason="Categorical overrides __repr__")
178+
@pytest.mark.parametrize("size", ["big", "small"])
179+
def test_array_repr(self, data, size):
180+
super().test_array_repr(data, size)
181+
182+
@pytest.mark.xfail(
183+
reason="Looks like the test (incorrectly) implicitly assumes int/bool dtype"
184+
)
185+
def test_invert(self, data):
186+
super().test_invert(data)
187+
188+
@pytest.mark.xfail(reason="TBD")
189+
@pytest.mark.parametrize("as_index", [True, False])
190+
def test_groupby_extension_agg(self, as_index, data_for_grouping):
191+
super().test_groupby_extension_agg(as_index, data_for_grouping)
222192

223193

224194
class Test2DCompat(base.NDArrayBacked2DTests):

0 commit comments

Comments
 (0)