diff --git a/doc/source/whatsnew/v1.6.0.rst b/doc/source/whatsnew/v1.6.0.rst index d726f69286469..4223ad3f826a2 100644 --- a/doc/source/whatsnew/v1.6.0.rst +++ b/doc/source/whatsnew/v1.6.0.rst @@ -215,7 +215,7 @@ MultiIndex - Bug in :class:`MultiIndex.set_levels` raising ``IndexError`` when setting empty level (:issue:`48636`) - Bug in :meth:`MultiIndex.unique` losing extension array dtype (:issue:`48335`) - Bug in :meth:`MultiIndex.intersection` losing extension array (:issue:`48604`) -- Bug in :meth:`MultiIndex.union` losing extension array (:issue:`48498`, :issue:`48505`) +- Bug in :meth:`MultiIndex.union` losing extension array (:issue:`48498`, :issue:`48505`, :issue:`48591`) - Bug in :meth:`MultiIndex.append` not checking names for equality (:issue:`48288`) - Bug in :meth:`MultiIndex.symmetric_difference` losing extension array (:issue:`48607`) - diff --git a/pandas/core/algorithms.py b/pandas/core/algorithms.py index 926c79d87e5a0..a670ddf8a54aa 100644 --- a/pandas/core/algorithms.py +++ b/pandas/core/algorithms.py @@ -1970,22 +1970,34 @@ def _sort_tuples( return original_values[indexer] +@overload def union_with_duplicates(lvals: ArrayLike, rvals: ArrayLike) -> ArrayLike: + ... + + +@overload +def union_with_duplicates(lvals: MultiIndex, rvals: MultiIndex) -> MultiIndex: + ... + + +def union_with_duplicates( + lvals: ArrayLike | MultiIndex, rvals: ArrayLike | MultiIndex +) -> ArrayLike | MultiIndex: """ Extracts the union from lvals and rvals with respect to duplicates and nans in both arrays. Parameters ---------- - lvals: np.ndarray or ExtensionArray + lvals: np.ndarray or ExtensionArray or MultiIndex left values which is ordered in front. - rvals: np.ndarray or ExtensionArray + rvals: np.ndarray or ExtensionArray or MultiIndex right values ordered after lvals. Returns ------- - np.ndarray or ExtensionArray - Containing the unsorted union of both arrays. + Index or MultiIndex + Containing the unsorted union of both Indexes. Notes ----- @@ -1995,8 +2007,11 @@ def union_with_duplicates(lvals: ArrayLike, rvals: ArrayLike) -> ArrayLike: l_count = value_counts(lvals, dropna=False) r_count = value_counts(rvals, dropna=False) l_count, r_count = l_count.align(r_count, fill_value=0) - unique_array = unique(concat_compat([lvals, rvals])) - unique_array = ensure_wrapped_if_datetimelike(unique_array) + if isinstance(lvals, ABCMultiIndex): + unique_array = lvals.append(rvals).unique() + else: + unique_array = unique(concat_compat([lvals, rvals])) + unique_array = ensure_wrapped_if_datetimelike(unique_array) for i, value in enumerate(unique_array): indexer += [i] * int(max(l_count.at[value], r_count.at[value])) diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index 7dc04474cbcd8..0a42619664d15 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -3437,7 +3437,7 @@ def _union(self, other: Index, sort): elif not other.is_unique: # other has duplicates result_dups = algos.union_with_duplicates(lvals, rvals) - return _maybe_try_sort(result_dups, sort) + return maybe_try_sort(result_dups, sort) # Self may have duplicates; other already checked as unique # find indexes of things in "other" that are not in "self" @@ -3461,7 +3461,7 @@ def _union(self, other: Index, sort): if not self.is_monotonic_increasing or not other.is_monotonic_increasing: # if both are monotonic then result should already be sorted - result = _maybe_try_sort(result, sort) + result = maybe_try_sort(result, sort) return result @@ -3581,7 +3581,7 @@ def _intersection(self, other: Index, sort: bool = False): return ensure_wrapped_if_datetimelike(res) res_values = self._intersection_via_get_indexer(other, sort=sort) - res_values = _maybe_try_sort(res_values, sort) + res_values = maybe_try_sort(res_values, sort) return res_values def _wrap_intersection_result(self, other, result): @@ -3690,7 +3690,7 @@ def _difference(self, other, sort): the_diff = this.take(label_diff) else: the_diff = this._values.take(label_diff) - the_diff = _maybe_try_sort(the_diff, sort) + the_diff = maybe_try_sort(the_diff, sort) return the_diff @@ -3767,7 +3767,7 @@ def symmetric_difference(self, other, result_name=None, sort=None): right_diff = other.take(right_indexer) res_values = left_diff.append(right_diff) - result = _maybe_try_sort(res_values, sort) + result = maybe_try_sort(res_values, sort) if not self._is_multi: return Index(result, name=result_name, dtype=res_values.dtype) @@ -7564,7 +7564,7 @@ def unpack_nested_dtype(other: _IndexT) -> _IndexT: return other -def _maybe_try_sort(result, sort): +def maybe_try_sort(result, sort): if sort is None: try: result = algos.safe_sort(result) diff --git a/pandas/core/indexes/multi.py b/pandas/core/indexes/multi.py index 26dd957ff4d57..ffa71b54c30fa 100644 --- a/pandas/core/indexes/multi.py +++ b/pandas/core/indexes/multi.py @@ -3674,20 +3674,16 @@ def equal_levels(self, other: MultiIndex) -> bool: def _union(self, other, sort) -> MultiIndex: other, result_names = self._convert_can_do_setop(other) - if ( - any(-1 in code for code in self.codes) - and any(-1 in code for code in other.codes) - or other.has_duplicates + if other.has_duplicates: + result_dups = algos.union_with_duplicates(self, other) + return ibase.maybe_try_sort(result_dups, sort) + + elif any(-1 in code for code in self.codes) and any( + -1 in code for code in other.codes ): - # This is only necessary if both sides have nans or other has dups, + # This is only necessary if both sides have nans, # fast_unique_multiple is faster - result = super()._union(other, sort) - - if isinstance(result, MultiIndex): - return result - return MultiIndex.from_arrays( - zip(*result), sortorder=None, names=result_names - ) + return super()._union(other, sort) else: rvals = other._values.astype(object, copy=False) diff --git a/pandas/tests/indexes/multi/test_setops.py b/pandas/tests/indexes/multi/test_setops.py index 718ac407d4a3f..36e4b7d94bb65 100644 --- a/pandas/tests/indexes/multi/test_setops.py +++ b/pandas/tests/indexes/multi/test_setops.py @@ -578,6 +578,19 @@ def test_union_keep_ea_dtype(any_numeric_ea_dtype, val): tm.assert_index_equal(result, expected) +def test_union_keep_ea_dtype_duplicates_right(any_numeric_ea_dtype): + # GH#48591 + arr1 = Series([4, 2], dtype=any_numeric_ea_dtype) + arr2 = Series([2, 1, 1], dtype=any_numeric_ea_dtype) + midx = MultiIndex.from_arrays([arr1, [1, 2]], names=["a", None]) + midx2 = MultiIndex.from_arrays([arr2, [2, 1, 1]]) + result = midx.union(midx2) + expected = MultiIndex.from_arrays( + [Series([1, 1, 2, 4], dtype=any_numeric_ea_dtype), [1, 1, 2, 1]] + ) + tm.assert_index_equal(result, expected) + + def test_union_duplicates(index, request): # GH#38977 if index.empty or isinstance(index, (IntervalIndex, CategoricalIndex)): @@ -590,15 +603,6 @@ def test_union_duplicates(index, request): result = mi2.union(mi1) expected = mi2.sort_values() tm.assert_index_equal(result, expected) - - if mi2.levels[0].dtype == np.uint64 and (mi2.get_level_values(0) < 2**63).all(): - # GH#47294 - union uses lib.fast_zip, converting data to Python integers - # and loses type information. Result is then unsigned only when values are - # sufficiently large to require unsigned dtype. This happens only if other - # has dups or one of both have missing values - expected = expected.set_levels( - [expected.levels[0].astype(int), expected.levels[1]] - ) result = mi1.union(mi2) tm.assert_index_equal(result, expected)