Skip to content

Commit 639a9c2

Browse files
authored
CLN: Implement multiindex handling for get_op_result_name (#38323)
* CLN: Implement multiindex handling for get_op_result_name * Change import order * Override method * Move import * Remove import * Fix merge issue * Move methods
1 parent 1fa0c4c commit 639a9c2

File tree

4 files changed

+63
-5
lines changed

4 files changed

+63
-5
lines changed

pandas/core/indexes/base.py

-1
Original file line numberDiff line numberDiff line change
@@ -2580,7 +2580,6 @@ def __nonzero__(self):
25802580
# --------------------------------------------------------------------
25812581
# Set Operation Methods
25822582

2583-
@final
25842583
def _get_reconciled_name_object(self, other):
25852584
"""
25862585
If the result of a set operation will be self,

pandas/core/indexes/multi.py

+31-3
Original file line numberDiff line numberDiff line change
@@ -3593,6 +3593,34 @@ def _union(self, other, sort):
35933593
def _is_comparable_dtype(self, dtype: DtypeObj) -> bool:
35943594
return is_object_dtype(dtype)
35953595

3596+
def _get_reconciled_name_object(self, other):
3597+
"""
3598+
If the result of a set operation will be self,
3599+
return self, unless the names change, in which
3600+
case make a shallow copy of self.
3601+
"""
3602+
names = self._maybe_match_names(other)
3603+
if self.names != names:
3604+
return self.rename(names)
3605+
return self
3606+
3607+
def _maybe_match_names(self, other):
3608+
"""
3609+
Try to find common names to attach to the result of an operation between
3610+
a and b. Return a consensus list of names if they match at least partly
3611+
or None if they have completely different names.
3612+
"""
3613+
if len(self.names) != len(other.names):
3614+
return None
3615+
names = []
3616+
for a_name, b_name in zip(self.names, other.names):
3617+
if a_name == b_name:
3618+
names.append(a_name)
3619+
else:
3620+
# TODO: what if they both have np.nan for their names?
3621+
names.append(None)
3622+
return names
3623+
35963624
def intersection(self, other, sort=False):
35973625
"""
35983626
Form the intersection of two MultiIndex objects.
@@ -3616,12 +3644,12 @@ def intersection(self, other, sort=False):
36163644
"""
36173645
self._validate_sort_keyword(sort)
36183646
self._assert_can_do_setop(other)
3619-
other, result_names = self._convert_can_do_setop(other)
3647+
other, _ = self._convert_can_do_setop(other)
36203648

36213649
if self.equals(other):
36223650
if self.has_duplicates:
3623-
return self.unique().rename(result_names)
3624-
return self.rename(result_names)
3651+
return self.unique()._get_reconciled_name_object(other)
3652+
return self._get_reconciled_name_object(other)
36253653

36263654
return self._intersection(other, sort=sort)
36273655

pandas/core/ops/common.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def _maybe_match_name(a, b):
9393
"""
9494
Try to find a name to attach to the result of an operation between
9595
a and b. If only one of these has a `name` attribute, return that
96-
name. Otherwise return a consensus name if they match of None if
96+
name. Otherwise return a consensus name if they match or None if
9797
they have different names.
9898
9999
Parameters

pandas/tests/indexes/multi/test_setops.py

+31
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,29 @@ def test_intersect_with_duplicates(tuples, exp_tuples):
421421
tm.assert_index_equal(result, expected)
422422

423423

424+
@pytest.mark.parametrize(
425+
"data, names, expected",
426+
[
427+
((1,), None, None),
428+
((1,), ["a"], None),
429+
((1,), ["b"], None),
430+
((1, 2), ["c", "d"], [None, None]),
431+
((1, 2), ["b", "a"], [None, None]),
432+
((1, 2, 3), ["a", "b", "c"], None),
433+
((1, 2), ["a", "c"], ["a", None]),
434+
((1, 2), ["c", "b"], [None, "b"]),
435+
((1, 2), ["a", "b"], ["a", "b"]),
436+
((1, 2), [None, "b"], [None, "b"]),
437+
],
438+
)
439+
def test_maybe_match_names(data, names, expected):
440+
# GH#38323
441+
mi = pd.MultiIndex.from_tuples([], names=["a", "b"])
442+
mi2 = pd.MultiIndex.from_tuples([data], names=names)
443+
result = mi._maybe_match_names(mi2)
444+
assert result == expected
445+
446+
424447
def test_intersection_equal_different_names():
425448
# GH#30302
426449
mi1 = MultiIndex.from_arrays([[1, 2], [3, 4]], names=["c", "b"])
@@ -429,3 +452,11 @@ def test_intersection_equal_different_names():
429452
result = mi1.intersection(mi2)
430453
expected = MultiIndex.from_arrays([[1, 2], [3, 4]], names=[None, "b"])
431454
tm.assert_index_equal(result, expected)
455+
456+
457+
def test_intersection_different_names():
458+
# GH#38323
459+
mi = MultiIndex.from_arrays([[1], [3]], names=["c", "b"])
460+
mi2 = MultiIndex.from_arrays([[1], [3]])
461+
result = mi.intersection(mi2)
462+
tm.assert_index_equal(result, mi2)

0 commit comments

Comments
 (0)