@@ -98,26 +98,64 @@ def data_for_compare(request):
98
98
return SparseArray ([0 , 0 , np .nan , - 2 , - 1 , 4 , 2 , 3 , 0 , 0 ], fill_value = request .param )
99
99
100
100
101
- class BaseSparseTests :
101
+ class TestSparseArray (base .ExtensionTests ):
102
+ def _supports_reduction (self , obj , op_name : str ) -> bool :
103
+ return True
104
+
105
+ @pytest .mark .parametrize ("skipna" , [True , False ])
106
+ def test_reduce_series_numeric (self , data , all_numeric_reductions , skipna , request ):
107
+ if all_numeric_reductions in [
108
+ "prod" ,
109
+ "median" ,
110
+ "var" ,
111
+ "std" ,
112
+ "sem" ,
113
+ "skew" ,
114
+ "kurt" ,
115
+ ]:
116
+ mark = pytest .mark .xfail (
117
+ reason = "This should be viable but is not implemented"
118
+ )
119
+ request .node .add_marker (mark )
120
+ elif (
121
+ all_numeric_reductions in ["sum" , "max" , "min" , "mean" ]
122
+ and data .dtype .kind == "f"
123
+ and not skipna
124
+ ):
125
+ mark = pytest .mark .xfail (reason = "getting a non-nan float" )
126
+ request .node .add_marker (mark )
127
+
128
+ super ().test_reduce_series_numeric (data , all_numeric_reductions , skipna )
129
+
130
+ @pytest .mark .parametrize ("skipna" , [True , False ])
131
+ def test_reduce_frame (self , data , all_numeric_reductions , skipna , request ):
132
+ if all_numeric_reductions in [
133
+ "prod" ,
134
+ "median" ,
135
+ "var" ,
136
+ "std" ,
137
+ "sem" ,
138
+ "skew" ,
139
+ "kurt" ,
140
+ ]:
141
+ mark = pytest .mark .xfail (
142
+ reason = "This should be viable but is not implemented"
143
+ )
144
+ request .node .add_marker (mark )
145
+ elif (
146
+ all_numeric_reductions in ["sum" , "max" , "min" , "mean" ]
147
+ and data .dtype .kind == "f"
148
+ and not skipna
149
+ ):
150
+ mark = pytest .mark .xfail (reason = "ExtensionArray NA mask are different" )
151
+ request .node .add_marker (mark )
152
+
153
+ super ().test_reduce_frame (data , all_numeric_reductions , skipna )
154
+
102
155
def _check_unsupported (self , data ):
103
156
if data .dtype == SparseDtype (int , 0 ):
104
157
pytest .skip ("Can't store nan in int array." )
105
158
106
-
107
- class TestDtype (BaseSparseTests , base .BaseDtypeTests ):
108
- def test_array_type_with_arg (self , data , dtype ):
109
- assert dtype .construct_array_type () is SparseArray
110
-
111
-
112
- class TestInterface (BaseSparseTests , base .BaseInterfaceTests ):
113
- pass
114
-
115
-
116
- class TestConstructors (BaseSparseTests , base .BaseConstructorsTests ):
117
- pass
118
-
119
-
120
- class TestReshaping (BaseSparseTests , base .BaseReshapingTests ):
121
159
def test_concat_mixed_dtypes (self , data ):
122
160
# https://github.com/pandas-dev/pandas/issues/20762
123
161
# This should be the same, aside from concat([sparse, float])
@@ -173,8 +211,6 @@ def test_merge(self, data, na_value):
173
211
self ._check_unsupported (data )
174
212
super ().test_merge (data , na_value )
175
213
176
-
177
- class TestGetitem (BaseSparseTests , base .BaseGetitemTests ):
178
214
def test_get (self , data ):
179
215
ser = pd .Series (data , index = [2 * i for i in range (len (data ))])
180
216
if np .isnan (ser .values .fill_value ):
@@ -187,16 +223,6 @@ def test_reindex(self, data, na_value):
187
223
self ._check_unsupported (data )
188
224
super ().test_reindex (data , na_value )
189
225
190
-
191
- class TestSetitem (BaseSparseTests , base .BaseSetitemTests ):
192
- pass
193
-
194
-
195
- class TestIndex (base .BaseIndexTests ):
196
- pass
197
-
198
-
199
- class TestMissing (BaseSparseTests , base .BaseMissingTests ):
200
226
def test_isna (self , data_missing ):
201
227
sarr = SparseArray (data_missing )
202
228
expected_dtype = SparseDtype (bool , pd .isna (data_missing .dtype .fill_value ))
@@ -249,8 +275,6 @@ def test_fillna_frame(self, data_missing):
249
275
250
276
tm .assert_frame_equal (result , expected )
251
277
252
-
253
- class TestMethods (BaseSparseTests , base .BaseMethodsTests ):
254
278
_combine_le_expected_dtype = "Sparse[bool]"
255
279
256
280
def test_fillna_copy_frame (self , data_missing , using_copy_on_write ):
@@ -351,16 +375,12 @@ def test_map_raises(self, data, na_action):
351
375
with pytest .raises (ValueError , match = msg ):
352
376
data .map (lambda x : np .nan , na_action = na_action )
353
377
354
-
355
- class TestCasting (BaseSparseTests , base .BaseCastingTests ):
356
378
@pytest .mark .xfail (raises = TypeError , reason = "no sparse StringDtype" )
357
379
def test_astype_string (self , data , nullable_string_dtype ):
358
380
# TODO: this fails bc we do not pass through nullable_string_dtype;
359
381
# If we did, the 0-cases would xpass
360
382
super ().test_astype_string (data )
361
383
362
-
363
- class TestArithmeticOps (BaseSparseTests , base .BaseArithmeticOpsTests ):
364
384
series_scalar_exc = None
365
385
frame_scalar_exc = None
366
386
divmod_exc = None
@@ -397,17 +417,27 @@ def test_arith_frame_with_scalar(self, data, all_arithmetic_operators, request):
397
417
request .applymarker (mark )
398
418
super ().test_arith_frame_with_scalar (data , all_arithmetic_operators )
399
419
400
-
401
- class TestComparisonOps ( BaseSparseTests ):
402
- def _compare_other ( self , data_for_compare : SparseArray , comparison_op , other ):
420
+ def _compare_other (
421
+ self , ser : pd . Series , data_for_compare : SparseArray , comparison_op , other
422
+ ):
403
423
op = comparison_op
404
424
405
425
result = op (data_for_compare , other )
406
- assert isinstance (result , SparseArray )
426
+ if isinstance (other , pd .Series ):
427
+ assert isinstance (result , pd .Series )
428
+ assert isinstance (result .dtype , SparseDtype )
429
+ else :
430
+ assert isinstance (result , SparseArray )
407
431
assert result .dtype .subtype == np .bool_
408
432
409
- if isinstance (other , SparseArray ):
410
- fill_value = op (data_for_compare .fill_value , other .fill_value )
433
+ if isinstance (other , pd .Series ):
434
+ fill_value = op (data_for_compare .fill_value , other ._values .fill_value )
435
+ expected = SparseArray (
436
+ op (data_for_compare .to_dense (), np .asarray (other )),
437
+ fill_value = fill_value ,
438
+ dtype = np .bool_ ,
439
+ )
440
+
411
441
else :
412
442
fill_value = np .all (
413
443
op (np .asarray (data_for_compare .fill_value ), np .asarray (other ))
@@ -418,36 +448,51 @@ def _compare_other(self, data_for_compare: SparseArray, comparison_op, other):
418
448
fill_value = fill_value ,
419
449
dtype = np .bool_ ,
420
450
)
421
- tm .assert_sp_array_equal (result , expected )
451
+ if isinstance (other , pd .Series ):
452
+ # error: Incompatible types in assignment
453
+ expected = pd .Series (expected ) # type: ignore[assignment]
454
+ tm .assert_equal (result , expected )
422
455
423
456
def test_scalar (self , data_for_compare : SparseArray , comparison_op ):
424
- self ._compare_other (data_for_compare , comparison_op , 0 )
425
- self ._compare_other (data_for_compare , comparison_op , 1 )
426
- self ._compare_other (data_for_compare , comparison_op , - 1 )
427
- self ._compare_other (data_for_compare , comparison_op , np .nan )
457
+ ser = pd .Series (data_for_compare )
458
+ self ._compare_other (ser , data_for_compare , comparison_op , 0 )
459
+ self ._compare_other (ser , data_for_compare , comparison_op , 1 )
460
+ self ._compare_other (ser , data_for_compare , comparison_op , - 1 )
461
+ self ._compare_other (ser , data_for_compare , comparison_op , np .nan )
462
+
463
+ def test_array (self , data_for_compare : SparseArray , comparison_op , request ):
464
+ if data_for_compare .dtype .fill_value == 0 and comparison_op .__name__ in [
465
+ "eq" ,
466
+ "ge" ,
467
+ "le" ,
468
+ ]:
469
+ mark = pytest .mark .xfail (reason = "Wrong fill_value" )
470
+ request .applymarker (mark )
428
471
429
- @pytest .mark .xfail (reason = "Wrong indices" )
430
- def test_array (self , data_for_compare : SparseArray , comparison_op ):
431
472
arr = np .linspace (- 4 , 5 , 10 )
432
- self ._compare_other (data_for_compare , comparison_op , arr )
473
+ ser = pd .Series (data_for_compare )
474
+ self ._compare_other (ser , data_for_compare , comparison_op , arr )
433
475
434
- @pytest .mark .xfail (reason = "Wrong indices" )
435
- def test_sparse_array (self , data_for_compare : SparseArray , comparison_op ):
476
+ def test_sparse_array (self , data_for_compare : SparseArray , comparison_op , request ):
477
+ if data_for_compare .dtype .fill_value == 0 and comparison_op .__name__ != "gt" :
478
+ mark = pytest .mark .xfail (reason = "Wrong fill_value" )
479
+ request .applymarker (mark )
480
+
481
+ ser = pd .Series (data_for_compare )
436
482
arr = data_for_compare + 1
437
- self ._compare_other (data_for_compare , comparison_op , arr )
483
+ self ._compare_other (ser , data_for_compare , comparison_op , arr )
438
484
arr = data_for_compare * 2
439
- self ._compare_other (data_for_compare , comparison_op , arr )
485
+ self ._compare_other (ser , data_for_compare , comparison_op , arr )
440
486
441
-
442
- class TestPrinting (BaseSparseTests , base .BasePrintingTests ):
443
487
@pytest .mark .xfail (reason = "Different repr" )
444
488
def test_array_repr (self , data , size ):
445
489
super ().test_array_repr (data , size )
446
490
447
-
448
- class TestParsing (BaseSparseTests , base .BaseParsingTests ):
449
- pass
491
+ @pytest .mark .xfail (reason = "result does not match expected" )
492
+ @pytest .mark .parametrize ("as_index" , [True , False ])
493
+ def test_groupby_extension_agg (self , as_index , data_for_grouping ):
494
+ super ().test_groupby_extension_agg (as_index , data_for_grouping )
450
495
451
496
452
- class TestNoNumericAccumulations ( base . BaseAccumulateTests ):
453
- pass
497
+ def test_array_type_with_arg ( dtype ):
498
+ assert dtype . construct_array_type () is SparseArray
0 commit comments