forked from TheAlgorithms/Python
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnon_recursive_segment_tree.py
125 lines (109 loc) · 3.25 KB
/
non_recursive_segment_tree.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
"""
A non-recursive Segment Tree implementation with range query and single item update.
>>> SegmentTree([1, 2, 3], lambda a, b: a + b).query(0, 2)
6
>>> SegmentTree([3, 1, 2], min).query(0, 2)
1
>>> SegmentTree([2, 3, 1], max).query(0, 2)
3
>>> st = SegmentTree([1, 5, 7, -1, 6], lambda a, b: a + b)
>>> st.update(1, -1)
>>> st.update(2, 3)
>>> st.query(1, 2)
2
>>> st.query(1, 1)
-1
>>> st.update(4, 1)
>>> st.query(3, 4)
0
"""
class SegmentTree:
def __init__(self, arr, fnc):
self.N = len(arr)
self.st = [None for _ in range(len(arr))] + arr
self.fn = fnc
self.build()
def build(self):
for p in range(self.N - 1, 0, -1):
self.st[p] = self.fn(self.st[p * 2], self.st[p * 2 + 1])
def update(self, p, v):
"""
Update an element in log(N) time
:param p: position to be update
:param v: new value
>>> st = SegmentTree([3, 1, 2, 4], min)
>>> st.query(0, 3)
1
>>> st.update(2, -1)
>>> st.query(0, 3)
-1
"""
p += self.N
self.st[p] = v
while p > 1:
p = p // 2
self.st[p] = self.fn(self.st[p * 2], self.st[p * 2 + 1])
def query(self, l, r):
"""
Get range query value in log(N) time
:param l: left index
:param r: right index
:return: value in range [l, r]
>>> st = SegmentTree([1, 2, 3, 4], lambda a, b: a + b)
>>> st.query(0, 2)
6
>>> st.query(1, 2)
5
>>> st.query(0, 3)
10
>>> st.query(2, 3)
7
"""
l, r = l + self.N, r + self.N
res = None
while l <= r:
if l % 2 == 1:
res = self.st[l] if res is None else self.fn(res, self.st[l])
if r % 2 == 0:
res = self.st[r] if res is None else self.fn(res, self.st[r])
l, r = (l + 1) // 2, (r - 1) // 2
return res
if __name__ == "__main__":
from functools import reduce
test_array = [1, 10, -2, 9, -3, 8, 4, -7, 5, 6, 11, -12]
test_updates = {
0: 7,
1: 2,
2: 6,
3: -14,
4: 5,
5: 4,
6: 7,
7: -10,
8: 9,
9: 10,
10: 12,
11: 1,
}
min_segment_tree = SegmentTree(test_array, min)
max_segment_tree = SegmentTree(test_array, max)
sum_segment_tree = SegmentTree(test_array, lambda a, b: a + b)
def test_all_segments():
"""
Test all possible segments
"""
for i in range(len(test_array)):
for j in range(i, len(test_array)):
min_range = reduce(min, test_array[i : j + 1])
max_range = reduce(max, test_array[i : j + 1])
sum_range = reduce(lambda a, b: a + b, test_array[i : j + 1])
assert min_range == min_segment_tree.query(i, j)
assert max_range == max_segment_tree.query(i, j)
assert sum_range == sum_segment_tree.query(i, j)
test_all_segments()
for index, value in test_updates.items():
test_array[index] = value
min_segment_tree.update(index, value)
max_segment_tree.update(index, value)
sum_segment_tree.update(index, value)
test_all_segments()