|
| 1 | +""" |
| 2 | +Author : Sanjay Muthu <https://github.com/XenoBytesX> |
| 3 | +
|
| 4 | +This is a Pure Python implementation of the Segment Tree Data Structure |
| 5 | +
|
| 6 | +The problem statement is: |
| 7 | + Given an array and q queries, |
| 8 | + each query is one of two types:- |
| 9 | + 1. update:- (index, value) |
| 10 | + update the array at index i to be equal to the new value |
| 11 | + 2. query:- (l, r) |
| 12 | + print the result for the query from l to r |
| 13 | + Here, the query depends on the problem which the segment tree is implemented on, |
| 14 | + a common example of the query is sum or xor |
| 15 | + (https://www.loginradius.com/blog/engineering/how-does-bitwise-xor-work/) |
| 16 | +
|
| 17 | +Example: |
| 18 | + array (a) = [5, 2, 3, 1, 7, 2, 9] |
| 19 | + queries (q) = 2 |
| 20 | + query = sum |
| 21 | +
|
| 22 | + query 1:- update 1 3 |
| 23 | + - a[1] becomes 2 |
| 24 | + - a = [5, 3, 3, 1, 7, 2, 9] |
| 25 | +
|
| 26 | + query 2:- query 1 5 |
| 27 | + - a[1] + a[2] + a[3] + a[4] + a[5] = 3+3+1+7+2 = 16 |
| 28 | + - answer is 16 |
| 29 | +
|
| 30 | +Time Complexity:- O(N + Q) |
| 31 | +-- O(N) pre-calculation time to calculate the prefix sum array |
| 32 | +-- and O(1) time per each query = O(1 * Q) = O(Q) time |
| 33 | +
|
| 34 | +Space Complexity:- O(N Log N + Q Log N) |
| 35 | +-- O(N Log N) time for building the segment tree |
| 36 | +-- O(log n) time for each query |
| 37 | +-- Q queries are there so total time complexity is O(Q Log n) |
| 38 | +
|
| 39 | +Algorithm:- |
| 40 | +We first build the segment tree. An example of what the tree would look like:- |
| 41 | +(query type is sum) |
| 42 | +array = [5, 2, 3, 6, 1, 2] |
| 43 | +modified_array = [5, 2, 3, 6, 1, 2, 0, 0] size is 8 which a power of 2 |
| 44 | +so we can build the segment tree |
| 45 | +
|
| 46 | +segment tree:- |
| 47 | + 19 |
| 48 | + / \ |
| 49 | + / \ |
| 50 | + / \ |
| 51 | + / \ |
| 52 | + 16 3 |
| 53 | + / \\ / \ |
| 54 | + / \\ / \ |
| 55 | + / \\ / \ |
| 56 | + 7 9 3 0 |
| 57 | + / \\ / \\ / \\ / \ |
| 58 | + / \\ / \\ / \\ / \ |
| 59 | +/ \\ / \\ / \\ / \ |
| 60 | +5 2 3 6 1 2 0 0 |
| 61 | +
|
| 62 | +
|
| 63 | +This segment tree cannot be stored in code so we convert it into a list |
| 64 | +
|
| 65 | +segment tree list = [19, 16, 3, 7, 9, 3, 0, 5, 2, 3, 6, 1, 2, 0, 0] |
| 66 | +There is a property of this list that we can use to make the code much simpler |
| 67 | +segment tree list[2*i] and segment tree list[2*i+1] |
| 68 | +are the children of segment tree list[i] |
| 69 | +
|
| 70 | +
|
| 71 | +For Updating:- |
| 72 | +We first update the base element (the last row elements) |
| 73 | +and then slowly staircase up to update the entire segment tree part |
| 74 | +from the updated element |
| 75 | +
|
| 76 | +For querying:- |
| 77 | +We start from the root(the topmost element) and go down, each node has one of 3 cases:- |
| 78 | + Case 1. The node is completely inside the required range |
| 79 | + then return the node value |
| 80 | + Case 2. The node is completely outside the required range |
| 81 | + then return 0 |
| 82 | + Case 3. The node is partially inside the required range |
| 83 | + Query both the children and add their results and return that |
| 84 | +""" |
| 85 | + |
| 86 | +class SegmentTree: |
| 87 | + def __init__(self, arr, merge_func, default): |
| 88 | + """ |
| 89 | + Initializes the segment tree |
| 90 | + :param arr: Input array |
| 91 | + :param merge_func: The function which is used to merge |
| 92 | + two elements of the segment tree |
| 93 | + :param default: The default value for the nodes |
| 94 | + (Ex:- 0 if merge_func is sum, inf if merge_func is min, etc.) |
| 95 | + """ |
| 96 | + self.arr = arr |
| 97 | + self.n = len(arr) |
| 98 | + |
| 99 | + # while self.n is not a power of two |
| 100 | + while (self.n & (self.n-1)) != 0: |
| 101 | + self.n += 1 |
| 102 | + self.arr.append(default) |
| 103 | + |
| 104 | + self.merge_func = merge_func |
| 105 | + self.default = default |
| 106 | + self.segment_tree = [default] * (2 * self.n) |
| 107 | + |
| 108 | + for i in range(self.n): |
| 109 | + self.segment_tree[self.n + i] = arr[i] |
| 110 | + |
| 111 | + for i in range(self.n - 1, 0, -1): |
| 112 | + self.segment_tree[i] = self.merge_func(self.segment_tree[2 * i], |
| 113 | + self.segment_tree[2 * i + 1]) |
| 114 | + |
| 115 | + def update(self, index, value): |
| 116 | + """ |
| 117 | + Updates the value at an index and propagates the change to all parents |
| 118 | + """ |
| 119 | + self.segment_tree[self.n + index] = value |
| 120 | + |
| 121 | + while index >= 1: |
| 122 | + index //= 2 # Go to the parent of index |
| 123 | + self.segment_tree[index] = self.merge_func(self.segment_tree[2 * index], |
| 124 | + self.segment_tree[2 * index + 1]) |
| 125 | + |
| 126 | + def query(self, left, right, node_index=1, node_left=0, node_right=None): |
| 127 | + """ |
| 128 | + Finds the answer of self.merge_query(left, left+1, left+2, left+3, ..., right) |
| 129 | + """ |
| 130 | + if not node_right: |
| 131 | + # We cant add self.n as the default value in the function |
| 132 | + # because self itself is a parameter so we do it this way |
| 133 | + node_right = self.n |
| 134 | + |
| 135 | + # If the node is completely outside the query region we return the default value |
| 136 | + if node_left > right or node_right < left: |
| 137 | + return self.default |
| 138 | + |
| 139 | + # If the node is completely inside the query region we return the node's value |
| 140 | + if node_left > left and node_right < right: |
| 141 | + return self.segment_tree[node_index] |
| 142 | + |
| 143 | + # Else:- |
| 144 | + # Find the middle element |
| 145 | + mid = int((node_left + node_right) / 2) |
| 146 | + |
| 147 | + # The answer is sum (or min or anything in the merge_func) |
| 148 | + # of the query values of both the children nodes |
| 149 | + return self.merge_func( |
| 150 | + self.query(left, right, node_index * 2, node_left, mid), |
| 151 | + self.query(left, right, node_index * 2 + 1, mid + 1, node_right) |
| 152 | + ) |
0 commit comments