Skip to content

Commit e746d93

Browse files
authored
added persistent segment tree
1 parent 0457860 commit e746d93

File tree

1 file changed

+151
-0
lines changed

1 file changed

+151
-0
lines changed
+151
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
class Node:
2+
def __init__(self, value=0):
3+
"""
4+
Initialize a segment tree node.
5+
6+
Args:
7+
value (int): The value of the node.
8+
"""
9+
self.value = value
10+
self.left = None
11+
self.right = None
12+
13+
14+
class PersistentSegmentTree:
15+
def __init__(self, arr):
16+
"""
17+
Initialize the persistent segment tree with the given array.
18+
19+
Args:
20+
arr (list): The initial array to build the segment tree.
21+
"""
22+
self.n = len(arr)
23+
self.roots = []
24+
self.roots.append(self._build(arr, 0, self.n - 1))
25+
26+
def _build(self, arr, start, end):
27+
"""
28+
Recursively build the segment tree.
29+
30+
Args:
31+
arr (list): The input array.
32+
start (int): The starting index of the segment.
33+
end (int): The ending index of the segment.
34+
35+
Returns:
36+
Node: The root node of the segment tree for the current segment.
37+
"""
38+
if start == end:
39+
return Node(arr[start])
40+
41+
mid = (start + end) // 2
42+
node = Node()
43+
node.left = self._build(arr, start, mid)
44+
node.right = self._build(arr, mid + 1, end)
45+
node.value = node.left.value + node.right.value
46+
return node
47+
48+
def update(self, version, index, value):
49+
"""
50+
Update the value at the specified index in the specified version.
51+
52+
Args:
53+
version (int): The version of the segment tree to update.
54+
index (int): The index to update.
55+
value (int): The new value to set at the index.
56+
57+
Returns:
58+
int: The index of the new version of the root node.
59+
"""
60+
new_root = self._update(self.roots[version], 0, self.n - 1, index, value)
61+
self.roots.append(new_root)
62+
return len(self.roots) - 1 # return the index of the new version
63+
64+
def _update(self, node, start, end, index, value):
65+
"""
66+
Recursively update the segment tree.
67+
68+
Args:
69+
node (Node): The current node of the segment tree.
70+
start (int): The starting index of the segment.
71+
end (int): The ending index of the segment.
72+
index (int): The index to update.
73+
value (int): The new value to set at the index.
74+
75+
Returns:
76+
Node: The new root node after the update.
77+
"""
78+
if start == end:
79+
new_node = Node(value)
80+
return new_node
81+
82+
mid = (start + end) // 2
83+
new_node = Node()
84+
if index <= mid:
85+
new_node.left = self._update(node.left, start, mid, index, value)
86+
new_node.right = node.right
87+
else:
88+
new_node.left = node.left
89+
new_node.right = self._update(node.right, mid + 1, end, index, value)
90+
91+
new_node.value = new_node.left.value + new_node.right.value
92+
return new_node
93+
94+
def query(self, version, left, right):
95+
"""
96+
Query the sum of values in the range [left, right] for the specified version.
97+
98+
Args:
99+
version (int): The version of the segment tree to query.
100+
left (int): The left index of the range.
101+
right (int): The right index of the range.
102+
103+
Returns:
104+
int: The sum of the values in the specified range.
105+
"""
106+
return self._query(self.roots[version], 0, self.n - 1, left, right)
107+
108+
def _query(self, node, start, end, left, right):
109+
"""
110+
Recursively query the segment tree.
111+
112+
Args:
113+
node (Node): The current node of the segment tree.
114+
start (int): The starting index of the segment.
115+
end (int): The ending index of the segment.
116+
left (int): The left index of the range.
117+
right (int): The right index of the range.
118+
119+
Returns:
120+
int: The sum of the values in the specified range.
121+
"""
122+
if right < start or end < left:
123+
return 0 # out of range
124+
125+
if left <= start and end <= right:
126+
return node.value # completely within range
127+
128+
mid = (start + end) // 2
129+
sum_left = self._query(node.left, start, mid, left, right)
130+
sum_right = self._query(node.right, mid + 1, end, left, right)
131+
return sum_left + sum_right
132+
133+
134+
# Example usage and doctests
135+
if __name__ == "__main__":
136+
import doctest
137+
138+
# Creating an initial array
139+
arr = [1, 2, 3, 4, 5]
140+
pst = PersistentSegmentTree(arr)
141+
142+
# Querying the initial version
143+
assert pst.query(0, 0, 4) == 15 # sum of [1, 2, 3, 4, 5]
144+
145+
# Updating index 2 to value 10 in version 0
146+
new_version = pst.update(0, 2, 10)
147+
148+
# Querying the updated version
149+
assert pst.query(new_version, 0, 4) == 22 # sum of [1, 2, 10, 4, 5]
150+
assert pst.query(0, 0, 4) == 15 # original version unchanged
151+

0 commit comments

Comments
 (0)