Skip to content

Commit e0fbcaf

Browse files
committed
Revert "[PERF] Get rid of MultiIndex conversion in IntervalIndex.intersection (pandas-dev#26225)"
This reverts commit 5d2d6b4.
1 parent 3ff4f38 commit e0fbcaf

File tree

6 files changed

+158
-305
lines changed

6 files changed

+158
-305
lines changed

asv_bench/benchmarks/index_object.py

-9
Original file line numberDiff line numberDiff line change
@@ -196,20 +196,11 @@ def setup(self, N):
196196
self.intv = IntervalIndex.from_arrays(left, right)
197197
self.intv._engine
198198

199-
self.left = IntervalIndex.from_breaks(np.arange(N))
200-
self.right = IntervalIndex.from_breaks(np.arange(N - 3, 2 * N - 3))
201-
202199
def time_monotonic_inc(self, N):
203200
self.intv.is_monotonic_increasing
204201

205202
def time_is_unique(self, N):
206203
self.intv.is_unique
207204

208-
def time_intersection(self, N):
209-
self.left.intersection(self.right)
210-
211-
def time_intersection_duplicate(self, N):
212-
self.intv.intersection(self.right)
213-
214205

215206
from .pandas_vb_common import setup # noqa: F401

doc/source/whatsnew/v0.25.0.rst

-1
Original file line numberDiff line numberDiff line change
@@ -510,7 +510,6 @@ Performance Improvements
510510
- Improved performance of :meth:`read_csv` by much faster parsing of ``MM/YYYY`` and ``DD/MM/YYYY`` datetime formats (:issue:`25922`)
511511
- Improved performance of nanops for dtypes that cannot store NaNs. Speedup is particularly prominent for :meth:`Series.all` and :meth:`Series.any` (:issue:`25070`)
512512
- Improved performance of :meth:`Series.map` for dictionary mappers on categorical series by mapping the categories instead of mapping all values (:issue:`23785`)
513-
- Improved performance of :meth:`IntervalIndex.intersection` (:issue:`24813`)
514513
- Improved performance of :meth:`read_csv` by faster concatenating date columns without extra conversion to string for integer/float zero and float ``NaN``; by faster checking the string for the possibility of being a date (:issue:`25754`)
515514
- Improved performance of :attr:`IntervalIndex.is_unique` by removing conversion to ``MultiIndex`` (:issue:`24813`)
516515

pandas/core/indexes/base.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -2440,7 +2440,9 @@ def _union(self, other, sort):
24402440
def _wrap_setop_result(self, other, result):
24412441
return self._constructor(result, name=get_op_result_name(self, other))
24422442

2443-
_index_shared_docs['intersection'] = """
2443+
# TODO: standardize return type of non-union setops type(self vs other)
2444+
def intersection(self, other, sort=False):
2445+
"""
24442446
Form the intersection of two Index objects.
24452447
24462448
This returns a new Index with elements common to the index and `other`.
@@ -2474,10 +2476,6 @@ def _wrap_setop_result(self, other, result):
24742476
>>> idx1.intersection(idx2)
24752477
Int64Index([3, 4], dtype='int64')
24762478
"""
2477-
2478-
# TODO: standardize return type of non-union setops type(self vs other)
2479-
@Appender(_index_shared_docs['intersection'])
2480-
def intersection(self, other, sort=False):
24812479
self._validate_sort_keyword(sort)
24822480
self._assert_can_do_setop(other)
24832481
other = ensure_index(other)

pandas/core/indexes/interval.py

+21-106
Original file line numberDiff line numberDiff line change
@@ -97,42 +97,6 @@ def _new_IntervalIndex(cls, d):
9797
return cls.from_arrays(**d)
9898

9999

100-
class SetopCheck:
101-
"""
102-
This is called to decorate the set operations of IntervalIndex
103-
to perform the type check in advance.
104-
"""
105-
def __init__(self, op_name):
106-
self.op_name = op_name
107-
108-
def __call__(self, setop):
109-
def func(intvidx_self, other, sort=False):
110-
intvidx_self._assert_can_do_setop(other)
111-
other = ensure_index(other)
112-
113-
if not isinstance(other, IntervalIndex):
114-
result = getattr(intvidx_self.astype(object),
115-
self.op_name)(other)
116-
if self.op_name in ('difference',):
117-
result = result.astype(intvidx_self.dtype)
118-
return result
119-
elif intvidx_self.closed != other.closed:
120-
msg = ('can only do set operations between two IntervalIndex '
121-
'objects that are closed on the same side')
122-
raise ValueError(msg)
123-
124-
# GH 19016: ensure set op will not return a prohibited dtype
125-
subtypes = [intvidx_self.dtype.subtype, other.dtype.subtype]
126-
common_subtype = find_common_type(subtypes)
127-
if is_object_dtype(common_subtype):
128-
msg = ('can only do {op} between two IntervalIndex '
129-
'objects that have compatible dtypes')
130-
raise TypeError(msg.format(op=self.op_name))
131-
132-
return setop(intvidx_self, other, sort)
133-
return func
134-
135-
136100
@Appender(_interval_shared_docs['class'] % dict(
137101
klass="IntervalIndex",
138102
summary="Immutable index of intervals that are closed on the same side.",
@@ -1138,78 +1102,28 @@ def equals(self, other):
11381102
def overlaps(self, other):
11391103
return self._data.overlaps(other)
11401104

1141-
@Appender(_index_shared_docs['intersection'])
1142-
@SetopCheck(op_name='intersection')
1143-
def intersection(self, other, sort=False):
1144-
if self.left.is_unique and self.right.is_unique:
1145-
taken = self._intersection_unique(other)
1146-
else:
1147-
# duplicates
1148-
taken = self._intersection_non_unique(other)
1149-
1150-
if sort is None:
1151-
taken = taken.sort_values()
1152-
1153-
return taken
1154-
1155-
def _intersection_unique(self, other):
1156-
"""
1157-
Used when the IntervalIndex does not have any common endpoint,
1158-
no mater left or right.
1159-
Return the intersection with another IntervalIndex.
1160-
1161-
Parameters
1162-
----------
1163-
other : IntervalIndex
1164-
1165-
Returns
1166-
-------
1167-
taken : IntervalIndex
1168-
"""
1169-
lindexer = self.left.get_indexer(other.left)
1170-
rindexer = self.right.get_indexer(other.right)
1171-
1172-
match = (lindexer == rindexer) & (lindexer != -1)
1173-
indexer = lindexer.take(match.nonzero()[0])
1174-
1175-
return self.take(indexer)
1176-
1177-
def _intersection_non_unique(self, other):
1178-
"""
1179-
Used when the IntervalIndex does have some common endpoints,
1180-
on either sides.
1181-
Return the intersection with another IntervalIndex.
1182-
1183-
Parameters
1184-
----------
1185-
other : IntervalIndex
1186-
1187-
Returns
1188-
-------
1189-
taken : IntervalIndex
1190-
"""
1191-
mask = np.zeros(len(self), dtype=bool)
1192-
1193-
if self.hasnans and other.hasnans:
1194-
first_nan_loc = np.arange(len(self))[self.isna()][0]
1195-
mask[first_nan_loc] = True
1196-
1197-
lmiss = other.left.get_indexer_non_unique(self.left)[1]
1198-
lmatch = np.setdiff1d(np.arange(len(self)), lmiss)
1199-
1200-
for i in lmatch:
1201-
potential = other.left.get_loc(self.left[i])
1202-
if is_scalar(potential):
1203-
if self.right[i] == other.right[potential]:
1204-
mask[i] = True
1205-
elif self.right[i] in other.right[potential]:
1206-
mask[i] = True
1207-
1208-
return self[mask]
1209-
12101105
def _setop(op_name, sort=None):
1211-
@SetopCheck(op_name=op_name)
12121106
def func(self, other, sort=sort):
1107+
self._assert_can_do_setop(other)
1108+
other = ensure_index(other)
1109+
if not isinstance(other, IntervalIndex):
1110+
result = getattr(self.astype(object), op_name)(other)
1111+
if op_name in ('difference',):
1112+
result = result.astype(self.dtype)
1113+
return result
1114+
elif self.closed != other.closed:
1115+
msg = ('can only do set operations between two IntervalIndex '
1116+
'objects that are closed on the same side')
1117+
raise ValueError(msg)
1118+
1119+
# GH 19016: ensure set op will not return a prohibited dtype
1120+
subtypes = [self.dtype.subtype, other.dtype.subtype]
1121+
common_subtype = find_common_type(subtypes)
1122+
if is_object_dtype(common_subtype):
1123+
msg = ('can only do {op} between two IntervalIndex '
1124+
'objects that have compatible dtypes')
1125+
raise TypeError(msg.format(op=op_name))
1126+
12131127
result = getattr(self._multiindex, op_name)(other._multiindex,
12141128
sort=sort)
12151129
result_name = get_op_result_name(self, other)
@@ -1234,6 +1148,7 @@ def is_all_dates(self):
12341148
return False
12351149

12361150
union = _setop('union')
1151+
intersection = _setop('intersection', sort=False)
12371152
difference = _setop('difference')
12381153
symmetric_difference = _setop('symmetric_difference')
12391154

pandas/tests/indexes/interval/test_interval.py

+134
Original file line numberDiff line numberDiff line change
@@ -795,6 +795,140 @@ def test_non_contiguous(self, closed):
795795

796796
assert 1.5 not in index
797797

798+
@pytest.mark.parametrize("sort", [None, False])
799+
def test_union(self, closed, sort):
800+
index = self.create_index(closed=closed)
801+
other = IntervalIndex.from_breaks(range(5, 13), closed=closed)
802+
803+
expected = IntervalIndex.from_breaks(range(13), closed=closed)
804+
result = index[::-1].union(other, sort=sort)
805+
if sort is None:
806+
tm.assert_index_equal(result, expected)
807+
assert tm.equalContents(result, expected)
808+
809+
result = other[::-1].union(index, sort=sort)
810+
if sort is None:
811+
tm.assert_index_equal(result, expected)
812+
assert tm.equalContents(result, expected)
813+
814+
tm.assert_index_equal(index.union(index, sort=sort), index)
815+
tm.assert_index_equal(index.union(index[:1], sort=sort), index)
816+
817+
# GH 19101: empty result, same dtype
818+
index = IntervalIndex(np.array([], dtype='int64'), closed=closed)
819+
result = index.union(index, sort=sort)
820+
tm.assert_index_equal(result, index)
821+
822+
# GH 19101: empty result, different dtypes
823+
other = IntervalIndex(np.array([], dtype='float64'), closed=closed)
824+
result = index.union(other, sort=sort)
825+
tm.assert_index_equal(result, index)
826+
827+
@pytest.mark.parametrize("sort", [None, False])
828+
def test_intersection(self, closed, sort):
829+
index = self.create_index(closed=closed)
830+
other = IntervalIndex.from_breaks(range(5, 13), closed=closed)
831+
832+
expected = IntervalIndex.from_breaks(range(5, 11), closed=closed)
833+
result = index[::-1].intersection(other, sort=sort)
834+
if sort is None:
835+
tm.assert_index_equal(result, expected)
836+
assert tm.equalContents(result, expected)
837+
838+
result = other[::-1].intersection(index, sort=sort)
839+
if sort is None:
840+
tm.assert_index_equal(result, expected)
841+
assert tm.equalContents(result, expected)
842+
843+
tm.assert_index_equal(index.intersection(index, sort=sort), index)
844+
845+
# GH 19101: empty result, same dtype
846+
other = IntervalIndex.from_breaks(range(300, 314), closed=closed)
847+
expected = IntervalIndex(np.array([], dtype='int64'), closed=closed)
848+
result = index.intersection(other, sort=sort)
849+
tm.assert_index_equal(result, expected)
850+
851+
# GH 19101: empty result, different dtypes
852+
breaks = np.arange(300, 314, dtype='float64')
853+
other = IntervalIndex.from_breaks(breaks, closed=closed)
854+
result = index.intersection(other, sort=sort)
855+
tm.assert_index_equal(result, expected)
856+
857+
@pytest.mark.parametrize("sort", [None, False])
858+
def test_difference(self, closed, sort):
859+
index = IntervalIndex.from_arrays([1, 0, 3, 2],
860+
[1, 2, 3, 4],
861+
closed=closed)
862+
result = index.difference(index[:1], sort=sort)
863+
expected = index[1:]
864+
if sort is None:
865+
expected = expected.sort_values()
866+
tm.assert_index_equal(result, expected)
867+
868+
# GH 19101: empty result, same dtype
869+
result = index.difference(index, sort=sort)
870+
expected = IntervalIndex(np.array([], dtype='int64'), closed=closed)
871+
tm.assert_index_equal(result, expected)
872+
873+
# GH 19101: empty result, different dtypes
874+
other = IntervalIndex.from_arrays(index.left.astype('float64'),
875+
index.right, closed=closed)
876+
result = index.difference(other, sort=sort)
877+
tm.assert_index_equal(result, expected)
878+
879+
@pytest.mark.parametrize("sort", [None, False])
880+
def test_symmetric_difference(self, closed, sort):
881+
index = self.create_index(closed=closed)
882+
result = index[1:].symmetric_difference(index[:-1], sort=sort)
883+
expected = IntervalIndex([index[0], index[-1]])
884+
if sort is None:
885+
tm.assert_index_equal(result, expected)
886+
assert tm.equalContents(result, expected)
887+
888+
# GH 19101: empty result, same dtype
889+
result = index.symmetric_difference(index, sort=sort)
890+
expected = IntervalIndex(np.array([], dtype='int64'), closed=closed)
891+
if sort is None:
892+
tm.assert_index_equal(result, expected)
893+
assert tm.equalContents(result, expected)
894+
895+
# GH 19101: empty result, different dtypes
896+
other = IntervalIndex.from_arrays(index.left.astype('float64'),
897+
index.right, closed=closed)
898+
result = index.symmetric_difference(other, sort=sort)
899+
tm.assert_index_equal(result, expected)
900+
901+
@pytest.mark.parametrize('op_name', [
902+
'union', 'intersection', 'difference', 'symmetric_difference'])
903+
@pytest.mark.parametrize("sort", [None, False])
904+
def test_set_incompatible_types(self, closed, op_name, sort):
905+
index = self.create_index(closed=closed)
906+
set_op = getattr(index, op_name)
907+
908+
# TODO: standardize return type of non-union setops type(self vs other)
909+
# non-IntervalIndex
910+
if op_name == 'difference':
911+
expected = index
912+
else:
913+
expected = getattr(index.astype('O'), op_name)(Index([1, 2, 3]))
914+
result = set_op(Index([1, 2, 3]), sort=sort)
915+
tm.assert_index_equal(result, expected)
916+
917+
# mixed closed
918+
msg = ('can only do set operations between two IntervalIndex objects '
919+
'that are closed on the same side')
920+
for other_closed in {'right', 'left', 'both', 'neither'} - {closed}:
921+
other = self.create_index(closed=other_closed)
922+
with pytest.raises(ValueError, match=msg):
923+
set_op(other, sort=sort)
924+
925+
# GH 19016: incompatible dtypes
926+
other = interval_range(Timestamp('20180101'), periods=9, closed=closed)
927+
msg = ('can only do {op} between two IntervalIndex objects that have '
928+
'compatible dtypes').format(op=op_name)
929+
with pytest.raises(TypeError, match=msg):
930+
set_op(other, sort=sort)
931+
798932
def test_isin(self, closed):
799933
index = self.create_index(closed=closed)
800934

0 commit comments

Comments
 (0)