Skip to content

Commit dc4fddc

Browse files
authored
REF: Simplify Index.union (#41773)
1 parent a8b313f commit dc4fddc

File tree

2 files changed

+24
-31
lines changed

2 files changed

+24
-31
lines changed

pandas/core/indexes/base.py

+14-15
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,6 @@
7777
is_float_dtype,
7878
is_hashable,
7979
is_integer,
80-
is_integer_dtype,
8180
is_interval_dtype,
8281
is_iterator,
8382
is_list_like,
@@ -2963,20 +2962,7 @@ def union(self, other, sort=None):
29632962
stacklevel=2,
29642963
)
29652964

2966-
dtype = find_common_type([self.dtype, other.dtype])
2967-
if self._is_numeric_dtype and other._is_numeric_dtype:
2968-
# Right now, we treat union(int, float) a bit special.
2969-
# See https://github.com/pandas-dev/pandas/issues/26778 for discussion
2970-
# We may change union(int, float) to go to object.
2971-
# float | [u]int -> float (the special case)
2972-
# <T> | <T> -> T
2973-
# <T> | <U> -> object
2974-
if not (is_integer_dtype(self.dtype) and is_integer_dtype(other.dtype)):
2975-
dtype = np.dtype("float64")
2976-
else:
2977-
# one is int64 other is uint64
2978-
dtype = np.dtype("object")
2979-
2965+
dtype = self._find_common_type_compat(other)
29802966
left = self.astype(dtype, copy=False)
29812967
right = other.astype(dtype, copy=False)
29822968
return left.union(right, sort=sort)
@@ -5410,6 +5396,19 @@ def _find_common_type_compat(self, target) -> DtypeObj:
54105396
return IntervalDtype(np.float64, closed=self.closed)
54115397

54125398
target_dtype, _ = infer_dtype_from(target, pandas_dtype=True)
5399+
5400+
# special case: if one dtype is uint64 and the other a signed int, return object
5401+
# See https://github.com/pandas-dev/pandas/issues/26778 for discussion
5402+
# Now it's:
5403+
# * float | [u]int -> float
5404+
# * uint64 | signed int -> object
5405+
# We may change union(float | [u]int) to go to object.
5406+
if self.dtype == "uint64" or target_dtype == "uint64":
5407+
if is_signed_integer_dtype(self.dtype) or is_signed_integer_dtype(
5408+
target_dtype
5409+
):
5410+
return np.dtype("object")
5411+
54135412
dtype = find_common_type([self.dtype, target_dtype])
54145413
if dtype.kind in ["i", "u"]:
54155414
# TODO: what about reversed with self being categorical?

pandas/tests/indexes/test_setops.py

+10-16
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import numpy as np
99
import pytest
1010

11-
from pandas.core.dtypes.common import is_dtype_equal
11+
from pandas.core.dtypes.cast import find_common_type
1212

1313
from pandas import (
1414
CategoricalIndex,
@@ -25,6 +25,7 @@
2525
import pandas._testing as tm
2626
from pandas.api.types import (
2727
is_datetime64tz_dtype,
28+
is_signed_integer_dtype,
2829
pandas_dtype,
2930
)
3031

@@ -48,7 +49,11 @@ def test_union_different_types(index_flat, index_flat2):
4849
idx1 = index_flat
4950
idx2 = index_flat2
5051

51-
type_pair = tuple(sorted([idx1.dtype.type, idx2.dtype.type], key=lambda x: str(x)))
52+
common_dtype = find_common_type([idx1.dtype, idx2.dtype])
53+
54+
any_uint64 = idx1.dtype == np.uint64 or idx2.dtype == np.uint64
55+
idx1_signed = is_signed_integer_dtype(idx1.dtype)
56+
idx2_signed = is_signed_integer_dtype(idx2.dtype)
5257

5358
# Union with a non-unique, non-monotonic index raises error
5459
# This applies to the boolean index
@@ -58,23 +63,12 @@ def test_union_different_types(index_flat, index_flat2):
5863
res1 = idx1.union(idx2)
5964
res2 = idx2.union(idx1)
6065

61-
if is_dtype_equal(idx1.dtype, idx2.dtype):
62-
assert res1.dtype == idx1.dtype
63-
assert res2.dtype == idx1.dtype
64-
65-
elif type_pair not in COMPATIBLE_INCONSISTENT_PAIRS:
66-
# A union with a CategoricalIndex (even as dtype('O')) and a
67-
# non-CategoricalIndex can only be made if both indices are monotonic.
68-
# This is true before this PR as well.
66+
if any_uint64 and (idx1_signed or idx2_signed):
6967
assert res1.dtype == np.dtype("O")
7068
assert res2.dtype == np.dtype("O")
71-
72-
elif idx1.dtype.kind in ["f", "i", "u"] and idx2.dtype.kind in ["f", "i", "u"]:
73-
assert res1.dtype == np.dtype("f8")
74-
assert res2.dtype == np.dtype("f8")
75-
7669
else:
77-
raise NotImplementedError
70+
assert res1.dtype == common_dtype
71+
assert res2.dtype == common_dtype
7872

7973

8074
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)