diff --git a/pandas/core/indexes/interval.py b/pandas/core/indexes/interval.py index 93117fbc22752..cc47740dba5f2 100644 --- a/pandas/core/indexes/interval.py +++ b/pandas/core/indexes/interval.py @@ -1,4 +1,5 @@ """ define the IntervalIndex """ +from functools import wraps from operator import le, lt import textwrap from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union, cast @@ -112,43 +113,41 @@ def _new_IntervalIndex(cls, d): return cls.from_arrays(**d) -class SetopCheck: +def setop_check(method): """ This is called to decorate the set operations of IntervalIndex to perform the type check in advance. """ + op_name = method.__name__ - def __init__(self, op_name): - self.op_name = op_name - - def __call__(self, setop): - def func(intvidx_self, other, sort=False): - intvidx_self._assert_can_do_setop(other) - other = ensure_index(other) - - if not isinstance(other, IntervalIndex): - result = getattr(intvidx_self.astype(object), self.op_name)(other) - if self.op_name in ("difference",): - result = result.astype(intvidx_self.dtype) - return result - elif intvidx_self.closed != other.closed: - raise ValueError( - "can only do set operations between two IntervalIndex " - "objects that are closed on the same side" - ) + @wraps(method) + def wrapped(self, other, sort=False): + self._assert_can_do_setop(other) + other = ensure_index(other) - # GH 19016: ensure set op will not return a prohibited dtype - subtypes = [intvidx_self.dtype.subtype, other.dtype.subtype] - common_subtype = find_common_type(subtypes) - if is_object_dtype(common_subtype): - raise TypeError( - f"can only do {self.op_name} between two IntervalIndex " - "objects that have compatible dtypes" - ) + if not isinstance(other, IntervalIndex): + result = getattr(self.astype(object), op_name)(other) + if op_name in ("difference",): + result = result.astype(self.dtype) + return result + elif self.closed != other.closed: + raise ValueError( + "can only do set operations between two IntervalIndex " + "objects that are closed on the same side" + ) + + # GH 19016: ensure set op will not return a prohibited dtype + subtypes = [self.dtype.subtype, other.dtype.subtype] + common_subtype = find_common_type(subtypes) + if is_object_dtype(common_subtype): + raise TypeError( + f"can only do {op_name} between two IntervalIndex " + "objects that have compatible dtypes" + ) - return setop(intvidx_self, other, sort) + return method(self, other, sort) - return func + return wrapped @Appender( @@ -1006,7 +1005,7 @@ def equals(self, other: object) -> bool: # Set Operations @Appender(Index.intersection.__doc__) - @SetopCheck(op_name="intersection") + @setop_check def intersection( self, other: "IntervalIndex", sort: bool = False ) -> "IntervalIndex": @@ -1075,7 +1074,6 @@ def _intersection_non_unique(self, other: "IntervalIndex") -> "IntervalIndex": return self[mask] def _setop(op_name: str, sort=None): - @SetopCheck(op_name=op_name) def func(self, other, sort=sort): result = getattr(self._multiindex, op_name)(other._multiindex, sort=sort) result_name = get_op_result_name(self, other) @@ -1088,7 +1086,8 @@ def func(self, other, sort=sort): return type(self).from_tuples(result, closed=self.closed, name=result_name) - return func + func.__name__ = op_name + return setop_check(func) union = _setop("union") difference = _setop("difference")