1
1
from __future__ import annotations
2
2
3
+
3
4
class Node :
4
5
def __init__ (self , value : int = 0 ) -> None :
5
6
self .value : int = value
6
7
self .left : Node | None = None
7
8
self .right : Node | None = None
8
9
10
+
9
11
class PersistentSegmentTree :
10
12
def __init__ (self , arr : list [int ]) -> None :
11
13
self .n : int = len (arr )
@@ -33,7 +35,9 @@ def update(self, version: int, index: int, value: int) -> int:
33
35
self .roots .append (new_root )
34
36
return len (self .roots ) - 1
35
37
36
- def _update (self , node : Node | None , start : int , end : int , index : int , value : int ) -> Node :
38
+ def _update (
39
+ self , node : Node | None , start : int , end : int , index : int , value : int
40
+ ) -> Node :
37
41
"""
38
42
Updates the node for the specified index and value and returns the new node.
39
43
"""
@@ -53,7 +57,9 @@ def _update(self, node: Node | None, start: int, end: int, index: int, value: in
53
57
new_node .left = node .left # Ensure left node is the same as the original
54
58
new_node .right = self ._update (node .right , mid + 1 , end , index , value )
55
59
56
- new_node .value = new_node .left .value + (new_node .right .value if new_node .right else 0 )
60
+ new_node .value = new_node .left .value + (
61
+ new_node .right .value if new_node .right else 0
62
+ )
57
63
58
64
return new_node
59
65
@@ -63,7 +69,9 @@ def query(self, version: int, left: int, right: int) -> int:
63
69
"""
64
70
return self ._query (self .roots [version ], 0 , self .n - 1 , left , right )
65
71
66
- def _query (self , node : Node | None , start : int , end : int , left : int , right : int ) -> int :
72
+ def _query (
73
+ self , node : Node | None , start : int , end : int , left : int , right : int
74
+ ) -> int :
67
75
"""
68
76
Queries the sum of values in the range [left, right] for the given node.
69
77
"""
@@ -72,12 +80,15 @@ def _query(self, node: Node | None, start: int, end: int, left: int, right: int)
72
80
if left <= start and right >= end :
73
81
return node .value
74
82
mid = (start + end ) // 2
75
- return (self ._query (node .left , start , mid , left , right ) +
76
- self ._query (node .right , mid + 1 , end , left , right ))
83
+ return self ._query (node .left , start , mid , left , right ) + self ._query (
84
+ node .right , mid + 1 , end , left , right
85
+ )
86
+
77
87
78
88
# Running the doctests
79
89
if __name__ == "__main__" :
80
90
import doctest
91
+
81
92
print ("Running doctests..." )
82
93
result = doctest .testmod ()
83
94
print (f"Ran { result .attempted } tests, { result .failed } failed." )
0 commit comments