Skip to content

Commit ec94b6c

Browse files
phoflnoatamir
authored andcommitted
BUG: MultiIndex.union losing extension array dtype (pandas-dev#48498)
* BUG: MultiIndex.union losing extension array dtype * Add gh ref * Fix tests * Fix typing * Add note
1 parent c47e212 commit ec94b6c

File tree

6 files changed

+78
-7
lines changed

6 files changed

+78
-7
lines changed

doc/source/whatsnew/v1.6.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ Missing
170170
MultiIndex
171171
^^^^^^^^^^
172172
- Bug in :meth:`MultiIndex.unique` losing extension array dtype (:issue:`48335`)
173+
- Bug in :meth:`MultiIndex.union` losing extension array (:issue:`48498`)
173174
- Bug in :meth:`MultiIndex.append` not checking names for equality (:issue:`48288`)
174175
-
175176

pandas/conftest.py

+25
Original file line numberDiff line numberDiff line change
@@ -1569,6 +1569,31 @@ def any_real_numpy_dtype(request):
15691569
return request.param
15701570

15711571

1572+
@pytest.fixture(
1573+
params=tm.ALL_REAL_NUMPY_DTYPES + tm.ALL_INT_EA_DTYPES + tm.FLOAT_EA_DTYPES
1574+
)
1575+
def any_real_numeric_dtype(request):
1576+
"""
1577+
Parameterized fixture for any (purely) real numeric dtype.
1578+
1579+
* int
1580+
* 'int8'
1581+
* 'uint8'
1582+
* 'int16'
1583+
* 'uint16'
1584+
* 'int32'
1585+
* 'uint32'
1586+
* 'int64'
1587+
* 'uint64'
1588+
* float
1589+
* 'float32'
1590+
* 'float64'
1591+
1592+
and associated ea dtypes.
1593+
"""
1594+
return request.param
1595+
1596+
15721597
@pytest.fixture(params=tm.ALL_NUMPY_DTYPES)
15731598
def any_numpy_dtype(request):
15741599
"""

pandas/core/algorithms.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -1830,6 +1830,7 @@ def safe_sort(
18301830
"Only list-like objects are allowed to be passed to safe_sort as values"
18311831
)
18321832
original_values = values
1833+
is_mi = isinstance(original_values, ABCMultiIndex)
18331834

18341835
if not isinstance(values, (np.ndarray, ABCExtensionArray)):
18351836
# don't convert to string types
@@ -1851,7 +1852,11 @@ def safe_sort(
18511852
else:
18521853
try:
18531854
sorter = values.argsort()
1854-
ordered = values.take(sorter)
1855+
if is_mi:
1856+
# Operate on original object instead of casted array (MultiIndex)
1857+
ordered = original_values.take(sorter)
1858+
else:
1859+
ordered = values.take(sorter)
18551860
except TypeError:
18561861
# Previous sorters failed or were not applicable, try `_sort_mixed`
18571862
# which would work, but which fails for special case of 1d arrays

pandas/core/indexes/base.py

+12-6
Original file line numberDiff line numberDiff line change
@@ -3398,8 +3398,8 @@ def _union(self, other: Index, sort):
33983398

33993399
elif not other.is_unique:
34003400
# other has duplicates
3401-
result = algos.union_with_duplicates(lvals, rvals)
3402-
return _maybe_try_sort(result, sort)
3401+
result_dups = algos.union_with_duplicates(lvals, rvals)
3402+
return _maybe_try_sort(result_dups, sort)
34033403

34043404
# Self may have duplicates; other already checked as unique
34053405
# find indexes of things in "other" that are not in "self"
@@ -3409,11 +3409,17 @@ def _union(self, other: Index, sort):
34093409
else:
34103410
missing = algos.unique1d(self.get_indexer_non_unique(other)[1])
34113411

3412-
if len(missing) > 0:
3413-
other_diff = rvals.take(missing)
3414-
result = concat_compat((lvals, other_diff))
3412+
result: Index | MultiIndex | ArrayLike
3413+
if self._is_multi:
3414+
# Preserve MultiIndex to avoid losing dtypes
3415+
result = self.append(other.take(missing))
3416+
34153417
else:
3416-
result = lvals
3418+
if len(missing) > 0:
3419+
other_diff = rvals.take(missing)
3420+
result = concat_compat((lvals, other_diff))
3421+
else:
3422+
result = lvals
34173423

34183424
if not self.is_monotonic_increasing or not other.is_monotonic_increasing:
34193425
# if both are monotonic then result should already be sorted

pandas/core/indexes/multi.py

+4
Original file line numberDiff line numberDiff line change
@@ -3644,6 +3644,10 @@ def _union(self, other, sort) -> MultiIndex:
36443644
# This is only necessary if both sides have nans or one has dups,
36453645
# fast_unique_multiple is faster
36463646
result = super()._union(other, sort)
3647+
3648+
if isinstance(result, MultiIndex):
3649+
return result
3650+
36473651
else:
36483652
rvals = other._values.astype(object, copy=False)
36493653
result = lib.fast_unique_multiple([self._values, rvals], sort=sort)

pandas/tests/indexes/multi/test_setops.py

+30
Original file line numberDiff line numberDiff line change
@@ -549,6 +549,36 @@ def test_union_duplicates(index, request):
549549
tm.assert_index_equal(result, expected)
550550

551551

552+
def test_union_keep_dtype_precision(any_real_numeric_dtype):
553+
# GH#48498
554+
arr1 = Series([4, 1, 1], dtype=any_real_numeric_dtype)
555+
arr2 = Series([1, 4], dtype=any_real_numeric_dtype)
556+
midx = MultiIndex.from_arrays([arr1, [2, 1, 1]], names=["a", None])
557+
midx2 = MultiIndex.from_arrays([arr2, [1, 2]], names=["a", None])
558+
559+
result = midx.union(midx2)
560+
expected = MultiIndex.from_arrays(
561+
([Series([1, 1, 4], dtype=any_real_numeric_dtype), [1, 1, 2]]),
562+
names=["a", None],
563+
)
564+
tm.assert_index_equal(result, expected)
565+
566+
567+
def test_union_keep_ea_dtype_with_na(any_numeric_ea_dtype):
568+
# GH#48498
569+
570+
arr1 = Series([4, pd.NA], dtype=any_numeric_ea_dtype)
571+
arr2 = Series([1, pd.NA], dtype=any_numeric_ea_dtype)
572+
midx = MultiIndex.from_arrays([arr1, [2, 1]], names=["a", None])
573+
midx2 = MultiIndex.from_arrays([arr2, [1, 2]])
574+
result = midx.union(midx2)
575+
# Expected is actually off and should contain (1, 1) too. See GH#37222
576+
expected = MultiIndex.from_arrays(
577+
[Series([4, pd.NA, pd.NA], dtype=any_numeric_ea_dtype), [2, 1, 2]]
578+
)
579+
tm.assert_index_equal(result, expected)
580+
581+
552582
@pytest.mark.parametrize(
553583
"levels1, levels2, codes1, codes2, names",
554584
[

0 commit comments

Comments
 (0)