Skip to content

Commit 2fd1fe7

Browse files
committed
API/REGR: Convert to float for index union
This restores the 0.24.x behavior of Index.union(other) between Float and (U)Int indexes. These are now floating dtype. left | right | output of left.union(right) ----- | ----- | ------ int |float | float64 int |uint | object float | uint | float64 pandas-dev#26778 (comment) Closes pandas-dev#26778
1 parent 8ea2d08 commit 2fd1fe7

File tree

3 files changed

+123
-5
lines changed

3 files changed

+123
-5
lines changed

pandas/core/indexes/numeric.py

+67-3
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@
77

88
from pandas.core.dtypes.common import (
99
is_bool, is_bool_dtype, is_dtype_equal, is_extension_array_dtype, is_float,
10-
is_integer_dtype, is_scalar, needs_i8_conversion, pandas_dtype)
10+
is_float_dtype, is_integer_dtype, is_scalar, needs_i8_conversion,
11+
pandas_dtype)
1112
import pandas.core.dtypes.concat as _concat
12-
from pandas.core.dtypes.generic import ABCInt64Index, ABCRangeIndex
13+
from pandas.core.dtypes.generic import (
14+
ABCFloat64Index, ABCInt64Index, ABCRangeIndex, ABCUInt64Index)
1315
from pandas.core.dtypes.missing import isna
1416

1517
from pandas.core import algorithms
@@ -123,6 +125,21 @@ def insert(self, loc, item):
123125
item = self._na_value
124126
return super().insert(loc, item)
125127

128+
def _union(self, other, sort):
129+
# float | [u]int -> float
130+
# <T> | <T> -> T
131+
# <T> | <U> -> object
132+
needs_cast = (
133+
(is_integer_dtype(self.dtype) and is_float_dtype(other.dtype)) or
134+
(is_integer_dtype(other.dtype) and is_float_dtype(self.dtype))
135+
)
136+
if needs_cast:
137+
first = self.astype("float")
138+
second = other.astype("float")
139+
return first._union(second, sort)
140+
else:
141+
return super()._union(other, sort)
142+
126143

127144
_num_index_shared_docs['class_descr'] = """
128145
Immutable ndarray implementing an ordered, sliceable set. The basic object
@@ -225,10 +242,24 @@ def _assert_safe_casting(cls, data, subarr):
225242
def _is_compatible_with_other(self, other):
226243
return (
227244
super()._is_compatible_with_other(other)
228-
or all(isinstance(type(obj), (ABCInt64Index, ABCRangeIndex))
245+
or all(isinstance(type(obj), (ABCInt64Index,
246+
ABCFloat64Index,
247+
ABCRangeIndex))
229248
for obj in [self, other])
230249
)
231250

251+
def _union(self, other, sort):
252+
needs_cast = (
253+
(is_integer_dtype(self.dtype) and is_float_dtype(other.dtype)) or
254+
(is_integer_dtype(other.dtype) and is_float_dtype(self.dtype))
255+
)
256+
if needs_cast:
257+
first = self.astype("float")
258+
second = other.astype("float")
259+
return first._union(second, sort)
260+
else:
261+
return super()._union(other, sort)
262+
232263

233264
Int64Index._add_numeric_methods()
234265
Int64Index._add_logical_methods()
@@ -301,6 +332,29 @@ def _assert_safe_casting(cls, data, subarr):
301332
raise TypeError('Unsafe NumPy casting, you must '
302333
'explicitly cast')
303334

335+
def _is_compatible_with_other(self, other):
336+
# not ABCInt64Index
337+
# TODO: dedpulicate with Int64Index.
338+
# TODO: who all needs this? Int, UInt, Float? Range?
339+
return (
340+
super()._is_compatible_with_other(other)
341+
or all(isinstance(type(obj), (ABCUInt64Index,
342+
ABCFloat64Index))
343+
for obj in [self, other])
344+
)
345+
346+
def _union(self, other, sort):
347+
needs_cast = (
348+
(is_integer_dtype(self.dtype) and is_float_dtype(other.dtype)) or
349+
(is_integer_dtype(other.dtype) and is_float_dtype(self.dtype))
350+
)
351+
if needs_cast:
352+
first = self.astype("float")
353+
second = other.astype("float")
354+
return first._union(second, sort)
355+
else:
356+
return super()._union(other, sort)
357+
304358

305359
UInt64Index._add_numeric_methods()
306360
UInt64Index._add_logical_methods()
@@ -447,6 +501,16 @@ def isin(self, values, level=None):
447501
self._validate_index_level(level)
448502
return algorithms.isin(np.array(self), values)
449503

504+
def _is_compatible_with_other(self, other):
505+
return (
506+
super()._is_compatible_with_other(other)
507+
or all(isinstance(type(obj), (ABCInt64Index,
508+
ABCFloat64Index,
509+
ABCUInt64Index,
510+
ABCRangeIndex))
511+
for obj in [self, other])
512+
)
513+
450514

451515
Float64Index._add_numeric_methods()
452516
Float64Index._add_logical_methods_disabled()

pandas/tests/indexes/test_numeric.py

+24
Original file line numberDiff line numberDiff line change
@@ -1118,3 +1118,27 @@ def test_join_outer(self):
11181118
tm.assert_index_equal(res, eres)
11191119
tm.assert_numpy_array_equal(lidx, elidx)
11201120
tm.assert_numpy_array_equal(ridx, eridx)
1121+
1122+
1123+
@pytest.mark.parametrize("dtype", ['int64', 'uint64'])
1124+
def test_int_float_union_dtype(dtype):
1125+
# [u]int | float -> float
1126+
index = pd.Index([0, 2, 3], dtype=dtype)
1127+
other = pd.Float64Index([0.5, 1.5])
1128+
expected = pd.Float64Index([0.0, 0.5, 1.5, 2.0, 3.0])
1129+
# result = index.union(other)
1130+
# tm.assert_index_equal(result, expected)
1131+
1132+
result = other.union(index)
1133+
tm.assert_index_equal(result, expected)
1134+
1135+
1136+
def test_range_float_union_dtype():
1137+
index = pd.RangeIndex(start=0, stop=3)
1138+
other = pd.Float64Index([0.5, 1.5])
1139+
result = index.union(other)
1140+
expected = pd.Float64Index([0.0, 0.5, 1, 1.5, 2.0])
1141+
tm.assert_index_equal(result, expected)
1142+
1143+
result = other.union(index)
1144+
tm.assert_index_equal(result, expected)

pandas/tests/indexes/test_setops.py

+32-2
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,16 @@
1010
from pandas.core.dtypes.common import is_dtype_equal
1111

1212
import pandas as pd
13-
from pandas import Int64Index, RangeIndex
13+
from pandas import Float64Index, Int64Index, RangeIndex, UInt64Index
14+
from pandas.api.types import pandas_dtype
1415
from pandas.tests.indexes.conftest import indices_list
1516
import pandas.util.testing as tm
1617

1718
COMPATIBLE_INCONSISTENT_PAIRS = {
18-
(Int64Index, RangeIndex): (tm.makeIntIndex, tm.makeRangeIndex)
19+
(Int64Index, RangeIndex): (tm.makeIntIndex, tm.makeRangeIndex),
20+
(Float64Index, Int64Index): (tm.makeFloatIndex, tm.makeIntIndex),
21+
(Float64Index, RangeIndex): (tm.makeFloatIndex, tm.makeIntIndex),
22+
(Float64Index, UInt64Index): (tm.makeFloatIndex, tm.makeUIntIndex),
1923
}
2024

2125

@@ -74,3 +78,29 @@ def test_compatible_inconsistent_pairs(idx_fact1, idx_fact2):
7478

7579
assert res1.dtype in (idx1.dtype, idx2.dtype)
7680
assert res2.dtype in (idx1.dtype, idx2.dtype)
81+
82+
83+
@pytest.mark.parametrize('left, right, expected', [
84+
('int64', 'int64', 'int64'),
85+
('int64', 'uint64', 'object'),
86+
('int64', 'float64', 'float64'),
87+
('uint64', 'float64', 'float64'),
88+
('uint64', 'uint64', 'uint64'),
89+
('float64', 'float64', 'float64'),
90+
('datetime64[ns]', 'int64', 'object'),
91+
('datetime64[ns]', 'uint64', 'object'),
92+
('datetime64[ns]', 'float64', 'object'),
93+
('datetime64[ns, CET]', 'int64', 'object'),
94+
('datetime64[ns, CET]', 'uint64', 'object'),
95+
('datetime64[ns, CET]', 'float64', 'object'),
96+
('Period[D]', 'int64', 'object'),
97+
('Period[D]', 'uint64', 'object'),
98+
('Period[D]', 'float64', 'object'),
99+
])
100+
def test_union_dtypes(left, right, expected):
101+
left = pandas_dtype(left)
102+
right = pandas_dtype(right)
103+
a = pd.Index([], dtype=left)
104+
b = pd.Index([], dtype=right)
105+
result = (a | b).dtype
106+
assert result == expected

0 commit comments

Comments
 (0)