1
1
""" define the IntervalIndex """
2
2
from __future__ import annotations
3
3
4
- from functools import wraps
5
4
from operator import (
6
5
le ,
7
6
lt ,
63
62
)
64
63
from pandas .core .dtypes .dtypes import IntervalDtype
65
64
66
- from pandas .core .algorithms import (
67
- take_nd ,
68
- unique ,
69
- )
65
+ from pandas .core .algorithms import take_nd
70
66
from pandas .core .arrays .interval import (
71
67
IntervalArray ,
72
68
_interval_shared_docs ,
93
89
TimedeltaIndex ,
94
90
timedelta_range ,
95
91
)
96
- from pandas .core .ops import get_op_result_name
97
92
98
93
if TYPE_CHECKING :
99
94
from pandas import CategoricalIndex
@@ -151,59 +146,6 @@ def _new_IntervalIndex(cls, d):
151
146
return cls .from_arrays (** d )
152
147
153
148
154
- def setop_check (method ):
155
- """
156
- This is called to decorate the set operations of IntervalIndex
157
- to perform the type check in advance.
158
- """
159
- op_name = method .__name__
160
-
161
- @wraps (method )
162
- def wrapped (self , other , sort = False ):
163
- self ._validate_sort_keyword (sort )
164
- self ._assert_can_do_setop (other )
165
- other , result_name = self ._convert_can_do_setop (other )
166
-
167
- if op_name == "difference" :
168
- if not isinstance (other , IntervalIndex ):
169
- result = getattr (self .astype (object ), op_name )(other , sort = sort )
170
- return result .astype (self .dtype )
171
-
172
- elif not self ._should_compare (other ):
173
- # GH#19016: ensure set op will not return a prohibited dtype
174
- result = getattr (self .astype (object ), op_name )(other , sort = sort )
175
- return result .astype (self .dtype )
176
-
177
- return method (self , other , sort )
178
-
179
- return wrapped
180
-
181
-
182
- def _setop (op_name : str ):
183
- """
184
- Implement set operation.
185
- """
186
-
187
- def func (self , other , sort = None ):
188
- # At this point we are assured
189
- # isinstance(other, IntervalIndex)
190
- # other.closed == self.closed
191
-
192
- result = getattr (self ._multiindex , op_name )(other ._multiindex , sort = sort )
193
- result_name = get_op_result_name (self , other )
194
-
195
- # GH 19101: ensure empty results have correct dtype
196
- if result .empty :
197
- result = result ._values .astype (self .dtype .subtype )
198
- else :
199
- result = result ._values
200
-
201
- return type (self ).from_tuples (result , closed = self .closed , name = result_name )
202
-
203
- func .__name__ = op_name
204
- return setop_check (func )
205
-
206
-
207
149
@Appender (
208
150
_interval_shared_docs ["class" ]
209
151
% {
@@ -859,11 +801,11 @@ def _intersection(self, other, sort):
859
801
"""
860
802
# For IntervalIndex we also know other.closed == self.closed
861
803
if self .left .is_unique and self .right .is_unique :
862
- taken = self . _intersection_unique (other )
804
+ return super (). _intersection (other , sort = sort )
863
805
elif other .left .is_unique and other .right .is_unique and self .isna ().sum () <= 1 :
864
806
# Swap other/self if other is unique and self does not have
865
807
# multiple NaNs
866
- taken = other . _intersection_unique ( self )
808
+ return super (). _intersection ( other , sort = sort )
867
809
else :
868
810
# duplicates
869
811
taken = self ._intersection_non_unique (other )
@@ -873,29 +815,6 @@ def _intersection(self, other, sort):
873
815
874
816
return taken
875
817
876
- def _intersection_unique (self , other : IntervalIndex ) -> IntervalIndex :
877
- """
878
- Used when the IntervalIndex does not have any common endpoint,
879
- no matter left or right.
880
- Return the intersection with another IntervalIndex.
881
-
882
- Parameters
883
- ----------
884
- other : IntervalIndex
885
-
886
- Returns
887
- -------
888
- IntervalIndex
889
- """
890
- lindexer = self .left .get_indexer (other .left )
891
- rindexer = self .right .get_indexer (other .right )
892
-
893
- match = (lindexer == rindexer ) & (lindexer != - 1 )
894
- indexer = lindexer .take (match .nonzero ()[0 ])
895
- indexer = unique (indexer )
896
-
897
- return self .take (indexer )
898
-
899
818
def _intersection_non_unique (self , other : IntervalIndex ) -> IntervalIndex :
900
819
"""
901
820
Used when the IntervalIndex does have some common endpoints,
@@ -923,9 +842,6 @@ def _intersection_non_unique(self, other: IntervalIndex) -> IntervalIndex:
923
842
924
843
return self [mask ]
925
844
926
- _union = _setop ("union" )
927
- _difference = _setop ("difference" )
928
-
929
845
# --------------------------------------------------------------------
930
846
931
847
@property
0 commit comments