diff --git a/pandas/_libs/intervaltree.pxi.in b/pandas/_libs/intervaltree.pxi.in index 333c05f7c0dc5..316c9e5b7e5f0 100644 --- a/pandas/_libs/intervaltree.pxi.in +++ b/pandas/_libs/intervaltree.pxi.in @@ -6,12 +6,20 @@ WARNING: DO NOT edit .pxi FILE directly, .pxi is generated from .pxi.in from pandas._libs.algos import is_monotonic -ctypedef fused scalar_t: - float64_t - float32_t +ctypedef fused int_scalar_t: int64_t int32_t + float64_t + float32_t + +ctypedef fused uint_scalar_t: uint64_t + float64_t + float32_t + +ctypedef fused scalar_t: + int_scalar_t + uint_scalar_t # ---------------------------------------------------------------------- # IntervalTree @@ -128,7 +136,12 @@ cdef class IntervalTree(IntervalMixin): result = Int64Vector() old_len = 0 for i in range(len(target)): - self.root.query(result, target[i]) + try: + self.root.query(result, target[i]) + except OverflowError: + # overflow -> no match, which is already handled below + pass + if result.data.n == old_len: result.append(-1) elif result.data.n > old_len + 1: @@ -150,7 +163,12 @@ cdef class IntervalTree(IntervalMixin): missing = Int64Vector() old_len = 0 for i in range(len(target)): - self.root.query(result, target[i]) + try: + self.root.query(result, target[i]) + except OverflowError: + # overflow -> no match, which is already handled below + pass + if result.data.n == old_len: result.append(-1) missing.append(i) @@ -202,19 +220,26 @@ for dtype in ['float32', 'float64', 'int32', 'int64', 'uint64']: ('neither', '<', '<')]: cmp_left_converse = '<' if cmp_left == '<=' else '<=' cmp_right_converse = '<' if cmp_right == '<=' else '<=' + if dtype.startswith('int'): + fused_prefix = 'int_' + elif dtype.startswith('uint'): + fused_prefix = 'uint_' + elif dtype.startswith('float'): + fused_prefix = '' nodes.append((dtype, dtype.title(), closed, closed.title(), cmp_left, cmp_right, cmp_left_converse, - cmp_right_converse)) + cmp_right_converse, + fused_prefix)) }} NODE_CLASSES = {} {{for dtype, dtype_title, closed, closed_title, cmp_left, cmp_right, - cmp_left_converse, cmp_right_converse in nodes}} + cmp_left_converse, cmp_right_converse, fused_prefix in nodes}} cdef class {{dtype_title}}Closed{{closed_title}}IntervalNode: """Non-terminal node for an IntervalTree @@ -317,7 +342,7 @@ cdef class {{dtype_title}}Closed{{closed_title}}IntervalNode: @cython.wraparound(False) @cython.boundscheck(False) @cython.initializedcheck(False) - cpdef query(self, Int64Vector result, scalar_t point): + cpdef query(self, Int64Vector result, {{fused_prefix}}scalar_t point): """Recursively query this node and its sub-nodes for intervals that overlap with the query point. """ diff --git a/pandas/tests/indexes/interval/test_interval_tree.py b/pandas/tests/indexes/interval/test_interval_tree.py index f2fca34e083c2..695a98777eadb 100644 --- a/pandas/tests/indexes/interval/test_interval_tree.py +++ b/pandas/tests/indexes/interval/test_interval_tree.py @@ -63,6 +63,17 @@ def test_get_indexer(self, tree): ): tree.get_indexer(np.array([3.0])) + @pytest.mark.parametrize( + "dtype, target_value", [("int64", 2 ** 63 + 1), ("uint64", -1)] + ) + def test_get_indexer_overflow(self, dtype, target_value): + left, right = np.array([0, 1], dtype=dtype), np.array([1, 2], dtype=dtype) + tree = IntervalTree(left, right) + + result = tree.get_indexer(np.array([target_value])) + expected = np.array([-1], dtype="intp") + tm.assert_numpy_array_equal(result, expected) + def test_get_indexer_non_unique(self, tree): indexer, missing = tree.get_indexer_non_unique(np.array([1.0, 2.0, 6.5])) @@ -82,6 +93,21 @@ def test_get_indexer_non_unique(self, tree): expected = np.array([2], dtype="intp") tm.assert_numpy_array_equal(result, expected) + @pytest.mark.parametrize( + "dtype, target_value", [("int64", 2 ** 63 + 1), ("uint64", -1)] + ) + def test_get_indexer_non_unique_overflow(self, dtype, target_value): + left, right = np.array([0, 2], dtype=dtype), np.array([1, 3], dtype=dtype) + tree = IntervalTree(left, right) + target = np.array([target_value]) + + result_indexer, result_missing = tree.get_indexer_non_unique(target) + expected_indexer = np.array([-1], dtype="intp") + tm.assert_numpy_array_equal(result_indexer, expected_indexer) + + expected_missing = np.array([0], dtype="intp") + tm.assert_numpy_array_equal(result_missing, expected_missing) + def test_duplicates(self, dtype): left = np.array([0, 0, 0], dtype=dtype) tree = IntervalTree(left, left + 1)