Skip to content

Commit 04b44c4

Browse files
authored
Update persistent_segment_tree.py
1 parent ff5a9b8 commit 04b44c4

File tree

1 file changed

+29
-4
lines changed

1 file changed

+29
-4
lines changed

data_structures/persistent_segment_tree.py

+29-4
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,19 @@ def _build(self, arr: list[int], start: int, end: int) -> Node:
3434
return node
3535

3636
def update(self, version: int, index: int, value: int) -> int:
37+
"""
38+
Updates the value at the given index and returns the new version.
39+
40+
>>> pst = PersistentSegmentTree([1, 2, 3, 4])
41+
>>> version_1 = pst.update(0, 1, 5) # Update index 1 to 5
42+
>>> pst.query(version_1, 0, 3) # Query sum of all elements in new version
43+
13
44+
>>> pst.query(0, 0, 3) # Original version remains unchanged
45+
10
46+
>>> version_2 = pst.update(version_1, 3, 6) # Update index 3 to 6 in version_1
47+
>>> pst.query(version_2, 0, 3) # Query sum of all elements in newest version
48+
15
49+
"""
3750
new_root = self._update(self.roots[version], 0, self.n - 1, index, value)
3851
self.roots.append(new_root)
3952
return len(self.roots) - 1
@@ -72,6 +85,20 @@ def _update(self, node: Node, start: int, end: int, index: int, value: int) -> N
7285
return new_node
7386

7487
def query(self, version: int, left: int, right: int) -> int:
88+
"""
89+
Queries the sum in the given range for the specified version.
90+
91+
>>> pst = PersistentSegmentTree([1, 2, 3, 4])
92+
>>> pst.query(0, 0, 3) # Sum of all elements in original version
93+
10
94+
>>> pst.query(0, 1, 2) # Sum of elements at index 1 and 2 in original version
95+
5
96+
>>> version_1 = pst.update(0, 1, 5) # Update index 1 to 5
97+
>>> pst.query(version_1, 0, 3) # Sum of all elements in new version
98+
13
99+
>>> pst.query(version_1, 1, 2) # Sum of elements at index 1 and 2 in new version
100+
8
101+
"""
75102
return self._query(self.roots[version], 0, self.n - 1, left, right)
76103

77104
def _query(self, node: Node, start: int, end: int, left: int, right: int) -> int:
@@ -92,15 +119,13 @@ def _query(self, node: Node, start: int, end: int, left: int, right: int) -> int
92119
if left <= start and right >= end:
93120
return node.value
94121
mid = (start + end) // 2
95-
return self._query(node.left, start, mid, left, right) + self._query(
96-
node.right, mid + 1, end, left, right
97-
)
122+
return (self._query(node.left, start, mid, left, right) +
123+
self._query(node.right, mid + 1, end, left, right))
98124

99125

100126
# Running the doctests
101127
if __name__ == "__main__":
102128
import doctest
103-
104129
print("Running doctests...")
105130
result = doctest.testmod()
106131
print(f"Ran {result.attempted} tests, {result.failed} failed.")

0 commit comments

Comments
 (0)