|
61 | 61 | from pandas.core.dtypes.dtypes import IntervalDtype
|
62 | 62 | from pandas.core.dtypes.missing import is_valid_na_for_dtype
|
63 | 63 |
|
| 64 | +from pandas.core.algorithms import unique |
64 | 65 | from pandas.core.arrays.interval import (
|
65 | 66 | IntervalArray,
|
66 | 67 | _interval_shared_docs,
|
@@ -792,6 +793,50 @@ def _format_data(self, name=None) -> str:
|
792 | 793 | # name argument is unused here; just for compat with base / categorical
|
793 | 794 | return self._data._format_data() + "," + self._format_space()
|
794 | 795 |
|
| 796 | + # -------------------------------------------------------------------- |
| 797 | + # Set Operations |
| 798 | + |
| 799 | + def _intersection(self, other, sort): |
| 800 | + """ |
| 801 | + intersection specialized to the case with matching dtypes. |
| 802 | + """ |
| 803 | + # For IntervalIndex we also know other.closed == self.closed |
| 804 | + if self.left.is_unique and self.right.is_unique: |
| 805 | + taken = self._intersection_unique(other) |
| 806 | + elif other.left.is_unique and other.right.is_unique and self.isna().sum() <= 1: |
| 807 | + # Swap other/self if other is unique and self does not have |
| 808 | + # multiple NaNs |
| 809 | + taken = other._intersection_unique(self) |
| 810 | + else: |
| 811 | + return super()._intersection(other, sort) |
| 812 | + |
| 813 | + if sort is None: |
| 814 | + taken = taken.sort_values() |
| 815 | + |
| 816 | + return taken |
| 817 | + |
| 818 | + def _intersection_unique(self, other: IntervalIndex) -> IntervalIndex: |
| 819 | + """ |
| 820 | + Used when the IntervalIndex does not have any common endpoint, |
| 821 | + no matter left or right. |
| 822 | + Return the intersection with another IntervalIndex. |
| 823 | + Parameters |
| 824 | + ---------- |
| 825 | + other : IntervalIndex |
| 826 | + Returns |
| 827 | + ------- |
| 828 | + IntervalIndex |
| 829 | + """ |
| 830 | + # Note: this is much more performant than super()._intersection(other) |
| 831 | + lindexer = self.left.get_indexer(other.left) |
| 832 | + rindexer = self.right.get_indexer(other.right) |
| 833 | + |
| 834 | + match = (lindexer == rindexer) & (lindexer != -1) |
| 835 | + indexer = lindexer.take(match.nonzero()[0]) |
| 836 | + indexer = unique(indexer) |
| 837 | + |
| 838 | + return self.take(indexer) |
| 839 | + |
795 | 840 | # --------------------------------------------------------------------
|
796 | 841 |
|
797 | 842 | @property
|
|
0 commit comments