1
1
from __future__ import annotations
2
2
3
+
3
4
class Node :
4
5
def __init__ (self , value : int = 0 ) -> None :
5
6
self .value : int = value
6
7
self .left : Node | None = None
7
8
self .right : Node | None = None
8
9
10
+
9
11
class PersistentSegmentTree :
10
12
def __init__ (self , arr : list [int ]) -> None :
11
13
self .n : int = len (arr )
@@ -52,7 +54,9 @@ def update(self, version: int, index: int, value: int) -> int:
52
54
self .roots .append (new_root )
53
55
return len (self .roots ) - 1
54
56
55
- def _update (self , node : Node | None , start : int , end : int , index : int , value : int ) -> Node :
57
+ def _update (
58
+ self , node : Node | None , start : int , end : int , index : int , value : int
59
+ ) -> Node :
56
60
if node is None :
57
61
raise ValueError ("Cannot update a None node" )
58
62
@@ -69,7 +73,9 @@ def _update(self, node: Node | None, start: int, end: int, index: int, value: in
69
73
new_node .left = node .left
70
74
new_node .right = self ._update (node .right , mid + 1 , end , index , value )
71
75
72
- new_node .value = new_node .left .value + (new_node .right .value if new_node .right else 0 )
76
+ new_node .value = new_node .left .value + (
77
+ new_node .right .value if new_node .right else 0
78
+ )
73
79
74
80
return new_node
75
81
@@ -90,7 +96,9 @@ def query(self, version: int, left: int, right: int) -> int:
90
96
"""
91
97
return self ._query (self .roots [version ], 0 , self .n - 1 , left , right )
92
98
93
- def _query (self , node : Node | None , start : int , end : int , left : int , right : int ) -> int :
99
+ def _query (
100
+ self , node : Node | None , start : int , end : int , left : int , right : int
101
+ ) -> int :
94
102
if node is None :
95
103
return 0
96
104
@@ -99,12 +107,15 @@ def _query(self, node: Node | None, start: int, end: int, left: int, right: int)
99
107
if left <= start and right >= end :
100
108
return node .value
101
109
mid = (start + end ) // 2
102
- return (self ._query (node .left , start , mid , left , right ) +
103
- self ._query (node .right , mid + 1 , end , left , right ))
110
+ return self ._query (node .left , start , mid , left , right ) + self ._query (
111
+ node .right , mid + 1 , end , left , right
112
+ )
113
+
104
114
105
115
# Running the doctests
106
116
if __name__ == "__main__" :
107
117
import doctest
118
+
108
119
print ("Running doctests..." )
109
120
result = doctest .testmod ()
110
121
print (f"Ran { result .attempted } tests, { result .failed } failed." )
0 commit comments