diff --git a/pandas/core/arrays/interval.py b/pandas/core/arrays/interval.py index 2672e964736f0..f47ab83b6bde1 100644 --- a/pandas/core/arrays/interval.py +++ b/pandas/core/arrays/interval.py @@ -1787,5 +1787,7 @@ def _maybe_convert_platform_interval(values) -> ArrayLike: values = extract_array(values, extract_numpy=True) if not hasattr(values, "dtype"): - return np.asarray(values) + values = np.asarray(values) + if is_integer_dtype(values) and values.dtype != np.int64: + values = values.astype(np.int64) return values diff --git a/pandas/core/dtypes/cast.py b/pandas/core/dtypes/cast.py index f3ce104aa4a3e..b3f2426256ccf 100644 --- a/pandas/core/dtypes/cast.py +++ b/pandas/core/dtypes/cast.py @@ -35,6 +35,7 @@ ArrayLike, Dtype, DtypeObj, + NumpyIndexT, Scalar, npt, ) @@ -65,6 +66,7 @@ is_numeric_dtype, is_object_dtype, is_scalar, + is_signed_integer_dtype, is_string_dtype, is_timedelta64_dtype, is_unsigned_integer_dtype, @@ -412,6 +414,29 @@ def trans(x): return result +def maybe_upcast_numeric_to_64bit(arr: NumpyIndexT) -> NumpyIndexT: + """ + If array is a int/uint/float bit size lower than 64 bit, upcast it to 64 bit. + + Parameters + ---------- + arr : ndarray or ExtensionArray + + Returns + ------- + ndarray or ExtensionArray + """ + dtype = arr.dtype + if is_signed_integer_dtype(dtype) and dtype != np.int64: + return arr.astype(np.int64) + elif is_unsigned_integer_dtype(dtype) and dtype != np.uint64: + return arr.astype(np.uint64) + elif is_float_dtype(dtype) and dtype != np.float64: + return arr.astype(np.float64) + else: + return arr + + def maybe_cast_pointwise_result( result: ArrayLike, dtype: DtypeObj, diff --git a/pandas/core/indexes/interval.py b/pandas/core/indexes/interval.py index ae0b7dc5116cd..c6bd7b8aae980 100644 --- a/pandas/core/indexes/interval.py +++ b/pandas/core/indexes/interval.py @@ -44,6 +44,7 @@ infer_dtype_from_scalar, maybe_box_datetimelike, maybe_downcast_numeric, + maybe_upcast_numeric_to_64bit, ) from pandas.core.dtypes.common import ( ensure_platform_int, @@ -59,8 +60,6 @@ is_number, is_object_dtype, is_scalar, - is_signed_integer_dtype, - is_unsigned_integer_dtype, ) from pandas.core.dtypes.dtypes import IntervalDtype from pandas.core.dtypes.missing import is_valid_na_for_dtype @@ -342,8 +341,11 @@ def from_tuples( # "Union[IndexEngine, ExtensionEngine]" in supertype "Index" @cache_readonly def _engine(self) -> IntervalTree: # type: ignore[override] + # IntervalTree does not supports numpy array unless they are 64 bit left = self._maybe_convert_i8(self.left) + left = maybe_upcast_numeric_to_64bit(left) right = self._maybe_convert_i8(self.right) + right = maybe_upcast_numeric_to_64bit(right) return IntervalTree(left, right, closed=self.closed) def __contains__(self, key: Any) -> bool: @@ -520,13 +522,12 @@ def _maybe_convert_i8(self, key): The original key if no conversion occurred, int if converted scalar, Int64Index if converted list-like. """ - original = key if is_list_like(key): key = ensure_index(key) - key = self._maybe_convert_numeric_to_64bit(key) + key = maybe_upcast_numeric_to_64bit(key) if not self._needs_i8_conversion(key): - return original + return key scalar = is_scalar(key) if is_interval_dtype(key) or isinstance(key, Interval): @@ -569,20 +570,6 @@ def _maybe_convert_i8(self, key): return key_i8 - def _maybe_convert_numeric_to_64bit(self, idx: Index) -> Index: - # IntervalTree only supports 64 bit numpy array - dtype = idx.dtype - if np.issubclass_(dtype.type, np.number): - return idx - elif is_signed_integer_dtype(dtype) and dtype != np.int64: - return idx.astype(np.int64) - elif is_unsigned_integer_dtype(dtype) and dtype != np.uint64: - return idx.astype(np.uint64) - elif is_float_dtype(dtype) and dtype != np.float64: - return idx.astype(np.float64) - else: - return idx - def _searchsorted_monotonic(self, label, side: Literal["left", "right"] = "left"): if not self.is_non_overlapping_monotonic: raise KeyError( diff --git a/pandas/tests/indexes/interval/test_interval.py b/pandas/tests/indexes/interval/test_interval.py index 98c21fad1f8c2..aded8a9b59ab8 100644 --- a/pandas/tests/indexes/interval/test_interval.py +++ b/pandas/tests/indexes/interval/test_interval.py @@ -415,27 +415,36 @@ def test_maybe_convert_i8_nat(self, breaks): tm.assert_index_equal(result, expected) @pytest.mark.parametrize( - "breaks", - [np.arange(5, dtype="int64"), np.arange(5, dtype="float64")], - ids=lambda x: str(x.dtype), + "make_key", + [lambda breaks: breaks, list], + ids=["lambda", "list"], ) + def test_maybe_convert_i8_numeric(self, make_key, any_real_numpy_dtype): + # GH 20636 + breaks = np.arange(5, dtype=any_real_numpy_dtype) + index = IntervalIndex.from_breaks(breaks) + key = make_key(breaks) + + result = index._maybe_convert_i8(key) + expected = Index(key) + tm.assert_index_equal(result, expected) + @pytest.mark.parametrize( "make_key", [ IntervalIndex.from_breaks, lambda breaks: Interval(breaks[0], breaks[1]), - lambda breaks: breaks, lambda breaks: breaks[0], - list, ], - ids=["IntervalIndex", "Interval", "Index", "scalar", "list"], + ids=["IntervalIndex", "Interval", "scalar"], ) - def test_maybe_convert_i8_numeric(self, breaks, make_key): + def test_maybe_convert_i8_numeric_identical(self, make_key, any_real_numpy_dtype): # GH 20636 + breaks = np.arange(5, dtype=any_real_numpy_dtype) index = IntervalIndex.from_breaks(breaks) key = make_key(breaks) - # no conversion occurs for numeric + # test if _maybe_convert_i8 won't change key if an Interval or IntervalIndex result = index._maybe_convert_i8(key) assert result is key