|
61 | 61 | from pandas.core.dtypes.dtypes import IntervalDtype
|
62 | 62 | from pandas.core.dtypes.missing import is_valid_na_for_dtype
|
63 | 63 |
|
| 64 | +from pandas.core.algorithms import unique |
64 | 65 | from pandas.core.arrays.interval import (
|
65 | 66 | IntervalArray,
|
66 | 67 | _interval_shared_docs,
|
@@ -787,6 +788,50 @@ def _format_data(self, name=None) -> str:
|
787 | 788 | # name argument is unused here; just for compat with base / categorical
|
788 | 789 | return self._data._format_data() + "," + self._format_space()
|
789 | 790 |
|
| 791 | + # -------------------------------------------------------------------- |
| 792 | + # Set Operations |
| 793 | + |
| 794 | + def _intersection(self, other, sort): |
| 795 | + """ |
| 796 | + intersection specialized to the case with matching dtypes. |
| 797 | + """ |
| 798 | + # For IntervalIndex we also know other.closed == self.closed |
| 799 | + if self.left.is_unique and self.right.is_unique: |
| 800 | + taken = self._intersection_unique(other) |
| 801 | + elif other.left.is_unique and other.right.is_unique and self.isna().sum() <= 1: |
| 802 | + # Swap other/self if other is unique and self does not have |
| 803 | + # multiple NaNs |
| 804 | + taken = other._intersection_unique(self) |
| 805 | + else: |
| 806 | + return super()._intersection(other, sort) |
| 807 | + |
| 808 | + if sort is None: |
| 809 | + taken = taken.sort_values() |
| 810 | + |
| 811 | + return taken |
| 812 | + |
| 813 | + def _intersection_unique(self, other: IntervalIndex) -> IntervalIndex: |
| 814 | + """ |
| 815 | + Used when the IntervalIndex does not have any common endpoint, |
| 816 | + no matter left or right. |
| 817 | + Return the intersection with another IntervalIndex. |
| 818 | + Parameters |
| 819 | + ---------- |
| 820 | + other : IntervalIndex |
| 821 | + Returns |
| 822 | + ------- |
| 823 | + IntervalIndex |
| 824 | + """ |
| 825 | + # Note: this is much more performant than super()._intersection(other) |
| 826 | + lindexer = self.left.get_indexer(other.left) |
| 827 | + rindexer = self.right.get_indexer(other.right) |
| 828 | + |
| 829 | + match = (lindexer == rindexer) & (lindexer != -1) |
| 830 | + indexer = lindexer.take(match.nonzero()[0]) |
| 831 | + indexer = unique(indexer) |
| 832 | + |
| 833 | + return self.take(indexer) |
| 834 | + |
790 | 835 | # --------------------------------------------------------------------
|
791 | 836 |
|
792 | 837 | @property
|
|
0 commit comments