@@ -4,6 +4,7 @@ def __init__(self, value: int = 0) -> None:
4
4
self .left = None
5
5
self .right = None
6
6
7
+
7
8
class PersistentSegmentTree :
8
9
def __init__ (self , arr : list [int ]) -> None :
9
10
self .n = len (arr )
@@ -13,7 +14,7 @@ def __init__(self, arr: list[int]) -> None:
13
14
def _build (self , arr : list [int ], start : int , end : int ) -> Node :
14
15
"""
15
16
Builds a segment tree from the provided array.
16
-
17
+
17
18
>>> pst = PersistentSegmentTree([1, 2, 3])
18
19
>>> root = pst._build([1, 2, 3], 0, 2)
19
20
>>> root.value # Sum of the whole array
@@ -31,7 +32,7 @@ def _build(self, arr: list[int], start: int, end: int) -> Node:
31
32
def update (self , version : int , index : int , value : int ) -> int :
32
33
"""
33
34
Updates the segment tree with a new value at the specified index.
34
-
35
+
35
36
>>> pst = PersistentSegmentTree([1, 2, 3])
36
37
>>> version_1 = pst.update(0, 1, 5)
37
38
>>> pst.query(version_1, 0, 2) # Query sum from index 0 to 2
@@ -59,7 +60,7 @@ def _update(self, node: Node, start: int, end: int, index: int, value: int) -> N
59
60
def query (self , version : int , left : int , right : int ) -> int :
60
61
"""
61
62
Queries the sum in the given range for the specified version.
62
-
63
+
63
64
>>> pst = PersistentSegmentTree([1, 2, 3])
64
65
>>> version_1 = pst.update(0, 1, 5)
65
66
>>> pst.query(version_1, 0, 1) # Query sum from index 0 to 1
@@ -75,5 +76,6 @@ def _query(self, node: Node, start: int, end: int, left: int, right: int) -> int
75
76
if left <= start and right >= end :
76
77
return node .value
77
78
mid = (start + end ) // 2
78
- return (self ._query (node .left , start , mid , left , right ) +
79
- self ._query (node .right , mid + 1 , end , left , right ))
79
+ return self ._query (node .left , start , mid , left , right ) + self ._query (
80
+ node .right , mid + 1 , end , left , right
81
+ )
0 commit comments