Skip to content

Commit cd357c6

Browse files
jschendeljreback
authored andcommitted
BLD: Fix IntervalTree build warnings (#30560)
1 parent 7ecd9af commit cd357c6

File tree

2 files changed

+59
-8
lines changed

2 files changed

+59
-8
lines changed

pandas/_libs/intervaltree.pxi.in

+33-8
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,20 @@ WARNING: DO NOT edit .pxi FILE directly, .pxi is generated from .pxi.in
66

77
from pandas._libs.algos import is_monotonic
88

9-
ctypedef fused scalar_t:
10-
float64_t
11-
float32_t
9+
ctypedef fused int_scalar_t:
1210
int64_t
1311
int32_t
12+
float64_t
13+
float32_t
14+
15+
ctypedef fused uint_scalar_t:
1416
uint64_t
17+
float64_t
18+
float32_t
19+
20+
ctypedef fused scalar_t:
21+
int_scalar_t
22+
uint_scalar_t
1523

1624
# ----------------------------------------------------------------------
1725
# IntervalTree
@@ -128,7 +136,12 @@ cdef class IntervalTree(IntervalMixin):
128136
result = Int64Vector()
129137
old_len = 0
130138
for i in range(len(target)):
131-
self.root.query(result, target[i])
139+
try:
140+
self.root.query(result, target[i])
141+
except OverflowError:
142+
# overflow -> no match, which is already handled below
143+
pass
144+
132145
if result.data.n == old_len:
133146
result.append(-1)
134147
elif result.data.n > old_len + 1:
@@ -150,7 +163,12 @@ cdef class IntervalTree(IntervalMixin):
150163
missing = Int64Vector()
151164
old_len = 0
152165
for i in range(len(target)):
153-
self.root.query(result, target[i])
166+
try:
167+
self.root.query(result, target[i])
168+
except OverflowError:
169+
# overflow -> no match, which is already handled below
170+
pass
171+
154172
if result.data.n == old_len:
155173
result.append(-1)
156174
missing.append(i)
@@ -202,19 +220,26 @@ for dtype in ['float32', 'float64', 'int32', 'int64', 'uint64']:
202220
('neither', '<', '<')]:
203221
cmp_left_converse = '<' if cmp_left == '<=' else '<='
204222
cmp_right_converse = '<' if cmp_right == '<=' else '<='
223+
if dtype.startswith('int'):
224+
fused_prefix = 'int_'
225+
elif dtype.startswith('uint'):
226+
fused_prefix = 'uint_'
227+
elif dtype.startswith('float'):
228+
fused_prefix = ''
205229
nodes.append((dtype, dtype.title(),
206230
closed, closed.title(),
207231
cmp_left,
208232
cmp_right,
209233
cmp_left_converse,
210-
cmp_right_converse))
234+
cmp_right_converse,
235+
fused_prefix))
211236

212237
}}
213238

214239
NODE_CLASSES = {}
215240

216241
{{for dtype, dtype_title, closed, closed_title, cmp_left, cmp_right,
217-
cmp_left_converse, cmp_right_converse in nodes}}
242+
cmp_left_converse, cmp_right_converse, fused_prefix in nodes}}
218243

219244
cdef class {{dtype_title}}Closed{{closed_title}}IntervalNode:
220245
"""Non-terminal node for an IntervalTree
@@ -317,7 +342,7 @@ cdef class {{dtype_title}}Closed{{closed_title}}IntervalNode:
317342
@cython.wraparound(False)
318343
@cython.boundscheck(False)
319344
@cython.initializedcheck(False)
320-
cpdef query(self, Int64Vector result, scalar_t point):
345+
cpdef query(self, Int64Vector result, {{fused_prefix}}scalar_t point):
321346
"""Recursively query this node and its sub-nodes for intervals that
322347
overlap with the query point.
323348
"""

pandas/tests/indexes/interval/test_interval_tree.py

+26
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,17 @@ def test_get_indexer(self, tree):
6363
):
6464
tree.get_indexer(np.array([3.0]))
6565

66+
@pytest.mark.parametrize(
67+
"dtype, target_value", [("int64", 2 ** 63 + 1), ("uint64", -1)]
68+
)
69+
def test_get_indexer_overflow(self, dtype, target_value):
70+
left, right = np.array([0, 1], dtype=dtype), np.array([1, 2], dtype=dtype)
71+
tree = IntervalTree(left, right)
72+
73+
result = tree.get_indexer(np.array([target_value]))
74+
expected = np.array([-1], dtype="intp")
75+
tm.assert_numpy_array_equal(result, expected)
76+
6677
def test_get_indexer_non_unique(self, tree):
6778
indexer, missing = tree.get_indexer_non_unique(np.array([1.0, 2.0, 6.5]))
6879

@@ -82,6 +93,21 @@ def test_get_indexer_non_unique(self, tree):
8293
expected = np.array([2], dtype="intp")
8394
tm.assert_numpy_array_equal(result, expected)
8495

96+
@pytest.mark.parametrize(
97+
"dtype, target_value", [("int64", 2 ** 63 + 1), ("uint64", -1)]
98+
)
99+
def test_get_indexer_non_unique_overflow(self, dtype, target_value):
100+
left, right = np.array([0, 2], dtype=dtype), np.array([1, 3], dtype=dtype)
101+
tree = IntervalTree(left, right)
102+
target = np.array([target_value])
103+
104+
result_indexer, result_missing = tree.get_indexer_non_unique(target)
105+
expected_indexer = np.array([-1], dtype="intp")
106+
tm.assert_numpy_array_equal(result_indexer, expected_indexer)
107+
108+
expected_missing = np.array([0], dtype="intp")
109+
tm.assert_numpy_array_equal(result_missing, expected_missing)
110+
85111
def test_duplicates(self, dtype):
86112
left = np.array([0, 0, 0], dtype=dtype)
87113
tree = IntervalTree(left, left + 1)

0 commit comments

Comments
 (0)