Skip to content

Commit 2d53cef

Browse files
jbrockmendelcbpygit
authored andcommitted
TST: use one-class pattern for SparseArray (pandas-dev#56513)
* TST: use one-class pattern for SparseArray * mypy fixup
1 parent 0c5e90e commit 2d53cef

File tree

2 files changed

+111
-60
lines changed

2 files changed

+111
-60
lines changed

pandas/tests/extension/base/setitem.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,13 @@ def skip_if_immutable(self, dtype, request):
4343
# This fixture is auto-used, but we want to not-skip
4444
# test_is_immutable.
4545
return
46-
pytest.skip(f"__setitem__ test not applicable with immutable dtype {dtype}")
46+
47+
# When BaseSetitemTests is mixed into ExtensionTests, we only
48+
# want this fixture to operate on the tests defined in this
49+
# class/file.
50+
defined_in = node.function.__qualname__.split(".")[0]
51+
if defined_in == "BaseSetitemTests":
52+
pytest.skip("__setitem__ test not applicable with immutable dtype")
4753

4854
def test_is_immutable(self, data):
4955
if data.dtype._is_immutable:

pandas/tests/extension/test_sparse.py

+104-59
Original file line numberDiff line numberDiff line change
@@ -98,26 +98,64 @@ def data_for_compare(request):
9898
return SparseArray([0, 0, np.nan, -2, -1, 4, 2, 3, 0, 0], fill_value=request.param)
9999

100100

101-
class BaseSparseTests:
101+
class TestSparseArray(base.ExtensionTests):
102+
def _supports_reduction(self, obj, op_name: str) -> bool:
103+
return True
104+
105+
@pytest.mark.parametrize("skipna", [True, False])
106+
def test_reduce_series_numeric(self, data, all_numeric_reductions, skipna, request):
107+
if all_numeric_reductions in [
108+
"prod",
109+
"median",
110+
"var",
111+
"std",
112+
"sem",
113+
"skew",
114+
"kurt",
115+
]:
116+
mark = pytest.mark.xfail(
117+
reason="This should be viable but is not implemented"
118+
)
119+
request.node.add_marker(mark)
120+
elif (
121+
all_numeric_reductions in ["sum", "max", "min", "mean"]
122+
and data.dtype.kind == "f"
123+
and not skipna
124+
):
125+
mark = pytest.mark.xfail(reason="getting a non-nan float")
126+
request.node.add_marker(mark)
127+
128+
super().test_reduce_series_numeric(data, all_numeric_reductions, skipna)
129+
130+
@pytest.mark.parametrize("skipna", [True, False])
131+
def test_reduce_frame(self, data, all_numeric_reductions, skipna, request):
132+
if all_numeric_reductions in [
133+
"prod",
134+
"median",
135+
"var",
136+
"std",
137+
"sem",
138+
"skew",
139+
"kurt",
140+
]:
141+
mark = pytest.mark.xfail(
142+
reason="This should be viable but is not implemented"
143+
)
144+
request.node.add_marker(mark)
145+
elif (
146+
all_numeric_reductions in ["sum", "max", "min", "mean"]
147+
and data.dtype.kind == "f"
148+
and not skipna
149+
):
150+
mark = pytest.mark.xfail(reason="ExtensionArray NA mask are different")
151+
request.node.add_marker(mark)
152+
153+
super().test_reduce_frame(data, all_numeric_reductions, skipna)
154+
102155
def _check_unsupported(self, data):
103156
if data.dtype == SparseDtype(int, 0):
104157
pytest.skip("Can't store nan in int array.")
105158

106-
107-
class TestDtype(BaseSparseTests, base.BaseDtypeTests):
108-
def test_array_type_with_arg(self, data, dtype):
109-
assert dtype.construct_array_type() is SparseArray
110-
111-
112-
class TestInterface(BaseSparseTests, base.BaseInterfaceTests):
113-
pass
114-
115-
116-
class TestConstructors(BaseSparseTests, base.BaseConstructorsTests):
117-
pass
118-
119-
120-
class TestReshaping(BaseSparseTests, base.BaseReshapingTests):
121159
def test_concat_mixed_dtypes(self, data):
122160
# https://github.com/pandas-dev/pandas/issues/20762
123161
# This should be the same, aside from concat([sparse, float])
@@ -173,8 +211,6 @@ def test_merge(self, data, na_value):
173211
self._check_unsupported(data)
174212
super().test_merge(data, na_value)
175213

176-
177-
class TestGetitem(BaseSparseTests, base.BaseGetitemTests):
178214
def test_get(self, data):
179215
ser = pd.Series(data, index=[2 * i for i in range(len(data))])
180216
if np.isnan(ser.values.fill_value):
@@ -187,16 +223,6 @@ def test_reindex(self, data, na_value):
187223
self._check_unsupported(data)
188224
super().test_reindex(data, na_value)
189225

190-
191-
class TestSetitem(BaseSparseTests, base.BaseSetitemTests):
192-
pass
193-
194-
195-
class TestIndex(base.BaseIndexTests):
196-
pass
197-
198-
199-
class TestMissing(BaseSparseTests, base.BaseMissingTests):
200226
def test_isna(self, data_missing):
201227
sarr = SparseArray(data_missing)
202228
expected_dtype = SparseDtype(bool, pd.isna(data_missing.dtype.fill_value))
@@ -249,8 +275,6 @@ def test_fillna_frame(self, data_missing):
249275

250276
tm.assert_frame_equal(result, expected)
251277

252-
253-
class TestMethods(BaseSparseTests, base.BaseMethodsTests):
254278
_combine_le_expected_dtype = "Sparse[bool]"
255279

256280
def test_fillna_copy_frame(self, data_missing, using_copy_on_write):
@@ -351,16 +375,12 @@ def test_map_raises(self, data, na_action):
351375
with pytest.raises(ValueError, match=msg):
352376
data.map(lambda x: np.nan, na_action=na_action)
353377

354-
355-
class TestCasting(BaseSparseTests, base.BaseCastingTests):
356378
@pytest.mark.xfail(raises=TypeError, reason="no sparse StringDtype")
357379
def test_astype_string(self, data, nullable_string_dtype):
358380
# TODO: this fails bc we do not pass through nullable_string_dtype;
359381
# If we did, the 0-cases would xpass
360382
super().test_astype_string(data)
361383

362-
363-
class TestArithmeticOps(BaseSparseTests, base.BaseArithmeticOpsTests):
364384
series_scalar_exc = None
365385
frame_scalar_exc = None
366386
divmod_exc = None
@@ -397,17 +417,27 @@ def test_arith_frame_with_scalar(self, data, all_arithmetic_operators, request):
397417
request.applymarker(mark)
398418
super().test_arith_frame_with_scalar(data, all_arithmetic_operators)
399419

400-
401-
class TestComparisonOps(BaseSparseTests):
402-
def _compare_other(self, data_for_compare: SparseArray, comparison_op, other):
420+
def _compare_other(
421+
self, ser: pd.Series, data_for_compare: SparseArray, comparison_op, other
422+
):
403423
op = comparison_op
404424

405425
result = op(data_for_compare, other)
406-
assert isinstance(result, SparseArray)
426+
if isinstance(other, pd.Series):
427+
assert isinstance(result, pd.Series)
428+
assert isinstance(result.dtype, SparseDtype)
429+
else:
430+
assert isinstance(result, SparseArray)
407431
assert result.dtype.subtype == np.bool_
408432

409-
if isinstance(other, SparseArray):
410-
fill_value = op(data_for_compare.fill_value, other.fill_value)
433+
if isinstance(other, pd.Series):
434+
fill_value = op(data_for_compare.fill_value, other._values.fill_value)
435+
expected = SparseArray(
436+
op(data_for_compare.to_dense(), np.asarray(other)),
437+
fill_value=fill_value,
438+
dtype=np.bool_,
439+
)
440+
411441
else:
412442
fill_value = np.all(
413443
op(np.asarray(data_for_compare.fill_value), np.asarray(other))
@@ -418,36 +448,51 @@ def _compare_other(self, data_for_compare: SparseArray, comparison_op, other):
418448
fill_value=fill_value,
419449
dtype=np.bool_,
420450
)
421-
tm.assert_sp_array_equal(result, expected)
451+
if isinstance(other, pd.Series):
452+
# error: Incompatible types in assignment
453+
expected = pd.Series(expected) # type: ignore[assignment]
454+
tm.assert_equal(result, expected)
422455

423456
def test_scalar(self, data_for_compare: SparseArray, comparison_op):
424-
self._compare_other(data_for_compare, comparison_op, 0)
425-
self._compare_other(data_for_compare, comparison_op, 1)
426-
self._compare_other(data_for_compare, comparison_op, -1)
427-
self._compare_other(data_for_compare, comparison_op, np.nan)
457+
ser = pd.Series(data_for_compare)
458+
self._compare_other(ser, data_for_compare, comparison_op, 0)
459+
self._compare_other(ser, data_for_compare, comparison_op, 1)
460+
self._compare_other(ser, data_for_compare, comparison_op, -1)
461+
self._compare_other(ser, data_for_compare, comparison_op, np.nan)
462+
463+
def test_array(self, data_for_compare: SparseArray, comparison_op, request):
464+
if data_for_compare.dtype.fill_value == 0 and comparison_op.__name__ in [
465+
"eq",
466+
"ge",
467+
"le",
468+
]:
469+
mark = pytest.mark.xfail(reason="Wrong fill_value")
470+
request.applymarker(mark)
428471

429-
@pytest.mark.xfail(reason="Wrong indices")
430-
def test_array(self, data_for_compare: SparseArray, comparison_op):
431472
arr = np.linspace(-4, 5, 10)
432-
self._compare_other(data_for_compare, comparison_op, arr)
473+
ser = pd.Series(data_for_compare)
474+
self._compare_other(ser, data_for_compare, comparison_op, arr)
433475

434-
@pytest.mark.xfail(reason="Wrong indices")
435-
def test_sparse_array(self, data_for_compare: SparseArray, comparison_op):
476+
def test_sparse_array(self, data_for_compare: SparseArray, comparison_op, request):
477+
if data_for_compare.dtype.fill_value == 0 and comparison_op.__name__ != "gt":
478+
mark = pytest.mark.xfail(reason="Wrong fill_value")
479+
request.applymarker(mark)
480+
481+
ser = pd.Series(data_for_compare)
436482
arr = data_for_compare + 1
437-
self._compare_other(data_for_compare, comparison_op, arr)
483+
self._compare_other(ser, data_for_compare, comparison_op, arr)
438484
arr = data_for_compare * 2
439-
self._compare_other(data_for_compare, comparison_op, arr)
485+
self._compare_other(ser, data_for_compare, comparison_op, arr)
440486

441-
442-
class TestPrinting(BaseSparseTests, base.BasePrintingTests):
443487
@pytest.mark.xfail(reason="Different repr")
444488
def test_array_repr(self, data, size):
445489
super().test_array_repr(data, size)
446490

447-
448-
class TestParsing(BaseSparseTests, base.BaseParsingTests):
449-
pass
491+
@pytest.mark.xfail(reason="result does not match expected")
492+
@pytest.mark.parametrize("as_index", [True, False])
493+
def test_groupby_extension_agg(self, as_index, data_for_grouping):
494+
super().test_groupby_extension_agg(as_index, data_for_grouping)
450495

451496

452-
class TestNoNumericAccumulations(base.BaseAccumulateTests):
453-
pass
497+
def test_array_type_with_arg(dtype):
498+
assert dtype.construct_array_type() is SparseArray

0 commit comments

Comments
 (0)