@@ -265,7 +265,7 @@ def data_for_twos(data):
265
265
# TODO: skip otherwise?
266
266
267
267
268
- class TestBaseCasting (base .BaseCastingTests ):
268
+ class TestArrowArray (base .ExtensionTests ):
269
269
def test_astype_str (self , data , request ):
270
270
pa_dtype = data .dtype .pyarrow_dtype
271
271
if pa .types .is_binary (pa_dtype ):
@@ -276,8 +276,6 @@ def test_astype_str(self, data, request):
276
276
)
277
277
super ().test_astype_str (data )
278
278
279
-
280
- class TestConstructors (base .BaseConstructorsTests ):
281
279
def test_from_dtype (self , data , request ):
282
280
pa_dtype = data .dtype .pyarrow_dtype
283
281
if pa .types .is_string (pa_dtype ) or pa .types .is_decimal (pa_dtype ):
@@ -338,12 +336,6 @@ def test_from_sequence_of_strings_pa_array(self, data, request):
338
336
result = type (data )._from_sequence_of_strings (pa_array , dtype = data .dtype )
339
337
tm .assert_extension_array_equal (result , data )
340
338
341
-
342
- class TestGetitemTests (base .BaseGetitemTests ):
343
- pass
344
-
345
-
346
- class TestBaseAccumulateTests (base .BaseAccumulateTests ):
347
339
def check_accumulate (self , ser , op_name , skipna ):
348
340
result = getattr (ser , op_name )(skipna = skipna )
349
341
@@ -409,8 +401,6 @@ def test_accumulate_series(self, data, all_numeric_accumulations, skipna, reques
409
401
410
402
self .check_accumulate (ser , op_name , skipna )
411
403
412
-
413
- class TestReduce (base .BaseReduceTests ):
414
404
def _supports_reduction (self , obj , op_name : str ) -> bool :
415
405
dtype = tm .get_dtype (obj )
416
406
# error: Item "dtype[Any]" of "dtype[Any] | ExtensionDtype" has
@@ -561,8 +551,6 @@ def test_median_not_approximate(self, typ):
561
551
result = pd .Series ([1 , 2 ], dtype = f"{ typ } [pyarrow]" ).median ()
562
552
assert result == 1.5
563
553
564
-
565
- class TestBaseGroupby (base .BaseGroupbyTests ):
566
554
def test_in_numeric_groupby (self , data_for_grouping ):
567
555
dtype = data_for_grouping .dtype
568
556
if is_string_dtype (dtype ):
@@ -583,8 +571,6 @@ def test_in_numeric_groupby(self, data_for_grouping):
583
571
else :
584
572
super ().test_in_numeric_groupby (data_for_grouping )
585
573
586
-
587
- class TestBaseDtype (base .BaseDtypeTests ):
588
574
def test_construct_from_string_own_name (self , dtype , request ):
589
575
pa_dtype = dtype .pyarrow_dtype
590
576
if pa .types .is_decimal (pa_dtype ):
@@ -651,20 +637,12 @@ def test_is_not_string_type(self, dtype):
651
637
else :
652
638
super ().test_is_not_string_type (dtype )
653
639
654
-
655
- class TestBaseIndex (base .BaseIndexTests ):
656
- pass
657
-
658
-
659
- class TestBaseInterface (base .BaseInterfaceTests ):
660
640
@pytest .mark .xfail (
661
641
reason = "GH 45419: pyarrow.ChunkedArray does not support views." , run = False
662
642
)
663
643
def test_view (self , data ):
664
644
super ().test_view (data )
665
645
666
-
667
- class TestBaseMissing (base .BaseMissingTests ):
668
646
def test_fillna_no_op_returns_copy (self , data ):
669
647
data = data [~ data .isna ()]
670
648
@@ -677,28 +655,18 @@ def test_fillna_no_op_returns_copy(self, data):
677
655
assert result is not data
678
656
tm .assert_extension_array_equal (result , data )
679
657
680
-
681
- class TestBasePrinting (base .BasePrintingTests ):
682
- pass
683
-
684
-
685
- class TestBaseReshaping (base .BaseReshapingTests ):
686
658
@pytest .mark .xfail (
687
659
reason = "GH 45419: pyarrow.ChunkedArray does not support views" , run = False
688
660
)
689
661
def test_transpose (self , data ):
690
662
super ().test_transpose (data )
691
663
692
-
693
- class TestBaseSetitem (base .BaseSetitemTests ):
694
664
@pytest .mark .xfail (
695
665
reason = "GH 45419: pyarrow.ChunkedArray does not support views" , run = False
696
666
)
697
667
def test_setitem_preserves_views (self , data ):
698
668
super ().test_setitem_preserves_views (data )
699
669
700
-
701
- class TestBaseParsing (base .BaseParsingTests ):
702
670
@pytest .mark .parametrize ("dtype_backend" , ["pyarrow" , no_default ])
703
671
@pytest .mark .parametrize ("engine" , ["c" , "python" ])
704
672
def test_EA_types (self , engine , data , dtype_backend , request ):
@@ -736,8 +704,6 @@ def test_EA_types(self, engine, data, dtype_backend, request):
736
704
expected = df
737
705
tm .assert_frame_equal (result , expected )
738
706
739
-
740
- class TestBaseUnaryOps (base .BaseUnaryOpsTests ):
741
707
def test_invert (self , data , request ):
742
708
pa_dtype = data .dtype .pyarrow_dtype
743
709
if not pa .types .is_boolean (pa_dtype ):
@@ -749,8 +715,6 @@ def test_invert(self, data, request):
749
715
)
750
716
super ().test_invert (data )
751
717
752
-
753
- class TestBaseMethods (base .BaseMethodsTests ):
754
718
@pytest .mark .parametrize ("periods" , [1 , - 2 ])
755
719
def test_diff (self , data , periods , request ):
756
720
pa_dtype = data .dtype .pyarrow_dtype
@@ -814,8 +778,6 @@ def test_argreduce_series(
814
778
815
779
_combine_le_expected_dtype = "bool[pyarrow]"
816
780
817
-
818
- class TestBaseArithmeticOps (base .BaseArithmeticOpsTests ):
819
781
divmod_exc = NotImplementedError
820
782
821
783
def get_op_from_name (self , op_name ):
@@ -838,6 +800,9 @@ def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result):
838
800
# while ArrowExtensionArray maintains original type
839
801
expected = pointwise_result
840
802
803
+ if op_name in ["eq" , "ne" , "lt" , "le" , "gt" , "ge" ]:
804
+ return pointwise_result .astype ("boolean[pyarrow]" )
805
+
841
806
was_frame = False
842
807
if isinstance (expected , pd .DataFrame ):
843
808
was_frame = True
@@ -1121,28 +1086,6 @@ def test_add_series_with_extension_array(self, data, request):
1121
1086
)
1122
1087
super ().test_add_series_with_extension_array (data )
1123
1088
1124
-
1125
- class TestBaseComparisonOps (base .BaseComparisonOpsTests ):
1126
- def test_compare_array (self , data , comparison_op , na_value ):
1127
- ser = pd .Series (data )
1128
- # pd.Series([ser.iloc[0]] * len(ser)) may not return ArrowExtensionArray
1129
- # since ser.iloc[0] is a python scalar
1130
- other = pd .Series (pd .array ([ser .iloc [0 ]] * len (ser ), dtype = data .dtype ))
1131
- if comparison_op .__name__ in ["eq" , "ne" ]:
1132
- # comparison should match point-wise comparisons
1133
- result = comparison_op (ser , other )
1134
- # Series.combine does not calculate the NA mask correctly
1135
- # when comparing over an array
1136
- assert result [8 ] is na_value
1137
- assert result [97 ] is na_value
1138
- expected = ser .combine (other , comparison_op )
1139
- expected [8 ] = na_value
1140
- expected [97 ] = na_value
1141
- tm .assert_series_equal (result , expected )
1142
-
1143
- else :
1144
- return super ().test_compare_array (data , comparison_op )
1145
-
1146
1089
def test_invalid_other_comp (self , data , comparison_op ):
1147
1090
# GH 48833
1148
1091
with pytest .raises (
0 commit comments