1
1
from __future__ import annotations
2
2
3
-
4
3
class Node :
5
4
def __init__ (self , value : int = 0 ) -> None :
6
5
self .value : int = value
7
6
self .left : Node | None = None
8
7
self .right : Node | None = None
9
8
10
-
11
9
class PersistentSegmentTree :
12
10
def __init__ (self , arr : list [int ]) -> None :
13
11
self .n : int = len (arr )
@@ -54,12 +52,22 @@ def update(self, version: int, index: int, value: int) -> int:
54
52
self .roots .append (new_root )
55
53
return len (self .roots ) - 1
56
54
57
- def _update (
58
- self , node : Node | None , start : int , end : int , index : int , value : int
59
- ) -> Node :
60
- if node is None :
61
- raise ValueError ("Cannot update a None node" )
55
+ def _update (self , node : Node , start : int , end : int , index : int , value : int ) -> Node :
56
+ """
57
+ Updates the node for the specified index and value and returns the new node.
62
58
59
+ >>> pst = PersistentSegmentTree([1, 2, 3, 4])
60
+ >>> old_root = pst.roots[0]
61
+ >>> new_root = pst._update(old_root, 0, 3, 1, 5) # Update index 1 to 5
62
+ >>> new_root.value # New sum after update
63
+ 13
64
+ >>> old_root.value # Old root remains unchanged
65
+ 10
66
+ >>> new_root.left.value # Updated left child
67
+ 6
68
+ >>> new_root.right.value # Right child remains the same
69
+ 7
70
+ """
63
71
if start == end :
64
72
return Node (value )
65
73
@@ -68,14 +76,12 @@ def _update(
68
76
69
77
if index <= mid :
70
78
new_node .left = self ._update (node .left , start , mid , index , value )
71
- new_node .right = node .right
79
+ new_node .right = node .right # Ensure right node is the same as the original
72
80
else :
73
- new_node .left = node .left
81
+ new_node .left = node .left # Ensure left node is the same as the original
74
82
new_node .right = self ._update (node .right , mid + 1 , end , index , value )
75
83
76
- new_node .value = new_node .left .value + (
77
- new_node .right .value if new_node .right else 0
78
- )
84
+ new_node .value = new_node .left .value + (new_node .right .value if new_node .right else 0 )
79
85
80
86
return new_node
81
87
@@ -91,31 +97,35 @@ def query(self, version: int, left: int, right: int) -> int:
91
97
>>> version_1 = pst.update(0, 1, 5) # Update index 1 to 5
92
98
>>> pst.query(version_1, 0, 3) # Sum of all elements in new version
93
99
13
94
- >>> pst.query(version_1, 1, 2) # Sum of elements at index 1 and 2 in new version
100
+ >>> pst.query(version_1, 1, 2) # Sum of elements at index 1 and 2
95
101
8
96
102
"""
97
103
return self ._query (self .roots [version ], 0 , self .n - 1 , left , right )
98
104
99
- def _query (
100
- self , node : Node | None , start : int , end : int , left : int , right : int
101
- ) -> int :
102
- if node is None :
103
- return 0
105
+ def _query (self , node : Node , start : int , end : int , left : int , right : int ) -> int :
106
+ """
107
+ Queries the sum of values in the range [left, right] for the given node.
104
108
105
- if left > end or right < start :
109
+ >>> pst = PersistentSegmentTree([1, 2, 3, 4])
110
+ >>> root = pst.roots[0]
111
+ >>> pst._query(root, 0, 3, 1, 2) # Sum of elements at index 1 and 2
112
+ 5
113
+ >>> pst._query(root, 0, 3, 0, 3) # Sum of all elements
114
+ 10
115
+ >>> pst._query(root, 0, 3, 2, 3) # Sum of elements at index 2 and 3
116
+ 7
117
+ """
118
+ if node is None or left > end or right < start :
106
119
return 0
107
120
if left <= start and right >= end :
108
121
return node .value
109
122
mid = (start + end ) // 2
110
- return self ._query (node .left , start , mid , left , right ) + self ._query (
111
- node .right , mid + 1 , end , left , right
112
- )
113
-
123
+ return (self ._query (node .left , start , mid , left , right ) +
124
+ self ._query (node .right , mid + 1 , end , left , right ))
114
125
115
126
# Running the doctests
116
127
if __name__ == "__main__" :
117
128
import doctest
118
-
119
129
print ("Running doctests..." )
120
130
result = doctest .testmod ()
121
131
print (f"Ran { result .attempted } tests, { result .failed } failed." )
0 commit comments