diff --git a/doc/source/whatsnew/v1.6.0.rst b/doc/source/whatsnew/v1.6.0.rst index ee5085fd9ad89..1c2fa280465e8 100644 --- a/doc/source/whatsnew/v1.6.0.rst +++ b/doc/source/whatsnew/v1.6.0.rst @@ -168,6 +168,7 @@ Missing MultiIndex ^^^^^^^^^^ - Bug in :meth:`MultiIndex.unique` losing extension array dtype (:issue:`48335`) +- Bug in :meth:`MultiIndex.union` losing extension array (:issue:`48498`) - Bug in :meth:`MultiIndex.append` not checking names for equality (:issue:`48288`) - diff --git a/pandas/conftest.py b/pandas/conftest.py index 6f31e2a11486a..84a0fa7331c17 100644 --- a/pandas/conftest.py +++ b/pandas/conftest.py @@ -1569,6 +1569,31 @@ def any_real_numpy_dtype(request): return request.param +@pytest.fixture( + params=tm.ALL_REAL_NUMPY_DTYPES + tm.ALL_INT_EA_DTYPES + tm.FLOAT_EA_DTYPES +) +def any_real_numeric_dtype(request): + """ + Parameterized fixture for any (purely) real numeric dtype. + + * int + * 'int8' + * 'uint8' + * 'int16' + * 'uint16' + * 'int32' + * 'uint32' + * 'int64' + * 'uint64' + * float + * 'float32' + * 'float64' + + and associated ea dtypes. + """ + return request.param + + @pytest.fixture(params=tm.ALL_NUMPY_DTYPES) def any_numpy_dtype(request): """ diff --git a/pandas/core/algorithms.py b/pandas/core/algorithms.py index 77de1636a92eb..1054b839149dc 100644 --- a/pandas/core/algorithms.py +++ b/pandas/core/algorithms.py @@ -1830,6 +1830,7 @@ def safe_sort( "Only list-like objects are allowed to be passed to safe_sort as values" ) original_values = values + is_mi = isinstance(original_values, ABCMultiIndex) if not isinstance(values, (np.ndarray, ABCExtensionArray)): # don't convert to string types @@ -1851,7 +1852,11 @@ def safe_sort( else: try: sorter = values.argsort() - ordered = values.take(sorter) + if is_mi: + # Operate on original object instead of casted array (MultiIndex) + ordered = original_values.take(sorter) + else: + ordered = values.take(sorter) except TypeError: # Previous sorters failed or were not applicable, try `_sort_mixed` # which would work, but which fails for special case of 1d arrays diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index 78e1f713644dd..ee16857337df9 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -3398,8 +3398,8 @@ def _union(self, other: Index, sort): elif not other.is_unique: # other has duplicates - result = algos.union_with_duplicates(lvals, rvals) - return _maybe_try_sort(result, sort) + result_dups = algos.union_with_duplicates(lvals, rvals) + return _maybe_try_sort(result_dups, sort) # Self may have duplicates; other already checked as unique # find indexes of things in "other" that are not in "self" @@ -3409,11 +3409,17 @@ def _union(self, other: Index, sort): else: missing = algos.unique1d(self.get_indexer_non_unique(other)[1]) - if len(missing) > 0: - other_diff = rvals.take(missing) - result = concat_compat((lvals, other_diff)) + result: Index | MultiIndex | ArrayLike + if self._is_multi: + # Preserve MultiIndex to avoid losing dtypes + result = self.append(other.take(missing)) + else: - result = lvals + if len(missing) > 0: + other_diff = rvals.take(missing) + result = concat_compat((lvals, other_diff)) + else: + result = lvals if not self.is_monotonic_increasing or not other.is_monotonic_increasing: # if both are monotonic then result should already be sorted diff --git a/pandas/core/indexes/multi.py b/pandas/core/indexes/multi.py index 13c3b9200371f..1b35cc03f6fdd 100644 --- a/pandas/core/indexes/multi.py +++ b/pandas/core/indexes/multi.py @@ -3644,6 +3644,10 @@ def _union(self, other, sort) -> MultiIndex: # This is only necessary if both sides have nans or one has dups, # fast_unique_multiple is faster result = super()._union(other, sort) + + if isinstance(result, MultiIndex): + return result + else: rvals = other._values.astype(object, copy=False) result = lib.fast_unique_multiple([self._values, rvals], sort=sort) diff --git a/pandas/tests/indexes/multi/test_setops.py b/pandas/tests/indexes/multi/test_setops.py index e15809a751b73..7383d5a551e7b 100644 --- a/pandas/tests/indexes/multi/test_setops.py +++ b/pandas/tests/indexes/multi/test_setops.py @@ -549,6 +549,36 @@ def test_union_duplicates(index, request): tm.assert_index_equal(result, expected) +def test_union_keep_dtype_precision(any_real_numeric_dtype): + # GH#48498 + arr1 = Series([4, 1, 1], dtype=any_real_numeric_dtype) + arr2 = Series([1, 4], dtype=any_real_numeric_dtype) + midx = MultiIndex.from_arrays([arr1, [2, 1, 1]], names=["a", None]) + midx2 = MultiIndex.from_arrays([arr2, [1, 2]], names=["a", None]) + + result = midx.union(midx2) + expected = MultiIndex.from_arrays( + ([Series([1, 1, 4], dtype=any_real_numeric_dtype), [1, 1, 2]]), + names=["a", None], + ) + tm.assert_index_equal(result, expected) + + +def test_union_keep_ea_dtype_with_na(any_numeric_ea_dtype): + # GH#48498 + + arr1 = Series([4, pd.NA], dtype=any_numeric_ea_dtype) + arr2 = Series([1, pd.NA], dtype=any_numeric_ea_dtype) + midx = MultiIndex.from_arrays([arr1, [2, 1]], names=["a", None]) + midx2 = MultiIndex.from_arrays([arr2, [1, 2]]) + result = midx.union(midx2) + # Expected is actually off and should contain (1, 1) too. See GH#37222 + expected = MultiIndex.from_arrays( + [Series([4, pd.NA, pd.NA], dtype=any_numeric_ea_dtype), [2, 1, 2]] + ) + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize( "levels1, levels2, codes1, codes2, names", [