Skip to content

Commit e9ac7af

Browse files
authored
TYP: IntervalIndex.SetopCheck (#36995)
1 parent fcbdb7d commit e9ac7af

File tree

1 file changed

+31
-32
lines changed

1 file changed

+31
-32
lines changed

pandas/core/indexes/interval.py

+31-32
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
""" define the IntervalIndex """
2+
from functools import wraps
23
from operator import le, lt
34
import textwrap
45
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union, cast
@@ -112,43 +113,41 @@ def _new_IntervalIndex(cls, d):
112113
return cls.from_arrays(**d)
113114

114115

115-
class SetopCheck:
116+
def setop_check(method):
116117
"""
117118
This is called to decorate the set operations of IntervalIndex
118119
to perform the type check in advance.
119120
"""
121+
op_name = method.__name__
120122

121-
def __init__(self, op_name):
122-
self.op_name = op_name
123-
124-
def __call__(self, setop):
125-
def func(intvidx_self, other, sort=False):
126-
intvidx_self._assert_can_do_setop(other)
127-
other = ensure_index(other)
128-
129-
if not isinstance(other, IntervalIndex):
130-
result = getattr(intvidx_self.astype(object), self.op_name)(other)
131-
if self.op_name in ("difference",):
132-
result = result.astype(intvidx_self.dtype)
133-
return result
134-
elif intvidx_self.closed != other.closed:
135-
raise ValueError(
136-
"can only do set operations between two IntervalIndex "
137-
"objects that are closed on the same side"
138-
)
123+
@wraps(method)
124+
def wrapped(self, other, sort=False):
125+
self._assert_can_do_setop(other)
126+
other = ensure_index(other)
139127

140-
# GH 19016: ensure set op will not return a prohibited dtype
141-
subtypes = [intvidx_self.dtype.subtype, other.dtype.subtype]
142-
common_subtype = find_common_type(subtypes)
143-
if is_object_dtype(common_subtype):
144-
raise TypeError(
145-
f"can only do {self.op_name} between two IntervalIndex "
146-
"objects that have compatible dtypes"
147-
)
128+
if not isinstance(other, IntervalIndex):
129+
result = getattr(self.astype(object), op_name)(other)
130+
if op_name in ("difference",):
131+
result = result.astype(self.dtype)
132+
return result
133+
elif self.closed != other.closed:
134+
raise ValueError(
135+
"can only do set operations between two IntervalIndex "
136+
"objects that are closed on the same side"
137+
)
138+
139+
# GH 19016: ensure set op will not return a prohibited dtype
140+
subtypes = [self.dtype.subtype, other.dtype.subtype]
141+
common_subtype = find_common_type(subtypes)
142+
if is_object_dtype(common_subtype):
143+
raise TypeError(
144+
f"can only do {op_name} between two IntervalIndex "
145+
"objects that have compatible dtypes"
146+
)
148147

149-
return setop(intvidx_self, other, sort)
148+
return method(self, other, sort)
150149

151-
return func
150+
return wrapped
152151

153152

154153
@Appender(
@@ -1006,7 +1005,7 @@ def equals(self, other: object) -> bool:
10061005
# Set Operations
10071006

10081007
@Appender(Index.intersection.__doc__)
1009-
@SetopCheck(op_name="intersection")
1008+
@setop_check
10101009
def intersection(
10111010
self, other: "IntervalIndex", sort: bool = False
10121011
) -> "IntervalIndex":
@@ -1075,7 +1074,6 @@ def _intersection_non_unique(self, other: "IntervalIndex") -> "IntervalIndex":
10751074
return self[mask]
10761075

10771076
def _setop(op_name: str, sort=None):
1078-
@SetopCheck(op_name=op_name)
10791077
def func(self, other, sort=sort):
10801078
result = getattr(self._multiindex, op_name)(other._multiindex, sort=sort)
10811079
result_name = get_op_result_name(self, other)
@@ -1088,7 +1086,8 @@ def func(self, other, sort=sort):
10881086

10891087
return type(self).from_tuples(result, closed=self.closed, name=result_name)
10901088

1091-
return func
1089+
func.__name__ = op_name
1090+
return setop_check(func)
10921091

10931092
union = _setop("union")
10941093
difference = _setop("difference")

0 commit comments

Comments
 (0)