Skip to content

Commit 0b2b39f

Browse files
authored
Create segment_tree.py
1 parent 114d428 commit 0b2b39f

File tree

1 file changed

+152
-0
lines changed

1 file changed

+152
-0
lines changed

searches/segment_tree.py

+152
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
"""
2+
Author : Sanjay Muthu <https://github.com/XenoBytesX>
3+
4+
This is a Pure Python implementation of the Segment Tree Data Structure
5+
6+
The problem statement is:
7+
Given an array and q queries,
8+
each query is one of two types:-
9+
1. update:- (index, value)
10+
update the array at index i to be equal to the new value
11+
2. query:- (l, r)
12+
print the result for the query from l to r
13+
Here, the query depends on the problem which the segment tree is implemented on,
14+
a common example of the query is sum or xor
15+
(https://www.loginradius.com/blog/engineering/how-does-bitwise-xor-work/)
16+
17+
Example:
18+
array (a) = [5, 2, 3, 1, 7, 2, 9]
19+
queries (q) = 2
20+
query = sum
21+
22+
query 1:- update 1 3
23+
- a[1] becomes 2
24+
- a = [5, 3, 3, 1, 7, 2, 9]
25+
26+
query 2:- query 1 5
27+
- a[1] + a[2] + a[3] + a[4] + a[5] = 3+3+1+7+2 = 16
28+
- answer is 16
29+
30+
Time Complexity:- O(N + Q)
31+
-- O(N) pre-calculation time to calculate the prefix sum array
32+
-- and O(1) time per each query = O(1 * Q) = O(Q) time
33+
34+
Space Complexity:- O(N Log N + Q Log N)
35+
-- O(N Log N) time for building the segment tree
36+
-- O(log n) time for each query
37+
-- Q queries are there so total time complexity is O(Q Log n)
38+
39+
Algorithm:-
40+
We first build the segment tree. An example of what the tree would look like:-
41+
(query type is sum)
42+
array = [5, 2, 3, 6, 1, 2]
43+
modified_array = [5, 2, 3, 6, 1, 2, 0, 0] size is 8 which a power of 2
44+
so we can build the segment tree
45+
46+
segment tree:-
47+
19
48+
/ \
49+
/ \
50+
/ \
51+
/ \
52+
16 3
53+
/ \\ / \
54+
/ \\ / \
55+
/ \\ / \
56+
7 9 3 0
57+
/ \\ / \\ / \\ / \
58+
/ \\ / \\ / \\ / \
59+
/ \\ / \\ / \\ / \
60+
5 2 3 6 1 2 0 0
61+
62+
63+
This segment tree cannot be stored in code so we convert it into a list
64+
65+
segment tree list = [19, 16, 3, 7, 9, 3, 0, 5, 2, 3, 6, 1, 2, 0, 0]
66+
There is a property of this list that we can use to make the code much simpler
67+
segment tree list[2*i] and segment tree list[2*i+1]
68+
are the children of segment tree list[i]
69+
70+
71+
For Updating:-
72+
We first update the base element (the last row elements)
73+
and then slowly staircase up to update the entire segment tree part
74+
from the updated element
75+
76+
For querying:-
77+
We start from the root(the topmost element) and go down, each node has one of 3 cases:-
78+
Case 1. The node is completely inside the required range
79+
then return the node value
80+
Case 2. The node is completely outside the required range
81+
then return 0
82+
Case 3. The node is partially inside the required range
83+
Query both the children and add their results and return that
84+
"""
85+
86+
class SegmentTree:
87+
def __init__(self, arr, merge_func, default):
88+
"""
89+
Initializes the segment tree
90+
:param arr: Input array
91+
:param merge_func: The function which is used to merge
92+
two elements of the segment tree
93+
:param default: The default value for the nodes
94+
(Ex:- 0 if merge_func is sum, inf if merge_func is min, etc.)
95+
"""
96+
self.arr = arr
97+
self.n = len(arr)
98+
99+
# while self.n is not a power of two
100+
while (self.n & (self.n-1)) != 0:
101+
self.n += 1
102+
self.arr.append(default)
103+
104+
self.merge_func = merge_func
105+
self.default = default
106+
self.segment_tree = [default] * (2 * self.n)
107+
108+
for i in range(self.n):
109+
self.segment_tree[self.n + i] = arr[i]
110+
111+
for i in range(self.n - 1, 0, -1):
112+
self.segment_tree[i] = self.merge_func(self.segment_tree[2 * i],
113+
self.segment_tree[2 * i + 1])
114+
115+
def update(self, index, value):
116+
"""
117+
Updates the value at an index and propagates the change to all parents
118+
"""
119+
self.segment_tree[self.n + index] = value
120+
121+
while index >= 1:
122+
index //= 2 # Go to the parent of index
123+
self.segment_tree[index] = self.merge_func(self.segment_tree[2 * index],
124+
self.segment_tree[2 * index + 1])
125+
126+
def query(self, left, right, node_index=1, node_left=0, node_right=None):
127+
"""
128+
Finds the answer of self.merge_query(left, left+1, left+2, left+3, ..., right)
129+
"""
130+
if not node_right:
131+
# We cant add self.n as the default value in the function
132+
# because self itself is a parameter so we do it this way
133+
node_right = self.n
134+
135+
# If the node is completely outside the query region we return the default value
136+
if node_left > right or node_right < left:
137+
return self.default
138+
139+
# If the node is completely inside the query region we return the node's value
140+
if node_left > left and node_right < right:
141+
return self.segment_tree[node_index]
142+
143+
# Else:-
144+
# Find the middle element
145+
mid = int((node_left + node_right) / 2)
146+
147+
# The answer is sum (or min or anything in the merge_func)
148+
# of the query values of both the children nodes
149+
return self.merge_func(
150+
self.query(left, right, node_index * 2, node_left, mid),
151+
self.query(left, right, node_index * 2 + 1, mid + 1, node_right)
152+
)

0 commit comments

Comments
 (0)