Skip to content

Commit 0dcfc51

Browse files
VISWESWARAN1998mroeschke
authored andcommitted
Fix for issue pandas-dev#57268 - ENH: Preserve input start/end type in interval… (pandas-dev#57399)
* Fix for issue pandas-dev#57268 - ENH: Preserve input start/end type in interval_range * issue pandas-dev#57268 - github actions resolution * Use generated datatype from breaks * Ruff - Pre-commit issue fix * Fix for issue pandas-dev#57268 - floating point support * int - float dtype compatability * whatsnew documentation update * OS based varaible access * Fixing failed unit test cases * pytest - interval passsed * Python backwards compatability * Pytest * Fixing PyLint and mypy issues * dtype specification * Conditional statement simplification * remove redundant code blocks * Changing whatsnew to interval section * Passing expected in parameterize * Update doc/source/whatsnew/v3.0.0.rst Co-authored-by: Matthew Roeschke <[email protected]> * Update pandas/core/indexes/interval.py Co-authored-by: Matthew Roeschke <[email protected]> --------- Co-authored-by: Matthew Roeschke <[email protected]>
1 parent 2fa5f22 commit 0dcfc51

File tree

3 files changed

+36
-3
lines changed

3 files changed

+36
-3
lines changed

doc/source/whatsnew/v3.0.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ Strings
211211

212212
Interval
213213
^^^^^^^^
214-
-
214+
- Bug in :func:`interval_range` where start and end numeric types were always cast to 64 bit (:issue:`57268`)
215215
-
216216

217217
Indexing

pandas/core/indexes/interval.py

+21-2
Original file line numberDiff line numberDiff line change
@@ -1101,9 +1101,23 @@ def interval_range(
11011101
breaks: np.ndarray | TimedeltaIndex | DatetimeIndex
11021102

11031103
if is_number(endpoint):
1104+
dtype: np.dtype = np.dtype("int64")
11041105
if com.all_not_none(start, end, freq):
1106+
if (
1107+
isinstance(start, (float, np.float16))
1108+
or isinstance(end, (float, np.float16))
1109+
or isinstance(freq, (float, np.float16))
1110+
):
1111+
dtype = np.dtype("float64")
1112+
elif (
1113+
isinstance(start, (np.integer, np.floating))
1114+
and isinstance(end, (np.integer, np.floating))
1115+
and start.dtype == end.dtype
1116+
):
1117+
dtype = start.dtype
11051118
# 0.1 ensures we capture end
11061119
breaks = np.arange(start, end + (freq * 0.1), freq)
1120+
breaks = maybe_downcast_numeric(breaks, dtype)
11071121
else:
11081122
# compute the period/start/end if unspecified (at most one)
11091123
if periods is None:
@@ -1122,7 +1136,7 @@ def interval_range(
11221136
# expected "ndarray[Any, Any]" [
11231137
breaks = maybe_downcast_numeric(
11241138
breaks, # type: ignore[arg-type]
1125-
np.dtype("int64"),
1139+
dtype,
11261140
)
11271141
else:
11281142
# delegate to the appropriate range function
@@ -1131,4 +1145,9 @@ def interval_range(
11311145
else:
11321146
breaks = timedelta_range(start=start, end=end, periods=periods, freq=freq)
11331147

1134-
return IntervalIndex.from_breaks(breaks, name=name, closed=closed)
1148+
return IntervalIndex.from_breaks(
1149+
breaks,
1150+
name=name,
1151+
closed=closed,
1152+
dtype=IntervalDtype(subtype=breaks.dtype, closed=closed),
1153+
)

pandas/tests/indexes/interval/test_interval_range.py

+14
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,20 @@ def test_float_subtype(self, start, end, freq):
220220
expected = "int64" if is_integer(start + end) else "float64"
221221
assert result == expected
222222

223+
@pytest.mark.parametrize(
224+
"start, end, expected",
225+
[
226+
(np.int8(1), np.int8(10), np.dtype("int8")),
227+
(np.int8(1), np.float16(10), np.dtype("float64")),
228+
(np.float32(1), np.float32(10), np.dtype("float32")),
229+
(1, 10, np.dtype("int64")),
230+
(1, 10.0, np.dtype("float64")),
231+
],
232+
)
233+
def test_interval_dtype(self, start, end, expected):
234+
result = interval_range(start=start, end=end).dtype.subtype
235+
assert result == expected
236+
223237
def test_interval_range_fractional_period(self):
224238
# float value for periods
225239
expected = interval_range(start=0, periods=10)

0 commit comments

Comments
 (0)