Skip to content

BLD: Fix IntervalTree build warnings #30560

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 1, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 33 additions & 8 deletions pandas/_libs/intervaltree.pxi.in
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,20 @@ WARNING: DO NOT edit .pxi FILE directly, .pxi is generated from .pxi.in

from pandas._libs.algos import is_monotonic

ctypedef fused scalar_t:
float64_t
float32_t
ctypedef fused int_scalar_t:
int64_t
int32_t
float64_t
float32_t
Copy link
Member Author

@jschendel jschendel Dec 30, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can drop support for float32 and int32 dtypes here, which should help reduce build time. There is no practical way to actually get an IntervalTree of these dtypes since IntervalIndex is backed by 2 indexes and we don't have a Float32Index or Int32Index. Will save this for a followup though.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Soounds good.

im confused by the naming here: why are the float dtypes included in int_scalar_t/uint_scalar_t?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I couldn't think of a good way to name those. Idea is that int_scalar_t/uint_scalar_t are things that are comparable to int/uint, and we want to be able to compare against float in both cases to determine things like 0.5 being in the interval (0, 1).

I tried it without including floats and got "no matching type signature" errors when trying stuff like IntervalTree[int].get_indexer([0.5]).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn’t comparisons to float be very inaccurate in the 2 ** 63 plus range where we get unsigned?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, comparisons will be inaccurate above 2**61, so the issue is also present on the upper/lower ends of the int64 range. Comparisons are still valid below this so probably something we'll have to live with, e.g. the same holds for comparisons against UInt64Index but the behavior is still allowed there.


ctypedef fused uint_scalar_t:
uint64_t
float64_t
float32_t

ctypedef fused scalar_t:
int_scalar_t
uint_scalar_t

# ----------------------------------------------------------------------
# IntervalTree
Expand Down Expand Up @@ -128,7 +136,12 @@ cdef class IntervalTree(IntervalMixin):
result = Int64Vector()
old_len = 0
for i in range(len(target)):
self.root.query(result, target[i])
try:
self.root.query(result, target[i])
except OverflowError:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to add these as operations like IntervalTree[uint64].get_indexer([-1]) will raise an OverflowError. See the associated tests I added.

# overflow -> no match, which is already handled below
pass

if result.data.n == old_len:
result.append(-1)
elif result.data.n > old_len + 1:
Expand All @@ -150,7 +163,12 @@ cdef class IntervalTree(IntervalMixin):
missing = Int64Vector()
old_len = 0
for i in range(len(target)):
self.root.query(result, target[i])
try:
self.root.query(result, target[i])
except OverflowError:
# overflow -> no match, which is already handled below
pass

if result.data.n == old_len:
result.append(-1)
missing.append(i)
Expand Down Expand Up @@ -202,19 +220,26 @@ for dtype in ['float32', 'float64', 'int32', 'int64', 'uint64']:
('neither', '<', '<')]:
cmp_left_converse = '<' if cmp_left == '<=' else '<='
cmp_right_converse = '<' if cmp_right == '<=' else '<='
if dtype.startswith('int'):
fused_prefix = 'int_'
elif dtype.startswith('uint'):
fused_prefix = 'uint_'
elif dtype.startswith('float'):
fused_prefix = ''
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

open to a better naming convention for the fused types to make this less gross

nodes.append((dtype, dtype.title(),
closed, closed.title(),
cmp_left,
cmp_right,
cmp_left_converse,
cmp_right_converse))
cmp_right_converse,
fused_prefix))

}}

NODE_CLASSES = {}

{{for dtype, dtype_title, closed, closed_title, cmp_left, cmp_right,
cmp_left_converse, cmp_right_converse in nodes}}
cmp_left_converse, cmp_right_converse, fused_prefix in nodes}}

cdef class {{dtype_title}}Closed{{closed_title}}IntervalNode:
"""Non-terminal node for an IntervalTree
Expand Down Expand Up @@ -317,7 +342,7 @@ cdef class {{dtype_title}}Closed{{closed_title}}IntervalNode:
@cython.wraparound(False)
@cython.boundscheck(False)
@cython.initializedcheck(False)
cpdef query(self, Int64Vector result, scalar_t point):
cpdef query(self, Int64Vector result, {{fused_prefix}}scalar_t point):
"""Recursively query this node and its sub-nodes for intervals that
overlap with the query point.
"""
Expand Down
26 changes: 26 additions & 0 deletions pandas/tests/indexes/interval/test_interval_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,17 @@ def test_get_indexer(self, tree):
):
tree.get_indexer(np.array([3.0]))

@pytest.mark.parametrize(
"dtype, target_value", [("int64", 2 ** 63 + 1), ("uint64", -1)]
)
def test_get_indexer_overflow(self, dtype, target_value):
left, right = np.array([0, 1], dtype=dtype), np.array([1, 2], dtype=dtype)
tree = IntervalTree(left, right)

result = tree.get_indexer(np.array([target_value]))
expected = np.array([-1], dtype="intp")
tm.assert_numpy_array_equal(result, expected)

def test_get_indexer_non_unique(self, tree):
indexer, missing = tree.get_indexer_non_unique(np.array([1.0, 2.0, 6.5]))

Expand All @@ -82,6 +93,21 @@ def test_get_indexer_non_unique(self, tree):
expected = np.array([2], dtype="intp")
tm.assert_numpy_array_equal(result, expected)

@pytest.mark.parametrize(
"dtype, target_value", [("int64", 2 ** 63 + 1), ("uint64", -1)]
)
def test_get_indexer_non_unique_overflow(self, dtype, target_value):
left, right = np.array([0, 2], dtype=dtype), np.array([1, 3], dtype=dtype)
tree = IntervalTree(left, right)
target = np.array([target_value])

result_indexer, result_missing = tree.get_indexer_non_unique(target)
expected_indexer = np.array([-1], dtype="intp")
tm.assert_numpy_array_equal(result_indexer, expected_indexer)

expected_missing = np.array([0], dtype="intp")
tm.assert_numpy_array_equal(result_missing, expected_missing)

def test_duplicates(self, dtype):
left = np.array([0, 0, 0], dtype=dtype)
tree = IntervalTree(left, left + 1)
Expand Down