1
1
""" define the IntervalIndex """
2
+ from functools import wraps
2
3
from operator import le , lt
3
4
import textwrap
4
5
from typing import TYPE_CHECKING , Any , List , Optional , Tuple , Union , cast
@@ -112,43 +113,41 @@ def _new_IntervalIndex(cls, d):
112
113
return cls .from_arrays (** d )
113
114
114
115
115
- class SetopCheck :
116
+ def setop_check ( method ) :
116
117
"""
117
118
This is called to decorate the set operations of IntervalIndex
118
119
to perform the type check in advance.
119
120
"""
121
+ op_name = method .__name__
120
122
121
- def __init__ (self , op_name ):
122
- self .op_name = op_name
123
-
124
- def __call__ (self , setop ):
125
- def func (intvidx_self , other , sort = False ):
126
- intvidx_self ._assert_can_do_setop (other )
127
- other = ensure_index (other )
128
-
129
- if not isinstance (other , IntervalIndex ):
130
- result = getattr (intvidx_self .astype (object ), self .op_name )(other )
131
- if self .op_name in ("difference" ,):
132
- result = result .astype (intvidx_self .dtype )
133
- return result
134
- elif intvidx_self .closed != other .closed :
135
- raise ValueError (
136
- "can only do set operations between two IntervalIndex "
137
- "objects that are closed on the same side"
138
- )
123
+ @wraps (method )
124
+ def wrapped (self , other , sort = False ):
125
+ self ._assert_can_do_setop (other )
126
+ other = ensure_index (other )
139
127
140
- # GH 19016: ensure set op will not return a prohibited dtype
141
- subtypes = [intvidx_self .dtype .subtype , other .dtype .subtype ]
142
- common_subtype = find_common_type (subtypes )
143
- if is_object_dtype (common_subtype ):
144
- raise TypeError (
145
- f"can only do { self .op_name } between two IntervalIndex "
146
- "objects that have compatible dtypes"
147
- )
128
+ if not isinstance (other , IntervalIndex ):
129
+ result = getattr (self .astype (object ), op_name )(other )
130
+ if op_name in ("difference" ,):
131
+ result = result .astype (self .dtype )
132
+ return result
133
+ elif self .closed != other .closed :
134
+ raise ValueError (
135
+ "can only do set operations between two IntervalIndex "
136
+ "objects that are closed on the same side"
137
+ )
138
+
139
+ # GH 19016: ensure set op will not return a prohibited dtype
140
+ subtypes = [self .dtype .subtype , other .dtype .subtype ]
141
+ common_subtype = find_common_type (subtypes )
142
+ if is_object_dtype (common_subtype ):
143
+ raise TypeError (
144
+ f"can only do { op_name } between two IntervalIndex "
145
+ "objects that have compatible dtypes"
146
+ )
148
147
149
- return setop ( intvidx_self , other , sort )
148
+ return method ( self , other , sort )
150
149
151
- return func
150
+ return wrapped
152
151
153
152
154
153
@Appender (
@@ -1006,7 +1005,7 @@ def equals(self, other: object) -> bool:
1006
1005
# Set Operations
1007
1006
1008
1007
@Appender (Index .intersection .__doc__ )
1009
- @SetopCheck ( op_name = "intersection" )
1008
+ @setop_check
1010
1009
def intersection (
1011
1010
self , other : "IntervalIndex" , sort : bool = False
1012
1011
) -> "IntervalIndex" :
@@ -1075,7 +1074,6 @@ def _intersection_non_unique(self, other: "IntervalIndex") -> "IntervalIndex":
1075
1074
return self [mask ]
1076
1075
1077
1076
def _setop (op_name : str , sort = None ):
1078
- @SetopCheck (op_name = op_name )
1079
1077
def func (self , other , sort = sort ):
1080
1078
result = getattr (self ._multiindex , op_name )(other ._multiindex , sort = sort )
1081
1079
result_name = get_op_result_name (self , other )
@@ -1088,7 +1086,8 @@ def func(self, other, sort=sort):
1088
1086
1089
1087
return type (self ).from_tuples (result , closed = self .closed , name = result_name )
1090
1088
1091
- return func
1089
+ func .__name__ = op_name
1090
+ return setop_check (func )
1092
1091
1093
1092
union = _setop ("union" )
1094
1093
difference = _setop ("difference" )
0 commit comments