From 88923adfe6e78536e477d8019da7e4620ad7c954 Mon Sep 17 00:00:00 2001 From: tp Date: Sun, 18 Dec 2022 20:48:03 +0000 Subject: [PATCH 1/7] API:move use of maybe_convert_numeric_to_64bit to to also be used in IntervalIndex._engine --- pandas/core/indexes/interval.py | 41 +++++++++++-------- .../tests/indexes/interval/test_interval.py | 12 ++++-- 2 files changed, 33 insertions(+), 20 deletions(-) diff --git a/pandas/core/indexes/interval.py b/pandas/core/indexes/interval.py index ae0b7dc5116cd..d7e39a2cf6bf0 100644 --- a/pandas/core/indexes/interval.py +++ b/pandas/core/indexes/interval.py @@ -31,6 +31,7 @@ DtypeObj, IntervalClosedType, npt, + NumpyIndexT, ) from pandas.errors import InvalidIndexError from pandas.util._decorators import ( @@ -47,6 +48,7 @@ ) from pandas.core.dtypes.common import ( ensure_platform_int, + is_array_like, is_datetime64tz_dtype, is_datetime_or_timedelta_dtype, is_dtype_equal, @@ -146,6 +148,24 @@ def _new_IntervalIndex(cls, d): return cls.from_arrays(**d) +def maybe_convert_numeric_to_64bit(arr: NumpyIndexT) -> NumpyIndexT: + # IntervalTree only supports 64 bit numpy array + + if not is_array_like(arr): + return arr + dtype = arr.dtype + if not np.issubclass_(dtype.type, np.number): + return arr + elif is_signed_integer_dtype(dtype) and dtype != np.int64: + return arr.astype(np.int64) + elif is_unsigned_integer_dtype(dtype) and dtype != np.uint64: + return arr.astype(np.uint64) + elif is_float_dtype(dtype) and dtype != np.float64: + return arr.astype(np.float64) + else: + return arr + + @Appender( _interval_shared_docs["class"] % { @@ -343,7 +363,9 @@ def from_tuples( @cache_readonly def _engine(self) -> IntervalTree: # type: ignore[override] left = self._maybe_convert_i8(self.left) + left = maybe_convert_numeric_to_64bit(left) right = self._maybe_convert_i8(self.right) + right = maybe_convert_numeric_to_64bit(right) return IntervalTree(left, right, closed=self.closed) def __contains__(self, key: Any) -> bool: @@ -520,13 +542,12 @@ def _maybe_convert_i8(self, key): The original key if no conversion occurred, int if converted scalar, Int64Index if converted list-like. """ - original = key if is_list_like(key): key = ensure_index(key) - key = self._maybe_convert_numeric_to_64bit(key) + key = maybe_convert_numeric_to_64bit(key) if not self._needs_i8_conversion(key): - return original + return key scalar = is_scalar(key) if is_interval_dtype(key) or isinstance(key, Interval): @@ -569,20 +590,6 @@ def _maybe_convert_i8(self, key): return key_i8 - def _maybe_convert_numeric_to_64bit(self, idx: Index) -> Index: - # IntervalTree only supports 64 bit numpy array - dtype = idx.dtype - if np.issubclass_(dtype.type, np.number): - return idx - elif is_signed_integer_dtype(dtype) and dtype != np.int64: - return idx.astype(np.int64) - elif is_unsigned_integer_dtype(dtype) and dtype != np.uint64: - return idx.astype(np.uint64) - elif is_float_dtype(dtype) and dtype != np.float64: - return idx.astype(np.float64) - else: - return idx - def _searchsorted_monotonic(self, label, side: Literal["left", "right"] = "left"): if not self.is_non_overlapping_monotonic: raise KeyError( diff --git a/pandas/tests/indexes/interval/test_interval.py b/pandas/tests/indexes/interval/test_interval.py index 98c21fad1f8c2..1c82fe1b788ab 100644 --- a/pandas/tests/indexes/interval/test_interval.py +++ b/pandas/tests/indexes/interval/test_interval.py @@ -18,7 +18,10 @@ timedelta_range, ) import pandas._testing as tm -from pandas.core.api import Float64Index +from pandas.core.api import ( + Float64Index, + NumericIndex, +) import pandas.core.common as com @@ -435,9 +438,12 @@ def test_maybe_convert_i8_numeric(self, breaks, make_key): index = IntervalIndex.from_breaks(breaks) key = make_key(breaks) - # no conversion occurs for numeric result = index._maybe_convert_i8(key) - assert result is key + if not isinstance(result, NumericIndex): + assert result is key + else: + expected = NumericIndex(key) + tm.assert_index_equal(result, expected) @pytest.mark.parametrize( "breaks1, breaks2", From 471362b7e775edb45228152dfa165892678742e3 Mon Sep 17 00:00:00 2001 From: tp Date: Mon, 19 Dec 2022 08:47:31 +0000 Subject: [PATCH 2/7] move maybe_upcast_numeric_to_64bit to core.dtypes.cast --- pandas/core/arrays/interval.py | 8 ++++++-- pandas/core/dtypes/cast.py | 31 +++++++++++++++++++++++++++++++ pandas/core/indexes/interval.py | 30 +++++------------------------- 3 files changed, 42 insertions(+), 27 deletions(-) diff --git a/pandas/core/arrays/interval.py b/pandas/core/arrays/interval.py index 2672e964736f0..805aff64420e6 100644 --- a/pandas/core/arrays/interval.py +++ b/pandas/core/arrays/interval.py @@ -46,7 +46,10 @@ from pandas.errors import IntCastingNaNError from pandas.util._decorators import Appender -from pandas.core.dtypes.cast import LossySetitemError +from pandas.core.dtypes.cast import ( + LossySetitemError, + maybe_upcast_numeric_to_64bit, +) from pandas.core.dtypes.common import ( is_categorical_dtype, is_dtype_equal, @@ -1787,5 +1790,6 @@ def _maybe_convert_platform_interval(values) -> ArrayLike: values = extract_array(values, extract_numpy=True) if not hasattr(values, "dtype"): - return np.asarray(values) + values = np.asarray(values) + values = maybe_upcast_numeric_to_64bit(values) return values diff --git a/pandas/core/dtypes/cast.py b/pandas/core/dtypes/cast.py index f3ce104aa4a3e..e7b31beaef779 100644 --- a/pandas/core/dtypes/cast.py +++ b/pandas/core/dtypes/cast.py @@ -35,6 +35,7 @@ ArrayLike, Dtype, DtypeObj, + NumpyIndexT, Scalar, npt, ) @@ -52,6 +53,7 @@ ensure_int64, ensure_object, ensure_str, + is_array_like, is_bool, is_bool_dtype, is_complex, @@ -65,6 +67,7 @@ is_numeric_dtype, is_object_dtype, is_scalar, + is_signed_integer_dtype, is_string_dtype, is_timedelta64_dtype, is_unsigned_integer_dtype, @@ -412,6 +415,34 @@ def trans(x): return result +def maybe_upcast_numeric_to_64bit(arr: NumpyIndexT) -> NumpyIndexT: + """ + If array is a int/uint/float bit size lower than 64 bit, upcast it to 64 bit. + + Parameters + ---------- + arr : ndarray or ExtensionArray + + Returns + ------- + ndarray or ExtensionArray + """ + + if not is_array_like(arr): + return arr + dtype = arr.dtype + if not np.issubclass_(dtype.type, np.number): + return arr + elif is_signed_integer_dtype(dtype) and dtype != np.int64: + return arr.astype(np.int64) + elif is_unsigned_integer_dtype(dtype) and dtype != np.uint64: + return arr.astype(np.uint64) + elif is_float_dtype(dtype) and dtype != np.float64: + return arr.astype(np.float64) + else: + return arr + + def maybe_cast_pointwise_result( result: ArrayLike, dtype: DtypeObj, diff --git a/pandas/core/indexes/interval.py b/pandas/core/indexes/interval.py index d7e39a2cf6bf0..c6bd7b8aae980 100644 --- a/pandas/core/indexes/interval.py +++ b/pandas/core/indexes/interval.py @@ -31,7 +31,6 @@ DtypeObj, IntervalClosedType, npt, - NumpyIndexT, ) from pandas.errors import InvalidIndexError from pandas.util._decorators import ( @@ -45,10 +44,10 @@ infer_dtype_from_scalar, maybe_box_datetimelike, maybe_downcast_numeric, + maybe_upcast_numeric_to_64bit, ) from pandas.core.dtypes.common import ( ensure_platform_int, - is_array_like, is_datetime64tz_dtype, is_datetime_or_timedelta_dtype, is_dtype_equal, @@ -61,8 +60,6 @@ is_number, is_object_dtype, is_scalar, - is_signed_integer_dtype, - is_unsigned_integer_dtype, ) from pandas.core.dtypes.dtypes import IntervalDtype from pandas.core.dtypes.missing import is_valid_na_for_dtype @@ -148,24 +145,6 @@ def _new_IntervalIndex(cls, d): return cls.from_arrays(**d) -def maybe_convert_numeric_to_64bit(arr: NumpyIndexT) -> NumpyIndexT: - # IntervalTree only supports 64 bit numpy array - - if not is_array_like(arr): - return arr - dtype = arr.dtype - if not np.issubclass_(dtype.type, np.number): - return arr - elif is_signed_integer_dtype(dtype) and dtype != np.int64: - return arr.astype(np.int64) - elif is_unsigned_integer_dtype(dtype) and dtype != np.uint64: - return arr.astype(np.uint64) - elif is_float_dtype(dtype) and dtype != np.float64: - return arr.astype(np.float64) - else: - return arr - - @Appender( _interval_shared_docs["class"] % { @@ -362,10 +341,11 @@ def from_tuples( # "Union[IndexEngine, ExtensionEngine]" in supertype "Index" @cache_readonly def _engine(self) -> IntervalTree: # type: ignore[override] + # IntervalTree does not supports numpy array unless they are 64 bit left = self._maybe_convert_i8(self.left) - left = maybe_convert_numeric_to_64bit(left) + left = maybe_upcast_numeric_to_64bit(left) right = self._maybe_convert_i8(self.right) - right = maybe_convert_numeric_to_64bit(right) + right = maybe_upcast_numeric_to_64bit(right) return IntervalTree(left, right, closed=self.closed) def __contains__(self, key: Any) -> bool: @@ -544,7 +524,7 @@ def _maybe_convert_i8(self, key): """ if is_list_like(key): key = ensure_index(key) - key = maybe_convert_numeric_to_64bit(key) + key = maybe_upcast_numeric_to_64bit(key) if not self._needs_i8_conversion(key): return key From 8cc670a4261d829b6578592b977ab7c905d0abcf Mon Sep 17 00:00:00 2001 From: tp Date: Mon, 19 Dec 2022 09:34:15 +0000 Subject: [PATCH 3/7] update --- pandas/core/arrays/interval.py | 8 +++----- pandas/core/dtypes/cast.py | 4 ---- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/pandas/core/arrays/interval.py b/pandas/core/arrays/interval.py index 805aff64420e6..f47ab83b6bde1 100644 --- a/pandas/core/arrays/interval.py +++ b/pandas/core/arrays/interval.py @@ -46,10 +46,7 @@ from pandas.errors import IntCastingNaNError from pandas.util._decorators import Appender -from pandas.core.dtypes.cast import ( - LossySetitemError, - maybe_upcast_numeric_to_64bit, -) +from pandas.core.dtypes.cast import LossySetitemError from pandas.core.dtypes.common import ( is_categorical_dtype, is_dtype_equal, @@ -1791,5 +1788,6 @@ def _maybe_convert_platform_interval(values) -> ArrayLike: if not hasattr(values, "dtype"): values = np.asarray(values) - values = maybe_upcast_numeric_to_64bit(values) + if is_integer_dtype(values) and values.dtype != np.int64: + values = values.astype(np.int64) return values diff --git a/pandas/core/dtypes/cast.py b/pandas/core/dtypes/cast.py index e7b31beaef779..7c0c05957f2ee 100644 --- a/pandas/core/dtypes/cast.py +++ b/pandas/core/dtypes/cast.py @@ -53,7 +53,6 @@ ensure_int64, ensure_object, ensure_str, - is_array_like, is_bool, is_bool_dtype, is_complex, @@ -427,9 +426,6 @@ def maybe_upcast_numeric_to_64bit(arr: NumpyIndexT) -> NumpyIndexT: ------- ndarray or ExtensionArray """ - - if not is_array_like(arr): - return arr dtype = arr.dtype if not np.issubclass_(dtype.type, np.number): return arr From cc4859a813de3c53a375d5c28d9d9d24ca53527e Mon Sep 17 00:00:00 2001 From: tp Date: Mon, 19 Dec 2022 23:08:36 +0000 Subject: [PATCH 4/7] fix test_maybe_convert_i8_numeric --- .../tests/indexes/interval/test_interval.py | 33 +++++++++---------- 1 file changed, 15 insertions(+), 18 deletions(-) diff --git a/pandas/tests/indexes/interval/test_interval.py b/pandas/tests/indexes/interval/test_interval.py index 1c82fe1b788ab..f6f54d5577031 100644 --- a/pandas/tests/indexes/interval/test_interval.py +++ b/pandas/tests/indexes/interval/test_interval.py @@ -18,10 +18,7 @@ timedelta_range, ) import pandas._testing as tm -from pandas.core.api import ( - Float64Index, - NumericIndex, -) +from pandas.core.api import Float64Index import pandas.core.common as com @@ -417,33 +414,33 @@ def test_maybe_convert_i8_nat(self, breaks): result = index._maybe_convert_i8(to_convert) tm.assert_index_equal(result, expected) - @pytest.mark.parametrize( - "breaks", - [np.arange(5, dtype="int64"), np.arange(5, dtype="float64")], - ids=lambda x: str(x.dtype), - ) + def test_maybe_convert_i8_numeric(self, any_real_numpy_dtype): + # GH 20636 + breaks = np.arange(5, dtype=any_real_numpy_dtype) + index = IntervalIndex.from_breaks(breaks) + + result = index._maybe_convert_i8(breaks) + expected = Index(breaks) + tm.assert_index_equal(result, expected) + @pytest.mark.parametrize( "make_key", [ IntervalIndex.from_breaks, lambda breaks: Interval(breaks[0], breaks[1]), - lambda breaks: breaks, lambda breaks: breaks[0], - list, ], - ids=["IntervalIndex", "Interval", "Index", "scalar", "list"], + ids=["IntervalIndex", "Interval", "scalar"], ) - def test_maybe_convert_i8_numeric(self, breaks, make_key): + def test_maybe_convert_i8_numeric_identical(self, make_key, any_real_numpy_dtype): # GH 20636 + breaks = np.arange(5, dtype=any_real_numpy_dtype) index = IntervalIndex.from_breaks(breaks) key = make_key(breaks) + # test if _maybe_convert_i8 won't change key if an Interval or IntervalIndex result = index._maybe_convert_i8(key) - if not isinstance(result, NumericIndex): - assert result is key - else: - expected = NumericIndex(key) - tm.assert_index_equal(result, expected) + assert result is key @pytest.mark.parametrize( "breaks1, breaks2", From 2692932cd996f7aaac446899aefb877446a0bcd2 Mon Sep 17 00:00:00 2001 From: tp Date: Mon, 19 Dec 2022 23:20:44 +0000 Subject: [PATCH 5/7] fix test_maybe_convert_i8_numeric II --- pandas/tests/indexes/interval/test_interval.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/pandas/tests/indexes/interval/test_interval.py b/pandas/tests/indexes/interval/test_interval.py index f6f54d5577031..776f465c6d1b4 100644 --- a/pandas/tests/indexes/interval/test_interval.py +++ b/pandas/tests/indexes/interval/test_interval.py @@ -414,13 +414,18 @@ def test_maybe_convert_i8_nat(self, breaks): result = index._maybe_convert_i8(to_convert) tm.assert_index_equal(result, expected) - def test_maybe_convert_i8_numeric(self, any_real_numpy_dtype): + @pytest.mark.parametrize( + "make_key", [lambda breaks: breaks, list], + ids=["lambda", "list"], + ) + def test_maybe_convert_i8_numeric(self, make_key, any_real_numpy_dtype): # GH 20636 breaks = np.arange(5, dtype=any_real_numpy_dtype) index = IntervalIndex.from_breaks(breaks) + key = make_key(breaks) - result = index._maybe_convert_i8(breaks) - expected = Index(breaks) + result = index._maybe_convert_i8(key) + expected = Index(key) tm.assert_index_equal(result, expected) @pytest.mark.parametrize( From d26cb7a49f2f11c88fbf85e3b16c6e4c68f8680e Mon Sep 17 00:00:00 2001 From: tp Date: Mon, 19 Dec 2022 23:35:01 +0000 Subject: [PATCH 6/7] fix precommit --- pandas/tests/indexes/interval/test_interval.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pandas/tests/indexes/interval/test_interval.py b/pandas/tests/indexes/interval/test_interval.py index 776f465c6d1b4..aded8a9b59ab8 100644 --- a/pandas/tests/indexes/interval/test_interval.py +++ b/pandas/tests/indexes/interval/test_interval.py @@ -415,7 +415,8 @@ def test_maybe_convert_i8_nat(self, breaks): tm.assert_index_equal(result, expected) @pytest.mark.parametrize( - "make_key", [lambda breaks: breaks, list], + "make_key", + [lambda breaks: breaks, list], ids=["lambda", "list"], ) def test_maybe_convert_i8_numeric(self, make_key, any_real_numpy_dtype): From 880d51f3f9cfd3643c873ff76fc0c98b659a2753 Mon Sep 17 00:00:00 2001 From: tp Date: Tue, 10 Jan 2023 00:51:45 +0000 Subject: [PATCH 7/7] don't short-circuit --- pandas/core/dtypes/cast.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pandas/core/dtypes/cast.py b/pandas/core/dtypes/cast.py index 7c0c05957f2ee..b3f2426256ccf 100644 --- a/pandas/core/dtypes/cast.py +++ b/pandas/core/dtypes/cast.py @@ -427,9 +427,7 @@ def maybe_upcast_numeric_to_64bit(arr: NumpyIndexT) -> NumpyIndexT: ndarray or ExtensionArray """ dtype = arr.dtype - if not np.issubclass_(dtype.type, np.number): - return arr - elif is_signed_integer_dtype(dtype) and dtype != np.int64: + if is_signed_integer_dtype(dtype) and dtype != np.int64: return arr.astype(np.int64) elif is_unsigned_integer_dtype(dtype) and dtype != np.uint64: return arr.astype(np.uint64)