Skip to content

Commit 1417f36

Browse files
authored
Remove cudf.Scalar from interval_range (#17844)
Towards #17843 Needed to change some `interval_range` test to account for pandas-dev/pandas#57268 Authors: - Matthew Roeschke (https://github.com/mroeschke) Approvers: - Matthew Murray (https://github.com/Matt711) URL: #17844
1 parent d9b7a98 commit 1417f36

File tree

2 files changed

+49
-65
lines changed

2 files changed

+49
-65
lines changed

python/cudf/cudf/core/index.py

+26-17
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,14 @@
4848
from cudf.core.dtypes import IntervalDtype
4949
from cudf.core.join._join_helpers import _match_join_keys
5050
from cudf.core.mixins import BinaryOperand
51+
from cudf.core.scalar import pa_scalar_to_plc_scalar
5152
from cudf.core.single_column_frame import SingleColumnFrame
5253
from cudf.utils.docutils import copy_docstring
5354
from cudf.utils.dtypes import (
5455
SIZE_TYPE_DTYPE,
5556
_maybe_convert_to_default_type,
57+
cudf_dtype_from_pa_type,
58+
cudf_dtype_to_pa_type,
5659
find_common_type,
5760
is_mixed_with_object_dtype,
5861
)
@@ -3346,50 +3349,56 @@ def interval_range(
33463349
"freq, exactly three must be specified"
33473350
)
33483351

3349-
start = cudf.Scalar(start) if start is not None else start
3350-
end = cudf.Scalar(end) if end is not None else end
33513352
if periods is not None and not cudf.api.types.is_integer(periods):
33523353
warnings.warn(
33533354
"Non-integer 'periods' in cudf.date_range, and cudf.interval_range"
33543355
" are deprecated and will raise in a future version.",
33553356
FutureWarning,
33563357
)
3357-
periods = cudf.Scalar(int(periods)) if periods is not None else periods
3358-
freq = cudf.Scalar(freq) if freq is not None else freq
3359-
33603358
if start is None:
33613359
start = end - freq * periods
33623360
elif freq is None:
3363-
quotient, remainder = divmod((end - start).value, periods.value)
3361+
quotient, remainder = divmod(end - start, periods)
33643362
if remainder:
33653363
freq = (end - start) / periods
33663364
else:
3367-
freq = cudf.Scalar(int(quotient))
3365+
freq = int(quotient)
33683366
elif periods is None:
3369-
periods = cudf.Scalar(int((end - start) / freq))
3367+
periods = int((end - start) / freq)
33703368
elif end is None:
33713369
end = start + periods * freq
33723370

3371+
pa_start = pa.scalar(start)
3372+
pa_end = pa.scalar(end)
3373+
pa_freq = pa.scalar(freq)
3374+
33733375
if any(
3374-
not _is_non_decimal_numeric_dtype(x.dtype)
3375-
for x in (start, periods, freq, end)
3376+
not _is_non_decimal_numeric_dtype(cudf_dtype_from_pa_type(x.type))
3377+
for x in (pa_start, pa.scalar(periods), pa_freq, pa_end)
33763378
):
33773379
raise ValueError("start, end, periods, freq must be numeric values.")
33783380

3379-
periods = periods.astype("int64")
3380-
common_dtype = find_common_type((start.dtype, freq.dtype, end.dtype))
3381-
start = start.astype(common_dtype)
3382-
freq = freq.astype(common_dtype)
3381+
common_dtype = find_common_type(
3382+
(
3383+
cudf_dtype_from_pa_type(pa_start.type),
3384+
cudf_dtype_from_pa_type(pa_freq.type),
3385+
cudf_dtype_from_pa_type(pa_end.type),
3386+
)
3387+
)
3388+
pa_start = pa_start.cast(cudf_dtype_to_pa_type(common_dtype))
3389+
pa_freq = pa_freq.cast(cudf_dtype_to_pa_type(common_dtype))
33833390

33843391
with acquire_spill_lock():
33853392
bin_edges = libcudf.column.Column.from_pylibcudf(
33863393
plc.filling.sequence(
33873394
size=periods + 1,
3388-
init=start.device_value,
3389-
step=freq.device_value,
3395+
init=pa_scalar_to_plc_scalar(pa_start),
3396+
step=pa_scalar_to_plc_scalar(pa_freq),
33903397
)
33913398
)
3392-
return IntervalIndex.from_breaks(bin_edges, closed=closed, name=name)
3399+
return IntervalIndex.from_breaks(
3400+
bin_edges.astype(common_dtype), closed=closed, name=name
3401+
)
33933402

33943403

33953404
class IntervalIndex(Index):

python/cudf/cudf/tests/indexes/test_interval.py

+23-48
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2023-2024, NVIDIA CORPORATION.
1+
# Copyright (c) 2023-2025, NVIDIA CORPORATION.
22
import numpy as np
33
import pandas as pd
44
import pyarrow as pa
@@ -30,7 +30,6 @@ def test_interval_to_arrow():
3030
np.int64,
3131
np.float32,
3232
np.float64,
33-
cudf.Scalar,
3433
]
3534

3635
PERIODS_TYPES = [
@@ -39,10 +38,23 @@ def test_interval_to_arrow():
3938
np.int16,
4039
np.int32,
4140
np.int64,
42-
cudf.Scalar,
4341
]
4442

4543

44+
def assert_with_pandas_2_bug(pindex, gindex):
45+
# pandas upcasts to 64 bit https://github.com/pandas-dev/pandas/issues/57268
46+
# using Series to use check_dtype
47+
if gindex.dtype.subtype.kind == "f":
48+
gindex = gindex.astype(
49+
cudf.IntervalDtype(subtype="float64", closed=gindex.dtype.closed)
50+
)
51+
elif gindex.dtype.subtype.kind == "i":
52+
gindex = gindex.astype(
53+
cudf.IntervalDtype(subtype="int64", closed=gindex.dtype.closed)
54+
)
55+
assert_eq(pd.Series(pindex), cudf.Series(gindex), check_dtype=False)
56+
57+
4658
@pytest.mark.parametrize("closed", ["left", "right", "both", "neither"])
4759
@pytest.mark.parametrize("start", [0, 1, 2, 3])
4860
@pytest.mark.parametrize("end", [4, 5, 6, 7])
@@ -57,9 +69,7 @@ def test_interval_range_basic(start, end, closed):
5769
@pytest.mark.parametrize("end_t", INTERVAL_BOUNDARY_TYPES)
5870
def test_interval_range_dtype_basic(start_t, end_t):
5971
start, end = start_t(24), end_t(42)
60-
start_val = start.value if isinstance(start, cudf.Scalar) else start
61-
end_val = end.value if isinstance(end, cudf.Scalar) else end
62-
pindex = pd.interval_range(start=start_val, end=end_val, closed="left")
72+
pindex = pd.interval_range(start=start, end=end, closed="left")
6373
gindex = cudf.interval_range(start=start, end=end, closed="left")
6474

6575
assert_eq(pindex, gindex)
@@ -91,27 +101,11 @@ def test_interval_range_freq_basic(start, end, freq, closed):
91101
@pytest.mark.parametrize("freq_t", INTERVAL_BOUNDARY_TYPES)
92102
def test_interval_range_freq_basic_dtype(start_t, end_t, freq_t):
93103
start, end, freq = start_t(5), end_t(70), freq_t(3)
94-
start_val = start.value if isinstance(start, cudf.Scalar) else start
95-
end_val = end.value if isinstance(end, cudf.Scalar) else end
96-
freq_val = freq.value if isinstance(freq, cudf.Scalar) else freq
97-
pindex = pd.interval_range(
98-
start=start_val, end=end_val, freq=freq_val, closed="left"
99-
)
104+
pindex = pd.interval_range(start=start, end=end, freq=freq, closed="left")
100105
gindex = cudf.interval_range(
101106
start=start, end=end, freq=freq, closed="left"
102107
)
103-
if gindex.dtype.subtype.kind == "f":
104-
gindex = gindex.astype(
105-
cudf.IntervalDtype(subtype="float64", closed=gindex.dtype.closed)
106-
)
107-
elif gindex.dtype.subtype.kind == "i":
108-
gindex = gindex.astype(
109-
cudf.IntervalDtype(subtype="int64", closed=gindex.dtype.closed)
110-
)
111-
112-
# pandas upcasts to 64 bit https://github.com/pandas-dev/pandas/issues/57268
113-
# using Series to use check_dtype
114-
assert_eq(pd.Series(pindex), cudf.Series(gindex), check_dtype=False)
108+
assert_with_pandas_2_bug(pindex, gindex)
115109

116110

117111
@pytest.mark.parametrize("closed", ["left", "right", "both", "neither"])
@@ -134,13 +128,8 @@ def test_interval_range_periods_basic(start, end, periods, closed):
134128
@pytest.mark.parametrize("periods_t", PERIODS_TYPES)
135129
def test_interval_range_periods_basic_dtype(start_t, end_t, periods_t):
136130
start, end, periods = start_t(0), end_t(4), periods_t(1)
137-
start_val = start.value if isinstance(start, cudf.Scalar) else start
138-
end_val = end.value if isinstance(end, cudf.Scalar) else end
139-
periods_val = (
140-
periods.value if isinstance(periods, cudf.Scalar) else periods
141-
)
142131
pindex = pd.interval_range(
143-
start=start_val, end=end_val, periods=periods_val, closed="left"
132+
start=start, end=end, periods=periods, closed="left"
144133
)
145134
gindex = cudf.interval_range(
146135
start=start, end=end, periods=periods, closed="left"
@@ -188,19 +177,13 @@ def test_interval_range_periods_freq_end(end, freq, periods, closed):
188177
@pytest.mark.parametrize("end_t", INTERVAL_BOUNDARY_TYPES)
189178
def test_interval_range_periods_freq_end_dtype(periods_t, freq_t, end_t):
190179
periods, freq, end = periods_t(2), freq_t(3), end_t(10)
191-
freq_val = freq.value if isinstance(freq, cudf.Scalar) else freq
192-
end_val = end.value if isinstance(end, cudf.Scalar) else end
193-
periods_val = (
194-
periods.value if isinstance(periods, cudf.Scalar) else periods
195-
)
196180
pindex = pd.interval_range(
197-
end=end_val, freq=freq_val, periods=periods_val, closed="left"
181+
end=end, freq=freq, periods=periods, closed="left"
198182
)
199183
gindex = cudf.interval_range(
200184
end=end, freq=freq, periods=periods, closed="left"
201185
)
202-
203-
assert_eq(pindex, gindex)
186+
assert_with_pandas_2_bug(pindex, gindex)
204187

205188

206189
@pytest.mark.parametrize("closed", ["left", "right", "both", "neither"])
@@ -223,21 +206,13 @@ def test_interval_range_periods_freq_start(start, freq, periods, closed):
223206
@pytest.mark.parametrize("start_t", INTERVAL_BOUNDARY_TYPES)
224207
def test_interval_range_periods_freq_start_dtype(periods_t, freq_t, start_t):
225208
periods, freq, start = periods_t(2), freq_t(3), start_t(9)
226-
freq_val = freq.value if isinstance(freq, cudf.Scalar) else freq
227-
start_val = start.value if isinstance(start, cudf.Scalar) else start
228-
periods_val = (
229-
periods.value if isinstance(periods, cudf.Scalar) else periods
230-
)
231209
pindex = pd.interval_range(
232-
start=start_val, freq=freq_val, periods=periods_val, closed="left"
210+
start=start, freq=freq, periods=periods, closed="left"
233211
)
234212
gindex = cudf.interval_range(
235213
start=start, freq=freq, periods=periods, closed="left"
236214
)
237-
238-
# pandas upcasts to 64 bit https://github.com/pandas-dev/pandas/issues/57268
239-
# using Series to use check_dtype
240-
assert_eq(pd.Series(pindex), cudf.Series(gindex), check_dtype=False)
215+
assert_with_pandas_2_bug(pindex, gindex)
241216

242217

243218
@pytest.mark.parametrize("closed", ["right", "left", "both", "neither"])

0 commit comments

Comments
 (0)