Skip to content

Commit ada767d

Browse files
authoredOct 19, 2024
Add files via upload
1 parent b01fbff commit ada767d

File tree

1 file changed

+80
-0
lines changed

1 file changed

+80
-0
lines changed
 
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
class Node:
2+
def __init__(self, value: int = 0) -> None:
3+
self.value = value
4+
self.left = None
5+
self.right = None
6+
7+
8+
class PersistentSegmentTree:
9+
def __init__(self, arr: list[int]) -> None:
10+
self.n = len(arr)
11+
self.roots: list[Node] = []
12+
self.roots.append(self._build(arr, 0, self.n - 1))
13+
14+
def _build(self, arr: list[int], start: int, end: int) -> Node:
15+
"""
16+
Builds a segment tree from the provided array.
17+
18+
>>> pst = PersistentSegmentTree([1, 2, 3])
19+
>>> root = pst._build([1, 2, 3], 0, 2)
20+
>>> root.value # Sum of the whole array
21+
6
22+
"""
23+
if start == end:
24+
return Node(arr[start])
25+
mid = (start + end) // 2
26+
node = Node()
27+
node.left = self._build(arr, start, mid)
28+
node.right = self._build(arr, mid + 1, end)
29+
node.value = node.left.value + node.right.value
30+
return node
31+
32+
def update(self, version: int, index: int, value: int) -> int:
33+
"""
34+
Updates the segment tree with a new value at the specified index.
35+
36+
>>> pst = PersistentSegmentTree([1, 2, 3])
37+
>>> version_1 = pst.update(0, 1, 5)
38+
>>> pst.query(version_1, 0, 2) # Query sum from index 0 to 2
39+
9
40+
"""
41+
new_root = self._update(self.roots[version], 0, self.n - 1, index, value)
42+
self.roots.append(new_root)
43+
return len(self.roots) - 1 # return the index of the new version
44+
45+
def _update(self, node: Node, start: int, end: int, index: int, value: int) -> Node:
46+
if start == end:
47+
new_node = Node(value)
48+
return new_node
49+
mid = (start + end) // 2
50+
new_node = Node()
51+
if index <= mid:
52+
new_node.left = self._update(node.left, start, mid, index, value)
53+
new_node.right = node.right
54+
else:
55+
new_node.left = node.left
56+
new_node.right = self._update(node.right, mid + 1, end, index, value)
57+
new_node.value = new_node.left.value + new_node.right.value
58+
return new_node
59+
60+
def query(self, version: int, left: int, right: int) -> int:
61+
"""
62+
Queries the sum in the given range for the specified version.
63+
64+
>>> pst = PersistentSegmentTree([1, 2, 3])
65+
>>> version_1 = pst.update(0, 1, 5)
66+
>>> pst.query(version_1, 0, 1) # Query sum from index 0 to 1
67+
6
68+
>>> pst.query(version_1, 0, 2) # Query sum from index 0 to 2
69+
9
70+
"""
71+
return self._query(self.roots[version], 0, self.n - 1, left, right)
72+
73+
def _query(self, node: Node, start: int, end: int, left: int, right: int) -> int:
74+
if left > end or right < start:
75+
return 0
76+
if left <= start and right >= end:
77+
return node.value
78+
mid = (start + end) // 2
79+
return (self._query(node.left, start, mid, left, right) +
80+
self._query(node.right, mid + 1, end, left, right))

0 commit comments

Comments
 (0)