1
1
from __future__ import annotations
2
2
3
-
4
3
class Node :
5
4
def __init__ (self , value : int = 0 ) -> None :
6
5
self .value : int = value
7
6
self .left : Node | None = None
8
7
self .right : Node | None = None
9
8
10
-
11
9
class PersistentSegmentTree :
12
10
def __init__ (self , arr : list [int ]) -> None :
13
11
self .n : int = len (arr )
@@ -78,14 +76,12 @@ def _update(self, node: Node, start: int, end: int, index: int, value: int) -> N
78
76
79
77
if index <= mid :
80
78
new_node .left = self ._update (node .left , start , mid , index , value )
81
- new_node .right = node .right
79
+ new_node .right = node .right # Ensure right node is the same as the original
82
80
else :
83
- new_node .left = node .left
81
+ new_node .left = node .left # Ensure left node is the same as the original
84
82
new_node .right = self ._update (node .right , mid + 1 , end , index , value )
85
83
86
- new_node .value = (new_node .left .value if new_node .left else 0 ) + (
87
- new_node .right .value if new_node .right else 0
88
- )
84
+ new_node .value = new_node .left .value + (new_node .right .value if new_node .right else 0 )
89
85
90
86
return new_node
91
87
@@ -101,7 +97,7 @@ def query(self, version: int, left: int, right: int) -> int:
101
97
>>> version_1 = pst.update(0, 1, 5) # Update index 1 to 5
102
98
>>> pst.query(version_1, 0, 3) # Sum of all elements in new version
103
99
13
104
- >>> pst.query(version_1, 1, 2) # Sum of elements at index 1 and 2 in new version
100
+ >>> pst.query(version_1, 1, 2) # Sum of elements at index 1 and 2
105
101
8
106
102
"""
107
103
return self ._query (self .roots [version ], 0 , self .n - 1 , left , right )
@@ -119,20 +115,17 @@ def _query(self, node: Node, start: int, end: int, left: int, right: int) -> int
119
115
>>> pst._query(root, 0, 3, 2, 3) # Sum of elements at index 2 and 3
120
116
7
121
117
"""
122
- if left > end or right < start :
118
+ if node is None or left > end or right < start :
123
119
return 0
124
120
if left <= start and right >= end :
125
121
return node .value
126
122
mid = (start + end ) // 2
127
- return self ._query (node .left , start , mid , left , right ) + self ._query (
128
- node .right , mid + 1 , end , left , right
129
- )
130
-
123
+ return (self ._query (node .left , start , mid , left , right ) +
124
+ self ._query (node .right , mid + 1 , end , left , right ))
131
125
132
126
# Running the doctests
133
127
if __name__ == "__main__" :
134
128
import doctest
135
-
136
129
print ("Running doctests..." )
137
130
result = doctest .testmod ()
138
131
print (f"Ran { result .attempted } tests, { result .failed } failed." )
0 commit comments