Skip to content

Commit 881dc01

Browse files
committed
BUG: SparseArray doesn't recalc indices. (pandas-dev#44956, pandas-dev#45110)
1 parent e892d46 commit 881dc01

File tree

3 files changed

+48
-24
lines changed

3 files changed

+48
-24
lines changed

pandas/core/arrays/sparse/array.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -1680,13 +1680,14 @@ def _cmp_method(self, other, op) -> SparseArray:
16801680
op_name = op.__name__.strip("_")
16811681
return _sparse_array_op(self, other, op, op_name)
16821682
else:
1683+
# scalar
16831684
with np.errstate(all="ignore"):
16841685
fill_value = op(self.fill_value, other)
1685-
result = op(self.sp_values, other)
1686+
mask = np.full(len(self), fill_value, dtype=np.bool_)
1687+
mask[self.sp_index.indices] = op(self.sp_values, other)
16861688

16871689
return type(self)(
1688-
result,
1689-
sparse_index=self.sp_index,
1690+
mask,
16901691
fill_value=fill_value,
16911692
dtype=np.bool_,
16921693
)

pandas/tests/arrays/sparse/test_arithmetics.py

+2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ def mix(request):
2626
return request.param
2727

2828

29+
# FIXME: There are not SparseArray tests. There are numpy array tests.
30+
# We don't check indices. See GH #45110, #44956, XXX
2931
class TestSparseArrayArithmetics:
3032

3133
_base = np.array

pandas/tests/extension/test_sparse.py

+42-21
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,11 @@ def data_for_grouping(request):
100100
return SparseArray([1, 1, np.nan, np.nan, 2, 2, 1, 3], fill_value=request.param)
101101

102102

103+
@pytest.fixture(params=[0, np.nan])
104+
def data_for_compare(request):
105+
return SparseArray([0, 0, np.nan, -2, -1, 4, 2, 3, 0, 0], fill_value=request.param)
106+
107+
103108
class BaseSparseTests:
104109
def _check_unsupported(self, data):
105110
if data.dtype == SparseDtype(int, 0):
@@ -432,32 +437,48 @@ def _check_divmod_op(self, ser, op, other, exc=NotImplementedError):
432437
super()._check_divmod_op(ser, op, other, exc=None)
433438

434439

435-
class TestComparisonOps(BaseSparseTests, base.BaseComparisonOpsTests):
436-
def _compare_other(self, s, data, comparison_op, other):
440+
class TestComparisonOps(BaseSparseTests):
441+
def _compare_other(self, data_for_compare: SparseArray, comparison_op, other):
437442
op = comparison_op
438443

439-
# array
440-
result = pd.Series(op(data, other))
441-
# hard to test the fill value, since we don't know what expected
442-
# is in general.
443-
# Rely on tests in `tests/sparse` to validate that.
444-
assert isinstance(result.dtype, SparseDtype)
445-
assert result.dtype.subtype == np.dtype("bool")
446-
447-
with np.errstate(all="ignore"):
448-
expected = pd.Series(
449-
SparseArray(
450-
op(np.asarray(data), np.asarray(other)),
451-
fill_value=result.values.fill_value,
452-
)
444+
result = op(data_for_compare, other)
445+
assert isinstance(result, SparseArray)
446+
assert result.dtype.subtype == np.bool_
447+
448+
if isinstance(other, SparseArray):
449+
expected = SparseArray(
450+
op(data_for_compare.to_dense(), np.asarray(other)),
451+
fill_value=op(data_for_compare.fill_value, other.fill_value),
452+
dtype=np.bool_,
453+
)
454+
else:
455+
expected = SparseArray(
456+
op(data_for_compare.to_dense(), np.asarray(other)),
457+
fill_value=np.all(
458+
op(np.asarray(data_for_compare.fill_value), np.asarray(other))
459+
),
460+
dtype=np.bool_,
453461
)
454462

455-
tm.assert_series_equal(result, expected)
463+
tm.assert_sp_array_equal(result, expected)
456464

457-
# series
458-
ser = pd.Series(data)
459-
result = op(ser, other)
460-
tm.assert_series_equal(result, expected)
465+
def test_scalar(self, data_for_compare: SparseArray, comparison_op):
466+
self._compare_other(data_for_compare, comparison_op, 0)
467+
self._compare_other(data_for_compare, comparison_op, 1)
468+
self._compare_other(data_for_compare, comparison_op, -1)
469+
self._compare_other(data_for_compare, comparison_op, np.nan)
470+
471+
@pytest.mark.xfail(reason="Wrong indices")
472+
def test_array(self, data_for_compare: SparseArray, comparison_op):
473+
arr = np.linspace(-4, 5, 10)
474+
self._compare_other(data_for_compare, comparison_op, arr)
475+
476+
@pytest.mark.xfail(reason="Wrong indices")
477+
def test_sparse_array(self, data_for_compare: SparseArray, comparison_op):
478+
arr = data_for_compare + 1
479+
self._compare_other(data_for_compare, comparison_op, arr)
480+
arr = data_for_compare * 2
481+
self._compare_other(data_for_compare, comparison_op, arr)
461482

462483

463484
class TestPrinting(BaseSparseTests, base.BasePrintingTests):

0 commit comments

Comments
 (0)