@@ -100,6 +100,11 @@ def data_for_grouping(request):
100
100
return SparseArray ([1 , 1 , np .nan , np .nan , 2 , 2 , 1 , 3 ], fill_value = request .param )
101
101
102
102
103
+ @pytest .fixture (params = [0 , np .nan ])
104
+ def data_for_compare (request ):
105
+ return SparseArray ([0 , 0 , np .nan , - 2 , - 1 , 4 , 2 , 3 , 0 , 0 ], fill_value = request .param )
106
+
107
+
103
108
class BaseSparseTests :
104
109
def _check_unsupported (self , data ):
105
110
if data .dtype == SparseDtype (int , 0 ):
@@ -461,32 +466,45 @@ def _check_divmod_op(self, ser, op, other, exc=NotImplementedError):
461
466
super ()._check_divmod_op (ser , op , other , exc = None )
462
467
463
468
464
- class TestComparisonOps (BaseSparseTests , base . BaseComparisonOpsTests ):
465
- def _compare_other (self , s , data , comparison_op , other ):
469
+ class TestComparisonOps (BaseSparseTests ):
470
+ def _compare_other (self , data_for_compare : SparseArray , comparison_op , other ):
466
471
op = comparison_op
467
472
468
- # array
469
- result = pd .Series (op (data , other ))
470
- # hard to test the fill value, since we don't know what expected
471
- # is in general.
472
- # Rely on tests in `tests/sparse` to validate that.
473
- assert isinstance (result .dtype , SparseDtype )
474
- assert result .dtype .subtype == np .dtype ("bool" )
475
-
476
- with np .errstate (all = "ignore" ):
477
- expected = pd .Series (
478
- SparseArray (
479
- op (np .asarray (data ), np .asarray (other )),
480
- fill_value = result .values .fill_value ,
481
- )
473
+ result = op (data_for_compare , other )
474
+ assert isinstance (result , SparseArray )
475
+ assert result .dtype .subtype == np .bool_
476
+
477
+ if isinstance (other , SparseArray ):
478
+ fill_value = op (data_for_compare .fill_value , other .fill_value )
479
+ else :
480
+ fill_value = np .all (
481
+ op (np .asarray (data_for_compare .fill_value ), np .asarray (other ))
482
482
)
483
483
484
- tm .assert_series_equal (result , expected )
484
+ expected = SparseArray (
485
+ op (data_for_compare .to_dense (), np .asarray (other )),
486
+ fill_value = fill_value ,
487
+ dtype = np .bool_ ,
488
+ )
489
+ tm .assert_sp_array_equal (result , expected )
485
490
486
- # series
487
- ser = pd .Series (data )
488
- result = op (ser , other )
489
- tm .assert_series_equal (result , expected )
491
+ def test_scalar (self , data_for_compare : SparseArray , comparison_op ):
492
+ self ._compare_other (data_for_compare , comparison_op , 0 )
493
+ self ._compare_other (data_for_compare , comparison_op , 1 )
494
+ self ._compare_other (data_for_compare , comparison_op , - 1 )
495
+ self ._compare_other (data_for_compare , comparison_op , np .nan )
496
+
497
+ @pytest .mark .xfail (reason = "Wrong indices" )
498
+ def test_array (self , data_for_compare : SparseArray , comparison_op ):
499
+ arr = np .linspace (- 4 , 5 , 10 )
500
+ self ._compare_other (data_for_compare , comparison_op , arr )
501
+
502
+ @pytest .mark .xfail (reason = "Wrong indices" )
503
+ def test_sparse_array (self , data_for_compare : SparseArray , comparison_op ):
504
+ arr = data_for_compare + 1
505
+ self ._compare_other (data_for_compare , comparison_op , arr )
506
+ arr = data_for_compare * 2
507
+ self ._compare_other (data_for_compare , comparison_op , arr )
490
508
491
509
492
510
class TestPrinting (BaseSparseTests , base .BasePrintingTests ):
0 commit comments