Skip to content

Commit c2f3ce3

Browse files
author
Jean-Francois Zinque
authored
BUG: MultiIndex intersection with sort=False does not preserve order (#31312)
1 parent 143b011 commit c2f3ce3

File tree

4 files changed

+89
-30
lines changed

4 files changed

+89
-30
lines changed

asv_bench/benchmarks/multiindex_object.py

+39
Original file line numberDiff line numberDiff line change
@@ -160,4 +160,43 @@ def time_equals_non_object_index(self):
160160
self.mi_large_slow.equals(self.idx_non_object)
161161

162162

163+
class SetOperations:
164+
165+
params = [
166+
("monotonic", "non_monotonic"),
167+
("datetime", "int", "string"),
168+
("intersection", "union", "symmetric_difference"),
169+
]
170+
param_names = ["index_structure", "dtype", "method"]
171+
172+
def setup(self, index_structure, dtype, method):
173+
N = 10 ** 5
174+
level1 = range(1000)
175+
176+
level2 = date_range(start="1/1/2000", periods=N // 1000)
177+
dates_left = MultiIndex.from_product([level1, level2])
178+
179+
level2 = range(N // 1000)
180+
int_left = MultiIndex.from_product([level1, level2])
181+
182+
level2 = tm.makeStringIndex(N // 1000).values
183+
str_left = MultiIndex.from_product([level1, level2])
184+
185+
data = {
186+
"datetime": dates_left,
187+
"int": int_left,
188+
"string": str_left,
189+
}
190+
191+
if index_structure == "non_monotonic":
192+
data = {k: mi[::-1] for k, mi in data.items()}
193+
194+
data = {k: {"left": mi, "right": mi[:-1]} for k, mi in data.items()}
195+
self.left = data[dtype]["left"]
196+
self.right = data[dtype]["right"]
197+
198+
def time_operation(self, index_structure, dtype, method):
199+
getattr(self.left, method)(self.right)
200+
201+
163202
from .pandas_vb_common import setup # noqa: F401 isort:skip

doc/source/whatsnew/v1.1.0.rst

+10
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,16 @@ MultiIndex
176176
index=[["a", "a", "b", "b"], [1, 2, 1, 2]])
177177
# Rows are now ordered as the requested keys
178178
df.loc[(['b', 'a'], [2, 1]), :]
179+
180+
- Bug in :meth:`MultiIndex.intersection` was not guaranteed to preserve order when ``sort=False``. (:issue:`31325`)
181+
182+
.. ipython:: python
183+
184+
left = pd.MultiIndex.from_arrays([["b", "a"], [2, 1]])
185+
right = pd.MultiIndex.from_arrays([["a", "b", "c"], [1, 2, 3]])
186+
# Common elements are now guaranteed to be ordered by the left side
187+
left.intersection(right, sort=False)
188+
179189
-
180190

181191
I/O

pandas/core/indexes/multi.py

+17-3
Original file line numberDiff line numberDiff line change
@@ -3314,9 +3314,23 @@ def intersection(self, other, sort=False):
33143314
if self.equals(other):
33153315
return self
33163316

3317-
self_tuples = self._ndarray_values
3318-
other_tuples = other._ndarray_values
3319-
uniq_tuples = set(self_tuples) & set(other_tuples)
3317+
lvals = self._ndarray_values
3318+
rvals = other._ndarray_values
3319+
3320+
uniq_tuples = None # flag whether _inner_indexer was succesful
3321+
if self.is_monotonic and other.is_monotonic:
3322+
try:
3323+
uniq_tuples = self._inner_indexer(lvals, rvals)[0]
3324+
sort = False # uniq_tuples is already sorted
3325+
except TypeError:
3326+
pass
3327+
3328+
if uniq_tuples is None:
3329+
other_uniq = set(rvals)
3330+
seen = set()
3331+
uniq_tuples = [
3332+
x for x in lvals if x in other_uniq and not (x in seen or seen.add(x))
3333+
]
33203334

33213335
if sort is None:
33223336
uniq_tuples = sorted(uniq_tuples)

pandas/tests/indexes/multi/test_setops.py

+23-27
Original file line numberDiff line numberDiff line change
@@ -19,45 +19,41 @@ def test_set_ops_error_cases(idx, case, sort, method):
1919

2020

2121
@pytest.mark.parametrize("sort", [None, False])
22-
def test_intersection_base(idx, sort):
23-
first = idx[:5]
24-
second = idx[:3]
25-
intersect = first.intersection(second, sort=sort)
22+
@pytest.mark.parametrize("klass", [MultiIndex, np.array, Series, list])
23+
def test_intersection_base(idx, sort, klass):
24+
first = idx[2::-1] # first 3 elements reversed
25+
second = idx[:5]
2626

27-
if sort is None:
28-
tm.assert_index_equal(intersect, second.sort_values())
29-
assert tm.equalContents(intersect, second)
27+
if klass is not MultiIndex:
28+
second = klass(second.values)
3029

31-
# GH 10149
32-
cases = [klass(second.values) for klass in [np.array, Series, list]]
33-
for case in cases:
34-
result = first.intersection(case, sort=sort)
35-
if sort is None:
36-
tm.assert_index_equal(result, second.sort_values())
37-
assert tm.equalContents(result, second)
30+
intersect = first.intersection(second, sort=sort)
31+
if sort is None:
32+
expected = first.sort_values()
33+
else:
34+
expected = first
35+
tm.assert_index_equal(intersect, expected)
3836

3937
msg = "other must be a MultiIndex or a list of tuples"
4038
with pytest.raises(TypeError, match=msg):
4139
first.intersection([1, 2, 3], sort=sort)
4240

4341

4442
@pytest.mark.parametrize("sort", [None, False])
45-
def test_union_base(idx, sort):
46-
first = idx[3:]
43+
@pytest.mark.parametrize("klass", [MultiIndex, np.array, Series, list])
44+
def test_union_base(idx, sort, klass):
45+
first = idx[::-1]
4746
second = idx[:5]
48-
everything = idx
47+
48+
if klass is not MultiIndex:
49+
second = klass(second.values)
50+
4951
union = first.union(second, sort=sort)
5052
if sort is None:
51-
tm.assert_index_equal(union, everything.sort_values())
52-
assert tm.equalContents(union, everything)
53-
54-
# GH 10149
55-
cases = [klass(second.values) for klass in [np.array, Series, list]]
56-
for case in cases:
57-
result = first.union(case, sort=sort)
58-
if sort is None:
59-
tm.assert_index_equal(result, everything.sort_values())
60-
assert tm.equalContents(result, everything)
53+
expected = first.sort_values()
54+
else:
55+
expected = first
56+
tm.assert_index_equal(union, expected)
6157

6258
msg = "other must be a MultiIndex or a list of tuples"
6359
with pytest.raises(TypeError, match=msg):

0 commit comments

Comments
 (0)