Skip to content

Commit 9107f58

Browse files
authored
Auto backport of pr 42293 on 1.3.x (#42306)
1 parent 538d69a commit 9107f58

File tree

2 files changed

+70
-5
lines changed

2 files changed

+70
-5
lines changed

pandas/core/indexes/interval.py

+45
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
from pandas.core.dtypes.dtypes import IntervalDtype
6262
from pandas.core.dtypes.missing import is_valid_na_for_dtype
6363

64+
from pandas.core.algorithms import unique
6465
from pandas.core.arrays.interval import (
6566
IntervalArray,
6667
_interval_shared_docs,
@@ -792,6 +793,50 @@ def _format_data(self, name=None) -> str:
792793
# name argument is unused here; just for compat with base / categorical
793794
return self._data._format_data() + "," + self._format_space()
794795

796+
# --------------------------------------------------------------------
797+
# Set Operations
798+
799+
def _intersection(self, other, sort):
800+
"""
801+
intersection specialized to the case with matching dtypes.
802+
"""
803+
# For IntervalIndex we also know other.closed == self.closed
804+
if self.left.is_unique and self.right.is_unique:
805+
taken = self._intersection_unique(other)
806+
elif other.left.is_unique and other.right.is_unique and self.isna().sum() <= 1:
807+
# Swap other/self if other is unique and self does not have
808+
# multiple NaNs
809+
taken = other._intersection_unique(self)
810+
else:
811+
return super()._intersection(other, sort)
812+
813+
if sort is None:
814+
taken = taken.sort_values()
815+
816+
return taken
817+
818+
def _intersection_unique(self, other: IntervalIndex) -> IntervalIndex:
819+
"""
820+
Used when the IntervalIndex does not have any common endpoint,
821+
no matter left or right.
822+
Return the intersection with another IntervalIndex.
823+
Parameters
824+
----------
825+
other : IntervalIndex
826+
Returns
827+
-------
828+
IntervalIndex
829+
"""
830+
# Note: this is much more performant than super()._intersection(other)
831+
lindexer = self.left.get_indexer(other.left)
832+
rindexer = self.right.get_indexer(other.right)
833+
834+
match = (lindexer == rindexer) & (lindexer != -1)
835+
indexer = lindexer.take(match.nonzero()[0])
836+
indexer = unique(indexer)
837+
838+
return self.take(indexer)
839+
795840
# --------------------------------------------------------------------
796841

797842
@property

pandas/core/indexes/period.py

+25-5
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
)
1616
from pandas._libs.tslibs import (
1717
BaseOffset,
18+
NaT,
1819
Period,
1920
Resolution,
2021
Tick,
@@ -37,6 +38,7 @@
3738
pandas_dtype,
3839
)
3940
from pandas.core.dtypes.dtypes import PeriodDtype
41+
from pandas.core.dtypes.missing import is_valid_na_for_dtype
4042

4143
from pandas.core.arrays.period import (
4244
PeriodArray,
@@ -413,7 +415,10 @@ def get_loc(self, key, method=None, tolerance=None):
413415
if not is_scalar(key):
414416
raise InvalidIndexError(key)
415417

416-
if isinstance(key, str):
418+
if is_valid_na_for_dtype(key, self.dtype):
419+
key = NaT
420+
421+
elif isinstance(key, str):
417422

418423
try:
419424
loc = self._get_string_slice(key)
@@ -447,10 +452,25 @@ def get_loc(self, key, method=None, tolerance=None):
447452
else:
448453
key = asdt
449454

450-
elif is_integer(key):
451-
# Period constructor will cast to string, which we dont want
452-
raise KeyError(key)
453-
elif isinstance(key, Period) and key.freq != self.freq:
455+
elif isinstance(key, Period):
456+
sfreq = self.freq
457+
kfreq = key.freq
458+
if not (
459+
sfreq.n == kfreq.n
460+
and sfreq._period_dtype_code == kfreq._period_dtype_code
461+
):
462+
# GH#42247 For the subset of DateOffsets that can be Period freqs,
463+
# checking these two attributes is sufficient to check equality,
464+
# and much more performant than `self.freq == key.freq`
465+
raise KeyError(key)
466+
elif isinstance(key, datetime):
467+
try:
468+
key = Period(key, freq=self.freq)
469+
except ValueError as err:
470+
# we cannot construct the Period
471+
raise KeyError(orig_key) from err
472+
else:
473+
# in particular integer, which Period constructor would cast to string
454474
raise KeyError(key)
455475

456476
try:

0 commit comments

Comments
 (0)