Skip to content

Commit 597bede

Browse files
committed
BUG: SparseArray doesn't recalc indices. (pandas-dev#44956, pandas-dev#45110)
1 parent e892d46 commit 597bede

File tree

4 files changed

+53
-25
lines changed

4 files changed

+53
-25
lines changed

pandas/core/arrays/sparse/array.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -1680,13 +1680,14 @@ def _cmp_method(self, other, op) -> SparseArray:
16801680
op_name = op.__name__.strip("_")
16811681
return _sparse_array_op(self, other, op, op_name)
16821682
else:
1683+
# scalar
16831684
with np.errstate(all="ignore"):
16841685
fill_value = op(self.fill_value, other)
1685-
result = op(self.sp_values, other)
1686+
mask = np.full(len(self), fill_value, dtype=np.bool_)
1687+
mask[self.sp_index.indices] = op(self.sp_values, other)
16861688

16871689
return type(self)(
1688-
result,
1689-
sparse_index=self.sp_index,
1690+
mask,
16901691
fill_value=fill_value,
16911692
dtype=np.bool_,
16921693
)

pandas/tests/arrays/sparse/test_arithmetics.py

+2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ def mix(request):
2626
return request.param
2727

2828

29+
# FIXME: There are not SparseArray tests. There are numpy array tests.
30+
# We don't check indices. See GH #45110, #44956, XXX
2931
class TestSparseArrayArithmetics:
3032

3133
_base = np.array

pandas/tests/arrays/sparse/test_array.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -248,8 +248,8 @@ def test_scalar_with_index_infer_dtype(self, scalar, dtype):
248248
assert arr.dtype == dtype
249249
assert exp.dtype == dtype
250250

251-
# GH 23122
252251
def test_getitem_bool_sparse_array(self):
252+
# GH 23122
253253
spar_bool = SparseArray([False, True] * 5, dtype=np.bool8, fill_value=True)
254254
exp = SparseArray([np.nan, 2, np.nan, 5, 6])
255255
tm.assert_sp_array_equal(self.arr[spar_bool], exp)
@@ -266,6 +266,13 @@ def test_getitem_bool_sparse_array(self):
266266
exp = SparseArray([np.nan, 3, 5])
267267
tm.assert_sp_array_equal(res, exp)
268268

269+
def test_getitem_bool_sparse_array_as_comparison(self):
270+
# GH 45110
271+
arr = SparseArray([1, 2, 3, 4, np.nan, np.nan], fill_value=np.nan)
272+
res = arr[arr > 2]
273+
exp = SparseArray([3.0, 4.0], fill_value=np.nan)
274+
tm.assert_sp_array_equal(res, exp)
275+
269276
def test_get_item(self):
270277

271278
assert np.isnan(self.arr[1])

pandas/tests/extension/test_sparse.py

+39-21
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,11 @@ def data_for_grouping(request):
100100
return SparseArray([1, 1, np.nan, np.nan, 2, 2, 1, 3], fill_value=request.param)
101101

102102

103+
@pytest.fixture(params=[0, np.nan])
104+
def data_for_compare(request):
105+
return SparseArray([0, 0, np.nan, -2, -1, 4, 2, 3, 0, 0], fill_value=request.param)
106+
107+
103108
class BaseSparseTests:
104109
def _check_unsupported(self, data):
105110
if data.dtype == SparseDtype(int, 0):
@@ -432,32 +437,45 @@ def _check_divmod_op(self, ser, op, other, exc=NotImplementedError):
432437
super()._check_divmod_op(ser, op, other, exc=None)
433438

434439

435-
class TestComparisonOps(BaseSparseTests, base.BaseComparisonOpsTests):
436-
def _compare_other(self, s, data, comparison_op, other):
440+
class TestComparisonOps(BaseSparseTests):
441+
def _compare_other(self, data_for_compare: SparseArray, comparison_op, other):
437442
op = comparison_op
438443

439-
# array
440-
result = pd.Series(op(data, other))
441-
# hard to test the fill value, since we don't know what expected
442-
# is in general.
443-
# Rely on tests in `tests/sparse` to validate that.
444-
assert isinstance(result.dtype, SparseDtype)
445-
assert result.dtype.subtype == np.dtype("bool")
446-
447-
with np.errstate(all="ignore"):
448-
expected = pd.Series(
449-
SparseArray(
450-
op(np.asarray(data), np.asarray(other)),
451-
fill_value=result.values.fill_value,
452-
)
444+
result = op(data_for_compare, other)
445+
assert isinstance(result, SparseArray)
446+
assert result.dtype.subtype == np.bool_
447+
448+
if isinstance(other, SparseArray):
449+
fill_value = op(data_for_compare.fill_value, other.fill_value)
450+
else:
451+
fill_value = np.all(
452+
op(np.asarray(data_for_compare.fill_value), np.asarray(other))
453453
)
454454

455-
tm.assert_series_equal(result, expected)
455+
expected = SparseArray(
456+
op(data_for_compare.to_dense(), np.asarray(other)),
457+
fill_value=fill_value,
458+
dtype=np.bool_,
459+
)
460+
tm.assert_sp_array_equal(result, expected)
456461

457-
# series
458-
ser = pd.Series(data)
459-
result = op(ser, other)
460-
tm.assert_series_equal(result, expected)
462+
def test_scalar(self, data_for_compare: SparseArray, comparison_op):
463+
self._compare_other(data_for_compare, comparison_op, 0)
464+
self._compare_other(data_for_compare, comparison_op, 1)
465+
self._compare_other(data_for_compare, comparison_op, -1)
466+
self._compare_other(data_for_compare, comparison_op, np.nan)
467+
468+
@pytest.mark.xfail(reason="Wrong indices")
469+
def test_array(self, data_for_compare: SparseArray, comparison_op):
470+
arr = np.linspace(-4, 5, 10)
471+
self._compare_other(data_for_compare, comparison_op, arr)
472+
473+
@pytest.mark.xfail(reason="Wrong indices")
474+
def test_sparse_array(self, data_for_compare: SparseArray, comparison_op):
475+
arr = data_for_compare + 1
476+
self._compare_other(data_for_compare, comparison_op, arr)
477+
arr = data_for_compare * 2
478+
self._compare_other(data_for_compare, comparison_op, arr)
461479

462480

463481
class TestPrinting(BaseSparseTests, base.BasePrintingTests):

0 commit comments

Comments
 (0)