Skip to content

BUG: MultiIndex.union losing dtype when right has duplicates #48591

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/source/whatsnew/v1.6.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ Missing
MultiIndex
^^^^^^^^^^
- Bug in :meth:`MultiIndex.unique` losing extension array dtype (:issue:`48335`)
- Bug in :meth:`MultiIndex.union` losing extension array (:issue:`48498`, :issue:`48505`)
- Bug in :meth:`MultiIndex.union` losing extension array (:issue:`48498`, :issue:`48505`, :issue:`48591`)
- Bug in :meth:`MultiIndex.append` not checking names for equality (:issue:`48288`)
-

Expand Down
27 changes: 21 additions & 6 deletions pandas/core/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1968,22 +1968,34 @@ def _sort_tuples(
return original_values[indexer]


@overload
def union_with_duplicates(lvals: ArrayLike, rvals: ArrayLike) -> ArrayLike:
...


@overload
def union_with_duplicates(lvals: MultiIndex, rvals: MultiIndex) -> MultiIndex:
...


def union_with_duplicates(
lvals: ArrayLike | MultiIndex, rvals: ArrayLike | MultiIndex
) -> ArrayLike | MultiIndex:
"""
Extracts the union from lvals and rvals with respect to duplicates and nans in
both arrays.

Parameters
----------
lvals: np.ndarray or ExtensionArray
lvals: np.ndarray or ExtensionArray or MultiIndex
left values which is ordered in front.
rvals: np.ndarray or ExtensionArray
rvals: np.ndarray or ExtensionArray or MultiIndex
right values ordered after lvals.

Returns
-------
np.ndarray or ExtensionArray
Containing the unsorted union of both arrays.
Index or MultiIndex
Containing the unsorted union of both Indexes.

Notes
-----
Expand All @@ -1993,8 +2005,11 @@ def union_with_duplicates(lvals: ArrayLike, rvals: ArrayLike) -> ArrayLike:
l_count = value_counts(lvals, dropna=False)
r_count = value_counts(rvals, dropna=False)
l_count, r_count = l_count.align(r_count, fill_value=0)
unique_array = unique(concat_compat([lvals, rvals]))
unique_array = ensure_wrapped_if_datetimelike(unique_array)
if isinstance(lvals, ABCMultiIndex):
unique_array = lvals.append(rvals).unique()
else:
unique_array = unique(concat_compat([lvals, rvals]))
unique_array = ensure_wrapped_if_datetimelike(unique_array)

for i, value in enumerate(unique_array):
indexer += [i] * int(max(l_count.at[value], r_count.at[value]))
Expand Down
12 changes: 6 additions & 6 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3405,7 +3405,7 @@ def _union(self, other: Index, sort):
elif not other.is_unique:
# other has duplicates
result_dups = algos.union_with_duplicates(lvals, rvals)
return _maybe_try_sort(result_dups, sort)
return maybe_try_sort(result_dups, sort)

# Self may have duplicates; other already checked as unique
# find indexes of things in "other" that are not in "self"
Expand All @@ -3429,7 +3429,7 @@ def _union(self, other: Index, sort):

if not self.is_monotonic_increasing or not other.is_monotonic_increasing:
# if both are monotonic then result should already be sorted
result = _maybe_try_sort(result, sort)
result = maybe_try_sort(result, sort)

return result

Expand Down Expand Up @@ -3542,7 +3542,7 @@ def _intersection(self, other: Index, sort=False):
return ensure_wrapped_if_datetimelike(res)

res_values = self._intersection_via_get_indexer(other, sort=sort)
res_values = _maybe_try_sort(res_values, sort)
res_values = maybe_try_sort(res_values, sort)
return res_values

def _wrap_intersection_result(self, other, result):
Expand Down Expand Up @@ -3641,7 +3641,7 @@ def _difference(self, other, sort):

label_diff = np.setdiff1d(np.arange(this.size), indexer, assume_unique=True)
the_diff = this._values.take(label_diff)
the_diff = _maybe_try_sort(the_diff, sort)
the_diff = maybe_try_sort(the_diff, sort)

return the_diff

Expand Down Expand Up @@ -3718,7 +3718,7 @@ def symmetric_difference(self, other, result_name=None, sort=None):
right_diff = other._values.take(right_indexer)

res_values = concat_compat([left_diff, right_diff])
res_values = _maybe_try_sort(res_values, sort)
res_values = maybe_try_sort(res_values, sort)

# pass dtype so we retain object dtype
result = Index(res_values, name=result_name, dtype=res_values.dtype)
Expand Down Expand Up @@ -7521,7 +7521,7 @@ def unpack_nested_dtype(other: _IndexT) -> _IndexT:
return other


def _maybe_try_sort(result, sort):
def maybe_try_sort(result, sort):
if sort is None:
try:
result = algos.safe_sort(result)
Expand Down
20 changes: 8 additions & 12 deletions pandas/core/indexes/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -3648,20 +3648,16 @@ def equal_levels(self, other: MultiIndex) -> bool:

def _union(self, other, sort) -> MultiIndex:
other, result_names = self._convert_can_do_setop(other)
if (
any(-1 in code for code in self.codes)
and any(-1 in code for code in other.codes)
or other.has_duplicates
if other.has_duplicates:
result_dups = algos.union_with_duplicates(self, other)
return ibase.maybe_try_sort(result_dups, sort)

elif any(-1 in code for code in self.codes) and any(
-1 in code for code in other.codes
):
# This is only necessary if both sides have nans or other has dups,
# This is only necessary if both sides have nans,
# fast_unique_multiple is faster
result = super()._union(other, sort)

if isinstance(result, MultiIndex):
return result
return MultiIndex.from_arrays(
zip(*result), sortorder=None, names=result_names
)
return super()._union(other, sort)

else:
rvals = other._values.astype(object, copy=False)
Expand Down
22 changes: 13 additions & 9 deletions pandas/tests/indexes/multi/test_setops.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,19 @@ def test_union_keep_ea_dtype(any_numeric_ea_dtype, val):
tm.assert_index_equal(result, expected)


def test_union_keep_ea_dtype_duplicates_right(any_numeric_ea_dtype):
# GH#48591
arr1 = Series([4, 2], dtype=any_numeric_ea_dtype)
arr2 = Series([2, 1, 1], dtype=any_numeric_ea_dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you test with NA in this array too?

Copy link
Member Author

@phofl phofl Sep 19, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will add, but depends on #48608 and #48626, this fails currently because the sorting does not work right now

midx = MultiIndex.from_arrays([arr1, [1, 2]], names=["a", None])
midx2 = MultiIndex.from_arrays([arr2, [2, 1, 1]])
result = midx.union(midx2)
expected = MultiIndex.from_arrays(
[Series([1, 1, 2, 4], dtype=any_numeric_ea_dtype), [1, 1, 2, 1]]
)
tm.assert_index_equal(result, expected)


def test_union_duplicates(index, request):
# GH#38977
if index.empty or isinstance(index, (IntervalIndex, CategoricalIndex)):
Expand All @@ -553,15 +566,6 @@ def test_union_duplicates(index, request):
result = mi2.union(mi1)
expected = mi2.sort_values()
tm.assert_index_equal(result, expected)

if mi2.levels[0].dtype == np.uint64 and (mi2.get_level_values(0) < 2**63).all():
# GH#47294 - union uses lib.fast_zip, converting data to Python integers
# and loses type information. Result is then unsigned only when values are
# sufficiently large to require unsigned dtype. This happens only if other
# has dups or one of both have missing values
expected = expected.set_levels(
[expected.levels[0].astype(int), expected.levels[1]]
)
result = mi1.union(mi2)
tm.assert_index_equal(result, expected)

Expand Down