@@ -697,12 +697,28 @@ def _assert_can_do_setop(self, other):
697
697
if isinstance (other , PeriodIndex ) and self .freq != other .freq :
698
698
raise raise_on_incompatible (self , other )
699
699
700
- def intersection (self , other , sort = False ):
700
+ def _setop (self , other , sort , opname : str ):
701
+ """
702
+ Perform a set operation by dispatching to the Int64Index implementation.
703
+ """
701
704
self ._validate_sort_keyword (sort )
702
705
self ._assert_can_do_setop (other )
703
706
res_name = get_op_result_name (self , other )
704
707
other = ensure_index (other )
705
708
709
+ i8self = Int64Index ._simple_new (self .asi8 )
710
+ i8other = Int64Index ._simple_new (other .asi8 )
711
+ i8result = getattr (i8self , opname )(i8other , sort = sort )
712
+
713
+ parr = type (self ._data )(np .asarray (i8result , dtype = np .int64 ), dtype = self .dtype )
714
+ result = type (self )._simple_new (parr , name = res_name )
715
+ return result
716
+
717
+ def intersection (self , other , sort = False ):
718
+ self ._validate_sort_keyword (sort )
719
+ self ._assert_can_do_setop (other )
720
+ other = ensure_index (other )
721
+
706
722
if self .equals (other ):
707
723
return self ._get_reconciled_name_object (other )
708
724
@@ -712,35 +728,24 @@ def intersection(self, other, sort=False):
712
728
other = other .astype ("O" )
713
729
return this .intersection (other , sort = sort )
714
730
715
- i8self = Int64Index ._simple_new (self .asi8 )
716
- i8other = Int64Index ._simple_new (other .asi8 )
717
- i8result = i8self .intersection (i8other , sort = sort )
718
-
719
- result = self ._shallow_copy (np .asarray (i8result , dtype = np .int64 ), name = res_name )
720
- return result
731
+ return self ._setop (other , sort , opname = "intersection" )
721
732
722
733
def difference (self , other , sort = None ):
723
734
self ._validate_sort_keyword (sort )
724
735
self ._assert_can_do_setop (other )
725
- res_name = get_op_result_name (self , other )
726
736
other = ensure_index (other )
727
737
728
738
if self .equals (other ):
729
739
# pass an empty PeriodArray with the appropriate dtype
730
- return self . _shallow_copy (self ._data [:0 ])
740
+ return type ( self ). _simple_new (self ._data [:0 ], name = self . name )
731
741
732
742
if is_object_dtype (other ):
733
743
return self .astype (object ).difference (other ).astype (self .dtype )
734
744
735
745
elif not is_dtype_equal (self .dtype , other .dtype ):
736
746
return self
737
747
738
- i8self = Int64Index ._simple_new (self .asi8 )
739
- i8other = Int64Index ._simple_new (other .asi8 )
740
- i8result = i8self .difference (i8other , sort = sort )
741
-
742
- result = self ._shallow_copy (np .asarray (i8result , dtype = np .int64 ), name = res_name )
743
- return result
748
+ return self ._setop (other , sort , opname = "difference" )
744
749
745
750
def _union (self , other , sort ):
746
751
if not len (other ) or self .equals (other ) or not len (self ):
@@ -754,13 +759,7 @@ def _union(self, other, sort):
754
759
other = other .astype ("O" )
755
760
return this ._union (other , sort = sort )
756
761
757
- i8self = Int64Index ._simple_new (self .asi8 )
758
- i8other = Int64Index ._simple_new (other .asi8 )
759
- i8result = i8self ._union (i8other , sort = sort )
760
-
761
- res_name = get_op_result_name (self , other )
762
- result = self ._shallow_copy (np .asarray (i8result , dtype = np .int64 ), name = res_name )
763
- return result
762
+ return self ._setop (other , sort , opname = "_union" )
764
763
765
764
# ------------------------------------------------------------------------
766
765
0 commit comments