Skip to content

Commit 82270b7

Browse files
authored
Update persistent_segment_tree.py
1 parent 5de6184 commit 82270b7

File tree

1 file changed

+42
-31
lines changed

1 file changed

+42
-31
lines changed

data_structures/persistent_segment_tree.py

+42-31
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,23 @@ def __init__(self, value: int = 0) -> None:
77

88
class PersistentSegmentTree:
99
def __init__(self, arr: list[int]) -> None:
10-
"""
11-
Initialize the Persistent Segment Tree with the given array.
12-
13-
>>> pst = PersistentSegmentTree([1, 2, 3])
14-
>>> pst.query(0, 0, 2)
15-
6
16-
"""
1710
self.n = len(arr)
1811
self.roots: list[Node] = []
1912
self.roots.append(self._build(arr, 0, self.n - 1))
2013

2114
def _build(self, arr: list[int], start: int, end: int) -> Node:
15+
"""
16+
Builds a segment tree from the provided array.
17+
18+
>>> pst = PersistentSegmentTree([1, 2, 3, 4])
19+
>>> root = pst._build([1, 2, 3, 4], 0, 3)
20+
>>> root.value # Sum of the whole array
21+
10
22+
>>> root.left.value # Sum of the left half
23+
3
24+
>>> root.right.value # Sum of the right half
25+
7
26+
"""
2227
if start == end:
2328
return Node(arr[start])
2429
mid = (start + end) // 2
@@ -29,19 +34,26 @@ def _build(self, arr: list[int], start: int, end: int) -> Node:
2934
return node
3035

3136
def update(self, version: int, index: int, value: int) -> int:
32-
"""
33-
Update the value at the given index and return the new version.
34-
35-
>>> pst = PersistentSegmentTree([1, 2, 3])
36-
>>> version_1 = pst.update(0, 1, 5)
37-
>>> pst.query(version_1, 0, 2)
38-
9
39-
"""
4037
new_root = self._update(self.roots[version], 0, self.n - 1, index, value)
4138
self.roots.append(new_root)
4239
return len(self.roots) - 1
4340

4441
def _update(self, node: Node, start: int, end: int, index: int, value: int) -> Node:
42+
"""
43+
Updates the node for the specified index and value and returns the new node.
44+
45+
>>> pst = PersistentSegmentTree([1, 2, 3, 4])
46+
>>> old_root = pst.roots[0]
47+
>>> new_root = pst._update(old_root, 0, 3, 1, 5) # Update index 1 to 5
48+
>>> new_root.value # New sum after update
49+
13
50+
>>> old_root.value # Old root remains unchanged
51+
10
52+
>>> new_root.left.value # Updated left child
53+
6
54+
>>> new_root.right.value # Right child remains the same
55+
7
56+
"""
4557
if start == end:
4658
return Node(value)
4759

@@ -60,34 +72,33 @@ def _update(self, node: Node, start: int, end: int, index: int, value: int) -> N
6072
return new_node
6173

6274
def query(self, version: int, left: int, right: int) -> int:
63-
"""
64-
Query the sum in the given range for the specified version.
65-
66-
>>> pst = PersistentSegmentTree([1, 2, 3])
67-
>>> pst.query(0, 0, 2)
68-
6
69-
>>> version_1 = pst.update(0, 1, 5)
70-
>>> pst.query(version_1, 0, 1)
71-
6
72-
>>> pst.query(version_1, 0, 2)
73-
9
74-
"""
7575
return self._query(self.roots[version], 0, self.n - 1, left, right)
7676

7777
def _query(self, node: Node, start: int, end: int, left: int, right: int) -> int:
78+
"""
79+
Queries the sum of values in the range [left, right] for the given node.
80+
81+
>>> pst = PersistentSegmentTree([1, 2, 3, 4])
82+
>>> root = pst.roots[0]
83+
>>> pst._query(root, 0, 3, 1, 2) # Sum of elements at index 1 and 2
84+
5
85+
>>> pst._query(root, 0, 3, 0, 3) # Sum of all elements
86+
10
87+
>>> pst._query(root, 0, 3, 2, 3) # Sum of elements at index 2 and 3
88+
7
89+
"""
7890
if left > end or right < start:
7991
return 0
8092
if left <= start and right >= end:
8193
return node.value
8294
mid = (start + end) // 2
83-
return self._query(node.left, start, mid, left, right) + self._query(
84-
node.right, mid + 1, end, left, right
85-
)
95+
return (self._query(node.left, start, mid, left, right) +
96+
self._query(node.right, mid + 1, end, left, right))
8697

8798

99+
# Running the doctests
88100
if __name__ == "__main__":
89101
import doctest
90-
91102
print("Running doctests...")
92103
result = doctest.testmod()
93104
print(f"Ran {result.attempted} tests, {result.failed} failed.")

0 commit comments

Comments
 (0)