Skip to content

Commit aea824f

Browse files
authored
ENH: Keep dtypes in MultiIndex.union without NAs (#48505)
1 parent 68d6b47 commit aea824f

File tree

5 files changed

+76
-54
lines changed

5 files changed

+76
-54
lines changed

doc/source/whatsnew/v1.6.0.rst

+2-1
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ Performance improvements
105105
~~~~~~~~~~~~~~~~~~~~~~~~
106106
- Performance improvement in :meth:`.GroupBy.median` for nullable dtypes (:issue:`37493`)
107107
- Performance improvement in :meth:`MultiIndex.argsort` and :meth:`MultiIndex.sort_values` (:issue:`48406`)
108+
- Performance improvement in :meth:`MultiIndex.union` without missing values and without duplicates (:issue:`48505`)
108109
- Performance improvement in :meth:`.GroupBy.mean` and :meth:`.GroupBy.var` for extension array dtypes (:issue:`37493`)
109110
- Performance improvement for :meth:`Series.value_counts` with nullable dtype (:issue:`48338`)
110111
- Performance improvement for :class:`Series` constructor passing integer numpy array with nullable dtype (:issue:`48338`)
@@ -173,7 +174,7 @@ Missing
173174
MultiIndex
174175
^^^^^^^^^^
175176
- Bug in :meth:`MultiIndex.unique` losing extension array dtype (:issue:`48335`)
176-
- Bug in :meth:`MultiIndex.union` losing extension array (:issue:`48498`)
177+
- Bug in :meth:`MultiIndex.union` losing extension array (:issue:`48498`, :issue:`48505`)
177178
- Bug in :meth:`MultiIndex.append` not checking names for equality (:issue:`48288`)
178179
-
179180

pandas/_libs/lib.pyi

+1-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def is_bool_array(values: np.ndarray, skipna: bool = ...): ...
5959
def fast_multiget(mapping: dict, keys: np.ndarray, default=...) -> np.ndarray: ...
6060
def fast_unique_multiple_list_gen(gen: Generator, sort: bool = ...) -> list: ...
6161
def fast_unique_multiple_list(lists: list, sort: bool | None = ...) -> list: ...
62-
def fast_unique_multiple(arrays: list, sort: bool = ...) -> list: ...
62+
def fast_unique_multiple(left: np.ndarray, right: np.ndarray) -> list: ...
6363
def map_infer(
6464
arr: np.ndarray,
6565
f: Callable[[Any], Any],

pandas/_libs/lib.pyx

+23-36
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
from collections import abc
22
from decimal import Decimal
33
from enum import Enum
4-
import inspect
54
from typing import Literal
6-
import warnings
75

86
cimport cython
97
from cpython.datetime cimport (
@@ -31,8 +29,6 @@ from cython cimport (
3129
floating,
3230
)
3331

34-
from pandas.util._exceptions import find_stack_level
35-
3632
import_datetime()
3733

3834
import numpy as np
@@ -314,51 +310,42 @@ def item_from_zerodim(val: object) -> object:
314310

315311
@cython.wraparound(False)
316312
@cython.boundscheck(False)
317-
def fast_unique_multiple(list arrays, sort: bool = True):
313+
def fast_unique_multiple(ndarray left, ndarray right) -> list:
318314
"""
319-
Generate a list of unique values from a list of arrays.
315+
Generate a list indices we have to add to the left to get the union
316+
of both arrays.
320317

321318
Parameters
322319
----------
323-
list : array-like
324-
List of array-like objects.
325-
sort : bool
326-
Whether or not to sort the resulting unique list.
320+
left : np.ndarray
321+
Left array that is used as base.
322+
right : np.ndarray
323+
right array that is checked for values that are not in left.
324+
right can not have duplicates.
327325

328326
Returns
329327
-------
330-
list of unique values
328+
list of indices that we have to add to the left array.
331329
"""
332330
cdef:
333-
ndarray[object] buf
334-
Py_ssize_t k = len(arrays)
335-
Py_ssize_t i, j, n
336-
list uniques = []
337-
dict table = {}
331+
Py_ssize_t j, n
332+
list indices = []
333+
set table = set()
338334
object val, stub = 0
339335

340-
for i in range(k):
341-
buf = arrays[i]
342-
n = len(buf)
343-
for j in range(n):
344-
val = buf[j]
345-
if val not in table:
346-
table[val] = stub
347-
uniques.append(val)
336+
n = len(left)
337+
for j in range(n):
338+
val = left[j]
339+
if val not in table:
340+
table.add(val)
348341

349-
if sort is None:
350-
try:
351-
uniques.sort()
352-
except TypeError:
353-
warnings.warn(
354-
"The values in the array are unorderable. "
355-
"Pass `sort=False` to suppress this warning.",
356-
RuntimeWarning,
357-
stacklevel=find_stack_level(inspect.currentframe()),
358-
)
359-
pass
342+
n = len(right)
343+
for j in range(n):
344+
val = right[j]
345+
if val not in table:
346+
indices.append(j)
360347

361-
return uniques
348+
return indices
362349

363350

364351
@cython.wraparound(False)

pandas/core/indexes/multi.py

+21-4
Original file line numberDiff line numberDiff line change
@@ -3638,21 +3638,38 @@ def _union(self, other, sort) -> MultiIndex:
36383638
if (
36393639
any(-1 in code for code in self.codes)
36403640
and any(-1 in code for code in other.codes)
3641-
or self.has_duplicates
36423641
or other.has_duplicates
36433642
):
3644-
# This is only necessary if both sides have nans or one has dups,
3643+
# This is only necessary if both sides have nans or other has dups,
36453644
# fast_unique_multiple is faster
36463645
result = super()._union(other, sort)
36473646

36483647
if isinstance(result, MultiIndex):
36493648
return result
3649+
return MultiIndex.from_arrays(
3650+
zip(*result), sortorder=None, names=result_names
3651+
)
36503652

36513653
else:
36523654
rvals = other._values.astype(object, copy=False)
3653-
result = lib.fast_unique_multiple([self._values, rvals], sort=sort)
3655+
right_missing = lib.fast_unique_multiple(self._values, rvals)
3656+
if right_missing:
3657+
result = self.append(other.take(right_missing))
3658+
else:
3659+
result = self._get_reconciled_name_object(other)
36543660

3655-
return MultiIndex.from_arrays(zip(*result), sortorder=None, names=result_names)
3661+
if sort is None:
3662+
try:
3663+
result = result.sort_values()
3664+
except TypeError:
3665+
warnings.warn(
3666+
"The values in the array are unorderable. "
3667+
"Pass `sort=False` to suppress this warning.",
3668+
RuntimeWarning,
3669+
stacklevel=find_stack_level(inspect.currentframe()),
3670+
)
3671+
pass
3672+
return result
36563673

36573674
def _is_comparable_dtype(self, dtype: DtypeObj) -> bool:
36583675
return is_object_dtype(dtype)

pandas/tests/indexes/multi/test_setops.py

+29-12
Original file line numberDiff line numberDiff line change
@@ -261,12 +261,6 @@ def test_union(idx, sort):
261261
assert result.equals(idx)
262262

263263

264-
@pytest.mark.xfail(
265-
# This test was commented out from Oct 2011 to Dec 2021, may no longer
266-
# be relevant.
267-
reason="Length of names must match number of levels in MultiIndex",
268-
raises=ValueError,
269-
)
270264
def test_union_with_regular_index(idx):
271265
other = Index(["A", "B", "C"])
272266

@@ -277,7 +271,9 @@ def test_union_with_regular_index(idx):
277271
msg = "The values in the array are unorderable"
278272
with tm.assert_produces_warning(RuntimeWarning, match=msg):
279273
result2 = idx.union(other)
280-
assert result.equals(result2)
274+
# This is more consistent now, if sorting fails then we don't sort at all
275+
# in the MultiIndex case.
276+
assert not result.equals(result2)
281277

282278

283279
def test_intersection(idx, sort):
@@ -525,6 +521,26 @@ def test_union_nan_got_duplicated():
525521
tm.assert_index_equal(result, mi2)
526522

527523

524+
@pytest.mark.parametrize("val", [4, 1])
525+
def test_union_keep_ea_dtype(any_numeric_ea_dtype, val):
526+
# GH#48505
527+
528+
arr1 = Series([val, 2], dtype=any_numeric_ea_dtype)
529+
arr2 = Series([2, 1], dtype=any_numeric_ea_dtype)
530+
midx = MultiIndex.from_arrays([arr1, [1, 2]], names=["a", None])
531+
midx2 = MultiIndex.from_arrays([arr2, [2, 1]])
532+
result = midx.union(midx2)
533+
if val == 4:
534+
expected = MultiIndex.from_arrays(
535+
[Series([1, 2, 4], dtype=any_numeric_ea_dtype), [1, 2, 1]]
536+
)
537+
else:
538+
expected = MultiIndex.from_arrays(
539+
[Series([1, 2], dtype=any_numeric_ea_dtype), [1, 2]]
540+
)
541+
tm.assert_index_equal(result, expected)
542+
543+
528544
def test_union_duplicates(index, request):
529545
# GH#38977
530546
if index.empty or isinstance(index, (IntervalIndex, CategoricalIndex)):
@@ -534,18 +550,19 @@ def test_union_duplicates(index, request):
534550
values = index.unique().values.tolist()
535551
mi1 = MultiIndex.from_arrays([values, [1] * len(values)])
536552
mi2 = MultiIndex.from_arrays([[values[0]] + values, [1] * (len(values) + 1)])
537-
result = mi1.union(mi2)
553+
result = mi2.union(mi1)
538554
expected = mi2.sort_values()
555+
tm.assert_index_equal(result, expected)
556+
539557
if mi2.levels[0].dtype == np.uint64 and (mi2.get_level_values(0) < 2**63).all():
540558
# GH#47294 - union uses lib.fast_zip, converting data to Python integers
541559
# and loses type information. Result is then unsigned only when values are
542-
# sufficiently large to require unsigned dtype.
560+
# sufficiently large to require unsigned dtype. This happens only if other
561+
# has dups or one of both have missing values
543562
expected = expected.set_levels(
544563
[expected.levels[0].astype(int), expected.levels[1]]
545564
)
546-
tm.assert_index_equal(result, expected)
547-
548-
result = mi2.union(mi1)
565+
result = mi1.union(mi2)
549566
tm.assert_index_equal(result, expected)
550567

551568

0 commit comments

Comments
 (0)