Skip to content

Commit cd1a22e

Browse files
jbrockmendelJulianWgs
authored andcommitted
REF: stronger typing in IntervalTree (pandas-dev#41814)
1 parent 6d92086 commit cd1a22e

File tree

1 file changed

+39
-29
lines changed

1 file changed

+39
-29
lines changed

pandas/_libs/intervaltree.pxi.in

+39-29
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@ cdef class IntervalTree(IntervalMixin):
3131
we are emulating the IndexEngine interface
3232
"""
3333
cdef readonly:
34-
object left, right, root, dtype
34+
ndarray left, right
35+
IntervalNode root
36+
object dtype
3537
str closed
3638
object _is_overlapping, _left_sorter, _right_sorter
3739

@@ -203,6 +205,41 @@ cdef sort_values_and_indices(all_values, all_indices, subset):
203205
# Nodes
204206
# ----------------------------------------------------------------------
205207

208+
@cython.internal
209+
cdef class IntervalNode:
210+
cdef readonly:
211+
int64_t n_elements, n_center, leaf_size
212+
bint is_leaf_node
213+
214+
def __repr__(self) -> str:
215+
if self.is_leaf_node:
216+
return (
217+
f"<{type(self).__name__}: {self.n_elements} elements (terminal)>"
218+
)
219+
else:
220+
n_left = self.left_node.n_elements
221+
n_right = self.right_node.n_elements
222+
n_center = self.n_elements - n_left - n_right
223+
return (
224+
f"<{type(self).__name__}: "
225+
f"pivot {self.pivot}, {self.n_elements} elements "
226+
f"({n_left} left, {n_right} right, {n_center} overlapping)>"
227+
)
228+
229+
def counts(self):
230+
"""
231+
Inspect counts on this node
232+
useful for debugging purposes
233+
"""
234+
if self.is_leaf_node:
235+
return self.n_elements
236+
else:
237+
m = len(self.center_left_values)
238+
l = self.left_node.counts()
239+
r = self.right_node.counts()
240+
return (m, (l, r))
241+
242+
206243
# we need specialized nodes and leaves to optimize for different dtype and
207244
# closed values
208245

@@ -240,7 +277,7 @@ NODE_CLASSES = {}
240277

241278

242279
@cython.internal
243-
cdef class {{dtype_title}}Closed{{closed_title}}IntervalNode:
280+
cdef class {{dtype_title}}Closed{{closed_title}}IntervalNode(IntervalNode):
244281
"""Non-terminal node for an IntervalTree
245282

246283
Categorizes intervals by those that fall to the left, those that fall to
@@ -252,8 +289,6 @@ cdef class {{dtype_title}}Closed{{closed_title}}IntervalNode:
252289
int64_t[:] center_left_indices, center_right_indices, indices
253290
{{dtype}}_t min_left, max_right
254291
{{dtype}}_t pivot
255-
int64_t n_elements, n_center, leaf_size
256-
bint is_leaf_node
257292

258293
def __init__(self,
259294
ndarray[{{dtype}}_t, ndim=1] left,
@@ -381,31 +416,6 @@ cdef class {{dtype_title}}Closed{{closed_title}}IntervalNode:
381416
else:
382417
result.extend(self.center_left_indices)
383418

384-
def __repr__(self) -> str:
385-
if self.is_leaf_node:
386-
return ('<{{dtype_title}}Closed{{closed_title}}IntervalNode: '
387-
'%s elements (terminal)>' % self.n_elements)
388-
else:
389-
n_left = self.left_node.n_elements
390-
n_right = self.right_node.n_elements
391-
n_center = self.n_elements - n_left - n_right
392-
return ('<{{dtype_title}}Closed{{closed_title}}IntervalNode: '
393-
'pivot %s, %s elements (%s left, %s right, %s '
394-
'overlapping)>' % (self.pivot, self.n_elements,
395-
n_left, n_right, n_center))
396-
397-
def counts(self):
398-
"""
399-
Inspect counts on this node
400-
useful for debugging purposes
401-
"""
402-
if self.is_leaf_node:
403-
return self.n_elements
404-
else:
405-
m = len(self.center_left_values)
406-
l = self.left_node.counts()
407-
r = self.right_node.counts()
408-
return (m, (l, r))
409419

410420
NODE_CLASSES['{{dtype}}',
411421
'{{closed}}'] = {{dtype_title}}Closed{{closed_title}}IntervalNode

0 commit comments

Comments
 (0)