@@ -31,7 +31,9 @@ cdef class IntervalTree(IntervalMixin):
31
31
we are emulating the IndexEngine interface
32
32
"""
33
33
cdef readonly:
34
- object left, right, root, dtype
34
+ ndarray left, right
35
+ IntervalNode root
36
+ object dtype
35
37
str closed
36
38
object _is_overlapping, _left_sorter, _right_sorter
37
39
@@ -203,6 +205,41 @@ cdef sort_values_and_indices(all_values, all_indices, subset):
203
205
# Nodes
204
206
# ----------------------------------------------------------------------
205
207
208
+ @cython.internal
209
+ cdef class IntervalNode:
210
+ cdef readonly:
211
+ int64_t n_elements, n_center, leaf_size
212
+ bint is_leaf_node
213
+
214
+ def __repr__(self) -> str:
215
+ if self.is_leaf_node:
216
+ return (
217
+ f"<{type(self).__name__}: {self.n_elements} elements (terminal)>"
218
+ )
219
+ else:
220
+ n_left = self.left_node.n_elements
221
+ n_right = self.right_node.n_elements
222
+ n_center = self.n_elements - n_left - n_right
223
+ return (
224
+ f"<{type(self).__name__}: "
225
+ f"pivot {self.pivot}, {self.n_elements} elements "
226
+ f"({n_left} left, {n_right} right, {n_center} overlapping)>"
227
+ )
228
+
229
+ def counts(self):
230
+ """
231
+ Inspect counts on this node
232
+ useful for debugging purposes
233
+ """
234
+ if self.is_leaf_node:
235
+ return self.n_elements
236
+ else:
237
+ m = len(self.center_left_values)
238
+ l = self.left_node.counts()
239
+ r = self.right_node.counts()
240
+ return (m, (l, r))
241
+
242
+
206
243
# we need specialized nodes and leaves to optimize for different dtype and
207
244
# closed values
208
245
@@ -240,7 +277,7 @@ NODE_CLASSES = {}
240
277
241
278
242
279
@cython.internal
243
- cdef class {{dtype_title}}Closed{{closed_title}}IntervalNode:
280
+ cdef class {{dtype_title}}Closed{{closed_title}}IntervalNode(IntervalNode) :
244
281
"""Non-terminal node for an IntervalTree
245
282
246
283
Categorizes intervals by those that fall to the left, those that fall to
@@ -252,8 +289,6 @@ cdef class {{dtype_title}}Closed{{closed_title}}IntervalNode:
252
289
int64_t[:] center_left_indices, center_right_indices, indices
253
290
{{dtype}}_t min_left, max_right
254
291
{{dtype}}_t pivot
255
- int64_t n_elements, n_center, leaf_size
256
- bint is_leaf_node
257
292
258
293
def __init__(self,
259
294
ndarray[{{dtype}}_t, ndim=1] left,
@@ -381,31 +416,6 @@ cdef class {{dtype_title}}Closed{{closed_title}}IntervalNode:
381
416
else:
382
417
result.extend(self.center_left_indices)
383
418
384
- def __repr__(self) -> str:
385
- if self.is_leaf_node:
386
- return ('<{{dtype_title}}Closed{{closed_title}}IntervalNode: '
387
- '%s elements (terminal)>' % self.n_elements)
388
- else:
389
- n_left = self.left_node.n_elements
390
- n_right = self.right_node.n_elements
391
- n_center = self.n_elements - n_left - n_right
392
- return ('<{{dtype_title}}Closed{{closed_title}}IntervalNode: '
393
- 'pivot %s, %s elements (%s left, %s right, %s '
394
- 'overlapping)>' % (self.pivot, self.n_elements,
395
- n_left, n_right, n_center))
396
-
397
- def counts(self):
398
- """
399
- Inspect counts on this node
400
- useful for debugging purposes
401
- """
402
- if self.is_leaf_node:
403
- return self.n_elements
404
- else:
405
- m = len(self.center_left_values)
406
- l = self.left_node.counts()
407
- r = self.right_node.counts()
408
- return (m, (l, r))
409
419
410
420
NODE_CLASSES['{{dtype}}',
411
421
'{{closed}}'] = {{dtype_title}}Closed{{closed_title}}IntervalNode
0 commit comments