diff --git a/doc/source/whatsnew/v1.6.0.rst b/doc/source/whatsnew/v1.6.0.rst index 405b8cc0a5ded..af1180541453d 100644 --- a/doc/source/whatsnew/v1.6.0.rst +++ b/doc/source/whatsnew/v1.6.0.rst @@ -178,6 +178,7 @@ MultiIndex - Bug in :meth:`MultiIndex.unique` losing extension array dtype (:issue:`48335`) - Bug in :meth:`MultiIndex.union` losing extension array (:issue:`48498`, :issue:`48505`) - Bug in :meth:`MultiIndex.append` not checking names for equality (:issue:`48288`) +- Bug in :meth:`MultiIndex.symmetric_difference` losing extension array (:issue:`48607`) - I/O diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index 52150eafd7783..aefeca5e6d576 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -3711,31 +3711,23 @@ def symmetric_difference(self, other, result_name=None, sort=None): left_indexer = np.setdiff1d( np.arange(this.size), common_indexer, assume_unique=True ) - left_diff = this._values.take(left_indexer) + left_diff = this.take(left_indexer) # {other} minus {this} right_indexer = (indexer == -1).nonzero()[0] - right_diff = other._values.take(right_indexer) + right_diff = other.take(right_indexer) - res_values = concat_compat([left_diff, right_diff]) - res_values = _maybe_try_sort(res_values, sort) - - # pass dtype so we retain object dtype - result = Index(res_values, name=result_name, dtype=res_values.dtype) + res_values = left_diff.append(right_diff) + result = _maybe_try_sort(res_values, sort) - if self._is_multi: - self = cast("MultiIndex", self) + if not self._is_multi: + return Index(result, name=result_name, dtype=res_values.dtype) + else: + left_diff = cast("MultiIndex", left_diff) if len(result) == 0: - # On equal symmetric_difference MultiIndexes the difference is empty. - # Therefore, an empty MultiIndex is returned GH#13490 - return type(self)( - levels=[[] for _ in range(self.nlevels)], - codes=[[] for _ in range(self.nlevels)], - names=result.name, - ) - return type(self).from_tuples(result, names=result.name) - - return result + # result might be an Index, if other was an Index + return left_diff.remove_unused_levels().set_names(result_name) + return result.set_names(result_name) @final def _assert_can_do_setop(self, other) -> bool: diff --git a/pandas/tests/indexes/multi/test_setops.py b/pandas/tests/indexes/multi/test_setops.py index ce310a75e8e45..014deda340547 100644 --- a/pandas/tests/indexes/multi/test_setops.py +++ b/pandas/tests/indexes/multi/test_setops.py @@ -440,6 +440,22 @@ def test_setops_disallow_true(method): getattr(idx1, method)(idx2, sort=True) +@pytest.mark.parametrize("val", [pd.NA, 5]) +def test_symmetric_difference_keeping_ea_dtype(any_numeric_ea_dtype, val): + # GH#48607 + midx = MultiIndex.from_arrays( + [Series([1, 2], dtype=any_numeric_ea_dtype), [2, 1]], names=["a", None] + ) + midx2 = MultiIndex.from_arrays( + [Series([1, 2, val], dtype=any_numeric_ea_dtype), [1, 1, 3]] + ) + result = midx.symmetric_difference(midx2) + expected = MultiIndex.from_arrays( + [Series([1, 1, val], dtype=any_numeric_ea_dtype), [1, 2, 3]] + ) + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize( ("tuples", "exp_tuples"), [