1
1
class Node :
2
- def __init__ (self , value = 0 ):
3
- """
4
- Initialize a segment tree node.
5
-
6
- Args:
7
- value (int): The value of the node.
8
- """
2
+ def __init__ (self , value : int = 0 ) -> None :
9
3
self .value = value
10
4
self .left = None
11
5
self .right = None
12
6
13
-
14
7
class PersistentSegmentTree :
15
- def __init__ (self , arr ):
16
- """
17
- Initialize the persistent segment tree with the given array.
18
-
19
- Args:
20
- arr (list): The initial array to build the segment tree.
21
- """
8
+ def __init__ (self , arr : list [int ]) -> None :
22
9
self .n = len (arr )
23
- self .roots = []
10
+ self .roots : list [ Node ] = []
24
11
self .roots .append (self ._build (arr , 0 , self .n - 1 ))
25
12
26
- def _build (self , arr , start , end ) :
13
+ def _build (self , arr : list [ int ] , start : int , end : int ) -> Node :
27
14
"""
28
- Recursively build the segment tree.
29
-
30
- Args:
31
- arr (list): The input array.
32
- start (int): The starting index of the segment.
33
- end (int): The ending index of the segment.
34
-
35
- Returns:
36
- Node: The root node of the segment tree for the current segment.
15
+ Builds a segment tree from the provided array.
16
+
17
+ >>> pst = PersistentSegmentTree([1, 2, 3])
18
+ >>> root = pst._build([1, 2, 3], 0, 2)
19
+ >>> root.value # Sum of the whole array
20
+ 6
37
21
"""
38
22
if start == end :
39
23
return Node (arr [start ])
40
-
41
24
mid = (start + end ) // 2
42
25
node = Node ()
43
26
node .left = self ._build (arr , start , mid )
44
27
node .right = self ._build (arr , mid + 1 , end )
45
28
node .value = node .left .value + node .right .value
46
29
return node
47
30
48
- def update (self , version , index , value ) :
31
+ def update (self , version : int , index : int , value : int ) -> int :
49
32
"""
50
- Update the value at the specified index in the specified version.
51
-
52
- Args:
53
- version (int): The version of the segment tree to update.
54
- index (int): The index to update.
55
- value (int): The new value to set at the index.
56
-
57
- Returns:
58
- int: The index of the new version of the root node.
33
+ Updates the segment tree with a new value at the specified index.
34
+
35
+ >>> pst = PersistentSegmentTree([1, 2, 3])
36
+ >>> version_1 = pst.update(0, 1, 5)
37
+ >>> pst.query(version_1, 0, 2) # Query sum from index 0 to 2
38
+ 9
59
39
"""
60
40
new_root = self ._update (self .roots [version ], 0 , self .n - 1 , index , value )
61
41
self .roots .append (new_root )
62
42
return len (self .roots ) - 1 # return the index of the new version
63
43
64
- def _update (self , node , start , end , index , value ):
65
- """
66
- Recursively update the segment tree.
67
-
68
- Args:
69
- node (Node): The current node of the segment tree.
70
- start (int): The starting index of the segment.
71
- end (int): The ending index of the segment.
72
- index (int): The index to update.
73
- value (int): The new value to set at the index.
74
-
75
- Returns:
76
- Node: The new root node after the update.
77
- """
44
+ def _update (self , node : Node , start : int , end : int , index : int , value : int ) -> Node :
78
45
if start == end :
79
46
new_node = Node (value )
80
47
return new_node
81
-
82
48
mid = (start + end ) // 2
83
49
new_node = Node ()
84
50
if index <= mid :
@@ -87,64 +53,27 @@ def _update(self, node, start, end, index, value):
87
53
else :
88
54
new_node .left = node .left
89
55
new_node .right = self ._update (node .right , mid + 1 , end , index , value )
90
-
91
56
new_node .value = new_node .left .value + new_node .right .value
92
57
return new_node
93
58
94
- def query (self , version , left , right ) :
59
+ def query (self , version : int , left : int , right : int ) -> int :
95
60
"""
96
- Query the sum of values in the range [left, right] for the specified version.
97
-
98
- Args:
99
- version (int): The version of the segment tree to query.
100
- left (int): The left index of the range.
101
- right (int): The right index of the range.
102
-
103
- Returns:
104
- int: The sum of the values in the specified range.
61
+ Queries the sum in the given range for the specified version.
62
+
63
+ >>> pst = PersistentSegmentTree([1, 2, 3])
64
+ >>> version_1 = pst.update(0, 1, 5)
65
+ >>> pst.query(version_1, 0, 1) # Query sum from index 0 to 1
66
+ 6
67
+ >>> pst.query(version_1, 0, 2) # Query sum from index 0 to 2
68
+ 9
105
69
"""
106
70
return self ._query (self .roots [version ], 0 , self .n - 1 , left , right )
107
71
108
- def _query (self , node , start , end , left , right ):
109
- """
110
- Recursively query the segment tree.
111
-
112
- Args:
113
- node (Node): The current node of the segment tree.
114
- start (int): The starting index of the segment.
115
- end (int): The ending index of the segment.
116
- left (int): The left index of the range.
117
- right (int): The right index of the range.
118
-
119
- Returns:
120
- int: The sum of the values in the specified range.
121
- """
122
- if right < start or end < left :
123
- return 0 # out of range
124
-
125
- if left <= start and end <= right :
126
- return node .value # completely within range
127
-
72
+ def _query (self , node : Node , start : int , end : int , left : int , right : int ) -> int :
73
+ if left > end or right < start :
74
+ return 0
75
+ if left <= start and right >= end :
76
+ return node .value
128
77
mid = (start + end ) // 2
129
- sum_left = self ._query (node .left , start , mid , left , right )
130
- sum_right = self ._query (node .right , mid + 1 , end , left , right )
131
- return sum_left + sum_right
132
-
133
-
134
- # Example usage and doctests
135
- if __name__ == "__main__" :
136
- import doctest
137
-
138
- # Creating an initial array
139
- arr = [1 , 2 , 3 , 4 , 5 ]
140
- pst = PersistentSegmentTree (arr )
141
-
142
- # Querying the initial version
143
- assert pst .query (0 , 0 , 4 ) == 15 # sum of [1, 2, 3, 4, 5]
144
-
145
- # Updating index 2 to value 10 in version 0
146
- new_version = pst .update (0 , 2 , 10 )
147
-
148
- # Querying the updated version
149
- assert pst .query (new_version , 0 , 4 ) == 22 # sum of [1, 2, 10, 4, 5]
150
- assert pst .query (0 , 0 , 4 ) == 15 # original version unchanged
78
+ return (self ._query (node .left , start , mid , left , right ) +
79
+ self ._query (node .right , mid + 1 , end , left , right ))
0 commit comments