diff --git a/asv_bench/benchmarks/multiindex_object.py b/asv_bench/benchmarks/multiindex_object.py index 0e188c58012fa..793f0c7c03c77 100644 --- a/asv_bench/benchmarks/multiindex_object.py +++ b/asv_bench/benchmarks/multiindex_object.py @@ -160,4 +160,43 @@ def time_equals_non_object_index(self): self.mi_large_slow.equals(self.idx_non_object) +class SetOperations: + + params = [ + ("monotonic", "non_monotonic"), + ("datetime", "int", "string"), + ("intersection", "union", "symmetric_difference"), + ] + param_names = ["index_structure", "dtype", "method"] + + def setup(self, index_structure, dtype, method): + N = 10 ** 5 + level1 = range(1000) + + level2 = date_range(start="1/1/2000", periods=N // 1000) + dates_left = MultiIndex.from_product([level1, level2]) + + level2 = range(N // 1000) + int_left = MultiIndex.from_product([level1, level2]) + + level2 = tm.makeStringIndex(N // 1000).values + str_left = MultiIndex.from_product([level1, level2]) + + data = { + "datetime": dates_left, + "int": int_left, + "string": str_left, + } + + if index_structure == "non_monotonic": + data = {k: mi[::-1] for k, mi in data.items()} + + data = {k: {"left": mi, "right": mi[:-1]} for k, mi in data.items()} + self.left = data[dtype]["left"] + self.right = data[dtype]["right"] + + def time_operation(self, index_structure, dtype, method): + getattr(self.left, method)(self.right) + + from .pandas_vb_common import setup # noqa: F401 isort:skip diff --git a/doc/source/whatsnew/v1.1.0.rst b/doc/source/whatsnew/v1.1.0.rst index 40abb8f83de2f..aeed59f5e80f2 100644 --- a/doc/source/whatsnew/v1.1.0.rst +++ b/doc/source/whatsnew/v1.1.0.rst @@ -175,6 +175,16 @@ MultiIndex index=[["a", "a", "b", "b"], [1, 2, 1, 2]]) # Rows are now ordered as the requested keys df.loc[(['b', 'a'], [2, 1]), :] + +- Bug in :meth:`MultiIndex.intersection` was not guaranteed to preserve order when ``sort=False``. (:issue:`31325`) + +.. ipython:: python + + left = pd.MultiIndex.from_arrays([["b", "a"], [2, 1]]) + right = pd.MultiIndex.from_arrays([["a", "b", "c"], [1, 2, 3]]) + # Common elements are now guaranteed to be ordered by the left side + left.intersection(right, sort=False) + - I/O diff --git a/pandas/core/indexes/multi.py b/pandas/core/indexes/multi.py index 94d6564d372c7..4edfa078dd919 100644 --- a/pandas/core/indexes/multi.py +++ b/pandas/core/indexes/multi.py @@ -3314,9 +3314,23 @@ def intersection(self, other, sort=False): if self.equals(other): return self - self_tuples = self._ndarray_values - other_tuples = other._ndarray_values - uniq_tuples = set(self_tuples) & set(other_tuples) + lvals = self._ndarray_values + rvals = other._ndarray_values + + uniq_tuples = None # flag whether _inner_indexer was succesful + if self.is_monotonic and other.is_monotonic: + try: + uniq_tuples = self._inner_indexer(lvals, rvals)[0] + sort = False # uniq_tuples is already sorted + except TypeError: + pass + + if uniq_tuples is None: + other_uniq = set(rvals) + seen = set() + uniq_tuples = [ + x for x in lvals if x in other_uniq and not (x in seen or seen.add(x)) + ] if sort is None: uniq_tuples = sorted(uniq_tuples) diff --git a/pandas/tests/indexes/multi/test_setops.py b/pandas/tests/indexes/multi/test_setops.py index f949db537de67..627127f7b5b53 100644 --- a/pandas/tests/indexes/multi/test_setops.py +++ b/pandas/tests/indexes/multi/test_setops.py @@ -19,22 +19,20 @@ def test_set_ops_error_cases(idx, case, sort, method): @pytest.mark.parametrize("sort", [None, False]) -def test_intersection_base(idx, sort): - first = idx[:5] - second = idx[:3] - intersect = first.intersection(second, sort=sort) +@pytest.mark.parametrize("klass", [MultiIndex, np.array, Series, list]) +def test_intersection_base(idx, sort, klass): + first = idx[2::-1] # first 3 elements reversed + second = idx[:5] - if sort is None: - tm.assert_index_equal(intersect, second.sort_values()) - assert tm.equalContents(intersect, second) + if klass is not MultiIndex: + second = klass(second.values) - # GH 10149 - cases = [klass(second.values) for klass in [np.array, Series, list]] - for case in cases: - result = first.intersection(case, sort=sort) - if sort is None: - tm.assert_index_equal(result, second.sort_values()) - assert tm.equalContents(result, second) + intersect = first.intersection(second, sort=sort) + if sort is None: + expected = first.sort_values() + else: + expected = first + tm.assert_index_equal(intersect, expected) msg = "other must be a MultiIndex or a list of tuples" with pytest.raises(TypeError, match=msg): @@ -42,22 +40,20 @@ def test_intersection_base(idx, sort): @pytest.mark.parametrize("sort", [None, False]) -def test_union_base(idx, sort): - first = idx[3:] +@pytest.mark.parametrize("klass", [MultiIndex, np.array, Series, list]) +def test_union_base(idx, sort, klass): + first = idx[::-1] second = idx[:5] - everything = idx + + if klass is not MultiIndex: + second = klass(second.values) + union = first.union(second, sort=sort) if sort is None: - tm.assert_index_equal(union, everything.sort_values()) - assert tm.equalContents(union, everything) - - # GH 10149 - cases = [klass(second.values) for klass in [np.array, Series, list]] - for case in cases: - result = first.union(case, sort=sort) - if sort is None: - tm.assert_index_equal(result, everything.sort_values()) - assert tm.equalContents(result, everything) + expected = first.sort_values() + else: + expected = first + tm.assert_index_equal(union, expected) msg = "other must be a MultiIndex or a list of tuples" with pytest.raises(TypeError, match=msg):