Skip to content

Commit 7bc0462

Browse files
qf-jonathancclauss
authored andcommitted
Non-recursive Segment Tree implementation (TheAlgorithms#1543)
* Non-recursive Segment Tree implementation * Added type hints and explanations links
1 parent 62e51fe commit 7bc0462

File tree

1 file changed

+153
-0
lines changed

1 file changed

+153
-0
lines changed
+153
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
"""
2+
A non-recursive Segment Tree implementation with range query and single element update,
3+
works virtually with any list of the same type of elements with a "commutative" combiner.
4+
5+
Explanation:
6+
https://www.geeksforgeeks.org/iterative-segment-tree-range-minimum-query/
7+
https://www.geeksforgeeks.org/segment-tree-efficient-implementation/
8+
9+
>>> SegmentTree([1, 2, 3], lambda a, b: a + b).query(0, 2)
10+
6
11+
>>> SegmentTree([3, 1, 2], min).query(0, 2)
12+
1
13+
>>> SegmentTree([2, 3, 1], max).query(0, 2)
14+
3
15+
>>> st = SegmentTree([1, 5, 7, -1, 6], lambda a, b: a + b)
16+
>>> st.update(1, -1)
17+
>>> st.update(2, 3)
18+
>>> st.query(1, 2)
19+
2
20+
>>> st.query(1, 1)
21+
-1
22+
>>> st.update(4, 1)
23+
>>> st.query(3, 4)
24+
0
25+
>>> st = SegmentTree([[1, 2, 3], [3, 2, 1], [1, 1, 1]], lambda a, b: [a[i] + b[i] for i in range(len(a))])
26+
>>> st.query(0, 1)
27+
[4, 4, 4]
28+
>>> st.query(1, 2)
29+
[4, 3, 2]
30+
>>> st.update(1, [-1, -1, -1])
31+
>>> st.query(1, 2)
32+
[0, 0, 0]
33+
>>> st.query(0, 2)
34+
[1, 2, 3]
35+
"""
36+
from typing import List, Callable, TypeVar
37+
38+
T = TypeVar("T")
39+
40+
41+
class SegmentTree:
42+
def __init__(self, arr: List[T], fnc: Callable[[T, T], T]) -> None:
43+
"""
44+
Segment Tree constructor, it works just with commutative combiner.
45+
:param arr: list of elements for the segment tree
46+
:param fnc: commutative function for combine two elements
47+
48+
>>> SegmentTree(['a', 'b', 'c'], lambda a, b: '{}{}'.format(a, b)).query(0, 2)
49+
'abc'
50+
>>> SegmentTree([(1, 2), (2, 3), (3, 4)], lambda a, b: (a[0] + b[0], a[1] + b[1])).query(0, 2)
51+
(6, 9)
52+
"""
53+
self.N = len(arr)
54+
self.st = [None for _ in range(len(arr))] + arr
55+
self.fn = fnc
56+
self.build()
57+
58+
def build(self) -> None:
59+
for p in range(self.N - 1, 0, -1):
60+
self.st[p] = self.fn(self.st[p * 2], self.st[p * 2 + 1])
61+
62+
def update(self, p: int, v: T) -> None:
63+
"""
64+
Update an element in log(N) time
65+
:param p: position to be update
66+
:param v: new value
67+
68+
>>> st = SegmentTree([3, 1, 2, 4], min)
69+
>>> st.query(0, 3)
70+
1
71+
>>> st.update(2, -1)
72+
>>> st.query(0, 3)
73+
-1
74+
"""
75+
p += self.N
76+
self.st[p] = v
77+
while p > 1:
78+
p = p // 2
79+
self.st[p] = self.fn(self.st[p * 2], self.st[p * 2 + 1])
80+
81+
def query(self, l: int, r: int) -> T:
82+
"""
83+
Get range query value in log(N) time
84+
:param l: left element index
85+
:param r: right element index
86+
:return: element combined in the range [l, r]
87+
88+
>>> st = SegmentTree([1, 2, 3, 4], lambda a, b: a + b)
89+
>>> st.query(0, 2)
90+
6
91+
>>> st.query(1, 2)
92+
5
93+
>>> st.query(0, 3)
94+
10
95+
>>> st.query(2, 3)
96+
7
97+
"""
98+
l, r = l + self.N, r + self.N
99+
res = None
100+
while l <= r:
101+
if l % 2 == 1:
102+
res = self.st[l] if res is None else self.fn(res, self.st[l])
103+
if r % 2 == 0:
104+
res = self.st[r] if res is None else self.fn(res, self.st[r])
105+
l, r = (l + 1) // 2, (r - 1) // 2
106+
return res
107+
108+
109+
if __name__ == "__main__":
110+
from functools import reduce
111+
112+
test_array = [1, 10, -2, 9, -3, 8, 4, -7, 5, 6, 11, -12]
113+
114+
test_updates = {
115+
0: 7,
116+
1: 2,
117+
2: 6,
118+
3: -14,
119+
4: 5,
120+
5: 4,
121+
6: 7,
122+
7: -10,
123+
8: 9,
124+
9: 10,
125+
10: 12,
126+
11: 1,
127+
}
128+
129+
min_segment_tree = SegmentTree(test_array, min)
130+
max_segment_tree = SegmentTree(test_array, max)
131+
sum_segment_tree = SegmentTree(test_array, lambda a, b: a + b)
132+
133+
def test_all_segments():
134+
"""
135+
Test all possible segments
136+
"""
137+
for i in range(len(test_array)):
138+
for j in range(i, len(test_array)):
139+
min_range = reduce(min, test_array[i : j + 1])
140+
max_range = reduce(max, test_array[i : j + 1])
141+
sum_range = reduce(lambda a, b: a + b, test_array[i : j + 1])
142+
assert min_range == min_segment_tree.query(i, j)
143+
assert max_range == max_segment_tree.query(i, j)
144+
assert sum_range == sum_segment_tree.query(i, j)
145+
146+
test_all_segments()
147+
148+
for index, value in test_updates.items():
149+
test_array[index] = value
150+
min_segment_tree.update(index, value)
151+
max_segment_tree.update(index, value)
152+
sum_segment_tree.update(index, value)
153+
test_all_segments()

0 commit comments

Comments
 (0)