Skip to content

API: ensure IntervalIndex.left/right are 64bit if numeric, part II #50195

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jan 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pandas/core/arrays/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -1787,5 +1787,7 @@ def _maybe_convert_platform_interval(values) -> ArrayLike:
values = extract_array(values, extract_numpy=True)

if not hasattr(values, "dtype"):
return np.asarray(values)
values = np.asarray(values)
if is_integer_dtype(values) and values.dtype != np.int64:
values = values.astype(np.int64)
return values
25 changes: 25 additions & 0 deletions pandas/core/dtypes/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
ArrayLike,
Dtype,
DtypeObj,
NumpyIndexT,
Scalar,
npt,
)
Expand Down Expand Up @@ -65,6 +66,7 @@
is_numeric_dtype,
is_object_dtype,
is_scalar,
is_signed_integer_dtype,
is_string_dtype,
is_timedelta64_dtype,
is_unsigned_integer_dtype,
Expand Down Expand Up @@ -412,6 +414,29 @@ def trans(x):
return result


def maybe_upcast_numeric_to_64bit(arr: NumpyIndexT) -> NumpyIndexT:
"""
If array is a int/uint/float bit size lower than 64 bit, upcast it to 64 bit.

Parameters
----------
arr : ndarray or ExtensionArray

Returns
-------
ndarray or ExtensionArray
"""
dtype = arr.dtype
if is_signed_integer_dtype(dtype) and dtype != np.int64:
return arr.astype(np.int64)
elif is_unsigned_integer_dtype(dtype) and dtype != np.uint64:
return arr.astype(np.uint64)
elif is_float_dtype(dtype) and dtype != np.float64:
return arr.astype(np.float64)
else:
return arr


def maybe_cast_pointwise_result(
result: ArrayLike,
dtype: DtypeObj,
Expand Down
25 changes: 6 additions & 19 deletions pandas/core/indexes/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
infer_dtype_from_scalar,
maybe_box_datetimelike,
maybe_downcast_numeric,
maybe_upcast_numeric_to_64bit,
)
from pandas.core.dtypes.common import (
ensure_platform_int,
Expand All @@ -59,8 +60,6 @@
is_number,
is_object_dtype,
is_scalar,
is_signed_integer_dtype,
is_unsigned_integer_dtype,
)
from pandas.core.dtypes.dtypes import IntervalDtype
from pandas.core.dtypes.missing import is_valid_na_for_dtype
Expand Down Expand Up @@ -342,8 +341,11 @@ def from_tuples(
# "Union[IndexEngine, ExtensionEngine]" in supertype "Index"
@cache_readonly
def _engine(self) -> IntervalTree: # type: ignore[override]
# IntervalTree does not supports numpy array unless they are 64 bit
left = self._maybe_convert_i8(self.left)
left = maybe_upcast_numeric_to_64bit(left)
right = self._maybe_convert_i8(self.right)
right = maybe_upcast_numeric_to_64bit(right)
return IntervalTree(left, right, closed=self.closed)

def __contains__(self, key: Any) -> bool:
Expand Down Expand Up @@ -520,13 +522,12 @@ def _maybe_convert_i8(self, key):
The original key if no conversion occurred, int if converted scalar,
Int64Index if converted list-like.
"""
original = key
if is_list_like(key):
key = ensure_index(key)
key = self._maybe_convert_numeric_to_64bit(key)
key = maybe_upcast_numeric_to_64bit(key)

if not self._needs_i8_conversion(key):
return original
return key

scalar = is_scalar(key)
if is_interval_dtype(key) or isinstance(key, Interval):
Expand Down Expand Up @@ -569,20 +570,6 @@ def _maybe_convert_i8(self, key):

return key_i8

def _maybe_convert_numeric_to_64bit(self, idx: Index) -> Index:
# IntervalTree only supports 64 bit numpy array
dtype = idx.dtype
if np.issubclass_(dtype.type, np.number):
return idx
elif is_signed_integer_dtype(dtype) and dtype != np.int64:
return idx.astype(np.int64)
elif is_unsigned_integer_dtype(dtype) and dtype != np.uint64:
return idx.astype(np.uint64)
elif is_float_dtype(dtype) and dtype != np.float64:
return idx.astype(np.float64)
else:
return idx

def _searchsorted_monotonic(self, label, side: Literal["left", "right"] = "left"):
if not self.is_non_overlapping_monotonic:
raise KeyError(
Expand Down
25 changes: 17 additions & 8 deletions pandas/tests/indexes/interval/test_interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,27 +415,36 @@ def test_maybe_convert_i8_nat(self, breaks):
tm.assert_index_equal(result, expected)

@pytest.mark.parametrize(
"breaks",
[np.arange(5, dtype="int64"), np.arange(5, dtype="float64")],
ids=lambda x: str(x.dtype),
"make_key",
[lambda breaks: breaks, list],
ids=["lambda", "list"],
)
def test_maybe_convert_i8_numeric(self, make_key, any_real_numpy_dtype):
# GH 20636
breaks = np.arange(5, dtype=any_real_numpy_dtype)
index = IntervalIndex.from_breaks(breaks)
key = make_key(breaks)

result = index._maybe_convert_i8(key)
expected = Index(key)
tm.assert_index_equal(result, expected)

@pytest.mark.parametrize(
"make_key",
[
IntervalIndex.from_breaks,
lambda breaks: Interval(breaks[0], breaks[1]),
lambda breaks: breaks,
lambda breaks: breaks[0],
list,
],
ids=["IntervalIndex", "Interval", "Index", "scalar", "list"],
ids=["IntervalIndex", "Interval", "scalar"],
)
def test_maybe_convert_i8_numeric(self, breaks, make_key):
def test_maybe_convert_i8_numeric_identical(self, make_key, any_real_numpy_dtype):
# GH 20636
breaks = np.arange(5, dtype=any_real_numpy_dtype)
index = IntervalIndex.from_breaks(breaks)
key = make_key(breaks)

# no conversion occurs for numeric
# test if _maybe_convert_i8 won't change key if an Interval or IntervalIndex
result = index._maybe_convert_i8(key)
assert result is key

Expand Down