Skip to content

Commit 5d2d6b4

Browse files
makbigcjreback
authored andcommitted
[PERF] Get rid of MultiIndex conversion in IntervalIndex.intersection (pandas-dev#26225)
1 parent 5a724b5 commit 5d2d6b4

File tree

6 files changed

+305
-158
lines changed

6 files changed

+305
-158
lines changed

asv_bench/benchmarks/index_object.py

+9
Original file line numberDiff line numberDiff line change
@@ -196,11 +196,20 @@ def setup(self, N):
196196
self.intv = IntervalIndex.from_arrays(left, right)
197197
self.intv._engine
198198

199+
self.left = IntervalIndex.from_breaks(np.arange(N))
200+
self.right = IntervalIndex.from_breaks(np.arange(N - 3, 2 * N - 3))
201+
199202
def time_monotonic_inc(self, N):
200203
self.intv.is_monotonic_increasing
201204

202205
def time_is_unique(self, N):
203206
self.intv.is_unique
204207

208+
def time_intersection(self, N):
209+
self.left.intersection(self.right)
210+
211+
def time_intersection_duplicate(self, N):
212+
self.intv.intersection(self.right)
213+
205214

206215
from .pandas_vb_common import setup # noqa: F401

doc/source/whatsnew/v0.25.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,7 @@ Performance Improvements
510510
- Improved performance of :meth:`read_csv` by much faster parsing of ``MM/YYYY`` and ``DD/MM/YYYY`` datetime formats (:issue:`25922`)
511511
- Improved performance of nanops for dtypes that cannot store NaNs. Speedup is particularly prominent for :meth:`Series.all` and :meth:`Series.any` (:issue:`25070`)
512512
- Improved performance of :meth:`Series.map` for dictionary mappers on categorical series by mapping the categories instead of mapping all values (:issue:`23785`)
513+
- Improved performance of :meth:`IntervalIndex.intersection` (:issue:`24813`)
513514
- Improved performance of :meth:`read_csv` by faster concatenating date columns without extra conversion to string for integer/float zero and float ``NaN``; by faster checking the string for the possibility of being a date (:issue:`25754`)
514515
- Improved performance of :attr:`IntervalIndex.is_unique` by removing conversion to ``MultiIndex`` (:issue:`24813`)
515516

pandas/core/indexes/base.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -2440,9 +2440,7 @@ def _union(self, other, sort):
24402440
def _wrap_setop_result(self, other, result):
24412441
return self._constructor(result, name=get_op_result_name(self, other))
24422442

2443-
# TODO: standardize return type of non-union setops type(self vs other)
2444-
def intersection(self, other, sort=False):
2445-
"""
2443+
_index_shared_docs['intersection'] = """
24462444
Form the intersection of two Index objects.
24472445
24482446
This returns a new Index with elements common to the index and `other`.
@@ -2476,6 +2474,10 @@ def intersection(self, other, sort=False):
24762474
>>> idx1.intersection(idx2)
24772475
Int64Index([3, 4], dtype='int64')
24782476
"""
2477+
2478+
# TODO: standardize return type of non-union setops type(self vs other)
2479+
@Appender(_index_shared_docs['intersection'])
2480+
def intersection(self, other, sort=False):
24792481
self._validate_sort_keyword(sort)
24802482
self._assert_can_do_setop(other)
24812483
other = ensure_index(other)

pandas/core/indexes/interval.py

+106-21
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,42 @@ def _new_IntervalIndex(cls, d):
9797
return cls.from_arrays(**d)
9898

9999

100+
class SetopCheck:
101+
"""
102+
This is called to decorate the set operations of IntervalIndex
103+
to perform the type check in advance.
104+
"""
105+
def __init__(self, op_name):
106+
self.op_name = op_name
107+
108+
def __call__(self, setop):
109+
def func(intvidx_self, other, sort=False):
110+
intvidx_self._assert_can_do_setop(other)
111+
other = ensure_index(other)
112+
113+
if not isinstance(other, IntervalIndex):
114+
result = getattr(intvidx_self.astype(object),
115+
self.op_name)(other)
116+
if self.op_name in ('difference',):
117+
result = result.astype(intvidx_self.dtype)
118+
return result
119+
elif intvidx_self.closed != other.closed:
120+
msg = ('can only do set operations between two IntervalIndex '
121+
'objects that are closed on the same side')
122+
raise ValueError(msg)
123+
124+
# GH 19016: ensure set op will not return a prohibited dtype
125+
subtypes = [intvidx_self.dtype.subtype, other.dtype.subtype]
126+
common_subtype = find_common_type(subtypes)
127+
if is_object_dtype(common_subtype):
128+
msg = ('can only do {op} between two IntervalIndex '
129+
'objects that have compatible dtypes')
130+
raise TypeError(msg.format(op=self.op_name))
131+
132+
return setop(intvidx_self, other, sort)
133+
return func
134+
135+
100136
@Appender(_interval_shared_docs['class'] % dict(
101137
klass="IntervalIndex",
102138
summary="Immutable index of intervals that are closed on the same side.",
@@ -1102,28 +1138,78 @@ def equals(self, other):
11021138
def overlaps(self, other):
11031139
return self._data.overlaps(other)
11041140

1105-
def _setop(op_name, sort=None):
1106-
def func(self, other, sort=sort):
1107-
self._assert_can_do_setop(other)
1108-
other = ensure_index(other)
1109-
if not isinstance(other, IntervalIndex):
1110-
result = getattr(self.astype(object), op_name)(other)
1111-
if op_name in ('difference',):
1112-
result = result.astype(self.dtype)
1113-
return result
1114-
elif self.closed != other.closed:
1115-
msg = ('can only do set operations between two IntervalIndex '
1116-
'objects that are closed on the same side')
1117-
raise ValueError(msg)
1141+
@Appender(_index_shared_docs['intersection'])
1142+
@SetopCheck(op_name='intersection')
1143+
def intersection(self, other, sort=False):
1144+
if self.left.is_unique and self.right.is_unique:
1145+
taken = self._intersection_unique(other)
1146+
else:
1147+
# duplicates
1148+
taken = self._intersection_non_unique(other)
11181149

1119-
# GH 19016: ensure set op will not return a prohibited dtype
1120-
subtypes = [self.dtype.subtype, other.dtype.subtype]
1121-
common_subtype = find_common_type(subtypes)
1122-
if is_object_dtype(common_subtype):
1123-
msg = ('can only do {op} between two IntervalIndex '
1124-
'objects that have compatible dtypes')
1125-
raise TypeError(msg.format(op=op_name))
1150+
if sort is None:
1151+
taken = taken.sort_values()
1152+
1153+
return taken
1154+
1155+
def _intersection_unique(self, other):
1156+
"""
1157+
Used when the IntervalIndex does not have any common endpoint,
1158+
no mater left or right.
1159+
Return the intersection with another IntervalIndex.
1160+
1161+
Parameters
1162+
----------
1163+
other : IntervalIndex
1164+
1165+
Returns
1166+
-------
1167+
taken : IntervalIndex
1168+
"""
1169+
lindexer = self.left.get_indexer(other.left)
1170+
rindexer = self.right.get_indexer(other.right)
1171+
1172+
match = (lindexer == rindexer) & (lindexer != -1)
1173+
indexer = lindexer.take(match.nonzero()[0])
1174+
1175+
return self.take(indexer)
11261176

1177+
def _intersection_non_unique(self, other):
1178+
"""
1179+
Used when the IntervalIndex does have some common endpoints,
1180+
on either sides.
1181+
Return the intersection with another IntervalIndex.
1182+
1183+
Parameters
1184+
----------
1185+
other : IntervalIndex
1186+
1187+
Returns
1188+
-------
1189+
taken : IntervalIndex
1190+
"""
1191+
mask = np.zeros(len(self), dtype=bool)
1192+
1193+
if self.hasnans and other.hasnans:
1194+
first_nan_loc = np.arange(len(self))[self.isna()][0]
1195+
mask[first_nan_loc] = True
1196+
1197+
lmiss = other.left.get_indexer_non_unique(self.left)[1]
1198+
lmatch = np.setdiff1d(np.arange(len(self)), lmiss)
1199+
1200+
for i in lmatch:
1201+
potential = other.left.get_loc(self.left[i])
1202+
if is_scalar(potential):
1203+
if self.right[i] == other.right[potential]:
1204+
mask[i] = True
1205+
elif self.right[i] in other.right[potential]:
1206+
mask[i] = True
1207+
1208+
return self[mask]
1209+
1210+
def _setop(op_name, sort=None):
1211+
@SetopCheck(op_name=op_name)
1212+
def func(self, other, sort=sort):
11271213
result = getattr(self._multiindex, op_name)(other._multiindex,
11281214
sort=sort)
11291215
result_name = get_op_result_name(self, other)
@@ -1148,7 +1234,6 @@ def is_all_dates(self):
11481234
return False
11491235

11501236
union = _setop('union')
1151-
intersection = _setop('intersection', sort=False)
11521237
difference = _setop('difference')
11531238
symmetric_difference = _setop('symmetric_difference')
11541239

pandas/tests/indexes/interval/test_interval.py

-134
Original file line numberDiff line numberDiff line change
@@ -795,140 +795,6 @@ def test_non_contiguous(self, closed):
795795

796796
assert 1.5 not in index
797797

798-
@pytest.mark.parametrize("sort", [None, False])
799-
def test_union(self, closed, sort):
800-
index = self.create_index(closed=closed)
801-
other = IntervalIndex.from_breaks(range(5, 13), closed=closed)
802-
803-
expected = IntervalIndex.from_breaks(range(13), closed=closed)
804-
result = index[::-1].union(other, sort=sort)
805-
if sort is None:
806-
tm.assert_index_equal(result, expected)
807-
assert tm.equalContents(result, expected)
808-
809-
result = other[::-1].union(index, sort=sort)
810-
if sort is None:
811-
tm.assert_index_equal(result, expected)
812-
assert tm.equalContents(result, expected)
813-
814-
tm.assert_index_equal(index.union(index, sort=sort), index)
815-
tm.assert_index_equal(index.union(index[:1], sort=sort), index)
816-
817-
# GH 19101: empty result, same dtype
818-
index = IntervalIndex(np.array([], dtype='int64'), closed=closed)
819-
result = index.union(index, sort=sort)
820-
tm.assert_index_equal(result, index)
821-
822-
# GH 19101: empty result, different dtypes
823-
other = IntervalIndex(np.array([], dtype='float64'), closed=closed)
824-
result = index.union(other, sort=sort)
825-
tm.assert_index_equal(result, index)
826-
827-
@pytest.mark.parametrize("sort", [None, False])
828-
def test_intersection(self, closed, sort):
829-
index = self.create_index(closed=closed)
830-
other = IntervalIndex.from_breaks(range(5, 13), closed=closed)
831-
832-
expected = IntervalIndex.from_breaks(range(5, 11), closed=closed)
833-
result = index[::-1].intersection(other, sort=sort)
834-
if sort is None:
835-
tm.assert_index_equal(result, expected)
836-
assert tm.equalContents(result, expected)
837-
838-
result = other[::-1].intersection(index, sort=sort)
839-
if sort is None:
840-
tm.assert_index_equal(result, expected)
841-
assert tm.equalContents(result, expected)
842-
843-
tm.assert_index_equal(index.intersection(index, sort=sort), index)
844-
845-
# GH 19101: empty result, same dtype
846-
other = IntervalIndex.from_breaks(range(300, 314), closed=closed)
847-
expected = IntervalIndex(np.array([], dtype='int64'), closed=closed)
848-
result = index.intersection(other, sort=sort)
849-
tm.assert_index_equal(result, expected)
850-
851-
# GH 19101: empty result, different dtypes
852-
breaks = np.arange(300, 314, dtype='float64')
853-
other = IntervalIndex.from_breaks(breaks, closed=closed)
854-
result = index.intersection(other, sort=sort)
855-
tm.assert_index_equal(result, expected)
856-
857-
@pytest.mark.parametrize("sort", [None, False])
858-
def test_difference(self, closed, sort):
859-
index = IntervalIndex.from_arrays([1, 0, 3, 2],
860-
[1, 2, 3, 4],
861-
closed=closed)
862-
result = index.difference(index[:1], sort=sort)
863-
expected = index[1:]
864-
if sort is None:
865-
expected = expected.sort_values()
866-
tm.assert_index_equal(result, expected)
867-
868-
# GH 19101: empty result, same dtype
869-
result = index.difference(index, sort=sort)
870-
expected = IntervalIndex(np.array([], dtype='int64'), closed=closed)
871-
tm.assert_index_equal(result, expected)
872-
873-
# GH 19101: empty result, different dtypes
874-
other = IntervalIndex.from_arrays(index.left.astype('float64'),
875-
index.right, closed=closed)
876-
result = index.difference(other, sort=sort)
877-
tm.assert_index_equal(result, expected)
878-
879-
@pytest.mark.parametrize("sort", [None, False])
880-
def test_symmetric_difference(self, closed, sort):
881-
index = self.create_index(closed=closed)
882-
result = index[1:].symmetric_difference(index[:-1], sort=sort)
883-
expected = IntervalIndex([index[0], index[-1]])
884-
if sort is None:
885-
tm.assert_index_equal(result, expected)
886-
assert tm.equalContents(result, expected)
887-
888-
# GH 19101: empty result, same dtype
889-
result = index.symmetric_difference(index, sort=sort)
890-
expected = IntervalIndex(np.array([], dtype='int64'), closed=closed)
891-
if sort is None:
892-
tm.assert_index_equal(result, expected)
893-
assert tm.equalContents(result, expected)
894-
895-
# GH 19101: empty result, different dtypes
896-
other = IntervalIndex.from_arrays(index.left.astype('float64'),
897-
index.right, closed=closed)
898-
result = index.symmetric_difference(other, sort=sort)
899-
tm.assert_index_equal(result, expected)
900-
901-
@pytest.mark.parametrize('op_name', [
902-
'union', 'intersection', 'difference', 'symmetric_difference'])
903-
@pytest.mark.parametrize("sort", [None, False])
904-
def test_set_incompatible_types(self, closed, op_name, sort):
905-
index = self.create_index(closed=closed)
906-
set_op = getattr(index, op_name)
907-
908-
# TODO: standardize return type of non-union setops type(self vs other)
909-
# non-IntervalIndex
910-
if op_name == 'difference':
911-
expected = index
912-
else:
913-
expected = getattr(index.astype('O'), op_name)(Index([1, 2, 3]))
914-
result = set_op(Index([1, 2, 3]), sort=sort)
915-
tm.assert_index_equal(result, expected)
916-
917-
# mixed closed
918-
msg = ('can only do set operations between two IntervalIndex objects '
919-
'that are closed on the same side')
920-
for other_closed in {'right', 'left', 'both', 'neither'} - {closed}:
921-
other = self.create_index(closed=other_closed)
922-
with pytest.raises(ValueError, match=msg):
923-
set_op(other, sort=sort)
924-
925-
# GH 19016: incompatible dtypes
926-
other = interval_range(Timestamp('20180101'), periods=9, closed=closed)
927-
msg = ('can only do {op} between two IntervalIndex objects that have '
928-
'compatible dtypes').format(op=op_name)
929-
with pytest.raises(TypeError, match=msg):
930-
set_op(other, sort=sort)
931-
932798
def test_isin(self, closed):
933799
index = self.create_index(closed=closed)
934800

0 commit comments

Comments
 (0)