@@ -795,6 +795,140 @@ def test_non_contiguous(self, closed):
795
795
796
796
assert 1.5 not in index
797
797
798
+ @pytest .mark .parametrize ("sort" , [None , False ])
799
+ def test_union (self , closed , sort ):
800
+ index = self .create_index (closed = closed )
801
+ other = IntervalIndex .from_breaks (range (5 , 13 ), closed = closed )
802
+
803
+ expected = IntervalIndex .from_breaks (range (13 ), closed = closed )
804
+ result = index [::- 1 ].union (other , sort = sort )
805
+ if sort is None :
806
+ tm .assert_index_equal (result , expected )
807
+ assert tm .equalContents (result , expected )
808
+
809
+ result = other [::- 1 ].union (index , sort = sort )
810
+ if sort is None :
811
+ tm .assert_index_equal (result , expected )
812
+ assert tm .equalContents (result , expected )
813
+
814
+ tm .assert_index_equal (index .union (index , sort = sort ), index )
815
+ tm .assert_index_equal (index .union (index [:1 ], sort = sort ), index )
816
+
817
+ # GH 19101: empty result, same dtype
818
+ index = IntervalIndex (np .array ([], dtype = 'int64' ), closed = closed )
819
+ result = index .union (index , sort = sort )
820
+ tm .assert_index_equal (result , index )
821
+
822
+ # GH 19101: empty result, different dtypes
823
+ other = IntervalIndex (np .array ([], dtype = 'float64' ), closed = closed )
824
+ result = index .union (other , sort = sort )
825
+ tm .assert_index_equal (result , index )
826
+
827
+ @pytest .mark .parametrize ("sort" , [None , False ])
828
+ def test_intersection (self , closed , sort ):
829
+ index = self .create_index (closed = closed )
830
+ other = IntervalIndex .from_breaks (range (5 , 13 ), closed = closed )
831
+
832
+ expected = IntervalIndex .from_breaks (range (5 , 11 ), closed = closed )
833
+ result = index [::- 1 ].intersection (other , sort = sort )
834
+ if sort is None :
835
+ tm .assert_index_equal (result , expected )
836
+ assert tm .equalContents (result , expected )
837
+
838
+ result = other [::- 1 ].intersection (index , sort = sort )
839
+ if sort is None :
840
+ tm .assert_index_equal (result , expected )
841
+ assert tm .equalContents (result , expected )
842
+
843
+ tm .assert_index_equal (index .intersection (index , sort = sort ), index )
844
+
845
+ # GH 19101: empty result, same dtype
846
+ other = IntervalIndex .from_breaks (range (300 , 314 ), closed = closed )
847
+ expected = IntervalIndex (np .array ([], dtype = 'int64' ), closed = closed )
848
+ result = index .intersection (other , sort = sort )
849
+ tm .assert_index_equal (result , expected )
850
+
851
+ # GH 19101: empty result, different dtypes
852
+ breaks = np .arange (300 , 314 , dtype = 'float64' )
853
+ other = IntervalIndex .from_breaks (breaks , closed = closed )
854
+ result = index .intersection (other , sort = sort )
855
+ tm .assert_index_equal (result , expected )
856
+
857
+ @pytest .mark .parametrize ("sort" , [None , False ])
858
+ def test_difference (self , closed , sort ):
859
+ index = IntervalIndex .from_arrays ([1 , 0 , 3 , 2 ],
860
+ [1 , 2 , 3 , 4 ],
861
+ closed = closed )
862
+ result = index .difference (index [:1 ], sort = sort )
863
+ expected = index [1 :]
864
+ if sort is None :
865
+ expected = expected .sort_values ()
866
+ tm .assert_index_equal (result , expected )
867
+
868
+ # GH 19101: empty result, same dtype
869
+ result = index .difference (index , sort = sort )
870
+ expected = IntervalIndex (np .array ([], dtype = 'int64' ), closed = closed )
871
+ tm .assert_index_equal (result , expected )
872
+
873
+ # GH 19101: empty result, different dtypes
874
+ other = IntervalIndex .from_arrays (index .left .astype ('float64' ),
875
+ index .right , closed = closed )
876
+ result = index .difference (other , sort = sort )
877
+ tm .assert_index_equal (result , expected )
878
+
879
+ @pytest .mark .parametrize ("sort" , [None , False ])
880
+ def test_symmetric_difference (self , closed , sort ):
881
+ index = self .create_index (closed = closed )
882
+ result = index [1 :].symmetric_difference (index [:- 1 ], sort = sort )
883
+ expected = IntervalIndex ([index [0 ], index [- 1 ]])
884
+ if sort is None :
885
+ tm .assert_index_equal (result , expected )
886
+ assert tm .equalContents (result , expected )
887
+
888
+ # GH 19101: empty result, same dtype
889
+ result = index .symmetric_difference (index , sort = sort )
890
+ expected = IntervalIndex (np .array ([], dtype = 'int64' ), closed = closed )
891
+ if sort is None :
892
+ tm .assert_index_equal (result , expected )
893
+ assert tm .equalContents (result , expected )
894
+
895
+ # GH 19101: empty result, different dtypes
896
+ other = IntervalIndex .from_arrays (index .left .astype ('float64' ),
897
+ index .right , closed = closed )
898
+ result = index .symmetric_difference (other , sort = sort )
899
+ tm .assert_index_equal (result , expected )
900
+
901
+ @pytest .mark .parametrize ('op_name' , [
902
+ 'union' , 'intersection' , 'difference' , 'symmetric_difference' ])
903
+ @pytest .mark .parametrize ("sort" , [None , False ])
904
+ def test_set_incompatible_types (self , closed , op_name , sort ):
905
+ index = self .create_index (closed = closed )
906
+ set_op = getattr (index , op_name )
907
+
908
+ # TODO: standardize return type of non-union setops type(self vs other)
909
+ # non-IntervalIndex
910
+ if op_name == 'difference' :
911
+ expected = index
912
+ else :
913
+ expected = getattr (index .astype ('O' ), op_name )(Index ([1 , 2 , 3 ]))
914
+ result = set_op (Index ([1 , 2 , 3 ]), sort = sort )
915
+ tm .assert_index_equal (result , expected )
916
+
917
+ # mixed closed
918
+ msg = ('can only do set operations between two IntervalIndex objects '
919
+ 'that are closed on the same side' )
920
+ for other_closed in {'right' , 'left' , 'both' , 'neither' } - {closed }:
921
+ other = self .create_index (closed = other_closed )
922
+ with pytest .raises (ValueError , match = msg ):
923
+ set_op (other , sort = sort )
924
+
925
+ # GH 19016: incompatible dtypes
926
+ other = interval_range (Timestamp ('20180101' ), periods = 9 , closed = closed )
927
+ msg = ('can only do {op} between two IntervalIndex objects that have '
928
+ 'compatible dtypes' ).format (op = op_name )
929
+ with pytest .raises (TypeError , match = msg ):
930
+ set_op (other , sort = sort )
931
+
798
932
def test_isin (self , closed ):
799
933
index = self .create_index (closed = closed )
800
934
0 commit comments