Skip to content

REF: de-duplicate IntervalIndex setops #41832

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3024,7 +3024,7 @@ def _union(self, other: Index, sort):

# Self may have duplicates
# find indexes of things in "other" that are not in "self"
if self.is_unique:
if self._index_as_unique:
indexer = self.get_indexer(other)
missing = (indexer == -1).nonzero()[0]
else:
Expand Down Expand Up @@ -3196,14 +3196,18 @@ def difference(self, other, sort=None):
# Note: we do not (yet) sort even if sort=None GH#24959
return self.rename(result_name)

if not self._should_compare(other):
# Nothing matches -> difference is everything
return self.rename(result_name)

result = self._difference(other, sort=sort)
return self._wrap_setop_result(other, result)

def _difference(self, other, sort):

this = self._get_unique_index()

indexer = this.get_indexer(other)
indexer = this.get_indexer_for(other)
indexer = indexer.take((indexer != -1).nonzero()[0])

label_diff = np.setdiff1d(np.arange(this.size), indexer, assume_unique=True)
Expand Down
90 changes: 3 additions & 87 deletions pandas/core/indexes/interval.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
""" define the IntervalIndex """
from __future__ import annotations

from functools import wraps
from operator import (
le,
lt,
Expand Down Expand Up @@ -63,10 +62,7 @@
)
from pandas.core.dtypes.dtypes import IntervalDtype

from pandas.core.algorithms import (
take_nd,
unique,
)
from pandas.core.algorithms import take_nd
from pandas.core.arrays.interval import (
IntervalArray,
_interval_shared_docs,
Expand All @@ -93,7 +89,6 @@
TimedeltaIndex,
timedelta_range,
)
from pandas.core.ops import get_op_result_name

if TYPE_CHECKING:
from pandas import CategoricalIndex
Expand Down Expand Up @@ -151,59 +146,6 @@ def _new_IntervalIndex(cls, d):
return cls.from_arrays(**d)


def setop_check(method):
"""
This is called to decorate the set operations of IntervalIndex
to perform the type check in advance.
"""
op_name = method.__name__

@wraps(method)
def wrapped(self, other, sort=False):
self._validate_sort_keyword(sort)
self._assert_can_do_setop(other)
other, result_name = self._convert_can_do_setop(other)

if op_name == "difference":
if not isinstance(other, IntervalIndex):
result = getattr(self.astype(object), op_name)(other, sort=sort)
return result.astype(self.dtype)

elif not self._should_compare(other):
# GH#19016: ensure set op will not return a prohibited dtype
result = getattr(self.astype(object), op_name)(other, sort=sort)
return result.astype(self.dtype)

return method(self, other, sort)

return wrapped


def _setop(op_name: str):
"""
Implement set operation.
"""

def func(self, other, sort=None):
# At this point we are assured
# isinstance(other, IntervalIndex)
# other.closed == self.closed

result = getattr(self._multiindex, op_name)(other._multiindex, sort=sort)
result_name = get_op_result_name(self, other)

# GH 19101: ensure empty results have correct dtype
if result.empty:
result = result._values.astype(self.dtype.subtype)
else:
result = result._values

return type(self).from_tuples(result, closed=self.closed, name=result_name)

func.__name__ = op_name
return setop_check(func)


@Appender(
_interval_shared_docs["class"]
% {
Expand Down Expand Up @@ -861,11 +803,11 @@ def _intersection(self, other, sort):
"""
# For IntervalIndex we also know other.closed == self.closed
if self.left.is_unique and self.right.is_unique:
taken = self._intersection_unique(other)
return super()._intersection(other, sort=sort)
elif other.left.is_unique and other.right.is_unique and self.isna().sum() <= 1:
# Swap other/self if other is unique and self does not have
# multiple NaNs
taken = other._intersection_unique(self)
return super()._intersection(other, sort=sort)
else:
# duplicates
taken = self._intersection_non_unique(other)
Expand All @@ -875,29 +817,6 @@ def _intersection(self, other, sort):

return taken

def _intersection_unique(self, other: IntervalIndex) -> IntervalIndex:
"""
Used when the IntervalIndex does not have any common endpoint,
no matter left or right.
Return the intersection with another IntervalIndex.

Parameters
----------
other : IntervalIndex

Returns
-------
IntervalIndex
"""
lindexer = self.left.get_indexer(other.left)
rindexer = self.right.get_indexer(other.right)

match = (lindexer == rindexer) & (lindexer != -1)
indexer = lindexer.take(match.nonzero()[0])
indexer = unique(indexer)

return self.take(indexer)

def _intersection_non_unique(self, other: IntervalIndex) -> IntervalIndex:
"""
Used when the IntervalIndex does have some common endpoints,
Expand Down Expand Up @@ -925,9 +844,6 @@ def _intersection_non_unique(self, other: IntervalIndex) -> IntervalIndex:

return self[mask]

_union = _setop("union")
_difference = _setop("difference")

# --------------------------------------------------------------------

@property
Expand Down