@@ -7,18 +7,23 @@ def __init__(self, value: int = 0) -> None:
7
7
8
8
class PersistentSegmentTree :
9
9
def __init__ (self , arr : list [int ]) -> None :
10
- """
11
- Initialize the Persistent Segment Tree with the given array.
12
-
13
- >>> pst = PersistentSegmentTree([1, 2, 3])
14
- >>> pst.query(0, 0, 2)
15
- 6
16
- """
17
10
self .n = len (arr )
18
11
self .roots : list [Node ] = []
19
12
self .roots .append (self ._build (arr , 0 , self .n - 1 ))
20
13
21
14
def _build (self , arr : list [int ], start : int , end : int ) -> Node :
15
+ """
16
+ Builds a segment tree from the provided array.
17
+
18
+ >>> pst = PersistentSegmentTree([1, 2, 3, 4])
19
+ >>> root = pst._build([1, 2, 3, 4], 0, 3)
20
+ >>> root.value # Sum of the whole array
21
+ 10
22
+ >>> root.left.value # Sum of the left half
23
+ 3
24
+ >>> root.right.value # Sum of the right half
25
+ 7
26
+ """
22
27
if start == end :
23
28
return Node (arr [start ])
24
29
mid = (start + end ) // 2
@@ -29,19 +34,26 @@ def _build(self, arr: list[int], start: int, end: int) -> Node:
29
34
return node
30
35
31
36
def update (self , version : int , index : int , value : int ) -> int :
32
- """
33
- Update the value at the given index and return the new version.
34
-
35
- >>> pst = PersistentSegmentTree([1, 2, 3])
36
- >>> version_1 = pst.update(0, 1, 5)
37
- >>> pst.query(version_1, 0, 2)
38
- 9
39
- """
40
37
new_root = self ._update (self .roots [version ], 0 , self .n - 1 , index , value )
41
38
self .roots .append (new_root )
42
39
return len (self .roots ) - 1
43
40
44
41
def _update (self , node : Node , start : int , end : int , index : int , value : int ) -> Node :
42
+ """
43
+ Updates the node for the specified index and value and returns the new node.
44
+
45
+ >>> pst = PersistentSegmentTree([1, 2, 3, 4])
46
+ >>> old_root = pst.roots[0]
47
+ >>> new_root = pst._update(old_root, 0, 3, 1, 5) # Update index 1 to 5
48
+ >>> new_root.value # New sum after update
49
+ 13
50
+ >>> old_root.value # Old root remains unchanged
51
+ 10
52
+ >>> new_root.left.value # Updated left child
53
+ 6
54
+ >>> new_root.right.value # Right child remains the same
55
+ 7
56
+ """
45
57
if start == end :
46
58
return Node (value )
47
59
@@ -60,34 +72,33 @@ def _update(self, node: Node, start: int, end: int, index: int, value: int) -> N
60
72
return new_node
61
73
62
74
def query (self , version : int , left : int , right : int ) -> int :
63
- """
64
- Query the sum in the given range for the specified version.
65
-
66
- >>> pst = PersistentSegmentTree([1, 2, 3])
67
- >>> pst.query(0, 0, 2)
68
- 6
69
- >>> version_1 = pst.update(0, 1, 5)
70
- >>> pst.query(version_1, 0, 1)
71
- 6
72
- >>> pst.query(version_1, 0, 2)
73
- 9
74
- """
75
75
return self ._query (self .roots [version ], 0 , self .n - 1 , left , right )
76
76
77
77
def _query (self , node : Node , start : int , end : int , left : int , right : int ) -> int :
78
+ """
79
+ Queries the sum of values in the range [left, right] for the given node.
80
+
81
+ >>> pst = PersistentSegmentTree([1, 2, 3, 4])
82
+ >>> root = pst.roots[0]
83
+ >>> pst._query(root, 0, 3, 1, 2) # Sum of elements at index 1 and 2
84
+ 5
85
+ >>> pst._query(root, 0, 3, 0, 3) # Sum of all elements
86
+ 10
87
+ >>> pst._query(root, 0, 3, 2, 3) # Sum of elements at index 2 and 3
88
+ 7
89
+ """
78
90
if left > end or right < start :
79
91
return 0
80
92
if left <= start and right >= end :
81
93
return node .value
82
94
mid = (start + end ) // 2
83
- return self ._query (node .left , start , mid , left , right ) + self ._query (
84
- node .right , mid + 1 , end , left , right
85
- )
95
+ return (self ._query (node .left , start , mid , left , right ) +
96
+ self ._query (node .right , mid + 1 , end , left , right ))
86
97
87
98
99
+ # Running the doctests
88
100
if __name__ == "__main__" :
89
101
import doctest
90
-
91
102
print ("Running doctests..." )
92
103
result = doctest .testmod ()
93
104
print (f"Ran { result .attempted } tests, { result .failed } failed." )
0 commit comments