@@ -34,6 +34,19 @@ def _build(self, arr: list[int], start: int, end: int) -> Node:
34
34
return node
35
35
36
36
def update (self , version : int , index : int , value : int ) -> int :
37
+ """
38
+ Updates the value at the given index and returns the new version.
39
+
40
+ >>> pst = PersistentSegmentTree([1, 2, 3, 4])
41
+ >>> version_1 = pst.update(0, 1, 5) # Update index 1 to 5
42
+ >>> pst.query(version_1, 0, 3) # Query sum of all elements in new version
43
+ 13
44
+ >>> pst.query(0, 0, 3) # Original version remains unchanged
45
+ 10
46
+ >>> version_2 = pst.update(version_1, 3, 6) # Update index 3 to 6 in version_1
47
+ >>> pst.query(version_2, 0, 3) # Query sum of all elements in newest version
48
+ 15
49
+ """
37
50
new_root = self ._update (self .roots [version ], 0 , self .n - 1 , index , value )
38
51
self .roots .append (new_root )
39
52
return len (self .roots ) - 1
@@ -72,6 +85,20 @@ def _update(self, node: Node, start: int, end: int, index: int, value: int) -> N
72
85
return new_node
73
86
74
87
def query (self , version : int , left : int , right : int ) -> int :
88
+ """
89
+ Queries the sum in the given range for the specified version.
90
+
91
+ >>> pst = PersistentSegmentTree([1, 2, 3, 4])
92
+ >>> pst.query(0, 0, 3) # Sum of all elements in original version
93
+ 10
94
+ >>> pst.query(0, 1, 2) # Sum of elements at index 1 and 2 in original version
95
+ 5
96
+ >>> version_1 = pst.update(0, 1, 5) # Update index 1 to 5
97
+ >>> pst.query(version_1, 0, 3) # Sum of all elements in new version
98
+ 13
99
+ >>> pst.query(version_1, 1, 2) # Sum of elements at index 1 and 2 in new version
100
+ 8
101
+ """
75
102
return self ._query (self .roots [version ], 0 , self .n - 1 , left , right )
76
103
77
104
def _query (self , node : Node , start : int , end : int , left : int , right : int ) -> int :
@@ -92,15 +119,13 @@ def _query(self, node: Node, start: int, end: int, left: int, right: int) -> int
92
119
if left <= start and right >= end :
93
120
return node .value
94
121
mid = (start + end ) // 2
95
- return self ._query (node .left , start , mid , left , right ) + self ._query (
96
- node .right , mid + 1 , end , left , right
97
- )
122
+ return (self ._query (node .left , start , mid , left , right ) +
123
+ self ._query (node .right , mid + 1 , end , left , right ))
98
124
99
125
100
126
# Running the doctests
101
127
if __name__ == "__main__" :
102
128
import doctest
103
-
104
129
print ("Running doctests..." )
105
130
result = doctest .testmod ()
106
131
print (f"Ran { result .attempted } tests, { result .failed } failed." )
0 commit comments