Skip to content

Commit 8a2b8e2

Browse files
Backport PR #36927: BUG: Fix duplicates in intersection of multiindexes (#38155)
Co-authored-by: patrick <[email protected]>
1 parent 45c1016 commit 8a2b8e2

File tree

9 files changed

+60
-10
lines changed

9 files changed

+60
-10
lines changed

doc/source/whatsnew/v1.1.5.rst

+1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ Fixed regressions
2323
- Fixed regression in :meth:`DataFrame.groupby` aggregation with out-of-bounds datetime objects in an object-dtype column (:issue:`36003`)
2424
- Fixed regression in ``df.groupby(..).rolling(..)`` with the resulting :class:`MultiIndex` when grouping by a label that is in the index (:issue:`37641`)
2525
- Fixed regression in :meth:`DataFrame.fillna` not filling ``NaN`` after other operations such as :meth:`DataFrame.pivot` (:issue:`36495`).
26+
- Fixed regression in :meth:`MultiIndex.intersection` returning duplicates when at least one of the indexes had duplicates (:issue:`36915`)
2627

2728
.. ---------------------------------------------------------------------------
2829

pandas/core/indexes/base.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -2654,7 +2654,7 @@ def intersection(self, other, sort=False):
26542654
self._assert_can_do_setop(other)
26552655
other = ensure_index(other)
26562656

2657-
if self.equals(other):
2657+
if self.equals(other) and not self.has_duplicates:
26582658
return self._get_reconciled_name_object(other)
26592659

26602660
if not is_dtype_equal(self.dtype, other.dtype):
@@ -2672,7 +2672,7 @@ def intersection(self, other, sort=False):
26722672
except TypeError:
26732673
pass
26742674
else:
2675-
return self._wrap_setop_result(other, result)
2675+
return self._wrap_setop_result(other, algos.unique1d(result))
26762676

26772677
try:
26782678
indexer = Index(rvals).get_indexer(lvals)
@@ -2683,13 +2683,16 @@ def intersection(self, other, sort=False):
26832683
indexer = algos.unique1d(Index(rvals).get_indexer_non_unique(lvals)[0])
26842684
indexer = indexer[indexer != -1]
26852685

2686-
taken = other.take(indexer)
2686+
taken = other.take(indexer).unique()
26872687
res_name = get_op_result_name(self, other)
26882688

26892689
if sort is None:
26902690
taken = algos.safe_sort(taken.values)
26912691
return self._shallow_copy(taken, name=res_name)
26922692

2693+
# Intersection has to be unique
2694+
assert algos.unique(taken._values).shape == taken._values.shape
2695+
26932696
taken.name = res_name
26942697
return taken
26952698

pandas/core/indexes/multi.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -3398,6 +3398,8 @@ def intersection(self, other, sort=False):
33983398
other, result_names = self._convert_can_do_setop(other)
33993399

34003400
if self.equals(other):
3401+
if self.has_duplicates:
3402+
return self.unique()
34013403
return self
34023404

34033405
if not is_object_dtype(other.dtype):
@@ -3416,10 +3418,12 @@ def intersection(self, other, sort=False):
34163418
uniq_tuples = None # flag whether _inner_indexer was successful
34173419
if self.is_monotonic and other.is_monotonic:
34183420
try:
3419-
uniq_tuples = self._inner_indexer(lvals, rvals)[0]
3420-
sort = False # uniq_tuples is already sorted
3421+
inner_tuples = self._inner_indexer(lvals, rvals)[0]
3422+
sort = False # inner_tuples is already sorted
34213423
except TypeError:
34223424
pass
3425+
else:
3426+
uniq_tuples = algos.unique(inner_tuples)
34233427

34243428
if uniq_tuples is None:
34253429
other_uniq = set(rvals)

pandas/core/ops/__init__.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -539,7 +539,11 @@ def _should_reindex_frame_op(
539539
if fill_value is None and level is None and axis is default_axis:
540540
# TODO: any other cases we should handle here?
541541
cols = left.columns.intersection(right.columns)
542-
if not (cols.equals(left.columns) and cols.equals(right.columns)):
542+
543+
# Intersection is always unique so we have to check the unique columns
544+
left_uniques = left.columns.unique()
545+
right_uniques = right.columns.unique()
546+
if not (cols.equals(left_uniques) and cols.equals(right_uniques)):
543547
return True
544548

545549
return False

pandas/core/reshape/merge.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -1209,7 +1209,9 @@ def _validate_specification(self):
12091209
raise MergeError("Must pass left_on or left_index=True")
12101210
else:
12111211
# use the common columns
1212-
common_cols = self.left.columns.intersection(self.right.columns)
1212+
left_cols = self.left.columns
1213+
right_cols = self.right.columns
1214+
common_cols = left_cols.intersection(right_cols)
12131215
if len(common_cols) == 0:
12141216
raise MergeError(
12151217
"No common columns to perform merge on. "
@@ -1218,7 +1220,10 @@ def _validate_specification(self):
12181220
f"left_index={self.left_index}, "
12191221
f"right_index={self.right_index}"
12201222
)
1221-
if not common_cols.is_unique:
1223+
if (
1224+
not left_cols.join(common_cols, how="inner").is_unique
1225+
or not right_cols.join(common_cols, how="inner").is_unique
1226+
):
12221227
raise MergeError(f"Data columns not unique: {repr(common_cols)}")
12231228
self.left_on = self.right_on = common_cols
12241229
elif self.on is not None:

pandas/tests/indexes/multi/test_setops.py

+23
Original file line numberDiff line numberDiff line change
@@ -375,3 +375,26 @@ def test_setops_disallow_true(method):
375375

376376
with pytest.raises(ValueError, match="The 'sort' keyword only takes"):
377377
getattr(idx1, method)(idx2, sort=True)
378+
379+
380+
@pytest.mark.parametrize(
381+
("tuples", "exp_tuples"),
382+
[
383+
([("val1", "test1")], [("val1", "test1")]),
384+
([("val1", "test1"), ("val1", "test1")], [("val1", "test1")]),
385+
(
386+
[("val2", "test2"), ("val1", "test1")],
387+
[("val2", "test2"), ("val1", "test1")],
388+
),
389+
],
390+
)
391+
def test_intersect_with_duplicates(tuples, exp_tuples):
392+
# GH#36915
393+
left = MultiIndex.from_tuples(tuples, names=["first", "second"])
394+
right = MultiIndex.from_tuples(
395+
[("val1", "test1"), ("val1", "test1"), ("val2", "test2")],
396+
names=["first", "second"],
397+
)
398+
result = left.intersection(right)
399+
expected = MultiIndex.from_tuples(exp_tuples, names=["first", "second"])
400+
tm.assert_index_equal(result, expected)

pandas/tests/indexes/test_base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -688,7 +688,7 @@ def test_intersection_monotonic(self, index2, keeps_name, sort):
688688

689689
@pytest.mark.parametrize(
690690
"index2,expected_arr",
691-
[(Index(["B", "D"]), ["B"]), (Index(["B", "D", "A"]), ["A", "B", "A"])],
691+
[(Index(["B", "D"]), ["B"]), (Index(["B", "D", "A"]), ["A", "B"])],
692692
)
693693
def test_intersection_non_monotonic_non_unique(self, index2, expected_arr, sort):
694694
# non-monotonic non-unique

pandas/tests/indexes/test_setops.py

+10
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,13 @@ def test_union_dtypes(left, right, expected):
9595
b = pd.Index([], dtype=right)
9696
result = (a | b).dtype
9797
assert result == expected
98+
99+
100+
@pytest.mark.parametrize("values", [[1, 2, 2, 3], [3, 3]])
101+
def test_intersection_duplicates(values):
102+
# GH#31326
103+
a = pd.Index(values)
104+
b = pd.Index([3, 3])
105+
result = a.intersection(b)
106+
expected = pd.Index([3])
107+
tm.assert_index_equal(result, expected)

pandas/tests/reshape/merge/test_merge.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -742,7 +742,7 @@ def test_overlapping_columns_error_message(self):
742742

743743
# #2649, #10639
744744
df2.columns = ["key1", "foo", "foo"]
745-
msg = r"Data columns not unique: Index\(\['foo', 'foo'\], dtype='object'\)"
745+
msg = r"Data columns not unique: Index\(\['foo'\], dtype='object'\)"
746746
with pytest.raises(MergeError, match=msg):
747747
merge(df, df2)
748748

0 commit comments

Comments
 (0)