Skip to content

Commit f81f0df

Browse files
Terji PetersenTerji Petersen
Terji Petersen
authored and
Terji Petersen
committed
IntervalIndex
1 parent 255b1a7 commit f81f0df

File tree

2 files changed

+64
-25
lines changed

2 files changed

+64
-25
lines changed

pandas/core/indexes/interval.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@
5959
is_number,
6060
is_object_dtype,
6161
is_scalar,
62+
is_signed_integer_dtype,
63+
is_unsigned_integer_dtype,
6264
)
6365
from pandas.core.dtypes.dtypes import IntervalDtype
6466
from pandas.core.dtypes.missing import is_valid_na_for_dtype
@@ -340,8 +342,8 @@ def from_tuples(
340342
# "Union[IndexEngine, ExtensionEngine]" in supertype "Index"
341343
@cache_readonly
342344
def _engine(self) -> IntervalTree: # type: ignore[override]
343-
left = self._maybe_convert_i8(self.left)
344-
right = self._maybe_convert_i8(self.right)
345+
left = self._maybe_convert_to_64bit_if_numeric(self.left)
346+
right = self._maybe_convert_to_64bit_if_numeric(self.right)
345347
return IntervalTree(left, right, closed=self.closed)
346348

347349
def __contains__(self, key: Any) -> bool:
@@ -501,6 +503,18 @@ def _needs_i8_conversion(self, key) -> bool:
501503
i8_types = (Timestamp, Timedelta, DatetimeIndex, TimedeltaIndex)
502504
return isinstance(key, i8_types)
503505

506+
def _maybe_convert_to_64bit_if_numeric(self, key):
507+
key = self._maybe_convert_i8(key)
508+
dtype = key.dtype
509+
if is_signed_integer_dtype(dtype) and dtype != "int64":
510+
return key.astype(np.int64)
511+
elif is_unsigned_integer_dtype(dtype) and dtype != "uint64":
512+
return key.astype(np.uint64)
513+
elif is_float_dtype(dtype) and dtype != "float64":
514+
return key.astype(np.float64)
515+
else:
516+
return key
517+
504518
def _maybe_convert_i8(self, key):
505519
"""
506520
Maybe convert a given key to its equivalent i8 value(s). Used as a

pandas/tests/indexes/interval/test_constructors.py

+48-23
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@
1818
timedelta_range,
1919
)
2020
import pandas._testing as tm
21-
from pandas.api.types import is_unsigned_integer_dtype
22-
from pandas.core.api import (
23-
Float64Index,
24-
Int64Index,
25-
UInt64Index,
21+
from pandas.api.types import (
22+
is_float_dtype,
23+
is_signed_integer_dtype,
24+
is_unsigned_integer_dtype,
2625
)
26+
from pandas.core.api import NumericIndex
2727
from pandas.core.arrays import IntervalArray
2828
import pandas.core.common as com
2929

@@ -50,9 +50,17 @@ def _skip_test_constructor(self, dtype):
5050
[
5151
[3, 14, 15, 92, 653],
5252
np.arange(10, dtype="int64"),
53-
Int64Index(range(-10, 11)),
54-
UInt64Index(range(10, 31)),
55-
Float64Index(np.arange(20, 30, 0.5)),
53+
NumericIndex(range(-10, 11), dtype=np.int64),
54+
NumericIndex(range(-10, 11), dtype=np.int32),
55+
NumericIndex(range(-10, 11), dtype=np.int16),
56+
NumericIndex(range(-10, 11), dtype=np.int8),
57+
NumericIndex(range(10, 31), dtype=np.uint64),
58+
NumericIndex(range(10, 31), dtype=np.uint32),
59+
NumericIndex(range(10, 31), dtype=np.uint16),
60+
NumericIndex(range(10, 31), dtype=np.uint8),
61+
NumericIndex(np.arange(20, 30, 0.5), dtype=np.float64),
62+
NumericIndex(np.arange(20, 30, 0.5), dtype=np.float32),
63+
NumericIndex(np.arange(20, 30, 0.5), dtype=np.float16),
5664
date_range("20180101", periods=10),
5765
date_range("20180101", periods=10, tz="US/Eastern"),
5866
timedelta_range("1 day", periods=10),
@@ -81,10 +89,10 @@ def test_constructor(self, constructor, breaks, closed, name, use_dtype):
8189
@pytest.mark.parametrize(
8290
"breaks, subtype",
8391
[
84-
(Int64Index([0, 1, 2, 3, 4]), "float64"),
85-
(Int64Index([0, 1, 2, 3, 4]), "datetime64[ns]"),
86-
(Int64Index([0, 1, 2, 3, 4]), "timedelta64[ns]"),
87-
(Float64Index([0, 1, 2, 3, 4]), "int64"),
92+
(Index([0, 1, 2, 3, 4]), "float64"),
93+
(Index([0, 1, 2, 3, 4]), "datetime64[ns]"),
94+
(Index([0, 1, 2, 3, 4]), "timedelta64[ns]"),
95+
(Index([0, 1, 2, 3, 4], dtype=np.float64), "int64"),
8896
(date_range("2017-01-01", periods=5), "int64"),
8997
(timedelta_range("1 day", periods=5), "int64"),
9098
],
@@ -103,9 +111,18 @@ def test_constructor_dtype(self, constructor, breaks, subtype):
103111
@pytest.mark.parametrize(
104112
"breaks",
105113
[
106-
Int64Index([0, 1, 2, 3, 4]),
107-
UInt64Index([0, 1, 2, 3, 4]),
108-
Float64Index([0, 1, 2, 3, 4]),
114+
NumericIndex([0, 1, 2, 3, 4], dtype=np.int64),
115+
NumericIndex([0, 1, 2, 3, 4], dtype=np.int32),
116+
NumericIndex([0, 1, 2, 3, 4], dtype=np.int32),
117+
NumericIndex([0, 1, 2, 3, 4], dtype=np.int16),
118+
NumericIndex([0, 1, 2, 3, 4], dtype=np.int8),
119+
NumericIndex([0, 1, 2, 3, 4], dtype=np.uint64),
120+
NumericIndex([0, 1, 2, 3, 4], dtype=np.uint32),
121+
NumericIndex([0, 1, 2, 3, 4], dtype=np.uint16),
122+
NumericIndex([0, 1, 2, 3, 4], dtype=np.uint8),
123+
NumericIndex([0, 1, 2, 3, 4], dtype=np.float64),
124+
NumericIndex([0, 1, 2, 3, 4], dtype=np.float32),
125+
NumericIndex([0, 1, 2, 3, 4], dtype=np.float16),
109126
date_range("2017-01-01", periods=5),
110127
timedelta_range("1 day", periods=5),
111128
],
@@ -262,8 +279,8 @@ def test_mixed_float_int(self, left_subtype, right_subtype):
262279
right = np.arange(1, 10, dtype=right_subtype)
263280
result = IntervalIndex.from_arrays(left, right)
264281

265-
expected_left = Float64Index(left)
266-
expected_right = Float64Index(right)
282+
expected_left = Index(left, dtype=np.float64)
283+
expected_right = Index(right, dtype=np.float64)
267284
expected_subtype = np.float64
268285

269286
tm.assert_index_equal(result.left, expected_left)
@@ -313,8 +330,13 @@ class TestFromTuples(ConstructorTests):
313330
"""Tests specific to IntervalIndex.from_tuples"""
314331

315332
def _skip_test_constructor(self, dtype):
316-
if is_unsigned_integer_dtype(dtype):
317-
return True, "tuples don't have a dtype"
333+
msg = f"tuples don't have a dtype, so constructor won't see dtype {dtype}"
334+
if is_signed_integer_dtype(dtype) and dtype != "int64":
335+
return True, msg
336+
elif is_float_dtype(dtype) and dtype != "float64":
337+
return True, msg
338+
elif is_unsigned_integer_dtype(dtype):
339+
return True, msg
318340
else:
319341
return False, ""
320342

@@ -366,10 +388,13 @@ class TestClassConstructors(ConstructorTests):
366388
"""Tests specific to the IntervalIndex/Index constructors"""
367389

368390
def _skip_test_constructor(self, dtype):
369-
# get_kwargs_from_breaks in TestFromTuples and TestClassconstructors just return
370-
# tuples of ints, so IntervalIndex can't know the original dtype
371-
if is_unsigned_integer_dtype(dtype):
372-
return True, "tuples don't have a dtype"
391+
msg = f"tuples don't have a dtype, so constructor won't see dtype {dtype}"
392+
if is_signed_integer_dtype(dtype) and dtype != "int64":
393+
return True, msg
394+
elif is_float_dtype(dtype) and dtype != "float64":
395+
return True, msg
396+
elif is_unsigned_integer_dtype(dtype):
397+
return True, msg
373398
else:
374399
return False, ""
375400

0 commit comments

Comments
 (0)