Skip to content

Commit 94756fe

Browse files
authored
REF: de-duplicate IntervalIndex setops (#41832)
1 parent e34338a commit 94756fe

File tree

2 files changed

+9
-89
lines changed

2 files changed

+9
-89
lines changed

pandas/core/indexes/base.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -3013,7 +3013,7 @@ def _union(self, other: Index, sort):
30133013

30143014
# Self may have duplicates
30153015
# find indexes of things in "other" that are not in "self"
3016-
if self.is_unique:
3016+
if self._index_as_unique:
30173017
indexer = self.get_indexer(other)
30183018
missing = (indexer == -1).nonzero()[0]
30193019
else:
@@ -3171,14 +3171,18 @@ def difference(self, other, sort=None):
31713171
# Note: we do not (yet) sort even if sort=None GH#24959
31723172
return self.rename(result_name)
31733173

3174+
if not self._should_compare(other):
3175+
# Nothing matches -> difference is everything
3176+
return self.rename(result_name)
3177+
31743178
result = self._difference(other, sort=sort)
31753179
return self._wrap_setop_result(other, result)
31763180

31773181
def _difference(self, other, sort):
31783182

31793183
this = self._get_unique_index()
31803184

3181-
indexer = this.get_indexer(other)
3185+
indexer = this.get_indexer_for(other)
31823186
indexer = indexer.take((indexer != -1).nonzero()[0])
31833187

31843188
label_diff = np.setdiff1d(np.arange(this.size), indexer, assume_unique=True)

pandas/core/indexes/interval.py

+3-87
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
""" define the IntervalIndex """
22
from __future__ import annotations
33

4-
from functools import wraps
54
from operator import (
65
le,
76
lt,
@@ -63,10 +62,7 @@
6362
)
6463
from pandas.core.dtypes.dtypes import IntervalDtype
6564

66-
from pandas.core.algorithms import (
67-
take_nd,
68-
unique,
69-
)
65+
from pandas.core.algorithms import take_nd
7066
from pandas.core.arrays.interval import (
7167
IntervalArray,
7268
_interval_shared_docs,
@@ -93,7 +89,6 @@
9389
TimedeltaIndex,
9490
timedelta_range,
9591
)
96-
from pandas.core.ops import get_op_result_name
9792

9893
if TYPE_CHECKING:
9994
from pandas import CategoricalIndex
@@ -151,59 +146,6 @@ def _new_IntervalIndex(cls, d):
151146
return cls.from_arrays(**d)
152147

153148

154-
def setop_check(method):
155-
"""
156-
This is called to decorate the set operations of IntervalIndex
157-
to perform the type check in advance.
158-
"""
159-
op_name = method.__name__
160-
161-
@wraps(method)
162-
def wrapped(self, other, sort=False):
163-
self._validate_sort_keyword(sort)
164-
self._assert_can_do_setop(other)
165-
other, result_name = self._convert_can_do_setop(other)
166-
167-
if op_name == "difference":
168-
if not isinstance(other, IntervalIndex):
169-
result = getattr(self.astype(object), op_name)(other, sort=sort)
170-
return result.astype(self.dtype)
171-
172-
elif not self._should_compare(other):
173-
# GH#19016: ensure set op will not return a prohibited dtype
174-
result = getattr(self.astype(object), op_name)(other, sort=sort)
175-
return result.astype(self.dtype)
176-
177-
return method(self, other, sort)
178-
179-
return wrapped
180-
181-
182-
def _setop(op_name: str):
183-
"""
184-
Implement set operation.
185-
"""
186-
187-
def func(self, other, sort=None):
188-
# At this point we are assured
189-
# isinstance(other, IntervalIndex)
190-
# other.closed == self.closed
191-
192-
result = getattr(self._multiindex, op_name)(other._multiindex, sort=sort)
193-
result_name = get_op_result_name(self, other)
194-
195-
# GH 19101: ensure empty results have correct dtype
196-
if result.empty:
197-
result = result._values.astype(self.dtype.subtype)
198-
else:
199-
result = result._values
200-
201-
return type(self).from_tuples(result, closed=self.closed, name=result_name)
202-
203-
func.__name__ = op_name
204-
return setop_check(func)
205-
206-
207149
@Appender(
208150
_interval_shared_docs["class"]
209151
% {
@@ -859,11 +801,11 @@ def _intersection(self, other, sort):
859801
"""
860802
# For IntervalIndex we also know other.closed == self.closed
861803
if self.left.is_unique and self.right.is_unique:
862-
taken = self._intersection_unique(other)
804+
return super()._intersection(other, sort=sort)
863805
elif other.left.is_unique and other.right.is_unique and self.isna().sum() <= 1:
864806
# Swap other/self if other is unique and self does not have
865807
# multiple NaNs
866-
taken = other._intersection_unique(self)
808+
return super()._intersection(other, sort=sort)
867809
else:
868810
# duplicates
869811
taken = self._intersection_non_unique(other)
@@ -873,29 +815,6 @@ def _intersection(self, other, sort):
873815

874816
return taken
875817

876-
def _intersection_unique(self, other: IntervalIndex) -> IntervalIndex:
877-
"""
878-
Used when the IntervalIndex does not have any common endpoint,
879-
no matter left or right.
880-
Return the intersection with another IntervalIndex.
881-
882-
Parameters
883-
----------
884-
other : IntervalIndex
885-
886-
Returns
887-
-------
888-
IntervalIndex
889-
"""
890-
lindexer = self.left.get_indexer(other.left)
891-
rindexer = self.right.get_indexer(other.right)
892-
893-
match = (lindexer == rindexer) & (lindexer != -1)
894-
indexer = lindexer.take(match.nonzero()[0])
895-
indexer = unique(indexer)
896-
897-
return self.take(indexer)
898-
899818
def _intersection_non_unique(self, other: IntervalIndex) -> IntervalIndex:
900819
"""
901820
Used when the IntervalIndex does have some common endpoints,
@@ -923,9 +842,6 @@ def _intersection_non_unique(self, other: IntervalIndex) -> IntervalIndex:
923842

924843
return self[mask]
925844

926-
_union = _setop("union")
927-
_difference = _setop("difference")
928-
929845
# --------------------------------------------------------------------
930846

931847
@property

0 commit comments

Comments
 (0)